} catch (const std::exception& e) { \
ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \
}
-#define ASSERT_ANY_THROW(statement) \
- bool threw = false; \
- try { \
- (void)statement; \
- } catch (const std::exception& e) { \
- threw = true; \
- } \
- ASSERT_TRUE(threw); \
+#define ASSERT_ANY_THROW(statement) \
+ bool threw = false; \
+ try { \
+ (void)statement; \
+ } catch (const std::exception& e) { \
+ threw = true; \
+ } \
+ ASSERT_TRUE(threw);
#endif // defined(USE_GTEST)
+#include "torch/csrc/autograd/generated/variable_factories.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/jit/argument_spec.h"
#include "torch/csrc/jit/assertions.h"
#include "torch/csrc/jit/passes/requires_grad_analysis.h"
#include "torch/csrc/jit/passes/shape_analysis.h"
#include "torch/csrc/jit/passes/utils/subgraph_utils.h"
-#include "torch/csrc/jit/symbolic_variable.h"
#include "torch/csrc/jit/symbolic_script.h"
+#include "torch/csrc/jit/symbolic_variable.h"
#include "torch/csrc/jit/tracer.h"
#include "torch/csrc/utils/hash.h"
-#include "torch/csrc/autograd/generated/variable_factories.h"
#include "torch/csrc/autograd/engine.h"
#include "torch/csrc/autograd/variable.h"
return r;
}
-std::vector<at::Tensor> run(InterpreterState & interp, const std::vector<at::Tensor> & inputs) {
+std::vector<at::Tensor> run(
+ InterpreterState& interp,
+ const std::vector<at::Tensor>& inputs) {
std::vector<IValue> stack(inputs.begin(), inputs.end());
interp.run(stack);
return fmap(stack, [](const IValue& i) { return i.toTensor(); });
df_interpreter.run(df_stack);
// Outputs of f needs to be sliced
- f_stack.erase(
- f_stack.begin() + grad_spec.f_real_outputs, f_stack.end());
+ f_stack.erase(f_stack.begin() + grad_spec.f_real_outputs, f_stack.end());
return std::make_pair(as_tensorlist(f_stack), as_tensorlist(df_stack));
}
// make inputs
at::Tensor input = torch::randn(input_size);
- at::Tensor weight = torch::randn({out_channels, input_size[1], kernel_size[0], kernel_size[1]});
+ at::Tensor weight = torch::randn(
+ {out_channels, input_size[1], kernel_size[0], kernel_size[1]});
at::Tensor bias = torch::randn({out_channels});
// run forward eagerly
at::Tensor output, finput, fgradinput;
- std::tie(output, finput, fgradinput) = at::thnn_conv2d_forward(input, weight, kernel_size,
- bias, stride, padding);
+ std::tie(output, finput, fgradinput) = at::thnn_conv2d_forward(
+ input, weight, kernel_size, bias, stride, padding);
// make grad_outputs
at::Tensor grad_output = torch::randn_like(output);
// run backward eagerly
at::Tensor grad_input, grad_weight, grad_bias;
- std::tie(grad_input, grad_weight, grad_bias) = at::thnn_conv2d_backward(grad_output, input, weight,
- kernel_size, stride, padding,
- finput, fgradinput, {true, true, true});
+ std::tie(grad_input, grad_weight, grad_bias) = at::thnn_conv2d_backward(
+ grad_output,
+ input,
+ weight,
+ kernel_size,
+ stride,
+ padding,
+ finput,
+ fgradinput,
+ {true, true, true});
// make JIT graph
auto graph = std::make_shared<Graph>();
auto weightg = graph->addInput("weight");
auto biasg = graph->addInput("bias");
- Value* conv = graph->insert(aten::thnn_conv2d_forward, {inputg, weightg, ksz_val, biasg, kst_val, pad_val});
+ Value* conv = graph->insert(
+ aten::thnn_conv2d_forward,
+ {inputg, weightg, ksz_val, biasg, kst_val, pad_val});
auto outputs = conv->node()->outputs();
for (auto output : outputs) {
graph->registerOutput(output);
// Get outputs from the interpreter
tensor_list tensors_out, tensor_grads_out;
std::tie(tensors_out, tensor_grads_out) =
- runGradient(grad_spec, tensors_in, tensor_grads_in);
+ runGradient(grad_spec, tensors_in, tensor_grads_in);
// prepare expected structs
tensor_list expected_tensors_out, expected_tensor_grads_out;
}
void testATenNativeBatchNorm() {
- // aten::native_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)
+ // aten::native_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor
+ // running_mean, Tensor running_var, bool training, float momentum, float eps)
+ // -> (Tensor, Tensor, Tensor)
std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W
bool training = true;
float momentum = 0.9;
// run forward eagerly
at::Tensor output, savemean, saveinvstd;
- std::tie(output, savemean, saveinvstd) = at::native_batch_norm(input, weight, bias, running_mean_eager, running_var_eager, training, momentum, eps);
+ std::tie(output, savemean, saveinvstd) = at::native_batch_norm(
+ input,
+ weight,
+ bias,
+ running_mean_eager,
+ running_var_eager,
+ training,
+ momentum,
+ eps);
// make grad_outputs
at::Tensor grad_output = torch::randn_like(output);
// run backward eagerly
at::Tensor grad_input, grad_weight, grad_bias;
- // aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
- std::tie(grad_input, grad_weight, grad_bias) = at::native_batch_norm_backward(grad_output, input, weight,
- running_mean_eager, running_var_eager,
- savemean, saveinvstd, training, eps, {true, true, true});
+ // aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor
+ // weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor
+ // save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor,
+ // Tensor, Tensor)
+ std::tie(grad_input, grad_weight, grad_bias) = at::native_batch_norm_backward(
+ grad_output,
+ input,
+ weight,
+ running_mean_eager,
+ running_var_eager,
+ savemean,
+ saveinvstd,
+ training,
+ eps,
+ {true, true, true});
// make JIT graph
auto graph = std::make_shared<Graph>();
auto running_meang = graph->addInput("running_mean");
auto running_varg = graph->addInput("running_var");
- Value* bn = graph->insert(aten::native_batch_norm, {inputg, weightg, biasg, running_meang, running_varg, training_val, momentum_val, eps_val});
+ Value* bn = graph->insert(
+ aten::native_batch_norm,
+ {inputg,
+ weightg,
+ biasg,
+ running_meang,
+ running_varg,
+ training_val,
+ momentum_val,
+ eps_val});
auto outputs = bn->node()->outputs();
for (auto output : outputs) {
graph->registerOutput(output);
// Get outputs from the interpreter
tensor_list tensors_out, tensor_grads_out;
std::tie(tensors_out, tensor_grads_out) =
- runGradient(grad_spec, tensors_in, tensor_grads_in);
+ runGradient(grad_spec, tensors_in, tensor_grads_in);
// prepare expected structs
tensor_list expected_tensors_out, expected_tensor_grads_out;
[](const VL& v) -> VL { return {v[0].tanh()}; }},
{"t", unary_pointwise_2d, [](const VL& v) -> VL { return {v[0].t()}; }},
{"view",
- unary_pointwise_2d,
- [](const VL& v) -> VL { return {v[0].view({3, 2})}; }},
+ unary_pointwise_2d,
+ [](const VL& v) -> VL {
+ return {v[0].view({3, 2})};
+ }},
{"expand",
- {{2, 1}},
- [](const VL& v) -> VL { return {v[0].expand({2, 3})}; }},
+ {{2, 1}},
+ [](const VL& v) -> VL {
+ return {v[0].expand({2, 3})};
+ }},
{"mm",
{{10, 12}, {12, 15}},
[](const VL& v) -> VL { return {v[0].mm(v[1])}; }},
- // TODO: enable once we'll be able to capture lists across forward-backward
+ // TODO: enable once we'll be able to capture lists across
+ // forward-backward
//{"chunk", {{10, 12, 15}}, [](const VL& v) -> VL { return
// fmap<Variable>(v[0].chunk(4, 1)); }},
//{"chunk", {{10, 12, 15}}, [](const VL& v) -> VL { return
graph->registerOutput(d.value());
graph->registerOutput(e.value());
- auto a_var = autograd::make_variable(at::empty_strided(2, 2, at::CPU(at::kFloat).options()), true);
- auto b_var = autograd::make_variable(at::empty_strided(2, 2, at::CPU(at::kFloat).options()), false);
+ auto a_var = autograd::make_variable(
+ at::empty_strided(2, 2, at::CPU(at::kFloat).options()), true);
+ auto b_var = autograd::make_variable(
+ at::empty_strided(2, 2, at::CPU(at::kFloat).options()), false);
setInputTypes(*graph, ArgumentSpec(true, {a_var, b_var}, 2));
PropagateInputShapes(graph);
PropagateRequiresGrad(graph);
auto getFusionGroup = [](const std::shared_ptr<Graph>& graph) {
const auto& nodes = graph->nodes();
- auto maybe_fusion_group = std::find_if(
- nodes.begin(), nodes.end(),
- [](const Node* node) { return node->kind() == prim::FusionGroup; });
+ auto maybe_fusion_group =
+ std::find_if(nodes.begin(), nodes.end(), [](const Node* node) {
+ return node->kind() == prim::FusionGroup;
+ });
JIT_ASSERTM(
maybe_fusion_group != nodes.end(),
"testRegisterFusionCachesKernel: could not create FusionGroup");
ASSERT_EQ(op->schema().arguments()[0].name(), "_0");
ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType);
ASSERT_EQ(op->schema().arguments()[1].name(), "_1");
- ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::DynamicType);
+ ASSERT_EQ(
+ op->schema().arguments()[1].type()->kind(), TypeKind::DynamicType);
ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::DynamicType);
ASSERT_EQ(op->schema().arguments()[0].name(), "a");
ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType);
ASSERT_EQ(op->schema().arguments()[1].name(), "b");
- ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::DynamicType);
+ ASSERT_EQ(
+ op->schema().arguments()[1].type()->kind(), TypeKind::DynamicType);
ASSERT_EQ(op->schema().returns().size(), 1);
ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::DynamicType);
// nested arrays
auto s = parseSchema("at::what(int[][4] foo) -> ()");
ASSERT_TRUE(s.arguments().at(0).N() == 4);
- ASSERT_TRUE(IntType::get()->isSubtypeOf(s.arguments().at(0)
- .type()->expect<ListType>()
+ ASSERT_TRUE(IntType::get()->isSubtypeOf(s.arguments()
+ .at(0)
+ .type()
+ ->expect<ListType>()
->getElementType()
->expect<ListType>()
->getElementType()));
auto s2 = parseSchema("at::what(int[][] foo) -> ()");
- ASSERT_TRUE(IntType::get()->isSubtypeOf(s2.arguments().at(0)
- .type()->expect<ListType>()
- ->getElementType()
- ->expect<ListType>()
- ->getElementType()));
+ ASSERT_TRUE(IntType::get()->isSubtypeOf(s2.arguments()
+ .at(0)
+ .type()
+ ->expect<ListType>()
+ ->getElementType()
+ ->expect<ListType>()
+ ->getElementType()));
// named returns
parseSchema("at::what(Tensor! i_will_be_written_to) -> ()");
- auto s3 = parseSchema("at::what() -> (Tensor the_return, Tensor the_return2)");
+ auto s3 =
+ parseSchema("at::what() -> (Tensor the_return, Tensor the_return2)");
ASSERT_TRUE(s3.returns().at(0).name() == "the_return");
ASSERT_TRUE(s3.returns().at(1).name() == "the_return2");
// futures
auto s4 = parseSchema("at::what(Future(int) foo) -> ()");
- ASSERT_TRUE(IntType::get()->isSubtypeOf(s4.arguments().at(0)
- .type()->expect<FutureType>()
- ->getElementType()));
+ ASSERT_TRUE(IntType::get()->isSubtypeOf(
+ s4.arguments().at(0).type()->expect<FutureType>()->getElementType()));
// test tensor with annotated alias sets
parseSchema("at::what(Tensor(a) foo) -> (Tensor(a))");
}
}
-
std::unique_ptr<detail::DynamicDAG<std::string>> newDynamicDAG() {
- return std::unique_ptr<detail::DynamicDAG<std::string>>(new detail::DynamicDAG<std::string>());
+ return std::unique_ptr<detail::DynamicDAG<std::string>>(
+ new detail::DynamicDAG<std::string>());
}
void testNewVertex() {
bool moveBeforeTopologicallyValid(
const std::string& toInsert,
const std::string& insertPoint) {
- std::function<bool(Node*, Node*)> func = [this](Node* toInsert,
- Node* insertPoint) {
- return toInsert->moveBeforeTopologicallyValid(insertPoint, *aliasDb);
- };
+ std::function<bool(Node*, Node*)> func =
+ [this](Node* toInsert, Node* insertPoint) {
+ return toInsert->moveBeforeTopologicallyValid(insertPoint, *aliasDb);
+ };
return moveWithChecks(toInsert, insertPoint, func);
}
bool moveAfterTopologicallyValid(
const std::string& toInsert,
const std::string& insertPoint) {
- std::function<bool(Node*, Node*)> func = [this](Node* toInsert,
- Node* insertPoint) {
- return toInsert->moveAfterTopologicallyValid(insertPoint, *aliasDb);
- };
+ std::function<bool(Node*, Node*)> func =
+ [this](Node* toInsert, Node* insertPoint) {
+ return toInsert->moveAfterTopologicallyValid(insertPoint, *aliasDb);
+ };
return moveWithChecks(toInsert, insertPoint, func);
}
Only files that are in CLANG_FORMAT_WHITELIST are checked.
"""
import subprocess
-import glob
-import itertools
import os
import argparse
+import fnmatch
import difflib
import sys
+import re
-# Whitelist of files to check. Takes a glob syntax. Does not support
-# recursive globs ("**") because I am lazy and don't want to make that
-# work with Python 2.
-CLANG_FORMAT_WHITELIST = ["torch/csrc/jit/passes/alias_analysis*"]
+# Whitelist of directories to check. All files that in that directory
+# (recursively) will be checked.
+CLANG_FORMAT_WHITELIST = ["torch/csrc/jit/", "test/cpp/jit/"]
+
+CPP_FILE_REGEX = re.compile(".*\\.(h|cpp|cc|c|hpp)$")
def parse_args():
"Otherwise, just print the changes and exit"
),
)
+ parser.add_argument(
+ "--check-all",
+ action="store_true",
+ default=False,
+ help="If true, check all whitelisted files instead of just working copy changes",
+ )
parser.add_argument("--verbose", "-v", action="store_true", default=False)
return parser.parse_args()
def get_whitelisted_files():
"""
- Parse CLANG_FORMAT_WHITELIST and resolve all globs.
- Returns the set of all whitelisted filenames.
+ Parse CLANG_FORMAT_WHITELIST and resolve all directories.
+ Returns the set of whitelist cpp source files.
"""
- paths = [glob.glob(entry) for entry in CLANG_FORMAT_WHITELIST]
- # flatten the files list
- paths = itertools.chain(*paths)
- # filter out directories
- filenames = filter(lambda path: os.path.isfile(path), paths)
- return set(filenames)
+ matches = []
+ for dir in CLANG_FORMAT_WHITELIST:
+ for root, dirnames, filenames in os.walk(dir):
+ for filename in filenames:
+ if CPP_FILE_REGEX.fullmatch(filename):
+ matches.append(os.path.join(root, filename))
+ return set(matches)
def get_changed_files(rev):
def main():
args = parse_args()
- changed_files = get_changed_files(args.diff)
whitelisted_files = get_whitelisted_files()
- files_to_check = changed_files & whitelisted_files
+ if args.check_all:
+ files_to_check = whitelisted_files
+ else:
+ changed_files = get_changed_files(args.diff)
+ files_to_check = changed_files & whitelisted_files
if args.verbose:
print("Running clang-format on whitelisted files: ")
args = ["clang-format", "-i"]
args.extend(name_to_diffs.keys())
subprocess.check_output(args)
+
+ # add the changes so they will be committed
+ args = ["git", "add"]
+ args.extend(name_to_diffs.keys())
+ subprocess.check_output(args)
else:
print("ERROR: Running clang-format created changes: ")
for name, diff in name_to_diffs.items():
#include <ATen/core/alias_info.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
using ::c10::AliasInfo;
#pragma once
-#include <iostream>
-#include <vector>
#include <torch/csrc/autograd/variable.h>
-#include <torch/csrc/utils/hash.h>
+#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/stack.h>
#include <torch/csrc/jit/type.h>
-#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/variable_tensor_list.h>
+#include <torch/csrc/utils/hash.h>
+#include <iostream>
+#include <vector>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
-// GraphExecutor creates specializations of Graphs for different dimensionalitities
-// and types of inputs.
+// GraphExecutor creates specializations of Graphs for different
+// dimensionalitities and types of inputs.
-inline static at::Device ConvertIntToCPUOrCUDA(int device){
+inline static at::Device ConvertIntToCPUOrCUDA(int device) {
return device < 0 ? at::kCPU : at::Device(at::DeviceType::CUDA, device);
}
struct ArgumentInfo {
int device() const {
return device_;
}
- // XXX: It is guaranteed that this will return false when called on non-tensor arguments
+ // XXX: It is guaranteed that this will return false when called on non-tensor
+ // arguments
bool requires_grad() const {
return requires_grad_;
}
return TensorType::create(type(), ConvertIntToCPUOrCUDA(device()), dim());
}
-private:
+ private:
unsigned is_tensor_ : 1;
unsigned defined_ : 1;
unsigned requires_grad_ : 1;
unsigned : 5;
unsigned dim_ : 8;
- int device_ : 8; // NOTE: this needs to be signed because we use -1 to represent CPU
+ int device_ : 8; // NOTE: this needs to be signed because we use -1 to
+ // represent CPU
unsigned type_ : 8;
};
-static_assert(std::is_pod<ArgumentInfo>::value,
- "ArgumentInfo is to be a POD struct");
-static_assert(sizeof(ArgumentInfo) == sizeof(ArgumentInfo::plain_data_type),
- "ArgumentInfo is expected to be a 32-bit struct");
+static_assert(
+ std::is_pod<ArgumentInfo>::value,
+ "ArgumentInfo is to be a POD struct");
+static_assert(
+ sizeof(ArgumentInfo) == sizeof(ArgumentInfo::plain_data_type),
+ "ArgumentInfo is expected to be a 32-bit struct");
struct ArgumentSpec {
- ArgumentSpec(bool with_grad, at::ArrayRef<IValue> inputs, size_t num_flat_inputs) {
+ ArgumentSpec(
+ bool with_grad,
+ at::ArrayRef<IValue> inputs,
+ size_t num_flat_inputs) {
hash_code = num_flat_inputs;
args.resize(num_flat_inputs);
size_t offset = 0;
}
void addInput(const IValue& input, size_t& offset, bool with_grad) {
- auto & arg = args.at(offset);
+ auto& arg = args.at(offset);
// Initialize all fields to 0. This is convenient, because e.g.
// requires_grad() can be checked even on tensors AND will make
// padding bits all 0s.
combineHash(arg);
offset++;
} else if (input.isTuple()) {
- for (const IValue & elem : input.toTuple()->elements()) {
+ for (const IValue& elem : input.toTuple()->elements()) {
addInput(elem, offset, with_grad);
}
} else {
- // NB: no need to set is_tensor to false, because we memset the struct to 0 above
+ // NB: no need to set is_tensor to false, because we memset the struct to
+ // 0 above
combineHash(arg);
offset++;
}
}
- void combineHash(const ArgumentInfo &arg) {
+ void combineHash(const ArgumentInfo& arg) {
ArgumentInfo::plain_data_type arg_data;
std::memcpy(&arg_data, &arg, sizeof(ArgumentInfo));
hash_code = hash_combine(hash_code, arg_data);
// equality is fast: check ninputs, and then check the raw array data,
// there are no size/stride indirections
- bool operator==(const ArgumentSpec & spec) const {
- if (args.size() != spec.args.size()) return false;
- // NB: we need to break out early when there are no elements, because passing a
- // nullptr to memcmp is UB.
- if (args.size() == 0) return true;
- return std::memcmp(args.data(), spec.args.data(), args.size() * sizeof(ArgumentInfo)) == 0;
- }
- bool operator!=(const ArgumentSpec & spec) const {
+ bool operator==(const ArgumentSpec& spec) const {
+ if (args.size() != spec.args.size())
+ return false;
+ // NB: we need to break out early when there are no elements, because
+ // passing a nullptr to memcmp is UB.
+ if (args.size() == 0)
+ return true;
+ return std::memcmp(
+ args.data(),
+ spec.args.data(),
+ args.size() * sizeof(ArgumentInfo)) == 0;
+ }
+ bool operator!=(const ArgumentSpec& spec) const {
return !(*this == spec);
}
size_t size() const {
// inferred for it based on this ArgumentSpec.
std::vector<TypePtr> getTypes(Graph& graph) const {
size_t offset = 0;
- return fmap(graph.inputs(),
- [&](Value *v) { return fillType(v->type(), offset); });
+ return fmap(
+ graph.inputs(), [&](Value* v) { return fillType(v->type(), offset); });
}
-private:
+ private:
TypePtr fillType(TypePtr original, size_t& offset) const {
if (original->isSubtypeOf(DynamicType::get())) {
- auto & arg = args.at(offset++);
+ auto& arg = args.at(offset++);
if (!arg.defined())
return UndefinedTensorType::get();
- return TensorType::create(arg.type(), ConvertIntToCPUOrCUDA(arg.device()), arg.dim(), arg.requires_grad());
+ return TensorType::create(
+ arg.type(),
+ ConvertIntToCPUOrCUDA(arg.device()),
+ arg.dim(),
+ arg.requires_grad());
} else if (auto tuple_type = original->cast<TupleType>()) {
- return TupleType::create(fmap(tuple_type->elements(), [&](const TypePtr& subtype) {
- return fillType(subtype, offset);
- }));
+ return TupleType::create(fmap(
+ tuple_type->elements(),
+ [&](const TypePtr& subtype) { return fillType(subtype, offset); }));
} else {
offset++;
return original;
unsigned defined : 1;
unsigned requires_grad : 1;
signed device : 14;
- uint32_t total_dims; // all TensorInfoPODs are in CompleteArgumentSpec's tensor_info() array.
- // total_dims is the total number of dimensions seen so far
- // in all previous members of tensor_info(), including this tensor
- // 2*total_dims becomes the offset into the sizes_strides list
- // for the _next_ tensor in the tensor_info array
- // for tensor 0, the offset is always 0
+ uint32_t total_dims; // all TensorInfoPODs are in CompleteArgumentSpec's
+ // tensor_info() array. total_dims is the total number of
+ // dimensions seen so far in all previous members of
+ // tensor_info(), including this tensor 2*total_dims
+ // becomes the offset into the sizes_strides list for the
+ // _next_ tensor in the tensor_info array for tensor 0,
+ // the offset is always 0
};
-static_assert(sizeof(CompleteArgumentInfoPOD) == sizeof(int64_t),
- "CompleteArgumentInfoPOD must be 64-bit struct for CompleteArgumentSpec encoding to work");
+static_assert(
+ sizeof(CompleteArgumentInfoPOD) == sizeof(int64_t),
+ "CompleteArgumentInfoPOD must be 64-bit struct for CompleteArgumentSpec encoding to work");
struct CompleteArgumentInfo;
struct CompleteArgumentSpec {
CompleteArgumentSpec(bool with_grad, at::ArrayRef<IValue> inputs)
- : hash_code(0), ninputs(inputs.size()) {
+ : hash_code(0), ninputs(inputs.size()) {
int32_t all_dims = 0;
const int32_t num_inputs = inputs.size();
for (int32_t i = 0; i < num_inputs; i++) {
- if (!inputs[i].isTensor()) continue;
+ if (!inputs[i].isTensor())
+ continue;
auto tensor = inputs[i].toTensor();
all_dims += tensor.defined() ? tensor.ndimension() : 0;
}
// allocate enough room for all TensorPODs and dimensions
- data.resize(ninputs + all_dims*2);
+ data.resize(ninputs + all_dims * 2);
// and reinterpret our data array as these structs
auto* pods = reinterpret_cast<CompleteArgumentInfoPOD*>(data.data());
- int64_t * next_dim = sizes_strides();
+ int64_t* next_dim = sizes_strides();
int32_t total_dims = 0;
- for(int32_t i = 0; i < num_inputs; i++) {
- auto & pod = pods[i];
+ for (int32_t i = 0; i < num_inputs; i++) {
+ auto& pod = pods[i];
pod.is_tensor = static_cast<uint32_t>(inputs[i].isTensor());
if (pod.is_tensor) {
at::Tensor t = inputs[i].toTensor();
if (pod.defined) {
pod.type = static_cast<int>(t.type().scalarType());
pod.device = (!t.is_cuda()) ? -1 : t.get_device();
- pod.requires_grad = with_grad && autograd::as_variable_ref(t).requires_grad();
+ pod.requires_grad =
+ with_grad && autograd::as_variable_ref(t).requires_grad();
total_dims += t.ndimension();
auto sizes = t.sizes();
- std::copy(sizes.begin(),sizes.end(), next_dim);
+ std::copy(sizes.begin(), sizes.end(), next_dim);
next_dim += sizes.size();
auto strides = t.strides();
std::copy(strides.begin(), strides.end(), next_dim);
// we precompute the hash_code to minimize the time inside of hash
// table operations where we may need to hold a compiler cache lock.
hash_code = hash_combine(0, ninputs);
- for(auto d : data) {
+ for (auto d : data) {
hash_code = hash_combine(hash_code, d);
}
}
// equality is fast: check ninputs, and then check the raw array data,
// there are no size/stride indirections
- bool operator==(const CompleteArgumentSpec & spec) const {
+ bool operator==(const CompleteArgumentSpec& spec) const {
return ninputs == spec.ninputs && data == spec.data;
}
- bool operator!=(const CompleteArgumentSpec & spec) const {
+ bool operator!=(const CompleteArgumentSpec& spec) const {
return !(*this == spec);
}
friend struct CompleteArgumentInfo;
return hash_code;
}
-private:
+ private:
ArrayRef<CompleteArgumentInfoPOD> tensor_info() const {
return ArrayRef<CompleteArgumentInfoPOD>(
- reinterpret_cast<const CompleteArgumentInfoPOD*>(data.data()), ninputs);
+ reinterpret_cast<const CompleteArgumentInfoPOD*>(data.data()), ninputs);
}
- // the start of the sizes_strides information, which comes after the CompleteArgumentInfoPOD list.
+ // the start of the sizes_strides information, which comes after the
+ // CompleteArgumentInfoPOD list.
const int64_t* sizes_strides() const {
return data.data() + ninputs;
}
}
size_t hash_code; // precomputed on construction
int32_t ninputs;
- // layout is ninputs of TensorPOD (each 64-bit) followed by their size and stride info
- // for 3 tensors: [t0POD][t1POD][t2POD][t0 sizes][t0 strides][t1 sizes][t1 strides][t2 sizes][t2 strides]
+ // layout is ninputs of TensorPOD (each 64-bit) followed by their size and
+ // stride info for 3 tensors:
+ // [t0POD][t1POD][t2POD]...
+ // [t0 sizes][t0 strides][t1 sizes][t1 strides][t2 sizes][t2 strides]
std::vector<int64_t> data;
};
// public view of compressed CompleteArgumentInfo
struct CompleteArgumentInfo {
- CompleteArgumentInfo(const CompleteArgumentSpec & spec, const int i)
- : spec(spec), i(i) {}
+ CompleteArgumentInfo(const CompleteArgumentSpec& spec, const int i)
+ : spec(spec), i(i) {}
bool isTensor() const {
return pod(i).is_tensor;
}
}
int ndimension() const {
// See [valid range], it is always valid to ask for offset for (i + 1)
- return (sizes_strides_offset(i + 1) - sizes_strides_offset(i))/2;
+ return (sizes_strides_offset(i + 1) - sizes_strides_offset(i)) / 2;
}
at::IntList sizes() const {
- return at::IntList(spec.sizes_strides() + sizes_strides_offset(i), ndimension());
+ return at::IntList(
+ spec.sizes_strides() + sizes_strides_offset(i), ndimension());
}
at::IntList strides() const {
int ndim = ndimension();
- return at::IntList(spec.sizes_strides() + sizes_strides_offset(i) + ndim, ndim);
+ return at::IntList(
+ spec.sizes_strides() + sizes_strides_offset(i) + ndim, ndim);
}
operator TypePtr() const {
- if(!defined())
+ if (!defined())
return DynamicType::get();
- return CompleteTensorType::create(type(), ConvertIntToCPUOrCUDA(device()), sizes(), strides());
+ return CompleteTensorType::create(
+ type(), ConvertIntToCPUOrCUDA(device()), sizes(), strides());
}
-private:
+
+ private:
// offsetinto sizes_strides() array where the sizes start for tensor j
// [valid range] valid range is [0, ninputs]
- // (i.e. you can ask for the offset at ninputs, which would be the offset of the next tensor if it existed)
+ // (i.e. you can ask for the offset at ninputs, which would be the offset of
+ // the next tensor if it existed)
int sizes_strides_offset(int j) const {
- if(j == 0) return 0;
- return 2*pod(j - 1).total_dims;
+ if (j == 0)
+ return 0;
+ return 2 * pod(j - 1).total_dims;
}
- const CompleteArgumentInfoPOD & pod(int j) const {
+ const CompleteArgumentInfoPOD& pod(int j) const {
return spec.tensor_info().at(j);
}
- const CompleteArgumentSpec & spec;
+ const CompleteArgumentSpec& spec;
const int i;
};
-inline std::ostream & operator<<(std::ostream & out, const ArgumentInfo & info) {
- if(!info.defined()) {
+inline std::ostream& operator<<(std::ostream& out, const ArgumentInfo& info) {
+ if (!info.defined()) {
return out << "<undefined>";
}
- out << "Tensor(device=" << info.device()
- << ", type=" << toString(info.type())
- << ", requires_grad=" << info.requires_grad()
- << ", dims=" << info.dim() << ")";
+ out << "Tensor(device=" << info.device() << ", type=" << toString(info.type())
+ << ", requires_grad=" << info.requires_grad() << ", dims=" << info.dim()
+ << ")";
return out;
}
-inline std::ostream& operator<<(std::ostream & out, const ArgumentSpec & spec) {
+inline std::ostream& operator<<(std::ostream& out, const ArgumentSpec& spec) {
out << "{";
- for(size_t i = 0; i < spec.size(); ++i) {
+ for (size_t i = 0; i < spec.size(); ++i) {
if (i > 0)
out << ", ";
out << spec.at(i);
return out;
}
-inline std::ostream & operator<<(std::ostream & out, const CompleteArgumentInfo & info) {
- if(!info.defined()) {
+inline std::ostream& operator<<(
+ std::ostream& out,
+ const CompleteArgumentInfo& info) {
+ if (!info.defined()) {
return out << "<undefined>";
}
- out << "Tensor(device=" << info.device()
- << ", type=" << toString(info.type())
- << ", requires_grad=" << info.requires_grad()
- << ", sizes=" << info.sizes()
- << ", strides=" << info.strides() << ")";
+ out << "Tensor(device=" << info.device() << ", type=" << toString(info.type())
+ << ", requires_grad=" << info.requires_grad()
+ << ", sizes=" << info.sizes() << ", strides=" << info.strides() << ")";
return out;
}
-inline std::ostream& operator<<(std::ostream & out, const CompleteArgumentSpec & spec) {
+inline std::ostream& operator<<(
+ std::ostream& out,
+ const CompleteArgumentSpec& spec) {
out << "{";
- for(size_t i = 0; i < spec.size(); ++i) {
+ for (size_t i = 0; i < spec.size(); ++i) {
if (i > 0)
out << ", ";
out << spec.at(i);
}
}
-}}
+} // namespace jit
+} // namespace torch
namespace std {
- template<>
- struct hash<torch::jit::ArgumentSpec> {
- size_t operator()(const torch::jit::ArgumentSpec & spec) const {
- return spec.hashCode();
- }
- };
- template<>
- struct hash<torch::jit::CompleteArgumentSpec> {
- size_t operator()(const torch::jit::CompleteArgumentSpec & spec) const {
- return spec.hashCode();
- }
- };
-}
+template <>
+struct hash<torch::jit::ArgumentSpec> {
+ size_t operator()(const torch::jit::ArgumentSpec& spec) const {
+ return spec.hashCode();
+ }
+};
+template <>
+struct hash<torch::jit::CompleteArgumentSpec> {
+ size_t operator()(const torch::jit::CompleteArgumentSpec& spec) const {
+ return spec.hashCode();
+ }
+};
+} // namespace std
#pragma once
-#include <vector>
+#include <ATen/ATen.h>
+#include <ATen/Utils.h>
#include <cstdint>
-#include <string>
#include <memory>
+#include <string>
#include <vector>
-#include <ATen/ATen.h>
-#include <ATen/Utils.h>
#include <torch/csrc/jit/assertions.h>
#include <torch/csrc/jit/interned_strings.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
constexpr int max_tensor_display_size = 10;
-enum class AttributeKind {
- f,fs,i,is,s,ss,t,ts,g,gs
-};
-static inline const char * toString(AttributeKind kind) {
- static const char* names[] = {"f","fs","i","is","s","ss","t","ts","g","gs"};
- JIT_ASSERT(size_t(kind) < sizeof(names)/sizeof(AttributeKind));
+enum class AttributeKind { f, fs, i, is, s, ss, t, ts, g, gs };
+static inline const char* toString(AttributeKind kind) {
+ static const char* names[] = {
+ "f", "fs", "i", "is", "s", "ss", "t", "ts", "g", "gs"};
+ JIT_ASSERT(size_t(kind) < sizeof(names) / sizeof(AttributeKind));
return names[int(kind)];
}
struct AttributeValue {
- AttributeValue(Symbol name)
- : name(name) {}
+ AttributeValue(Symbol name) : name(name) {}
using Ptr = std::unique_ptr<AttributeValue>;
Symbol name;
virtual AttributeKind kind() const = 0;
virtual ~AttributeValue() = default;
};
-template<typename T, AttributeKind Kind>
+template <typename T, AttributeKind Kind>
struct ScalarAttributeValue : public AttributeValue {
using ConstructorType = T;
using ValueType = T;
ScalarAttributeValue(Symbol name, ConstructorType value_)
- : AttributeValue(name), value_(std::move(value_)) {}
- ValueType & value() {
+ : AttributeValue(name), value_(std::move(value_)) {}
+ ValueType& value() {
return value_;
}
Ptr clone() const override {
return Ptr(new ScalarAttributeValue(name, value_));
}
- AttributeKind kind() const override { return Kind; }
-private:
+ AttributeKind kind() const override {
+ return Kind;
+ }
+
+ private:
ValueType value_;
};
-template<typename T, AttributeKind Kind>
+template <typename T, AttributeKind Kind>
struct VectorAttributeValue : public AttributeValue {
using ConstructorType = std::vector<T>;
using ValueType = std::vector<T>;
VectorAttributeValue(Symbol name, ConstructorType value_)
- : AttributeValue(name), value_(std::move(value_)) {}
- ValueType & value() {
+ : AttributeValue(name), value_(std::move(value_)) {}
+ ValueType& value() {
return value_;
}
- AttributeKind kind() const override { return Kind; }
+ AttributeKind kind() const override {
+ return Kind;
+ }
std::unique_ptr<AttributeValue> clone() const override {
auto copy = value_;
return Ptr(new VectorAttributeValue(name, std::move(copy)));
}
-private:
+
+ private:
ValueType value_;
};
-using FloatAttr = ScalarAttributeValue<double,AttributeKind::f>;
-using FloatsAttr = VectorAttributeValue<double,AttributeKind::fs>;
-using IntAttr = ScalarAttributeValue<int64_t,AttributeKind::i>;
-using IntsAttr = VectorAttributeValue<int64_t,AttributeKind::is>;
-using StringAttr = ScalarAttributeValue<std::string,AttributeKind::s>;
-using StringsAttr = VectorAttributeValue<std::string,AttributeKind::ss>;
-using TensorAttr = ScalarAttributeValue<at::Tensor,AttributeKind::t>;
-using TensorsAttr = VectorAttributeValue<at::Tensor,AttributeKind::ts>;
+using FloatAttr = ScalarAttributeValue<double, AttributeKind::f>;
+using FloatsAttr = VectorAttributeValue<double, AttributeKind::fs>;
+using IntAttr = ScalarAttributeValue<int64_t, AttributeKind::i>;
+using IntsAttr = VectorAttributeValue<int64_t, AttributeKind::is>;
+using StringAttr = ScalarAttributeValue<std::string, AttributeKind::s>;
+using StringsAttr = VectorAttributeValue<std::string, AttributeKind::ss>;
+using TensorAttr = ScalarAttributeValue<at::Tensor, AttributeKind::t>;
+using TensorsAttr = VectorAttributeValue<at::Tensor, AttributeKind::ts>;
struct Graph;
-using GraphAttr = ScalarAttributeValue<std::shared_ptr<Graph>,AttributeKind::g>;
-using GraphsAttr = VectorAttributeValue<std::shared_ptr<Graph>,AttributeKind::gs>;
+using GraphAttr =
+ ScalarAttributeValue<std::shared_ptr<Graph>, AttributeKind::g>;
+using GraphsAttr =
+ VectorAttributeValue<std::shared_ptr<Graph>, AttributeKind::gs>;
struct AttributeError : public std::exception {
AttributeError(Symbol name, bool defined) {
std::stringstream ss;
- if(!defined) {
- ss << "required keyword attribute '" << name.toUnqualString() << "' is undefined.";
+ if (!defined) {
+ ss << "required keyword attribute '" << name.toUnqualString()
+ << "' is undefined.";
} else {
- ss << "required keyword attribute '" << name.toUnqualString() << "' has the wrong type";
+ ss << "required keyword attribute '" << name.toUnqualString()
+ << "' has the wrong type";
}
msg = ss.str();
}
- const char* what() const noexcept override {
+ const char* what() const noexcept override {
return msg.c_str();
}
-private:
+
+ private:
std::string msg;
};
// method chaining e.g:
// Node * n = g->create(kSelect)->i_(kOffset,3)->f_(kValue,3.5);
// we return Derived* pointers because Nodes are normally held as pointers.
-template<typename Derived>
+template <typename Derived>
struct Attributes {
Attributes() = default;
- void copyAttributes(const Attributes & rhs) {
+ void copyAttributes(const Attributes& rhs) {
values_.clear();
- for(auto & i : rhs.values_) {
+ for (auto& i : rhs.values_) {
values_.push_back(i->clone());
}
}
bool hasAttribute(Symbol name) const {
JIT_ASSERT(name.is_attr());
- return find(name,false) != values_.end();
+ return find(name, false) != values_.end();
}
// We want direct string accessors, as it is nicer to use than
// hasAttribute(Symbol::attr("blah"))
}
AttributeKind kindOf(Symbol name) const {
JIT_ASSERT(name.is_attr());
- return (*find(name,true))->kind();
+ return (*find(name, true))->kind();
}
AttributeKind kindOfS(const std::string& name) const {
return kindOf(Symbol::attr(name));
}
Derived* removeAttribute(Symbol name) {
JIT_ASSERT(name.is_attr());
- values_.erase(find(name,true));
+ values_.erase(find(name, true));
return This();
}
Derived* removeAttributeS(const std::string& name) {
// The names are returned in order, since name actually is the index.
std::vector<Symbol> attributeNames() const {
std::vector<Symbol> names;
- for(auto & a : values_)
+ for (auto& a : values_)
names.push_back(a->name);
return names;
}
std::vector<const char*> attributeNamesS() const {
std::vector<const char*> names;
- for(auto & a : values_)
+ for (auto& a : values_)
names.push_back(a->name.toUnqualString());
return names;
}
- #define CREATE_ACCESSOR(Kind, method) \
+#define CREATE_ACCESSOR(Kind, method) \
Derived* method##_(Symbol name, Kind##Attr::ConstructorType v) { \
- return set<Kind##Attr>(name,std::forward<Kind##Attr::ConstructorType>(v)); \
- } \
- const Kind##Attr::ValueType& method(Symbol name) const { \
- return get<Kind##Attr>(name); \
+ return set<Kind##Attr>( \
+ name, std::forward<Kind##Attr::ConstructorType>(v)); \
+ } \
+ const Kind##Attr::ValueType& method(Symbol name) const { \
+ return get<Kind##Attr>(name); \
}
- CREATE_ACCESSOR(Float,f)
- CREATE_ACCESSOR(Floats,fs)
- CREATE_ACCESSOR(String,s)
- CREATE_ACCESSOR(Strings,ss)
- CREATE_ACCESSOR(Int,i)
- CREATE_ACCESSOR(Ints,is)
- CREATE_ACCESSOR(Graph,g)
- CREATE_ACCESSOR(Graphs,gs)
+ CREATE_ACCESSOR(Float, f)
+ CREATE_ACCESSOR(Floats, fs)
+ CREATE_ACCESSOR(String, s)
+ CREATE_ACCESSOR(Strings, ss)
+ CREATE_ACCESSOR(Int, i)
+ CREATE_ACCESSOR(Ints, is)
+ CREATE_ACCESSOR(Graph, g)
+ CREATE_ACCESSOR(Graphs, gs)
- #undef CREATE_ACCESSOR
+#undef CREATE_ACCESSOR
// Our Graphs are not very const-correct, so we need to allow returning
// non-const references too
// does not use CREATE_ACCESSOR because we need additional asserts
Derived* t_(Symbol name, TensorAttr::ConstructorType v) {
JIT_ASSERT(!v.defined() || !v.is_variable());
- return set<TensorAttr>(name,std::forward<TensorAttr::ConstructorType>(v));
+ return set<TensorAttr>(name, std::forward<TensorAttr::ConstructorType>(v));
}
const TensorAttr::ValueType& t(Symbol name) const {
return get<TensorAttr>(name);
}
Derived* ts_(Symbol name, TensorsAttr::ConstructorType v) {
- for(auto & t : v) {
+ for (auto& t : v) {
JIT_ASSERT(!t.defined() || !t.is_variable());
}
- return set<TensorsAttr>(name,std::forward<TensorsAttr::ConstructorType>(v));
+ return set<TensorsAttr>(
+ name, std::forward<TensorsAttr::ConstructorType>(v));
}
const TensorsAttr::ValueType& ts(Symbol name) const {
return get<TensorsAttr>(name);
}
- template<typename T>
- static void printPrimList(std::ostream & out, const std::vector<T> & items) {
+ template <typename T>
+ static void printPrimList(std::ostream& out, const std::vector<T>& items) {
out << "[";
int i = 0;
- for(auto & item : items) {
- if(i++ > 0)
+ for (auto& item : items) {
+ if (i++ > 0)
out << ", ";
out << item;
}
std::vector<std::string> replace = {"\\n", "\\t", "\\v"};
for (size_t i = 0; i < search.size(); i++) {
size_t pos = s.find(search[i]);
- while(pos != std::string::npos) {
+ while (pos != std::string::npos) {
s.replace(pos, 1, replace[i]);
pos = s.find(search[i], pos + 1);
}
return s;
}
- void printValue(std::ostream & out, const Symbol & name) const {
- switch(kindOf(name)) {
+ void printValue(std::ostream& out, const Symbol& name) const {
+ switch (kindOf(name)) {
case AttributeKind::f:
out << f(name);
break;
out << "\"" << escapeString(s(name)) << "\"";
break;
case AttributeKind::ss:
- printPrimList(out,ss(name));
+ printPrimList(out, ss(name));
break;
- case AttributeKind::t:
- {
- at::Tensor tensor = t(name);
- // 1-elem tensors are usually boxed scalars, so print them like it
- if (tensor.numel() == 1) {
- auto scalar_tensor = tensor.view({}).item();
- out << "{";
- if (scalar_tensor.isFloatingPoint()) {
- out << scalar_tensor.toDouble();
- } else {
- out << scalar_tensor.toLong();
- }
- out << "}";
- } else if (tensor.numel() <= max_tensor_display_size) {
- // TODO: This is awful code. Also it doesn't work on Windows.
- std::ostringstream tensor_ss;
- tensor_ss << tensor;
- std::string tensor_s{tensor_ss.str()};
- // Remove newlines
- std::replace(tensor_s.begin(), tensor_s.end(), '\n', ' ');
- out << tensor_s;
+ case AttributeKind::t: {
+ at::Tensor tensor = t(name);
+ // 1-elem tensors are usually boxed scalars, so print them like it
+ if (tensor.numel() == 1) {
+ auto scalar_tensor = tensor.view({}).item();
+ out << "{";
+ if (scalar_tensor.isFloatingPoint()) {
+ out << scalar_tensor.toDouble();
} else {
- out << "<Tensor>";
+ out << scalar_tensor.toLong();
}
- break;
+ out << "}";
+ } else if (tensor.numel() <= max_tensor_display_size) {
+ // TODO: This is awful code. Also it doesn't work on Windows.
+ std::ostringstream tensor_ss;
+ tensor_ss << tensor;
+ std::string tensor_s{tensor_ss.str()};
+ // Remove newlines
+ std::replace(tensor_s.begin(), tensor_s.end(), '\n', ' ');
+ out << tensor_s;
+ } else {
+ out << "<Tensor>";
}
+ break;
+ }
case AttributeKind::ts:
out << "[<Tensors>]";
break;
}
}
-private:
+ private:
// UBSAN error: https://github.com/pytorch/pytorch/issues/9055
Derived* This() __ubsan_ignore_vptr__ {
return static_cast<Derived*>(this);
}
- template<typename T>
+ template <typename T>
Derived* set(Symbol name, typename T::ConstructorType v) {
JIT_ASSERT(name.is_attr());
auto it = find(name, false);
auto nv = AVPtr(new T(name, std::forward<typename T::ConstructorType>(v)));
- if(it == values_.end()) {
+ if (it == values_.end()) {
values_.push_back(std::move(nv));
} else {
*it = std::move(nv);
}
return This();
}
- template<typename T>
- typename T::ValueType & get(Symbol name) const {
+ template <typename T>
+ typename T::ValueType& get(Symbol name) const {
JIT_ASSERT(name.is_attr());
auto it = find(name, true);
auto* child = dynamic_cast<T*>(it->get());
- if(child == nullptr) {
+ if (child == nullptr) {
throw AttributeError(name, true);
}
return child->value();
using iterator = std::vector<AVPtr>::iterator;
iterator find(Symbol name, bool required) {
JIT_ASSERT(name.is_attr());
- auto it = std::find_if(values_.begin(), values_.end(),[&](const AVPtr & v) {
+ auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) {
return v->name == name;
});
- if(required && it == values_.end()) {
+ if (required && it == values_.end()) {
throw AttributeError(name, false);
}
JIT_ASSERT(!required || it != values_.end());
using const_iterator = std::vector<AVPtr>::const_iterator;
const_iterator find(Symbol name, bool required) const {
JIT_ASSERT(name.is_attr());
- auto it = std::find_if(values_.begin(), values_.end(),[&](const AVPtr & v) {
+ auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) {
return v->name == name;
});
- if(required && it == values_.end()) {
+ if (required && it == values_.end()) {
throw AttributeError(name, false);
}
JIT_ASSERT(!required || it != values_.end());
}
};
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/autodiff.h>
-#include "torch/csrc/jit/passes/lower_tuples.h"
+#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
-#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
-#include "torch/csrc/jit/symbolic_script.h"
+#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/symbolic_variable.h>
-#include <torch/csrc/jit/operator.h>
#include <torch/csrc/utils/functional.h>
+#include "torch/csrc/jit/passes/lower_tuples.h"
#include "torch/csrc/jit/script/compiler.h"
+#include "torch/csrc/jit/symbolic_script.h"
#include <torch/csrc/jit/assertions.h>
#include <algorithm>
#include <memory>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
using value_map = std::unordered_map<Value*, Value*>;
using value_set = std::unordered_set<Value*>;
-void wrapDim(int64_t & dim, const std::vector<int64_t> & sizes) {
+void wrapDim(int64_t& dim, const std::vector<int64_t>& sizes) {
if (dim < 0) {
dim += sizes.size();
}
}
-bool isDifferentiable(Node * n) {
+bool isDifferentiable(Node* n) {
// TODO: scalar-tensor ops should be canonicalized
static OperatorSet differentiable_ops = {
- "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
- "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
- "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
- "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
- "aten::mul(Tensor self, Tensor other) -> Tensor",
- "aten::mul(Tensor self, Scalar other) -> Tensor",
- "aten::div(Tensor self, Tensor other) -> Tensor",
- "aten::div(Tensor self, Scalar other) -> Tensor",
- "aten::max(Tensor self, Tensor other) -> Tensor",
- "aten::min(Tensor self, Tensor other) -> Tensor",
- "aten::sigmoid(Tensor self) -> Tensor",
- "aten::tanh(Tensor self) -> Tensor",
- "aten::relu(Tensor self) -> Tensor",
- "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",
- "aten::erf(Tensor self) -> Tensor",
- "aten::erfc(Tensor self) -> Tensor",
- "aten::exp(Tensor self) -> Tensor",
- "aten::t(Tensor self) -> Tensor",
- "aten::neg(Tensor self) -> Tensor",
- "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
- "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
- "aten::type_as(Tensor self, Tensor other) -> Tensor",
- "aten::unsqueeze(Tensor self, int dim) -> Tensor",
- "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor",
- "aten::mm(Tensor self, Tensor mat2) -> Tensor",
- "aten::lt(Tensor self, Tensor other) -> Tensor",
- "aten::le(Tensor self, Tensor other) -> Tensor",
- "aten::gt(Tensor self, Tensor other) -> Tensor",
- "aten::ge(Tensor self, Tensor other) -> Tensor",
- "aten::eq(Tensor self, Tensor other) -> Tensor",
- "aten::ne(Tensor self, Tensor other) -> Tensor",
- "aten::lt(Tensor self, Scalar other) -> Tensor",
- "aten::le(Tensor self, Scalar other) -> Tensor",
- "aten::gt(Tensor self, Scalar other) -> Tensor",
- "aten::ge(Tensor self, Scalar other) -> Tensor",
- "aten::eq(Tensor self, Scalar other) -> Tensor",
- "aten::ne(Tensor self, Scalar other) -> Tensor",
- "aten::abs(Tensor self) -> Tensor",
- "aten::acos(Tensor self) -> Tensor",
- "aten::asin(Tensor self) -> Tensor",
- "aten::atan(Tensor self) -> Tensor",
- "aten::ceil(Tensor self) -> Tensor",
- "aten::cos(Tensor self) -> Tensor",
- "aten::cosh(Tensor self) -> Tensor",
- "aten::exp(Tensor self) -> Tensor",
- "aten::expm1(Tensor self) -> Tensor",
- "aten::floor(Tensor self) -> Tensor",
- "aten::fmod(Tensor self, Scalar other) -> Tensor",
- "aten::frac(Tensor self) -> Tensor",
- "aten::log(Tensor self) -> Tensor",
- "aten::log10(Tensor self) -> Tensor",
- "aten::log1p(Tensor self) -> Tensor",
- "aten::log2(Tensor self) -> Tensor",
- "aten::reciprocal(Tensor self) -> Tensor",
- "aten::remainder(Tensor self, Scalar other) -> Tensor",
- "aten::round(Tensor self) -> Tensor",
- "aten::rsqrt(Tensor self) -> Tensor",
- "aten::sin(Tensor self) -> Tensor",
- "aten::sinh(Tensor self) -> Tensor",
- "aten::tan(Tensor self) -> Tensor",
- "aten::trunc(Tensor self) -> Tensor",
- "prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)",
- "aten::log_softmax(Tensor self, int dim) -> Tensor",
- "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
- "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)",
- "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)",
- "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
+ "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
+ "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
+ "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
+ "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
+ "aten::mul(Tensor self, Tensor other) -> Tensor",
+ "aten::mul(Tensor self, Scalar other) -> Tensor",
+ "aten::div(Tensor self, Tensor other) -> Tensor",
+ "aten::div(Tensor self, Scalar other) -> Tensor",
+ "aten::max(Tensor self, Tensor other) -> Tensor",
+ "aten::min(Tensor self, Tensor other) -> Tensor",
+ "aten::sigmoid(Tensor self) -> Tensor",
+ "aten::tanh(Tensor self) -> Tensor",
+ "aten::relu(Tensor self) -> Tensor",
+ "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",
+ "aten::erf(Tensor self) -> Tensor",
+ "aten::erfc(Tensor self) -> Tensor",
+ "aten::exp(Tensor self) -> Tensor",
+ "aten::t(Tensor self) -> Tensor",
+ "aten::neg(Tensor self) -> Tensor",
+ "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
+ "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
+ "aten::type_as(Tensor self, Tensor other) -> Tensor",
+ "aten::unsqueeze(Tensor self, int dim) -> Tensor",
+ "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor",
+ "aten::mm(Tensor self, Tensor mat2) -> Tensor",
+ "aten::lt(Tensor self, Tensor other) -> Tensor",
+ "aten::le(Tensor self, Tensor other) -> Tensor",
+ "aten::gt(Tensor self, Tensor other) -> Tensor",
+ "aten::ge(Tensor self, Tensor other) -> Tensor",
+ "aten::eq(Tensor self, Tensor other) -> Tensor",
+ "aten::ne(Tensor self, Tensor other) -> Tensor",
+ "aten::lt(Tensor self, Scalar other) -> Tensor",
+ "aten::le(Tensor self, Scalar other) -> Tensor",
+ "aten::gt(Tensor self, Scalar other) -> Tensor",
+ "aten::ge(Tensor self, Scalar other) -> Tensor",
+ "aten::eq(Tensor self, Scalar other) -> Tensor",
+ "aten::ne(Tensor self, Scalar other) -> Tensor",
+ "aten::abs(Tensor self) -> Tensor",
+ "aten::acos(Tensor self) -> Tensor",
+ "aten::asin(Tensor self) -> Tensor",
+ "aten::atan(Tensor self) -> Tensor",
+ "aten::ceil(Tensor self) -> Tensor",
+ "aten::cos(Tensor self) -> Tensor",
+ "aten::cosh(Tensor self) -> Tensor",
+ "aten::exp(Tensor self) -> Tensor",
+ "aten::expm1(Tensor self) -> Tensor",
+ "aten::floor(Tensor self) -> Tensor",
+ "aten::fmod(Tensor self, Scalar other) -> Tensor",
+ "aten::frac(Tensor self) -> Tensor",
+ "aten::log(Tensor self) -> Tensor",
+ "aten::log10(Tensor self) -> Tensor",
+ "aten::log1p(Tensor self) -> Tensor",
+ "aten::log2(Tensor self) -> Tensor",
+ "aten::reciprocal(Tensor self) -> Tensor",
+ "aten::remainder(Tensor self, Scalar other) -> Tensor",
+ "aten::round(Tensor self) -> Tensor",
+ "aten::rsqrt(Tensor self) -> Tensor",
+ "aten::sin(Tensor self) -> Tensor",
+ "aten::sinh(Tensor self) -> Tensor",
+ "aten::tan(Tensor self) -> Tensor",
+ "aten::trunc(Tensor self) -> Tensor",
+ "prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)",
+ "aten::log_softmax(Tensor self, int dim) -> Tensor",
+ "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
+ "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)",
+ "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)",
+ "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
};
// TODO: add support for the following fusible operators.
- // They're a little tricky to implement; max/min require mutability for best perf
- // "aten::atan2(Tensor self) -> Tensor",
- // "aten::max(Tensor self) -> Tensor",
- // "aten::min(Tensor self) -> Tensor"
-
- if (n->kind() == prim::Constant ||
- n->kind() == prim::Undefined ||
- n->kind() == prim::AutogradAdd ||
- n->kind() == prim::ConstantChunk ||
+ // They're a little tricky to implement; max/min require mutability for best
+ // perf "aten::atan2(Tensor self) -> Tensor", "aten::max(Tensor self) ->
+ // Tensor", "aten::min(Tensor self) -> Tensor"
+
+ if (n->kind() == prim::Constant || n->kind() == prim::Undefined ||
+ n->kind() == prim::AutogradAdd || n->kind() == prim::ConstantChunk ||
n->kind() == prim::None)
return true;
if (differentiable_ops.find(n))
return true;
}
- if (n->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor")) {
- return n->get<std::vector<int64_t>>(attr::size) && n->is_constant(attr::implicit) &&
- n->namedInput(attr::self)->type()->cast<CompleteTensorType>();
+ if (n->matches(
+ "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor")) {
+ return n->get<std::vector<int64_t>>(attr::size) &&
+ n->is_constant(attr::implicit) &&
+ n->namedInput(attr::self)->type()->cast<CompleteTensorType>();
}
if (n->matches("aten::view(Tensor self, int[] size) -> Tensor")) {
return n->get<std::vector<int64_t>>(attr::size) &&
- n->namedInput(attr::self)->type()->cast<CompleteTensorType>();
+ n->namedInput(attr::self)->type()->cast<CompleteTensorType>();
}
- if (n->matches("aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
+ if (n->matches(
+ "aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
// TODO(asuhan): support weight
return n->namedInput(attr::weight)->node()->kind() == prim::Undefined;
}
return false;
}
-
-bool isDifferentiable(Graph & g) {
- return std::all_of(g.nodes().begin(), g.nodes().end(),
- static_cast<bool(*)(Node*)>(isDifferentiable));
+bool isDifferentiable(Graph& g) {
+ return std::all_of(
+ g.nodes().begin(),
+ g.nodes().end(),
+ static_cast<bool (*)(Node*)>(isDifferentiable));
}
// NB: Write gradient using torchscript
//
// Here ctx is a tuple that carries all input/intermediate results needed in
// backward from forward pass.
-// This python code is compiled into a GradientPair which includes a forward graph
-// and a backward graph. Forward graph will be used to replace the node in grad_desc.f,
-// and backward graph will be used to construct GradOf(node) in reverse_block.
-// Grad_values(a.k.a gradOutputs) propagated through node->owningGraph() in
-// **reversed** order, thus GradientPair.forward ahould be inserted **after**
-// the node being replaced, so that we don't traverse the graph infinite times.
+//
+// This python code is compiled into a GradientPair which includes a forward
+// graph and a backward graph. Forward graph will be used to replace the node in
+// grad_desc.f, and backward graph will be used to construct GradOf(node) in
+// reverse_block. Grad_values(a.k.a gradOutputs) propagated through
+// node->owningGraph() in **reversed** order, thus GradientPair.forward ahould
+// be inserted **after** the node being replaced, so that we don't traverse the
+// graph infinite times.
+//
// The output of compiled forward graph is [real_outputs, ctx]
// The input of compiled backward graph is [ctx, grad_values]
-// We run LowerSimpleTuples afterwards to elmininate all tuples generated in this process.
-// The original node and TupleConstruct nodes in forward graph will be cleaned up
-// later using EliminateDeadCode(block).
-// TupleUnPack node in backward graph will be removed in eliminateDeadcode(ReverseDetails)
-// defined in this file.
+// We run LowerSimpleTuples afterwards to elmininate all tuples generated in
+// this process. The original node and TupleConstruct nodes in forward graph
+// will be cleaned up later using EliminateDeadCode(block). TupleUnPack node in
+// backward graph will be removed in eliminateDeadcode(ReverseDetails) defined
+// in this file.
static c10::optional<std::vector<Value*>> build_script_grad(
- Node* node,
- const ArrayRef<Value*>& grads) {
+ Node* node,
+ const ArrayRef<Value*>& grads) {
auto graph = node->owningGraph();
auto compiled_graphs = gradientInfoForSchema(node->schema());
{
WithInsertPoint guard(node->next());
auto fw_graph = compiled_graphs->forward;
- new_outputs = inlineCallTo(*graph, *fw_graph, node->inputs(), /*unpack_outputs=*/true);
+ new_outputs = inlineCallTo(
+ *graph, *fw_graph, node->inputs(), /*unpack_outputs=*/true);
for (size_t i = 0; i < node->outputs().size(); ++i) {
new_outputs.at(i)->setType(node->outputs()[i]->type());
new_outputs.at(i)->replaceAllUsesWith(node->outputs()[i]);
auto it = grad_vec.begin();
grad_vec.insert(it, new_outputs.back());
ArrayRef<Value*> grad(grad_vec);
- auto grad_inputs = inlineCallTo(*graph, *bw_graph, grad, /*unpack_outputs=*/true);
+ auto grad_inputs =
+ inlineCallTo(*graph, *bw_graph, grad, /*unpack_outputs=*/true);
return grad_inputs;
};
-static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_values) {
- static const OperatorSet comparison_ops = {
- "aten::lt(Tensor self, Tensor other) -> Tensor",
- "aten::le(Tensor self, Tensor other) -> Tensor",
- "aten::gt(Tensor self, Tensor other) -> Tensor",
- "aten::ge(Tensor self, Tensor other) -> Tensor",
- "aten::eq(Tensor self, Tensor other) -> Tensor",
- "aten::ne(Tensor self, Tensor other) -> Tensor",
- "aten::lt(Tensor self, Scalar other) -> Tensor",
- "aten::le(Tensor self, Scalar other) -> Tensor",
- "aten::gt(Tensor self, Scalar other) -> Tensor",
- "aten::ge(Tensor self, Scalar other) -> Tensor",
- "aten::eq(Tensor self, Scalar other) -> Tensor",
- "aten::ne(Tensor self, Scalar other) -> Tensor",
- };
- const auto sumToSizeOf = [node](SymbolicVariable v, Symbol input_name) -> SymbolicVariable {
- Value * size;
+namespace {
+class GradientHelper {
+ public:
+ GradientHelper(Node* n) : node(n) {}
+
+ std::vector<Value*> gradient(ArrayRef<Value*> grad_values) {
+ if (!isDifferentiable(node)) {
+ throw std::runtime_error(
+ std::string("differentiation of ") + node->kind().toDisplayString() +
+ " is not supported, or it is missing necessary type information");
+ }
+ // If AD is defined using torchscript, use it instead of symbolic
+ auto script_grads = build_script_grad(node, grad_values);
+ if (script_grads)
+ return *script_grads;
+ // Definition not found in torchscript, look up in the buildSymbolicGradient
+ // TODO: migrate all to using torchscript
+ auto sym_grads = buildSymbolicGradient(fmap<SymbolicVariable>(grad_values));
+ return fmap(sym_grads, [](const SymbolicVariable& v) { return v.value(); });
+ }
+
+ private:
+ Node* node;
+
+ SymbolicVariable sumToSizeOf(SymbolicVariable v, Symbol input_name) {
+ Value* size;
{
- WithInsertPoint insert_guard {node};
+ WithInsertPoint insert_guard{node};
size = SymbolicVariable(node->namedInput(input_name)).size();
}
return v.sumToSize(size);
};
- const auto build_sym_grad = [node, &sumToSizeOf](const std::vector<SymbolicVariable>& grads) -> std::vector<SymbolicVariable> {
+
+ std::vector<SymbolicVariable> buildSymbolicGradient(
+ const std::vector<SymbolicVariable>& grads) {
+ static const OperatorSet comparison_ops = {
+ "aten::lt(Tensor self, Tensor other) -> Tensor",
+ "aten::le(Tensor self, Tensor other) -> Tensor",
+ "aten::gt(Tensor self, Tensor other) -> Tensor",
+ "aten::ge(Tensor self, Tensor other) -> Tensor",
+ "aten::eq(Tensor self, Tensor other) -> Tensor",
+ "aten::ne(Tensor self, Tensor other) -> Tensor",
+ "aten::lt(Tensor self, Scalar other) -> Tensor",
+ "aten::le(Tensor self, Scalar other) -> Tensor",
+ "aten::gt(Tensor self, Scalar other) -> Tensor",
+ "aten::ge(Tensor self, Scalar other) -> Tensor",
+ "aten::eq(Tensor self, Scalar other) -> Tensor",
+ "aten::ne(Tensor self, Scalar other) -> Tensor",
+ };
auto inputs = fmap<SymbolicVariable>(node->inputs());
auto outputs = fmap<SymbolicVariable>(node->outputs());
- if (node->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
- return {sumToSizeOf(grads.at(0), attr::self),
- sumToSizeOf(grads.at(0) * node->namedInput(attr::alpha), attr::other),
- nullptr};
+ if (node->matches(
+ "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
+ return {
+ sumToSizeOf(grads.at(0), attr::self),
+ sumToSizeOf(grads.at(0) * node->namedInput(attr::alpha), attr::other),
+ nullptr};
- } else if (node->matches("aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor")) {
+ } else if (
+ node->matches(
+ "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor")) {
return {grads.at(0), nullptr, nullptr};
} else if (node->kind() == prim::AutogradAdd) {
// NB: AutogradAdds don't broadcast
return {grads.at(0), grads.at(0)};
- } else if (node->matches("aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
+ } else if (
+ node->matches(
+ "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
return {sumToSizeOf(grads.at(0), attr::self),
- sumToSizeOf(-grads.at(0) * node->namedInput(attr::alpha), attr::other),
+ sumToSizeOf(
+ -grads.at(0) * node->namedInput(attr::alpha), attr::other),
nullptr};
- } else if (node->matches("aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor")) {
+ } else if (
+ node->matches(
+ "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor")) {
return {grads.at(0), nullptr, nullptr};
- } else if (node->matches("aten::mul(Tensor self, Tensor other) -> Tensor")) {
+ } else if (node->matches(
+ "aten::mul(Tensor self, Tensor other) -> Tensor")) {
return {sumToSizeOf(grads.at(0) * inputs.at(1), attr::self),
sumToSizeOf(grads.at(0) * inputs.at(0), attr::other)};
- } else if (node->matches("aten::mul(Tensor self, Scalar other) -> Tensor")) {
+ } else if (node->matches(
+ "aten::mul(Tensor self, Scalar other) -> Tensor")) {
return {grads.at(0) * inputs.at(1), nullptr};
- } else if (node->matches("aten::div(Tensor self, Tensor other) -> Tensor")) {
+ } else if (node->matches(
+ "aten::div(Tensor self, Tensor other) -> Tensor")) {
return {sumToSizeOf(grads.at(0) / inputs.at(1), attr::self),
- sumToSizeOf(-grads.at(0) * inputs.at(0) / (inputs.at(1) * inputs.at(1)), attr::other)};
+ sumToSizeOf(
+ -grads.at(0) * inputs.at(0) / (inputs.at(1) * inputs.at(1)),
+ attr::other)};
- } else if (node->matches("aten::div(Tensor self, Scalar other) -> Tensor")) {
+ } else if (node->matches(
+ "aten::div(Tensor self, Scalar other) -> Tensor")) {
return {grads.at(0) / inputs.at(1), nullptr};
- } else if (node->matches("aten::max(Tensor self, Tensor other) -> Tensor")) {
- return {sumToSizeOf(grads.at(0) * (inputs.at(0) > inputs.at(1)).type_as(grads.at(0)), attr::self),
- sumToSizeOf(grads.at(0) * (inputs.at(1) > inputs.at(0)).type_as(grads.at(0)), attr::other)};
-
- } else if (node->matches("aten::min(Tensor self, Tensor other) -> Tensor")) {
- return {sumToSizeOf(grads.at(0) * (inputs.at(0) < inputs.at(1)).type_as(grads.at(0)), attr::self),
- sumToSizeOf(grads.at(0) * (inputs.at(1) < inputs.at(0)).type_as(grads.at(0)), attr::other)};
-
- } else if (node->matches("aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor")) {
+ } else if (node->matches(
+ "aten::max(Tensor self, Tensor other) -> Tensor")) {
+ return {
+ sumToSizeOf(
+ grads.at(0) * (inputs.at(0) > inputs.at(1)).type_as(grads.at(0)),
+ attr::self),
+ sumToSizeOf(
+ grads.at(0) * (inputs.at(1) > inputs.at(0)).type_as(grads.at(0)),
+ attr::other)};
+
+ } else if (node->matches(
+ "aten::min(Tensor self, Tensor other) -> Tensor")) {
+ return {
+ sumToSizeOf(
+ grads.at(0) * (inputs.at(0) < inputs.at(1)).type_as(grads.at(0)),
+ attr::self),
+ sumToSizeOf(
+ grads.at(0) * (inputs.at(1) < inputs.at(0)).type_as(grads.at(0)),
+ attr::other)};
+
+ } else if (
+ node->matches(
+ "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor")) {
return {nullptr,
- sumToSizeOf(grads.at(0) * inputs.at(0).type_as(grads.at(0)), attr::self),
- sumToSizeOf(grads.at(0) * (1 - inputs.at(0)).type_as(grads.at(0)), attr::other)};
+ sumToSizeOf(
+ grads.at(0) * inputs.at(0).type_as(grads.at(0)), attr::self),
+ sumToSizeOf(
+ grads.at(0) * (1 - inputs.at(0)).type_as(grads.at(0)),
+ attr::other)};
} else if (node->matches("aten::sigmoid(Tensor self) -> Tensor")) {
// TODO: The order of operations matter in this case. This
return {grads.at(0) * (1 - outputs.at(0) * outputs.at(0))};
} else if (node->matches("aten::relu(Tensor self) -> Tensor")) {
- return {grads.at(0) * (outputs.at(0) > at::Scalar(0)).type_as(outputs.at(0))};
+ return {grads.at(0) *
+ (outputs.at(0) > at::Scalar(0)).type_as(outputs.at(0))};
- } else if (node->matches("aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor")) {
+ } else if (
+ node->matches(
+ "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor")) {
// handle the case that min/max is None
Value* min = inputs.at(1);
bool min_must_be_none = min->node()->kind() == prim::None;
Value* max = inputs.at(2);
bool max_must_be_none = max->node()->kind() == prim::None;
// XXX - this formula is wrong when min or max are not stricly prim::None
- // but may be None dynamically. In this case an internal compiler error will
- // get thrown when trying to generate expressions involving the values of min/max
+ // but may be None dynamically. In this case an internal compiler error
+ // will get thrown when trying to generate expressions involving the
+ // values of min/max
if (!min_must_be_none && !max_must_be_none) {
- return {grads.at(0)
- * (1-(inputs.at(0) <= inputs.at(1)).type_as(inputs.at(0)))
- * (1-(inputs.at(0) >= inputs.at(2)).type_as(inputs.at(0))), nullptr, nullptr};
+ return {grads.at(0) *
+ (1 - (inputs.at(0) <= inputs.at(1)).type_as(inputs.at(0))) *
+ (1 - (inputs.at(0) >= inputs.at(2)).type_as(inputs.at(0))),
+ nullptr,
+ nullptr};
} else if (max_must_be_none) {
- return {grads.at(0)
- * (1-(inputs.at(0) <= inputs.at(1)).type_as(inputs.at(0))), nullptr, nullptr};
+ return {grads.at(0) *
+ (1 - (inputs.at(0) <= inputs.at(1)).type_as(inputs.at(0))),
+ nullptr,
+ nullptr};
} else if (min_must_be_none) {
- return {grads.at(0)
- * (1-(inputs.at(0) >= inputs.at(2)).type_as(inputs.at(0))), nullptr, nullptr};
+ return {grads.at(0) *
+ (1 - (inputs.at(0) >= inputs.at(2)).type_as(inputs.at(0))),
+ nullptr,
+ nullptr};
} else {
return {grads.at(0), nullptr, nullptr};
}
- } else if (node->matches("aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor")) {
+ } else if (
+ node->matches(
+ "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor")) {
auto threshold = node->get<at::Scalar>(attr::threshold).value();
- return {grads.at(0) * (inputs.at(0) > threshold).type_as(outputs.at(0)), nullptr, nullptr};
+ return {grads.at(0) * (inputs.at(0) > threshold).type_as(outputs.at(0)),
+ nullptr,
+ nullptr};
} else if (node->matches("aten::erf(Tensor self) -> Tensor")) {
- return {grads.at(0) * 1.12837916709551 * (-inputs.at(0) * inputs.at(0)).exp()};
+ return {grads.at(0) * 1.12837916709551 *
+ (-inputs.at(0) * inputs.at(0)).exp()};
} else if (node->matches("aten::erfc(Tensor self) -> Tensor")) {
- return {-grads.at(0) * 1.12837916709551 * (-inputs.at(0) * inputs.at(0)).exp()};
+ return {-grads.at(0) * 1.12837916709551 *
+ (-inputs.at(0) * inputs.at(0)).exp()};
} else if (node->matches("aten::exp(Tensor self) -> Tensor")) {
return {grads.at(0) * (outputs.at(0))};
return {grads.at(0) * inputs.at(0).sign()};
} else if (node->matches("aten::acos(Tensor self) -> Tensor")) {
- return {grads.at(0) * -((-inputs.at(0) * inputs.at(0) + at::Scalar(1)).rsqrt())};
+ return {grads.at(0) *
+ -((-inputs.at(0) * inputs.at(0) + at::Scalar(1)).rsqrt())};
} else if (node->matches("aten::asin(Tensor self) -> Tensor")) {
- return {grads.at(0) * (-inputs.at(0) * inputs.at(0) + at::Scalar(1)).rsqrt()};
+ return {grads.at(0) *
+ (-inputs.at(0) * inputs.at(0) + at::Scalar(1)).rsqrt()};
} else if (node->matches("aten::atan(Tensor self) -> Tensor")) {
return {grads.at(0) / (inputs.at(0) * inputs.at(0) + at::Scalar(1))};
- } else if (node->matches("prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)")) {
- Value * self_size;
+ } else if (
+ node->matches(
+ "prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)")) {
+ Value* self_size;
{
- WithInsertPoint insert_guard { node };
+ WithInsertPoint insert_guard{node};
self_size = inputs.at(0).size();
}
return {grads.at(0).expand(self_size), nullptr};
} else if (node->matches("aten::floor(Tensor self) -> Tensor")) {
return {SymbolicVariable::zeros_like(grads.at(0))};
- } else if (node->matches("aten::fmod(Tensor self, Scalar other) -> Tensor")) {
+ } else if (node->matches(
+ "aten::fmod(Tensor self, Scalar other) -> Tensor")) {
return {grads.at(0), nullptr};
} else if (node->matches("aten::frac(Tensor self) -> Tensor")) {
} else if (node->matches("aten::reciprocal(Tensor self) -> Tensor")) {
return {-grads.at(0) * outputs.at(0) * outputs.at(0)};
- } else if (node->matches("aten::remainder(Tensor self, Scalar other) -> Tensor")) {
+ } else if (node->matches(
+ "aten::remainder(Tensor self, Scalar other) -> Tensor")) {
return {grads.at(0), nullptr};
} else if (node->matches("aten::round(Tensor self) -> Tensor")) {
} else if (node->kind() == prim::ConstantChunk) {
return {SymbolicVariable::cat(grads, node->i(attr::dim))};
- } else if (node->matches("aten::view(Tensor self, int[] size) -> Tensor") ||
- node->matches("aten::reshape(Tensor self, int[] shape) -> Tensor")) {
- // TODO: if sizes are not available statically, add an operator that reutrns them as a tuple
- auto sizes = node->namedInput(attr::self)->type()->expect<CompleteTensorType>()->sizes();
+ } else if (
+ node->matches("aten::view(Tensor self, int[] size) -> Tensor") ||
+ node->matches("aten::reshape(Tensor self, int[] shape) -> Tensor")) {
+ // TODO: if sizes are not available statically, add an operator that
+ // reutrns them as a tuple
+ auto sizes = node->namedInput(attr::self)
+ ->type()
+ ->expect<CompleteTensorType>()
+ ->sizes();
return {grads.at(0).reshape(sizes), nullptr};
- } else if (node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) {
+ } else if (node->matches(
+ "aten::type_as(Tensor self, Tensor other) -> Tensor")) {
return {grads.at(0).type_as(inputs.at(0)), nullptr};
- } else if (node->matches("aten::unsqueeze(Tensor self, int dim) -> Tensor")) {
+ } else if (node->matches(
+ "aten::unsqueeze(Tensor self, int dim) -> Tensor")) {
return {grads.at(0).squeeze(node->namedInput(attr::dim)), nullptr};
- } else if (node->matches("aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor")) {
- return {sumToSizeOf(grads.at(0) * node->namedInput(attr::beta), attr::self),
- grads.at(0).mm(inputs.at(2).t()) * node->namedInput(attr::alpha),
- inputs.at(1).t().mm(grads.at(0)) * node->namedInput(attr::alpha),
- nullptr, nullptr};
+ } else if (
+ node->matches(
+ "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor")) {
+ return {
+ sumToSizeOf(grads.at(0) * node->namedInput(attr::beta), attr::self),
+ grads.at(0).mm(inputs.at(2).t()) * node->namedInput(attr::alpha),
+ inputs.at(1).t().mm(grads.at(0)) * node->namedInput(attr::alpha),
+ nullptr,
+ nullptr};
} else if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
- return {grads.at(0).mm(inputs.at(1).t()), inputs.at(0).t().mm(grads.at(0))};
+ return {grads.at(0).mm(inputs.at(1).t()),
+ inputs.at(0).t().mm(grads.at(0))};
- } else if (node->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor")) {
+ } else if (
+ node->matches(
+ "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor")) {
const auto& input_sizes = inputs.at(0).sizes();
if (input_sizes.size() == 0)
return {grads.at(0).sum(), nullptr, nullptr};
const auto& sizes = inputs.at(0).sizes();
std::vector<size_t> squeezed_dims;
for (size_t i = 0; i < sizes.size(); ++i) {
- if (sizes[i] != 1) continue;
+ if (sizes[i] != 1)
+ continue;
squeezed_dims.push_back(i);
}
SymbolicVariable returned_grad = grads.at(0);
}
return {returned_grad};
- } else if (node->matches("aten::squeeze(Tensor self, int dim) -> Tensor", /*const_inputs=*/attr::dim)) {
+ } else if (node->matches(
+ "aten::squeeze(Tensor self, int dim) -> Tensor",
+ /*const_inputs=*/attr::dim)) {
int64_t dim = *node->get<int64_t>(attr::dim);
const auto& sizes = inputs.at(0).sizes();
wrapDim(dim, sizes);
- if (sizes.size() == 0) {
+ if (sizes.size() == 0) {
return {grads.at(0), nullptr};
}
- return {sizes.at(dim) > 1 ? grads.at(0) : grads.at(0).unsqueeze(dim), nullptr};
+ return {sizes.at(dim) > 1 ? grads.at(0) : grads.at(0).unsqueeze(dim),
+ nullptr};
- } else if (node->matches("aten::cat(Tensor[] tensors, int dim) -> Tensor", /*const_inputs=*/attr::dim)) {
+ } else if (node->matches(
+ "aten::cat(Tensor[] tensors, int dim) -> Tensor",
+ /*const_inputs=*/attr::dim)) {
int dim = *node->get<int64_t>(attr::dim);
- auto tensor_inputs = inputs; tensor_inputs.pop_back();
+ auto tensor_inputs = inputs;
+ tensor_inputs.pop_back();
const auto& first_sizes = tensor_inputs.at(0).sizes();
const auto has_first_sizes = [&first_sizes](SymbolicVariable var) {
return var.sizes() == first_sizes;
// NB: this is a specialization for the common case where all inputs are
// of equal sizes. We can use a single split operation to handle that.
- if (std::all_of(tensor_inputs.begin(), tensor_inputs.end(), has_first_sizes)) {
+ if (std::all_of(
+ tensor_inputs.begin(), tensor_inputs.end(), has_first_sizes)) {
auto tensor_grads = grads.at(0).chunk(tensor_inputs.size(), dim);
tensor_grads.emplace_back(nullptr); // for attr::dim
return tensor_grads;
} else if (comparison_ops.find(node)) {
return {nullptr, nullptr};
- } else if (node->matches("aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor")) {
+ } else if (
+ node->matches(
+ "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor")) {
JIT_ASSERT(grads.size() == 1);
auto graph = node->owningGraph();
- auto backward_value = graph->insert(aten::avg_pool2d_backward, {
- grads.at(0).value(),
- node->namedInput(attr::self),
- node->namedInput(attr::kernel_size),
- node->namedInput(attr::stride),
- node->namedInput(attr::padding),
- node->namedInput(attr::ceil_mode),
- node->namedInput(attr::count_include_pad)});
- return {backward_value->node()->output(0), nullptr, nullptr, nullptr, nullptr, nullptr};
-
- } else if (node->matches("aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)")) {
+ auto backward_value = graph->insert(
+ aten::avg_pool2d_backward,
+ {grads.at(0).value(),
+ node->namedInput(attr::self),
+ node->namedInput(attr::kernel_size),
+ node->namedInput(attr::stride),
+ node->namedInput(attr::padding),
+ node->namedInput(attr::ceil_mode),
+ node->namedInput(attr::count_include_pad)});
+ return {backward_value->node()->output(0),
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr};
+
+ } else if (
+ node->matches(
+ "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)")) {
JIT_ASSERT(grads.size() == 2);
auto graph = node->owningGraph();
- auto backward_value = graph->insert(aten::max_pool2d_with_indices_backward, {
- grads.at(0).value(),
- node->namedInput(attr::self),
- node->namedInput(attr::kernel_size),
- node->namedInput(attr::stride),
- node->namedInput(attr::padding),
- node->namedInput(attr::dilation),
- node->namedInput(attr::ceil_mode),
- outputs.at(1).value()
- });
- return {backward_value->node()->output(0), nullptr, nullptr, nullptr, nullptr, nullptr};
-
- } else if (node->matches("aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)")) {
+ auto backward_value = graph->insert(
+ aten::max_pool2d_with_indices_backward,
+ {grads.at(0).value(),
+ node->namedInput(attr::self),
+ node->namedInput(attr::kernel_size),
+ node->namedInput(attr::stride),
+ node->namedInput(attr::padding),
+ node->namedInput(attr::dilation),
+ node->namedInput(attr::ceil_mode),
+ outputs.at(1).value()});
+ return {backward_value->node()->output(0),
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr};
+
+ } else if (
+ node->matches(
+ "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)")) {
auto graph = node->owningGraph();
- auto backward_value = graph->insert(aten::thnn_conv2d_backward, {
- grads.at(0).value(),
- inputs.at(0).value(),
- inputs.at(1).value(),
- node->namedInput(attr::kernel_size),
- node->namedInput(attr::stride),
- node->namedInput(attr::padding),
- outputs.at(1).value(),
- outputs.at(2).value(),
- graph->insertConstant(std::vector<bool>{true, true, true})
- });
- // graph->insert returns a tuple automatically if multiple outputs are returned. So unpack them again.
- Node* tuple_unpack_node = graph->insertNode(graph->createTupleUnpack(backward_value));
+ auto backward_value = graph->insert(
+ aten::thnn_conv2d_backward,
+ {grads.at(0).value(),
+ inputs.at(0).value(),
+ inputs.at(1).value(),
+ node->namedInput(attr::kernel_size),
+ node->namedInput(attr::stride),
+ node->namedInput(attr::padding),
+ outputs.at(1).value(),
+ outputs.at(2).value(),
+ graph->insertConstant(std::vector<bool>{true, true, true})});
+ // graph->insert returns a tuple automatically if multiple outputs are
+ // returned. So unpack them again.
+ Node* tuple_unpack_node =
+ graph->insertNode(graph->createTupleUnpack(backward_value));
auto tuple_outputs = tuple_unpack_node->outputs();
JIT_ASSERT(tuple_outputs.size() == size_t(3));
- return {tuple_outputs[0], tuple_outputs[1], nullptr, tuple_outputs[2], nullptr, nullptr};
+ return {tuple_outputs[0],
+ tuple_outputs[1],
+ nullptr,
+ tuple_outputs[2],
+ nullptr,
+ nullptr};
- } else if (node->matches("aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)")) {
+ } else if (
+ node->matches(
+ "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)")) {
auto graph = node->owningGraph();
- auto backward_value = graph->insert(aten::native_batch_norm_backward, {
- grads.at(0).value(),
- inputs.at(0).value(),
- inputs.at(1).value(),
- inputs.at(3).value(),
- inputs.at(4).value(),
- outputs.at(1).value(),
- outputs.at(2).value(),
- inputs.at(5).value(),
- inputs.at(7).value(),
- graph->insertConstant(std::vector<bool>{true, true, true})
- });
- // graph->insert returns a tuple automatically if multiple outputs are returned. So unpack them again.
- Node* tuple_unpack_node = graph->insertNode(graph->createTupleUnpack(backward_value));
+ auto backward_value = graph->insert(
+ aten::native_batch_norm_backward,
+ {grads.at(0).value(),
+ inputs.at(0).value(),
+ inputs.at(1).value(),
+ inputs.at(3).value(),
+ inputs.at(4).value(),
+ outputs.at(1).value(),
+ outputs.at(2).value(),
+ inputs.at(5).value(),
+ inputs.at(7).value(),
+ graph->insertConstant(std::vector<bool>{true, true, true})});
+ // graph->insert returns a tuple automatically if multiple outputs are
+ // returned. So unpack them again.
+ Node* tuple_unpack_node =
+ graph->insertNode(graph->createTupleUnpack(backward_value));
auto tuple_outputs = tuple_unpack_node->outputs();
JIT_ASSERT(tuple_outputs.size() == size_t(3));
- return {tuple_outputs[0], tuple_outputs[1], tuple_outputs[2], nullptr, nullptr, nullptr, nullptr, nullptr};
+ return {tuple_outputs[0],
+ tuple_outputs[1],
+ tuple_outputs[2],
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr};
- } else if (node->matches("aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
+ } else if (
+ node->matches(
+ "aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
auto graph = node->owningGraph();
auto total_weight = graph->insertNode(graph->createUndefined());
auto weight = graph->insertNode(graph->createUndefined());
- auto backward_value = graph->insert(aten::nll_loss_backward, {
- grads.at(0).value(),
- inputs.at(0).value(),
- inputs.at(1).value(),
- weight->output(),
- inputs.at(3).value(),
- inputs.at(4).value(),
- total_weight->output()
- });
- return {backward_value->node()->output(0), nullptr, nullptr, nullptr, nullptr};
-
- } else if (node->matches("aten::log_softmax(Tensor self, int dim) -> Tensor")) {
+ auto backward_value = graph->insert(
+ aten::nll_loss_backward,
+ {grads.at(0).value(),
+ inputs.at(0).value(),
+ inputs.at(1).value(),
+ weight->output(),
+ inputs.at(3).value(),
+ inputs.at(4).value(),
+ total_weight->output()});
+ return {backward_value->node()->output(0),
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr};
+
+ } else if (node->matches(
+ "aten::log_softmax(Tensor self, int dim) -> Tensor")) {
JIT_ASSERT(grads.size() == 1);
auto graph = node->owningGraph();
- auto backward_value = graph->insert(aten::_log_softmax_backward_data, {
- grads.at(0).value(),
- outputs.at(0).value(),
- node->namedInput(attr::dim),
- node->namedInput(attr::self)
- });
+ auto backward_value = graph->insert(
+ aten::_log_softmax_backward_data,
+ {grads.at(0).value(),
+ outputs.at(0).value(),
+ node->namedInput(attr::dim),
+ node->namedInput(attr::self)});
return {backward_value->node()->output(0), nullptr};
- } else if (node->kind() == prim::Constant || node->kind() == prim::Undefined || node->kind() == prim::None) {
+ } else if (
+ node->kind() == prim::Constant || node->kind() == prim::Undefined ||
+ node->kind() == prim::None) {
return {};
}
- throw std::runtime_error(std::string("failed to differentiate `") + node->kind().toDisplayString() + "`");
- };
- if (!isDifferentiable(node)) {
- throw std::runtime_error(std::string("differentiation of ") + node->kind().toDisplayString() + " "
- "is not supported, or it is missing necessary type information");
+ throw std::runtime_error(
+ std::string("failed to differentiate `") +
+ node->kind().toDisplayString() + "`");
}
- // If AD is defined using torchscript, use it instead of symbolic
- auto script_grads = build_script_grad(node, grad_values);
- if (script_grads)
- return *script_grads;
- // Definition not found in torchscript, look up in the build_sym_grad
- // TODO: migrate all to using torchscript
- auto sym_grads = build_sym_grad(fmap<SymbolicVariable>(grad_values));
- return fmap(sym_grads, [](const SymbolicVariable &v) { return v.value(); });
-}
-
-// If we have a function y = f(x) with jacobian J, the backwards of f is dx = J^t dy.
-// Note that because the backwards always implements this matrix multiply,
-// we know that it maps an input vector of zeros to an output vector of zero
-// regardless of what operations it choses to do inside to actually implement
-// the matrix multiply (most use some optimized form and never generate J^t).
-// More generally, we know that all of the backward computations are linear and
-// can use this property to do more aggressive optimizations later.
-// It is ok to replace any backward function with known-zero inputs with something
-// that produces known-zero outputs. This function encloses each know-linear
-// backward function in a 'GradOf' sub-block so that we can perform optimizations
-// using this information. In particular, specializeUndef will observe if
-// all the inputs to the linear block are Undef, which the autograd uses to represent
-// zeros, and then propagate the undefs to the outputs of the block.
-static std::vector<Value*> linearGradientForNode(Node* node, ArrayRef<Value*> grad_values) {
- auto & graph = *node->owningGraph();
+};
+} // namespace
+
+// If we have a function y = f(x) with jacobian J, the backwards of f is dx =
+// J^t dy. Note that because the backwards always implements this matrix
+// multiply, we know that it maps an input vector of zeros to an output vector
+// of zero regardless of what operations it choses to do inside to actually
+// implement the matrix multiply (most use some optimized form and never
+// generate J^t). More generally, we know that all of the backward computations
+// are linear and can use this property to do more aggressive optimizations
+// later. It is ok to replace any backward function with known-zero inputs with
+// something that produces known-zero outputs. This function encloses each
+// know-linear backward function in a 'GradOf' sub-block so that we can perform
+// optimizations using this information. In particular, specializeUndef will
+// observe if all the inputs to the linear block are Undef, which the autograd
+// uses to represent zeros, and then propagate the undefs to the outputs of the
+// block.
+static std::vector<Value*> linearGradientForNode(
+ Node* node,
+ ArrayRef<Value*> grad_values) {
+ auto& graph = *node->owningGraph();
auto linear = graph.insertNode(graph.create(prim::GradOf, {grad_values}, 0));
// to make reading gradient graphs easier, remember the name of the forward op
linear->s_(attr::name, node->kind().toDisplayString());
auto block = linear->addBlock();
WithInsertPoint guard(block);
- auto results = gradientForNode(node, grad_values);
- return fmap(results, [block, linear](Value *grad) -> Value* {
- if (!grad) return nullptr;
+ auto results = GradientHelper(node).gradient(grad_values);
+ return fmap(results, [block, linear](Value* grad) -> Value* {
+ if (!grad)
+ return nullptr;
block->registerOutput(grad);
return linear->addOutput()->copyMetadata(grad);
});
}
struct ReverseDetails {
- ReverseDetails(value_map&& grad_map, Block * reverse_block)
- : grad_map(std::move(grad_map))
- , reverse_block(reverse_block) {}
+ ReverseDetails(value_map&& grad_map, Block* reverse_block)
+ : grad_map(std::move(grad_map)), reverse_block(reverse_block) {}
value_map grad_map;
- Block * reverse_block;
+ Block* reverse_block;
};
// AutogradAdd is a special addition function that handles Undef
// - grad_desc has df_input_vjps and df_output_vjps set
// (but df_input_vjps will be modified later as well)
static ReverseDetails addReverseInline(Gradient& grad_desc) {
- auto & graph = *grad_desc.f;
+ auto& graph = *grad_desc.f;
// note: reverse_node is intentionally not inserted to avoid
// accidentally acting on it (e.g. in elminate dead code),
// std::cout << *reverse_node << to view its state.
}
return it->second;
};
- const auto set_grad = [&](Value *x, Value *dx) {
- if (Value * prev_grad = grad_map[x]) {
+ const auto set_grad = [&](Value* x, Value* dx) {
+ if (Value* prev_grad = grad_map[x]) {
grad_map[x] = createAutogradAdd(prev_grad, dx);
} else {
grad_map[x] = dx;
auto outputs = graph.outputs();
for (size_t i = 0, num_outputs = outputs.size(); i < num_outputs; ++i) {
- Value * output = outputs[i];
+ Value* output = outputs[i];
if (!output->requires_grad())
continue;
- Value * output_grad = reverse_block->addInput()->setType(output->type());
+ Value* output_grad = reverse_block->addInput()->setType(output->type());
set_grad(output, output_grad);
grad_desc.df_input_vjps.push_back(i);
}
- for (auto it = graph.nodes().rbegin(), end = graph.nodes().rend(); it != end; ++it) {
- Node *node = *it;
+ for (auto it = graph.nodes().rbegin(), end = graph.nodes().rend(); it != end;
+ ++it) {
+ Node* node = *it;
auto inputs = node->inputs();
auto outputs = node->outputs();
- if (std::all_of(outputs.begin(), outputs.end(), [](Value *v) { return !v->requires_grad(); })) {
+ if (std::all_of(outputs.begin(), outputs.end(), [](Value* v) {
+ return !v->requires_grad();
+ })) {
continue;
}
- value_list grad_inputs = linearGradientForNode(node, fmap(node->outputs(), get_grad));
+ value_list grad_inputs =
+ linearGradientForNode(node, fmap(node->outputs(), get_grad));
LowerSimpleTuples(reverse_block);
JIT_ASSERT(grad_inputs.size() == node->inputs().size());
for (size_t i = 0, num_inputs = grad_inputs.size(); i < num_inputs; ++i) {
- if (!inputs[i]->requires_grad()) continue;
- // NB: Not returning a gradient w.r.t. a value that requires grad is normal if the
- // input is non-differentiable. This happens e.g. in the aten::type_as case.
- if (!grad_inputs[i]) continue;
+ if (!inputs[i]->requires_grad())
+ continue;
+ // NB: Not returning a gradient w.r.t. a value that requires grad is
+ // normal if the input is non-differentiable. This happens e.g. in the
+ // aten::type_as case.
+ if (!grad_inputs[i])
+ continue;
set_grad(inputs[i], grad_inputs[i]);
}
}
auto inputs = graph.inputs();
for (size_t i = 0, num_inputs = inputs.size(); i < num_inputs; ++i) {
- Value * input = inputs[i];
+ Value* input = inputs[i];
if (!input->requires_grad())
continue;
- // NB: Not having a gradient defined w.r.t. an input to the graph which requires grad
- // can happen and is not an error. It might have been used only in non-differentiable
- // contexts (e.g. as second input to aten::type_as). In that case we simply ignore it
- // as an output, because it won't ever produce any meaningful values.
- if (grad_map.count(input) == 0) continue;
+ // NB: Not having a gradient defined w.r.t. an input to the graph which
+ // requires grad can happen and is not an error. It might have been used
+ // only in non-differentiable contexts (e.g. as second input to
+ // aten::type_as). In that case we simply ignore it as an output, because it
+ // won't ever produce any meaningful values.
+ if (grad_map.count(input) == 0)
+ continue;
reverse_block->registerOutput(get_grad(input));
grad_desc.df_output_vjps.push_back(i);
}
return ReverseDetails(std::move(grad_map), reverse_block);
}
-// Returns a topologically-sorted list of values produced in f, and used in its reverse program.
+// Returns a topologically-sorted list of values produced in f, and used in its
+// reverse program.
static value_list getReverseCaptures(Gradient& grad_desc) {
- auto & graph = *grad_desc.f;
+ auto& graph = *grad_desc.f;
auto primal_block = graph.block();
value_set reverse_captures_set;
value_list reverse_captures; // Invariant: topo sorted
- auto check_uses = [&](Value *v) {
+ auto check_uses = [&](Value* v) {
for (auto use : v->uses()) {
if (use.user->owningBlock() == primal_block)
continue;
}
}
};
- for (Value * input : graph.inputs()) {
+ for (Value* input : graph.inputs()) {
check_uses(input);
}
- for (Node * node : graph.nodes()) {
- for (Value * output : node->outputs())
+ for (Node* node : graph.nodes()) {
+ for (Value* output : node->outputs())
check_uses(output);
}
return reverse_captures;
}
-// Any temporary value from the primal graphs needs to be captured for later use in the
-// reverse graph, to avoid costly recomputations. However, a lot of the nodes we have
-// in our graphs are simply constants, which are cheap to execute and replicate, and so
-// it's better to just copy them into the reverse graph, without polluting the output
-// lists unnecessarily.
+// Any temporary value from the primal graphs needs to be captured for later use
+// in the reverse graph, to avoid costly recomputations. However, a lot of the
+// nodes we have in our graphs are simply constants, which are cheap to execute
+// and replicate, and so it's better to just copy them into the reverse graph,
+// without polluting the output lists unnecessarily.
static void liftConstants(Gradient& grad_desc, ReverseDetails& rev_info) {
static const auto err = [](Value*) -> Value* {
throw std::runtime_error("unexpected input");
};
- auto & graph = *grad_desc.f;
+ auto& graph = *grad_desc.f;
Block* reverse_block = rev_info.reverse_block;
- for (Node *top_node : reverse_block->nodes()) {
- JIT_ASSERT(top_node->kind() == prim::GradOf ||
- top_node->kind() == prim::AutogradAdd ||
- top_node->kind() == prim::Undefined);
- if (top_node->kind() != prim::GradOf) continue;
- Block * grad_body = top_node->blocks().at(0);
- for (Node *node : grad_body->nodes()) {
- for (Value * input : node->inputs()) {
- if (input->node()->kind() != prim::Constant) continue;
- if (input->node()->owningBlock() == grad_body) continue;
- Node *lifted_constant = graph.createClone(input->node(), err);
+ for (Node* top_node : reverse_block->nodes()) {
+ JIT_ASSERT(
+ top_node->kind() == prim::GradOf ||
+ top_node->kind() == prim::AutogradAdd ||
+ top_node->kind() == prim::Undefined);
+ if (top_node->kind() != prim::GradOf)
+ continue;
+ Block* grad_body = top_node->blocks().at(0);
+ for (Node* node : grad_body->nodes()) {
+ for (Value* input : node->inputs()) {
+ if (input->node()->kind() != prim::Constant)
+ continue;
+ if (input->node()->owningBlock() == grad_body)
+ continue;
+ Node* lifted_constant = graph.createClone(input->node(), err);
reverse_block->prependNode(lifted_constant);
node->replaceInputWith(input, lifted_constant->output());
}
}
}
-static void deduplicateSizeCaptures(Gradient& grad_desc, ReverseDetails& rev_info) {
- Block * primal_block = grad_desc.f->block();
- const auto usedOnlyInReverse = [primal_block](Value * v) {
- const auto & uses = v->uses();
- return std::all_of(uses.begin(), uses.end(),
- [primal_block](const Use& u) { return u.user->owningBlock() != primal_block; });
+static void deduplicateSizeCaptures(
+ Gradient& grad_desc,
+ ReverseDetails& rev_info) {
+ Block* primal_block = grad_desc.f->block();
+ const auto usedOnlyInReverse = [primal_block](Value* v) {
+ const auto& uses = v->uses();
+ return std::all_of(uses.begin(), uses.end(), [primal_block](const Use& u) {
+ return u.user->owningBlock() != primal_block;
+ });
};
auto captures = getReverseCaptures(grad_desc);
- value_set capture_set (captures.begin(), captures.end());
- for (Value * capture : captures) {
- Node * node = capture->node();
+ value_set capture_set(captures.begin(), captures.end());
+ for (Value* capture : captures) {
+ Node* node = capture->node();
if (!node->matches("aten::size(Tensor self) -> int[]")) {
continue;
}
if (usedOnlyInReverse(capture) && capture_set.count(node->input())) {
- WithInsertPoint insert_guard { *rev_info.reverse_block->nodes().begin() };
+ WithInsertPoint insert_guard{*rev_info.reverse_block->nodes().begin()};
capture->replaceAllUsesWith(SymbolicVariable(node->input()).size());
node->destroy();
}
}
static void Optimize(Gradient& grad_desc, ReverseDetails& rev_info) {
- // TODO: we are sometimes emitting expressions like SumToSize(SumToSize(x, s1), s2),
- // which are equivalent to SumToSize(x, s2), and could save us some captures, but I'm
- // not 100% sure how to optimize this at this stage, since we don't know which
- // GradOf blocks will be stitched together to form the derivative. I guess a smart
- // analysis could implement this, but I didn't have time before the 1.0 release,
- // so I put this only as a peephole optimization.
+ // TODO: we are sometimes emitting expressions like SumToSize(SumToSize(x,
+ // s1), s2), which are equivalent to SumToSize(x, s2), and could save us some
+ // captures, but I'm not 100% sure how to optimize this at this stage, since
+ // we don't know which GradOf blocks will be stitched together to form the
+ // derivative. I guess a smart analysis could implement this, but I didn't
+ // have time before the 1.0 release, so I put this only as a peephole
+ // optimization.
liftConstants(grad_desc, rev_info);
// We generally add a lot of aten::size calls (for derivatives of broadcasting
- // operators), and they often end up duplicated, and would get captured multiple
- // times. Make sure we deduplicate them before lifting.
+ // operators), and they often end up duplicated, and would get captured
+ // multiple times. Make sure we deduplicate them before lifting.
EliminateCommonSubexpression(grad_desc.f);
deduplicateSizeCaptures(grad_desc, rev_info);
eliminateDeadCode(rev_info);
// All intermediates needed in the second stage are added to
// outputs of f, and taken as inputs in df. For a more
// detailed description see Note [Gradient graphs] in autodiff.h.
-// This function also initializes the fields in grad_desc that were undefined after
-// `addReverseInline` (and extends `df_input_vjps` with vjps for captured temporaries).
+// This function also initializes the fields in grad_desc that were undefined
+// after `addReverseInline` (and extends `df_input_vjps` with vjps for captured
+// temporaries).
static void lambdaLiftReverse(Gradient& grad_desc, ReverseDetails& rev_info) {
- auto & graph = *grad_desc.f;
+ auto& graph = *grad_desc.f;
auto primal_block = graph.block();
auto reverse_block = rev_info.reverse_block;
std::unordered_map<Value*, size_t> orig_primal_outputs_idx;
std::unordered_map<Value*, size_t> orig_primal_inputs_idx;
- // NOTE: we use emplace to avoid replacing an existing index if an output is repeated
+ // NOTE: we use emplace to avoid replacing an existing index if an output is
+ // repeated
for (size_t i = 0, num_outputs = graph.outputs().size(); i < num_outputs; ++i)
orig_primal_outputs_idx.emplace(graph.outputs()[i], i);
for (size_t i = 0, num_inputs = graph.inputs().size(); i < num_inputs; ++i)
orig_primal_inputs_idx[graph.inputs()[i]] = i;
// NB: reverse_captures are already deduplicated, and in topo order
- for (Value * capture_val : reverse_captures) {
+ for (Value* capture_val : reverse_captures) {
// If it's already an output we don't have to add anything,
// but register the fact that it needs to be captured.
if (orig_primal_outputs_idx.count(capture_val) > 0) {
- grad_desc.df_input_captured_outputs.push_back(orig_primal_outputs_idx[capture_val]);
- // If it's an input, we could add it as an output but in fact it's
- // more efficient to use a special kind of capture.
+ grad_desc.df_input_captured_outputs.push_back(
+ orig_primal_outputs_idx[capture_val]);
+ // If it's an input, we could add it as an output but in fact it's
+ // more efficient to use a special kind of capture.
} else if (orig_primal_inputs_idx.count(capture_val) > 0) {
- grad_desc.df_input_captured_inputs.push_back(orig_primal_inputs_idx.at(capture_val));
- // Otherwise it's just a regular intermediate value that we need to add as an output
+ grad_desc.df_input_captured_inputs.push_back(
+ orig_primal_inputs_idx.at(capture_val));
+ // Otherwise it's just a regular intermediate value that we need to add as
+ // an output
} else {
- // we need to create a new temporary output for this capture because it wasn't availiable.
+ // we need to create a new temporary output for this capture because it
+ // wasn't availiable.
graph.registerOutput(capture_val);
- grad_desc.df_input_captured_outputs.emplace_back(graph.outputs().size() - 1);
+ grad_desc.df_input_captured_outputs.emplace_back(
+ graph.outputs().size() - 1);
}
}
// -- Add VJPs for temporaries, adjust df_input_vjps -------------------------
- // NB [possible optimization]: use the newly added vjp input as soon as the first
- // vjp for that value is generated, to reduce the lifespan of this input
+ // NB [possible optimization]: use the newly added vjp input as soon as the
+ // first vjp for that value is generated, to reduce the lifespan of this input
// (currently we add it to the final vjp after all adds).
for (size_t i = grad_desc.f_real_outputs; i < graph.outputs().size(); ++i) {
- Value * tmp = graph.outputs().at(i);
+ Value* tmp = graph.outputs().at(i);
// Add VJP inputs only for intermediates that actually required grad.
- // Note that we check the contents of the grad_map instead of tmp->requires_grad(),
- // becuase it's actually a more faithful source. tmp->requires_grad() is really an
- // overapproximation (i.e. it can have false positives), while the gradients we will
- // emit for this value can get DCE-d in the optimization pass (because it has no
- // influence on the real f's outputs that we differentiate).
- if (rev_info.grad_map.count(tmp) == 0) continue;
- Value * tmp_vjp_in = reverse_block->addInput()->setType(tmp->type());
- Value * tmp_vjp_prev = rev_info.grad_map.at(tmp);
- // This is quite weird because we can't first make a sum and then replace all uses
- // of tmp_vjp_prev (that would replace its use in the sum too!), so we create an
- // incorrect sum that doesn't use prev vjp, replace uses, and fix the sum.
- Value * new_vjp = createAutogradAdd(tmp_vjp_in, tmp_vjp_in);
+ // Note that we check the contents of the grad_map instead of
+ // tmp->requires_grad(), becuase it's actually a more faithful source.
+ // tmp->requires_grad() is really an overapproximation (i.e. it can have
+ // false positives), while the gradients we will emit for this value can get
+ // DCE-d in the optimization pass (because it has no influence on the real
+ // f's outputs that we differentiate).
+ if (rev_info.grad_map.count(tmp) == 0)
+ continue;
+ Value* tmp_vjp_in = reverse_block->addInput()->setType(tmp->type());
+ Value* tmp_vjp_prev = rev_info.grad_map.at(tmp);
+ // This is quite weird because we can't first make a sum and then replace
+ // all uses of tmp_vjp_prev (that would replace its use in the sum too!), so
+ // we create an incorrect sum that doesn't use prev vjp, replace uses, and
+ // fix the sum.
+ Value* new_vjp = createAutogradAdd(tmp_vjp_in, tmp_vjp_in);
new_vjp->node()->moveAfter(tmp_vjp_prev->node());
tmp_vjp_prev->replaceAllUsesWith(new_vjp);
new_vjp->node()->replaceInput(1, tmp_vjp_prev);
// construct a map from captured 'value' to the index in the input list
// used to extract this block into its own function
std::unordered_map<Value*, size_t> capture_to_formal_index;
- const auto & add_capture = [&](Value * captured) {
+ const auto& add_capture = [&](Value* captured) {
capture_to_formal_index[captured] = reverse_block->inputs().size();
reverse_block->addInput()->copyMetadata(captured);
};
- for(auto & offset : grad_desc.df_input_captured_inputs)
+ for (auto& offset : grad_desc.df_input_captured_inputs)
add_capture(graph.inputs()[offset]);
- for(auto & offset : grad_desc.df_input_captured_outputs)
+ for (auto& offset : grad_desc.df_input_captured_outputs)
add_capture(graph.outputs()[offset]);
grad_desc.df = std::make_shared<Graph>();
reverse_block->owningNode()->destroy();
}
-
Gradient differentiate(std::shared_ptr<Graph>& graph) {
Gradient grad_desc;
// Take ownership of the graph
- JIT_ASSERTM(graph.use_count() == 1,
- "differentiate will mutate and destroy the graph, so it requires "
- "graph.use_count() == 1, but found %d", graph.use_count());
+ JIT_ASSERTM(
+ graph.use_count() == 1,
+ "differentiate will mutate and destroy the graph, so it requires "
+ "graph.use_count() == 1, but found %d",
+ graph.use_count());
std::swap(graph, grad_desc.f);
// XXX: Take care when handling outputs - they can be duplicated!
return grad_desc;
}
-}}
+} // namespace jit
+} // namespace torch
#include <ATen/ATen.h>
-#include <vector>
#include <memory>
+#include <vector>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
using value_list = std::vector<Value*>;
+// clang-format off
// Example showcasing how Gradient is constructed:
//
// Let's assume we have a function f, `m` and `n` do not require grad
// df_output_vjps = {0} // i.e. connect next_edge[0] of grad_fn to x's (grad_fn, output_nr).
//
// Terminology: vjp = vector-jacobian product
+// clang-format on
struct Gradient {
explicit operator bool() const {
// Describes how to construct outputs of f from what its graph will return.
// This is necessary because some trailing outputs are intermediates produced
// only to be saved for df (and should be ignored).
- size_t f_real_outputs = 0; // initialized for safety.
+ size_t f_real_outputs = 0; // initialized for safety.
- // df inputs are split into two sections: vjps (aka grad_outputs) and captures.
- // VJPs are "seeds" for the gradient computation given for each input capture
- // of an Output kind.
- // Captures are values the need to be saved when f is run. We handle inputs
- // specially, because this allows us to avoid adding extra vjps as df inputs.
+ // df inputs are split into two sections: vjps (aka grad_outputs) and
+ // captures. VJPs are "seeds" for the gradient computation given for each
+ // input capture of an Output kind. Captures are values the need to be saved
+ // when f is run. We handle inputs specially, because this allows us to avoid
+ // adding extra vjps as df inputs.
std::vector<size_t> df_input_vjps; // Offsets into f's outputs.
// capture can come from inputs or outputs
std::vector<size_t> df_input_captured_inputs; // Offsets into f's inputs
std::vector<size_t> df_input_captured_outputs; // Offsets into f's outputs
-
// df will produce vjps for a subset of inputs of f that required grad.
- // df_output_vjps[idx] == inp_idx means that idx-th output of df produces a vjp
- // for inp_idx-th input of f.
+ // df_output_vjps[idx] == inp_idx means that idx-th output of df produces a
+ // vjp for inp_idx-th input of f.
std::vector<size_t> df_output_vjps; // Offsets into f's inputs.
// How to use gradient to implement a differentiable autograd function:
// - Use df_output_vjps to connect next_edges of grad_fn:
// for idx in df_output_vjps:
// grad_fn.add_next_edge(inputs[idx].gradient_edge())
- // - Save captures for df (care needs to be taken to use SavedVariables for inputs and
- // outputs that we will actually return)
+ // - Save captures for df (care needs to be taken to use SavedVariables for
+ // inputs and outputs that we will actually return)
// - Return outputs[:f_real_outputs]
//
// When running df:
TORCH_API Gradient differentiate(std::shared_ptr<Graph>& graph);
// can we take a derivative of this node symbolically?
-TORCH_API bool isDifferentiable(Node * n);
-TORCH_API bool isDifferentiable(Graph & g);
-TORCH_API bool isZero(Value * v);
+TORCH_API bool isDifferentiable(Node* n);
+TORCH_API bool isDifferentiable(Graph& g);
+TORCH_API bool isZero(Value* v);
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/batched/BatchTensor.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
-BatchTensor::BatchTensor(at::Tensor data, at::Tensor mask, at::Tensor dims){
- if(data.dim() != mask.dim() || mask.dim() != dims.size(0) + 1){
- throw std::runtime_error("malformed MaskedBatch with data.dim(): "
- + std::to_string(data.dim()) + ", mask.dim(): " + std::to_string(mask.dim())
- + ", dims.size(0): " + std::to_string(dims.size(0)));
+BatchTensor::BatchTensor(at::Tensor data, at::Tensor mask, at::Tensor dims) {
+ if (data.dim() != mask.dim() || mask.dim() != dims.size(0) + 1) {
+ throw std::runtime_error(
+ "malformed MaskedBatch with data.dim(): " + std::to_string(data.dim()) +
+ ", mask.dim(): " + std::to_string(mask.dim()) +
+ ", dims.size(0): " + std::to_string(dims.size(0)));
}
this->data = std::move(data);
this->mask = std::move(mask);
this->dims = std::move(dims);
}
-BatchTensor::BatchTensor(const at::Tensor& data, int64_t batch_size){
+BatchTensor::BatchTensor(const at::Tensor& data, int64_t batch_size) {
dims = at::empty(data.dim(), data.options().dtype(at::kByte));
dims.fill_(0);
std::vector<int64_t> sizes(data.dim() + 1, -1);
mask.fill_(1);
}
-BatchTensor::BatchTensor(const std::vector<at::Tensor>& datalist, at::Tensor dims) {
+BatchTensor::BatchTensor(
+ const std::vector<at::Tensor>& datalist,
+ at::Tensor dims) {
auto bs = datalist.size();
- std::vector<int64_t> sizes(dims.size(0) + 1, 0), mask_sizes(dims.size(0) + 1, 0);
+ std::vector<int64_t> sizes(dims.size(0) + 1, 0),
+ mask_sizes(dims.size(0) + 1, 0);
sizes[0] = bs;
mask_sizes[0] = bs;
- for(int64_t i = 1; i < dims.size(0) + 1; i++){
- for(const auto& x : datalist){
+ for (int64_t i = 1; i < dims.size(0) + 1; i++) {
+ for (const auto& x : datalist) {
sizes[i] = std::max(sizes[i], x.size(i));
}
mask_sizes[i] = *dims[i - 1].data<uint8_t>() ? sizes[i] : 1;
data.fill_(0);
mask = at::empty(mask_sizes, datalist[0].options().dtype(at::kByte));
mask.fill_(0);
- for(std::size_t i = 0; i < datalist.size(); i++){
+ for (std::size_t i = 0; i < datalist.size(); i++) {
auto data_item = data.narrow(0, i, 1);
auto mask_item = mask.narrow(0, i, 1);
- for(int64_t j = 0; j < dims.size(0); j++){
- if(*dims[j].data<uint8_t>()){
+ for (int64_t j = 0; j < dims.size(0); j++) {
+ if (*dims[j].data<uint8_t>()) {
data_item = data_item.narrow(j + 1, 0, datalist[i].size(j + 1));
mask_item = mask_item.narrow(j + 1, 0, datalist[i].size(j + 1));
}
std::vector<at::Tensor> BatchTensor::examples() {
std::vector<at::Tensor> result;
// calculate number of valid entries in dth dimension of data
- auto mask_sum = [](at::Tensor data, int d) -> int64_t{
+ auto mask_sum = [](at::Tensor data, int d) -> int64_t {
data = data.sum(d, /*keepdim=*/true);
- while(data.dim() >= 1)
+ while (data.dim() >= 1)
data = data[0];
return *data.data<int64_t>();
};
- for(int64_t i = 0; i < data.size(0); i++){
+ for (int64_t i = 0; i < data.size(0); i++) {
auto data_tmp = data.narrow(0, i, 1);
- for(int64_t d = 0; d < dims.size(0); d++){
- if(*dims[d].data<uint8_t>()){
+ for (int64_t d = 0; d < dims.size(0); d++) {
+ if (*dims[d].data<uint8_t>()) {
data_tmp = data_tmp.narrow(d + 1, 0, mask_sum(mask[i], d));
}
}
.def("get_dims", &BatchTensor::get_dims);
}
-}} // namespace torch::jit
+} // namespace jit
+} // namespace torch
#pragma once
+#include <ATen/ATen.h>
#include <ATen/Tensor.h>
#include <torch/csrc/jit/pybind.h>
-#include <ATen/ATen.h>
#include <iostream>
#include <vector>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
struct BatchTensor {
-public:
+ public:
BatchTensor(at::Tensor data, at::Tensor mask, at::Tensor dims);
// expand a tensor to a batchtensor given batch_size
BatchTensor(const at::Tensor& data, int64_t batch_size);
BatchTensor(const std::vector<at::Tensor>& datalist, at::Tensor dims);
- const char * toString() const {
+ const char* toString() const {
return "BatchTensor";
}
at::IntList sizes() const {
return data.dim();
}
std::vector<at::Tensor> examples();
- at::Tensor get_data(){
+ at::Tensor get_data() {
return data;
}
- at::Tensor get_mask(){
+ at::Tensor get_mask() {
return mask;
}
- at::Tensor get_dims(){
+ at::Tensor get_dims() {
return dims;
}
-public:
+ public:
// data is a Tensor whose size is the batch size in the batch dimension,
// the size of all examples in static dimensions,
- // and at least as large as the largest example in the batch in dynamic dimensions.
+ // and at least as large as the largest example in the batch in dynamic
+ // dimensions.
at::Tensor data;
// mask is a Tensor whose size is the batch size in the batch dimension,
// one in static dimensions,
- // and at least as large as the largest example in the batch in dynamic dimensions.
- // Each entry in the mask corresponds to one or more entries in the data array (singleton, i.e., static, dimensions are broadcasted),
- // with a one in the mask denoting that the corresponding data entries represent valid, meaningful data and a zero denoting that they do not.
+ // and at least as large as the largest example in the batch in dynamic
+ // dimensions. Each entry in the mask corresponds to one or more entries in
+ // the data array (singleton, i.e., static, dimensions are broadcasted), with
+ // a one in the mask denoting that the corresponding data entries represent
+ // valid, meaningful data and a zero denoting that they do not.
at::Tensor mask;
// dims is a 1-dimensional tensor with a bool for each non-batch dimension,
// representing whether that dimension is static (False) or dynamic (True).
};
void initBatchTensorBindings(PyObject* module);
-}} // namespace torch::jit
+} // namespace jit
+} // namespace torch
#define CATCH_CONFIG_PREFIX_ALL
#include <catch.hpp>
-// CATCH_REQUIRE_THROWS is not defined identically to REQUIRE_THROWS and causes warning;
-// define our own version that doesn't warn.
-#define _CATCH_REQUIRE_THROWS( ... ) INTERNAL_CATCH_THROWS( "CATCH_REQUIRE_THROWS", Catch::ResultDisposition::Normal, __VA_ARGS__ )
+// CATCH_REQUIRE_THROWS is not defined identically to REQUIRE_THROWS and causes
+// warning; define our own version that doesn't warn.
+#define _CATCH_REQUIRE_THROWS(...) \
+ INTERNAL_CATCH_THROWS( \
+ "CATCH_REQUIRE_THROWS", Catch::ResultDisposition::Normal, __VA_ARGS__)
#pragma once
+#include <sstream>
#include <string>
-#include <vector>
#include <unordered_map>
-#include <sstream>
+#include <vector>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
// A template environment is a mapping from template variable names, e.g.,
// identifier (corresponding to $identifier) to their expansions.
// in the top level environment, and then recurses into a parent
// environment if the key is not found.)
struct TemplateEnv {
- TemplateEnv()
- : parent(nullptr) {}
- TemplateEnv(TemplateEnv & parent)
- : parent(&parent) {}
+ TemplateEnv() : parent(nullptr) {}
+ TemplateEnv(TemplateEnv& parent) : parent(&parent) {}
using string_list = std::vector<std::string>;
// Add a string 'v' to the map at key 'k'.
- void s(const std::string & k, const std::string & v) {
+ void s(const std::string& k, const std::string& v) {
strings_[k] = v;
lists_.erase(k);
}
// Add a number 'v' to the map at key 'k'
- template<typename T>
- void d(const std::string & k, const T & v) {
+ template <typename T>
+ void d(const std::string& k, const T& v) {
strings_[k] = std::to_string(v);
lists_.erase(k);
}
// Retrieve the string representation of the value stored at 'k' from the map.
// Raises an exception if the key is not found.
- const std::string & s(const std::string & k) const {
- if(strings_.count(k) == 0) {
- if(parent) {
+ const std::string& s(const std::string& k) const {
+ if (strings_.count(k) == 0) {
+ if (parent) {
return parent->s(k);
}
notFound(k);
}
// Store a list of strings 'v' in the map at 'k'.
- void v(const std::string & k, const string_list & v) {
+ void v(const std::string& k, const string_list& v) {
lists_[k] = v;
strings_.erase(k);
}
// Retrieve a list of strings stored at 'k' from the map.
// Raises an exception if the key is not found.
- const string_list & v(const std::string & k) const {
- if(lists_.count(k) == 0) {
- if(parent) {
+ const string_list& v(const std::string& k) const {
+ if (lists_.count(k) == 0) {
+ if (parent) {
return parent->v(k);
}
notFound(k);
}
// Test if a string 'k' is a string (as opposed to a list.)
- bool keyIsString(const std::string & k) const {
- if(strings_.count(k) > 0)
+ bool keyIsString(const std::string& k) const {
+ if (strings_.count(k) > 0)
return true;
- if(lists_.count(k) > 0)
+ if (lists_.count(k) > 0)
return false;
- if(parent)
+ if (parent)
return parent->keyIsString(k);
notFound(k);
}
-private:
- [[ noreturn ]]
- void notFound(const std::string & k) const {
+
+ private:
+ [[noreturn]] void notFound(const std::string& k) const {
std::stringstream ss;
ss << "key not found: " << k;
throw std::logic_error(ss.str());
}
- std::unordered_map<std::string,std::string> strings_;
- std::unordered_map<std::string,string_list> lists_;
- TemplateEnv * parent;
+ std::unordered_map<std::string, std::string> strings_;
+ std::unordered_map<std::string, string_list> lists_;
+ TemplateEnv* parent;
};
/*
# if this list is not empty and ${foo,} will insert one after.
*/
struct CodeTemplate {
- /* implicit */ CodeTemplate(std::string t)
- : template_text(std::move(t)) {}
+ /* implicit */ CodeTemplate(std::string t) : template_text(std::move(t)) {}
- std::string format(const TemplateEnv & env) {
+ std::string format(const TemplateEnv& env) {
std::stringstream out;
size_t pos = 0;
size_t indent = 0;
bool all_whitespace = true;
- while(pos < template_text.size()) {
+ while (pos < template_text.size()) {
char c = template_text[pos];
- if(c == '$') {
+ if (c == '$') {
std::stringstream kss;
bool comma_before;
bool comma_after;
- size_t new_pos = parseKey(pos,kss,comma_before,comma_after);
+ size_t new_pos = parseKey(pos, kss, comma_before, comma_after);
std::string k = kss.str();
bool is_string = env.keyIsString(k);
- if(all_whitespace) {
- if(is_string)
+ if (all_whitespace) {
+ if (is_string)
emitStringWithIndents(out, indent, env.s(k));
else
emitLinesIndented(out, indent, env.v(k));
} else {
- if(is_string)
+ if (is_string)
out << env.s(k);
else
emitCommaSeparatedList(out, env.v(k), comma_before, comma_after);
pos = new_pos;
} else {
out << c;
- if(!isspace(c))
+ if (!isspace(c))
all_whitespace = false;
indent++;
- if(c == '\n') {
+ if (c == '\n') {
indent = 0;
all_whitespace = true;
}
}
return out.str();
}
-private:
+
+ private:
using string_list = std::vector<std::string>;
char charAt(size_t p) {
if (p >= template_text.size())
throw std::logic_error("EOS found in key");
return template_text[p];
}
- size_t parseKey(size_t pos, std::ostream & k, bool & comma_before, bool & comma_after) {
+ size_t parseKey(
+ size_t pos,
+ std::ostream& k,
+ bool& comma_before,
+ bool& comma_after) {
comma_before = false;
comma_after = false;
pos++;
- if(charAt(pos) == '{') {
+ if (charAt(pos) == '{') {
pos++;
- if(charAt(pos) == ',') {
+ if (charAt(pos) == ',') {
comma_before = true;
pos++;
}
pos = parseIdent(pos, k);
- if(charAt(pos) == ',') {
+ if (charAt(pos) == ',') {
comma_after = true;
pos++;
}
- if(charAt(pos) != '}')
+ if (charAt(pos) != '}')
throw std::logic_error("missing terminating '}'");
pos++;
return pos;
return parseIdent(pos, k);
}
}
- size_t parseIdent(size_t pos, std::ostream & k) {
- while(pos < template_text.size() &&
- (isalnum(template_text[pos]) || template_text[pos] == '_')) {
+ size_t parseIdent(size_t pos, std::ostream& k) {
+ while (pos < template_text.size() &&
+ (isalnum(template_text[pos]) || template_text[pos] == '_')) {
k << template_text[pos];
pos++;
}
return pos;
}
- void emitCommaSeparatedList(std::ostream & out, const string_list & strings, bool comma_before, bool comma_after) {
- if(comma_before && strings.size() > 0)
+ void emitCommaSeparatedList(
+ std::ostream& out,
+ const string_list& strings,
+ bool comma_before,
+ bool comma_after) {
+ if (comma_before && strings.size() > 0)
out << ", ";
- for(size_t i = 0; i < strings.size(); ++i) {
- if(i > 0)
+ for (size_t i = 0; i < strings.size(); ++i) {
+ if (i > 0)
out << ", ";
out << strings[i];
}
- if(comma_after && strings.size() > 0)
+ if (comma_after && strings.size() > 0)
out << ", ";
}
// These indentation functions follow the convention that they never emit
// leading or trailing newlines when the input string does not have leading
// or trailing newlines. It's the responsibility of the calling function
// to indent correctly in the context.
- void emitIndent(std::ostream & out, size_t indent) {
- for(size_t i = 0; i < indent; ++i) {
+ void emitIndent(std::ostream& out, size_t indent) {
+ for (size_t i = 0; i < indent; ++i) {
out << " ";
}
}
- void emitStringWithIndents(std::ostream & out, size_t indent, const std::string & str) {
- for(auto c : str) {
+ void emitStringWithIndents(
+ std::ostream& out,
+ size_t indent,
+ const std::string& str) {
+ for (auto c : str) {
out << c;
- if(c == '\n') {
+ if (c == '\n') {
emitIndent(out, indent);
}
}
}
- void emitLinesIndented(std::stringstream & out, size_t indent, const string_list & strings) {
- for(size_t i = 0; i < strings.size(); ++i) {
- if(i > 0)
+ void emitLinesIndented(
+ std::stringstream& out,
+ size_t indent,
+ const string_list& strings) {
+ for (size_t i = 0; i < strings.size(); ++i) {
+ if (i > 0)
emitIndent(out, indent);
- emitStringWithIndents(out,indent,strings[i]);
- if(i+1 != strings.size())
+ emitStringWithIndents(out, indent, strings[i]);
+ if (i + 1 != strings.size())
out << "\n";
}
}
std::string template_text;
};
-static inline std::string format(const std::string & fmt, TemplateEnv & env) {
+static inline std::string format(const std::string& fmt, TemplateEnv& env) {
return CodeTemplate(fmt).format(env);
}
-}}
+} // namespace jit
+} // namespace torch
+#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/jit/constants.h>
-#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/custom_operator.h>
-#include <torch/csrc/autograd/variable.h>
+#include <torch/csrc/jit/operator.h>
#include <torch/csrc/utils/functional.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
// IValue -> Constant node
Value* insertConstant(
const IValue& val,
c10::optional<SourceRange> loc,
c10::optional<ScopePtr> scope) {
- Node * n = g.create(prim::Constant);
- if(val.isTensor()) {
+ Node* n = g.create(prim::Constant);
+ if (val.isTensor()) {
at::Tensor ref = val.toTensor();
- if(!ref.defined()) {
+ if (!ref.defined()) {
n->destroy();
return g.insertNode(g.createUndefined())->output();
}
if (ref.is_variable()) {
ref = autograd::Variable(ref).data();
}
- n->output()->inferTypeFrom(ref); // note: before t_ because of std::move(ref)
+ n->output()->inferTypeFrom(
+ ref); // note: before t_ because of std::move(ref)
n->t_(attr::value, std::move(ref));
- } else if(val.isInt()) {
+ } else if (val.isInt()) {
n->i_(attr::value, val.toInt());
n->output()->setType(IntType::get());
- } else if(val.isDouble()) {
+ } else if (val.isDouble()) {
n->f_(attr::value, val.toDouble());
n->output()->setType(FloatType::get());
} else if (val.isBool()) {
n->output()->setType(BoolType::get());
} else if (val.isBoolList()) {
auto bool_list = val.toBoolList()->elements();
- n->is_(attr::value, std::vector<int64_t>(bool_list.begin(), bool_list.end()));
+ n->is_(
+ attr::value, std::vector<int64_t>(bool_list.begin(), bool_list.end()));
n->output()->setType(ListType::ofBools());
- } else if(val.isIntList()) {
+ } else if (val.isIntList()) {
n->is_(attr::value, val.toIntList()->elements());
n->output()->setType(ListType::ofInts());
- } else if(val.isTensorList()) {
- n->ts_(attr::value, fmap(val.toTensorList()->elements(), [](const at::Tensor & t) {
- return autograd::Variable(t).data();
- }));
+ } else if (val.isTensorList()) {
+ n->ts_(
+ attr::value,
+ fmap(val.toTensorList()->elements(), [](const at::Tensor& t) {
+ return autograd::Variable(t).data();
+ }));
n->output()->setType(ListType::ofTensors());
- } else if(val.isString()) {
+ } else if (val.isString()) {
n->s_(attr::value, val.toString()->string());
n->output()->setType(StringType::get());
- } else if(val.isDevice()) {
+ } else if (val.isDevice()) {
std::stringstream ss;
ss << val.toDevice();
n->s_(attr::value, ss.str());
n->output()->setType(DeviceObjType::get());
- } else if(val.isNone()) {
+ } else if (val.isNone()) {
n->destroy();
n = g.create(prim::None);
n->output()->setType(NoneType::get());
} else {
- throw constant_not_supported_error("Unsupported value kind: " + val.tagKind());
+ throw constant_not_supported_error(
+ "Unsupported value kind: " + val.tagKind());
}
- if(loc)
+ if (loc)
n->setSourceLocation(std::make_shared<SourceRange>(*loc));
- if(scope)
+ if (scope)
n->setScope(*scope);
return g.insertNode(n)->output();
}
RegisterOperators reg({
- // Implementation of constant node, computes and IValue
- Operator(
- FunctionSchema(prim::Constant, {}, {}, /*is_vararg=*/false, /*is_varret=*/true),
- [](const Node* node) -> Operation {
- TypePtr type = node->output()->type();
- if(type->isSubtypeOf(DynamicType::get())) {
- auto t = autograd::make_variable(node->t(attr::value));
- return [t](Stack& stack) {
- push(stack, t);
- return 0;
- };
- } else if (type->isSubtypeOf(BoolType::get())) {
- bool b = node->i(attr::value);
- return [b](Stack& stack) {
- push(stack, b);
- return 0;
- };
- } else if (
- type->isSubtypeOf(NumberType::get()) &&
- node->kindOf(attr::value) == AttributeKind::i) {
- auto i = node->i(attr::value);
- return [i](Stack& stack) {
- push(stack, i);
- return 0;
- };
- } else if (
- type->isSubtypeOf(NumberType::get()) &&
- node->kindOf(attr::value) == AttributeKind::f) {
- auto f = node->f(attr::value);
- return [f](Stack& stack) {
- push(stack, f);
- return 0;
- };
- } else if(type->isSubtypeOf(ListType::ofInts())) {
- const auto& is = node->is(attr::value);
- return [is](Stack& stack) {
- push(stack, is);
- return 0;
- };
- } else if(type->isSubtypeOf(ListType::ofBools())) {
- const auto& bs = node->is(attr::value);
- return [bs](Stack& stack) {
- push(stack, bs);
- return 0;
- };
- } else if(type->isSubtypeOf(ListType::ofTensors())) {
- const auto& ts = fmap(node->ts(attr::value), [](const at::Tensor & t) -> at::Tensor {
- return autograd::make_variable(t);
- });
- return [ts](Stack& stack) {
- push(stack, ts);
- return 0;
- };
- } else if (type == StringType::get()) {
- const auto& s = node->s(attr::value);
- return [s](Stack& stack) {
- push(stack, s);
- return 0;
- };
- } else if (type == DeviceObjType::get()) {
- auto d = c10::Device(node->s(attr::value));
- return [d](Stack& stack) {
- push(stack, d);
- return 0;
- };
- } else {
- std::stringstream ss;
- ss << "constant literal not supported for: " << type->str();
- throw std::runtime_error(ss.str());
- }
- }),
+ // Implementation of constant node, computes and IValue
+ Operator(
+ FunctionSchema(
+ prim::Constant,
+ {},
+ {},
+ /*is_vararg=*/false,
+ /*is_varret=*/true),
+ [](const Node* node) -> Operation {
+ TypePtr type = node->output()->type();
+ if (type->isSubtypeOf(DynamicType::get())) {
+ auto t = autograd::make_variable(node->t(attr::value));
+ return [t](Stack& stack) {
+ push(stack, t);
+ return 0;
+ };
+ } else if (type->isSubtypeOf(BoolType::get())) {
+ bool b = node->i(attr::value);
+ return [b](Stack& stack) {
+ push(stack, b);
+ return 0;
+ };
+ } else if (
+ type->isSubtypeOf(NumberType::get()) &&
+ node->kindOf(attr::value) == AttributeKind::i) {
+ auto i = node->i(attr::value);
+ return [i](Stack& stack) {
+ push(stack, i);
+ return 0;
+ };
+ } else if (
+ type->isSubtypeOf(NumberType::get()) &&
+ node->kindOf(attr::value) == AttributeKind::f) {
+ auto f = node->f(attr::value);
+ return [f](Stack& stack) {
+ push(stack, f);
+ return 0;
+ };
+ } else if (type->isSubtypeOf(ListType::ofInts())) {
+ const auto& is = node->is(attr::value);
+ return [is](Stack& stack) {
+ push(stack, is);
+ return 0;
+ };
+ } else if (type->isSubtypeOf(ListType::ofBools())) {
+ const auto& bs = node->is(attr::value);
+ return [bs](Stack& stack) {
+ push(stack, bs);
+ return 0;
+ };
+ } else if (type->isSubtypeOf(ListType::ofTensors())) {
+ const auto& ts = fmap(
+ node->ts(attr::value), [](const at::Tensor& t) -> at::Tensor {
+ return autograd::make_variable(t);
+ });
+ return [ts](Stack& stack) {
+ push(stack, ts);
+ return 0;
+ };
+ } else if (type == StringType::get()) {
+ const auto& s = node->s(attr::value);
+ return [s](Stack& stack) {
+ push(stack, s);
+ return 0;
+ };
+ } else if (type == DeviceObjType::get()) {
+ auto d = c10::Device(node->s(attr::value));
+ return [d](Stack& stack) {
+ push(stack, d);
+ return 0;
+ };
+ } else {
+ std::stringstream ss;
+ ss << "constant literal not supported for: " << type->str();
+ throw std::runtime_error(ss.str());
+ }
+ }),
});
c10::optional<IValue> toIValue(const Value* v) {
op(stack);
return stack.back();
}
-}}
+} // namespace jit
+} // namespace torch
#pragma once
+#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/ivalue.h>
#include <torch/csrc/jit/scope.h>
#include <torch/csrc/jit/source_range.h>
-#include <torch/csrc/WindowsTorchApiMacro.h>
// helpers for handling constants in the IR
// - create constant nodes from ints, floats, intlist, Tensors, and other types
// - implement primitive constant ops.
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
struct Graph;
struct Value;
// note: prefer g.insertConsant(val, loc) which does exactly the same thing
// this function is only declared/defined here because its implementation is
-// closely related to the implementation of prim::Constant that is also in constants.cpp
+// closely related to the implementation of prim::Constant that is also in
+// constants.cpp
TORCH_API Value* insertConstant(
Graph& g,
const IValue& val,
c10::optional<SourceRange> loc = c10::nullopt,
c10::optional<ScopePtr> scope = c10::nullopt);
-
//////////////////////////////////////////////////////////////////////////////////
// Helper for retrieving constants
//////////////////////////////////////////////////////////////////////////////////
// same rules as the interpreter
template <typename T>
c10::optional<T> constant_as(const Value* v) {
- if(auto ivalue = toIValue(v)) {
+ if (auto ivalue = toIValue(v)) {
return ivalue->to<T>();
}
return c10::nullopt;
}
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/utils/functional.h>
#include <torch/csrc/utils/memory.h>
-namespace torch { namespace jit { namespace detail {
+namespace torch {
+namespace jit {
+namespace detail {
// DynamicDAG is a simple directed acyclic graph that dynamically maintains a
// topological order as edges/vertices are added and removed.
// merge black nodes that are directly connected by contracting the
// edge between them while still maintaining the DAG and a topological order?
// Use contractEdge().
-// - Let's say you have a DAG where each vertex is a Node* and the edges represent
-// data dependencies. We wish to determine if adding a new Node* with certain
-// data dependencies (or moving an existing one to use new dependencies) is valid.
-// Use DynamicDAG::addEdge() to add the new data dependencies to the DAG:
-// it will either find a valid reordering of the DAG's topological order or throw
-// if the resulting DAG is invalid.
+// - Let's say you have a DAG where each vertex is a Node* and the edges
+// represent data dependencies. We wish to determine if adding a new Node*
+// with certain data dependencies (or moving an existing one to use new
+// dependencies) is valid. Use DynamicDAG::addEdge() to add the new data
+// dependencies to the DAG: it will either find a valid reordering of the
+// DAG's topological order or throw if the resulting DAG is invalid.
//
// The implementation is based off of the PK algorithm in the following paper:
// "A Dynamic Topsort Algorithm for Directed Acyclic Graphs"
// https://www.doc.ic.ac.uk/~phjk/Publications/DynamicTopoSortAlg-JEA-07.pdf
// It is summarized in [Edge addition] (see DynamicDAG<T>::addEdge)
-template <typename T> struct Vertex;
-template <typename T> struct DynamicDAG;
-template <typename T> using vertex_list = std::vector<Vertex<T>*>;
-template <typename T> using unique_vertex = std::unique_ptr<Vertex<T>>;
+template <typename T>
+struct Vertex;
+template <typename T>
+struct DynamicDAG;
+template <typename T>
+using vertex_list = std::vector<Vertex<T>*>;
+template <typename T>
+using unique_vertex = std::unique_ptr<Vertex<T>>;
-enum class DFSDirection {forward, backward};
+enum class DFSDirection { forward, backward };
// Used to represent adjacency lists in DynamicDAG.
// Has set semantics: stores distinct elements.
return a->ord < b->ord;
});
}
- size_t size() const { return data_.size(); }
- iterator begin() { return data_.begin(); }
- iterator end() { return data_.end(); }
- reverse_iterator rbegin() { return data_.rbegin(); }
- reverse_iterator rend() { return data_.rend(); }
+ size_t size() const {
+ return data_.size();
+ }
+ iterator begin() {
+ return data_.begin();
+ }
+ iterator end() {
+ return data_.end();
+ }
+ reverse_iterator rbegin() {
+ return data_.rbegin();
+ }
+ reverse_iterator rend() {
+ return data_.rend();
+ }
private:
std::vector<Vertex<T>*> data_;
});
}
- const vertex_list<T>& vector() { return data_; }
+ const vertex_list<T>& vector() {
+ return data_;
+ }
private:
vertex_list<T> data_;
template <typename T>
struct Vertex {
- Vertex(size_t ord, T datum)
- : ord(ord), visited_(false) { data.push_back(datum); }
+ Vertex(size_t ord, T datum) : ord(ord), visited_(false) {
+ data.push_back(datum);
+ }
std::vector<T> data;
size_t ord; // unique topological index
std::string toString();
- vertex_set<T>& in_edges() { return edges_.in_edges; }
- vertex_set<T>& out_edges() { return edges_.out_edges; }
- IOEdges<T>&& move_edges() { return std::move(edges_); }
+ vertex_set<T>& in_edges() {
+ return edges_.in_edges;
+ }
+ vertex_set<T>& out_edges() {
+ return edges_.out_edges;
+ }
+ IOEdges<T>&& move_edges() {
+ return std::move(edges_);
+ }
- bool visited() { return visited_; }
+ bool visited() {
+ return visited_;
+ }
-private:
+ private:
IOEdges<T> edges_;
friend visited_list<T>;
// max_size() >= the number of live vertices.
// for all vertices v, v.ord < max_size()
- size_t max_size() const { return vertices_.size(); };
+ size_t max_size() const {
+ return vertices_.size();
+ };
c10::optional<Vertex<T>*> at(size_t ord) const;
std::string toString();
// O(vertices_.size()). Used for testing, don't call this often.
template <typename T>
size_t DynamicDAG<T>::debugNumVertices() const {
- return std::count_if(vertices_.begin(), vertices_.end(),
- [](const unique_vertex<T>& v) {
- if (v) return true;
+ return std::count_if(
+ vertices_.begin(), vertices_.end(), [](const unique_vertex<T>& v) {
+ if (v)
+ return true;
return false;
});
}
void DynamicDAG<T>::debugCheckInvariants() {
for (size_t ord = 0; ord < vertices_.size(); ++ord) {
const auto& vertex = vertices_.at(ord);
- if (!vertex) continue;
+ if (!vertex)
+ continue;
AT_ASSERTM(vertex->ord == ord, toString());
for (auto* v : vertex->in_edges()) {
*
* Assume we are adding an edge x -> y and that ord(x) > ord(y).
* First, if there is a path y ----> x through some other vertices, then this
- * edge addition would create a cycle. Figure this out via DFS and throw if necessary.
+ * edge addition would create a cycle. Figure this out via DFS and throw if
+ * necessary.
*
* Now, consider the set of all vertices v such that ord(x) > ord(v) > ord(y).
* Call this set the affected region (AR) -- these are the only vertices we
* need to consider for reordering to make the resulting graph valid.
*
- * Find all children of y (through DFS) in AR (call this set deltaF and add y to it)
- * Find all parents of x in AR (call this set deltaB and add x to it).
+ * Find all children of y (through DFS) in AR (call this set deltaF and add y to
+ * it) Find all parents of x in AR (call this set deltaB and add x to it).
*
* Move y and all the children of y to after x and all the parents of x. The
* result topological ordering is valid.
* deltaB (sorted) = {c(2), d(4), x(6)}. deltaB ords = { 2, 4, 6 }
* deltaF (sorted) = {y(1), a(3), b(5)}. deltaF ords = { 1, 3, 5 }
*
- * 2) append the two lists: the result is the order we want these vertices to have.
+ * 2) append the two lists: the result is the order we want these vertices to
+ * have.
* L = {c(2), d(4), x(6), y(1), a(3), b(5)}.
*
* 3) Merge the sorted ords: R = { 1, 2, 3, 4, 5, 6 }.
*
* [Analysis]
* This is O(|AR| log |AR|). |AR| is equal to ord(consumer) - ord(producer).
- * AR is the "affected region": { v s.t. ord(v) in [ord(producer), ord(consumer)] }
- * consisting of the only vertices that can possibly be moved around due to this
- * edge addition.
+ * AR is the "affected region": { v s.t. ord(v) in [ord(producer),
+ * ord(consumer)] } consisting of the only vertices that can possibly be moved
+ * around due to this edge addition.
*
* NB: Pearce and Kelly give a complexity bound of <<delta>> where
* delta = union(deltaF, deltaB) and <<S>> on a set S is
void DynamicDAG<T>::addEdge(Vertex<T>* producer, Vertex<T>* consumer) {
JIT_ASSERT(producer != consumer);
- // NB: DynamicDAG is a simple graph. If an edge exists already, don't do anything.
+ // NB: DynamicDAG is a simple graph. If an edge exists already, don't do
+ // anything.
bool is_distinct = producer->out_edges().insert(consumer);
- if (!is_distinct) return;
+ if (!is_distinct)
+ return;
is_distinct = consumer->in_edges().insert(producer);
JIT_ASSERT(is_distinct);
visited_list<T> deltaF;
visited_list<T> deltaB;
- // Search for vertices that are reachable from consumer that have a now incorrect
- // topological ordering.
- if (dfsSearch(DFSDirection::forward, consumer, producer,
- /*bound=*/producer->ord, deltaF)) {
+ // Search for vertices that are reachable from consumer that have a now
+ // incorrect topological ordering.
+ if (dfsSearch(
+ DFSDirection::forward,
+ consumer,
+ producer,
+ /*bound=*/producer->ord,
+ deltaF)) {
// Path found! This means there's a cycle.
AT_ERROR("Cycle detected while trying to add edge.");
}
// Search for vertices that can reach producer that have a now incorrect
// topological ordering
- JIT_ASSERT(!dfsSearch(DFSDirection::backward, producer, consumer,
- /*bound=*/consumer->ord, deltaB));
+ JIT_ASSERT(!dfsSearch(
+ DFSDirection::backward,
+ producer,
+ consumer,
+ /*bound=*/consumer->ord,
+ deltaB));
// Reorder the vertices that are reachable from consumer to occur BEFORE
// the vertices that can reach producer.
// These are the only vertices that can possibly be moved around
// during edge contraction.
//
-// contractEdge is O(|AR| log |AR| * min(|out_edges(producer)|, |in_edges(consumer)|))
+// contractEdge is O(|AR| log |AR| * min(|out_edges(producer)|,
+// |in_edges(consumer)|))
template <typename T>
bool DynamicDAG<T>::contractEdge(Vertex<T>* producer, Vertex<T>* consumer) {
JIT_ASSERT(producer != consumer);
}
template <typename T>
-void DynamicDAG<T>::mergeProducerIntoConsumer(Vertex<T>* producer, Vertex<T>* consumer) {
+void DynamicDAG<T>::mergeProducerIntoConsumer(
+ Vertex<T>* producer,
+ Vertex<T>* consumer) {
// Optimization: we want to concat lists [producer.data, consumer.data].
// Instead of inserting into the beginning of consumer.data, do a swap.
- producer->data.insert(producer->data.end(), consumer->data.begin(), consumer->data.end());
+ producer->data.insert(
+ producer->data.end(), consumer->data.begin(), consumer->data.end());
std::swap(consumer->data, producer->data);
auto edges = removeVertex(producer);
}
template <typename T>
-void DynamicDAG<T>::mergeConsumerIntoProducer(Vertex<T>* producer, Vertex<T>* consumer) {
- producer->data.insert(producer->data.end(), consumer->data.begin(), consumer->data.end());
+void DynamicDAG<T>::mergeConsumerIntoProducer(
+ Vertex<T>* producer,
+ Vertex<T>* consumer) {
+ producer->data.insert(
+ producer->data.end(), consumer->data.begin(), consumer->data.end());
auto edges = removeVertex(consumer);
for (auto* parent : edges.in_edges) {
addEdge(parent, producer);
}
-
}
template <typename T>
-bool DynamicDAG<T>::contractionProducesCycle(Vertex<T>* producer, Vertex<T>* consumer) {
+bool DynamicDAG<T>::contractionProducesCycle(
+ Vertex<T>* producer,
+ Vertex<T>* consumer) {
visited_list<T> visited;
// If there are multiple paths from producer to consumer then contracting
// producer -> consumer edge.
size_t upper_bound = consumer->ord;
for (auto* child : producer->out_edges()) {
- if (child == consumer) continue;
- if (child->visited()) continue; // already visited by dfs
- if (dfsSearch(DFSDirection::forward, child, consumer, upper_bound, visited)) {
+ if (child == consumer)
+ continue;
+ if (child->visited())
+ continue; // already visited by dfs
+ if (dfsSearch(
+ DFSDirection::forward, child, consumer, upper_bound, visited)) {
return true;
}
}
return false;
}
-
-static bool is_within_bound(DFSDirection direction, size_t value, size_t bound) {
+static bool is_within_bound(
+ DFSDirection direction,
+ size_t value,
+ size_t bound) {
if (direction == DFSDirection::forward) {
return value < bound; // upper bound
} else {
auto* vertex = stack.back();
stack.pop_back();
- auto& next_edges = (direction == DFSDirection::forward) ?
- vertex->out_edges() :
- vertex->in_edges();
+ auto& next_edges = (direction == DFSDirection::forward)
+ ? vertex->out_edges()
+ : vertex->in_edges();
for (Vertex<T>* next : next_edges) {
if (next == end) {
return false;
}
-
// Reorder deltaB vertices to occur before deltaF vertices.
template <typename T>
void DynamicDAG<T>::reorder(visited_list<T> deltaF, visited_list<T> deltaB) {
}
// Sort the ords by merging two already sorted lists into a large sorted list.
- // input (example): deltaB = { v(1), v(4), v(7) } , deltaF = { v(0), v(2), v(5) }.
+ // input (example): deltaB = { v(1), v(4), v(7) } ,
+ // deltaF = { v(0), v(2), v(5) }.
// output: { 0, 1, 2, 4, 5, 7 }.
std::vector<size_t> gathered_ords;
gathered_ords.reserve(num_affected);
for (const auto* v : deltaF_) {
gathered_ords.push_back(v->ord);
}
- std::inplace_merge(gathered_ords.begin(), gathered_ords.begin() + middle, gathered_ords.end());
+ std::inplace_merge(
+ gathered_ords.begin(),
+ gathered_ords.begin() + middle,
+ gathered_ords.end());
// Return the vertices back into the vertices_ storage.
for (size_t i = 0; i < num_affected; ++i) {
ss << " " << d;
}
}
- ss << "} ("<< ord << ") -> [";
+ ss << "} (" << ord << ") -> [";
for (auto* c : out_edges()) {
ss << c->ord << " ";
}
return ss.str();
}
-}}}
+} // namespace detail
+} // namespace jit
+} // namespace torch
#include <google/protobuf/util/json_util.h>
#include <google/protobuf/util/type_resolver_util.h>
-#include <torch/csrc/jit/export.h>
#include <torch/csrc/autograd/symbolic.h>
+#include <torch/csrc/jit/export.h>
#include <torch/csrc/onnx/onnx.h>
-#include <torch/csrc/utils/functional.h>
#include <torch/csrc/jit/assertions.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/python_print.h>
-
+#include <torch/csrc/utils/functional.h>
#include <caffe2/core/types.h>
#include <caffe2/proto/caffe2_pb.h>
#include <string>
#include <vector>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
namespace {
namespace onnx_torch = ::torch::onnx;
return ss.str();
}
-void validateBlock(Block *b, onnx_torch::OperatorExportTypes operator_export_type) {
+void validateBlock(
+ Block* b,
+ onnx_torch::OperatorExportTypes operator_export_type) {
for (auto node : b->nodes()) {
- for (Block *sub_block : node->blocks()) {
+ for (Block* sub_block : node->blocks()) {
validateBlock(sub_block, operator_export_type);
}
// Macro'ed so we get a marginally better line number on failed export
-#define FAIL_EXPORT(name) \
- throw std::runtime_error(std::string("ONNX export failed: ") + name + "\n\nGraph we tried to export:\n" + b->owningGraph()->toString());
+#define FAIL_EXPORT(name) \
+ throw std::runtime_error( \
+ std::string("ONNX export failed: ") + name + \
+ "\n\nGraph we tried to export:\n" + b->owningGraph()->toString());
IR_IF(node, PythonOp)
- auto py_node = static_cast<torch::jit::PythonOp*>(value);
- FAIL_EXPORT(
- "Couldn't export Python operator " + py_node->name() +
- "\n\nDefined at:\n" + getNodeStackTraceString(node))
+ auto py_node = static_cast<torch::jit::PythonOp*>(value);
+ FAIL_EXPORT(
+ "Couldn't export Python operator " + py_node->name() +
+ "\n\nDefined at:\n" + getNodeStackTraceString(node))
IR_ELSE()
- // Special error messages for certain types of operators
- if (node->kind() == aten::expand) {
- if (operator_export_type == onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK) {
- WithInsertPoint guard(node);
- auto* new_node = b->owningGraph()->insertNode(
- b->owningGraph()->create(Symbol(::torch::jit::onnx::ATen), node->inputs(), node->outputs().size()));
- for (size_t i = 0; i < node->outputs().size(); ++i) {
- node->output(i)->replaceAllUsesWith(new_node->output(i));
- }
- new_node->s_(Symbol::fromQualString("attr::operator"), "expand");
+ // Special error messages for certain types of operators
+ if (node->kind() == aten::expand) {
+ if (operator_export_type ==
+ onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK) {
+ WithInsertPoint guard(node);
+ auto* new_node = b->owningGraph()->insertNode(b->owningGraph()->create(
+ Symbol(::torch::jit::onnx::ATen),
+ node->inputs(),
+ node->outputs().size()));
+ for (size_t i = 0; i < node->outputs().size(); ++i) {
+ node->output(i)->replaceAllUsesWith(new_node->output(i));
}
+ new_node->s_(Symbol::fromQualString("attr::operator"), "expand");
}
- if (node->kind() == prim::PackPadded || node->kind() == prim::PadPacked) {
- FAIL_EXPORT(
- "Cannot export individual pack_padded_sequence or pad_packed_sequence; these operations must occur in pairs.\n\nUsage of this operation occurred at:\n" +
- getNodeStackTraceString(node));
- }
- bool is_aten_enabled = operator_export_type ==
- onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK ||
- operator_export_type == onnx_torch::OperatorExportTypes::ONNX_ATEN;
- if (!node->kind().is_onnx() && !is_aten_enabled &&
- node->kind() != prim::Undefined) {
- FAIL_EXPORT(
- "Couldn't export operator " + node->kind().toDisplayString() + "\n\nDefined at:\n" +
- getNodeStackTraceString(node));
- }
+ }
+ if (node->kind() == prim::PackPadded || node->kind() == prim::PadPacked) {
+ FAIL_EXPORT(
+ "Cannot export individual pack_padded_sequence or pad_packed_sequence; these operations must occur in pairs.\n\nUsage of this operation occurred at:\n" +
+ getNodeStackTraceString(node));
+ }
+ bool is_aten_enabled = operator_export_type ==
+ onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK ||
+ operator_export_type == onnx_torch::OperatorExportTypes::ONNX_ATEN;
+ if (!node->kind().is_onnx() && !is_aten_enabled &&
+ node->kind() != prim::Undefined) {
+ FAIL_EXPORT(
+ "Couldn't export operator " + node->kind().toDisplayString() +
+ "\n\nDefined at:\n" + getNodeStackTraceString(node));
+ }
IR_END()
#undef FAIL_EXPORT
}
}
-void validateGraph(const std::shared_ptr<Graph>& graph, onnx_torch::OperatorExportTypes operator_export_type) {
+void validateGraph(
+ const std::shared_ptr<Graph>& graph,
+ onnx_torch::OperatorExportTypes operator_export_type) {
validateBlock(graph->block(), operator_export_type);
EliminateDeadCode(graph->block());
}
class EncoderBase {
public:
- EncoderBase(onnx_torch::OperatorExportTypes operator_export_type, bool strip_doc);
+ EncoderBase(
+ onnx_torch::OperatorExportTypes operator_export_type,
+ bool strip_doc);
onnx::ModelProto get_model_proto() {
return model_proto_;
}
protected:
- void EncodeGraph(onnx::GraphProto *graph_proto,
- const std::shared_ptr<Graph> &graph,
- const std::vector<at::Tensor> &initializers = {});
+ void EncodeGraph(
+ onnx::GraphProto* graph_proto,
+ const std::shared_ptr<Graph>& graph,
+ const std::vector<at::Tensor>& initializers = {});
- void EncodeBlock(onnx::GraphProto *graph_proto,
- const Block *block,
- const std::vector<at::Tensor> &initializers = {});
+ void EncodeBlock(
+ onnx::GraphProto* graph_proto,
+ const Block* block,
+ const std::vector<at::Tensor>& initializers = {});
virtual void EncodeTensor(
onnx::TensorProto* tensor_proto,
const at::Tensor& tensor,
const c10::optional<std::string> external_ref = {}) = 0;
- virtual void EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto,
- const Value* n) {};
+ virtual void EncodeIntermediateValueInfo(
+ onnx::GraphProto* graph_proto,
+ const Value* n){};
- virtual void EncodeValueInfo(onnx::GraphProto *graph_proto,
- onnx::ValueInfoProto* v,
- const Value* n);
+ virtual void EncodeValueInfo(
+ onnx::GraphProto* graph_proto,
+ onnx::ValueInfoProto* v,
+ const Value* n);
- void AddAttribute(onnx::NodeProto *node_proto, const jit::Node *node, const jit::Symbol name);
+ void AddAttribute(
+ onnx::NodeProto* node_proto,
+ const jit::Node* node,
+ const jit::Symbol name);
onnx::ModelProto model_proto_;
size_t num_blocks_;
};
onnx::TensorProto_DataType ATenTypeToOnnxType(at::ScalarType at_type) {
- switch(at_type) {
+ switch (at_type) {
case at::kDouble:
return onnx::TensorProto_DataType_DOUBLE;
case at::kFloat:
}
}
-EncoderBase::EncoderBase(onnx_torch::OperatorExportTypes operator_export_type, bool strip_doc)
+EncoderBase::EncoderBase(
+ onnx_torch::OperatorExportTypes operator_export_type,
+ bool strip_doc)
: num_blocks_(0),
operator_export_type_(operator_export_type),
strip_doc_(strip_doc) {
}
void EncoderBase::EncodeValueInfo(
- onnx::GraphProto *graph_proto,
+ onnx::GraphProto* graph_proto,
onnx::ValueInfoProto* v,
const Value* n) {
v->set_name(n->uniqueName());
}
void EncoderBase::EncodeGraph(
- onnx::GraphProto *graph_proto,
- const std::shared_ptr<Graph> &graph,
- const std::vector<at::Tensor> &initializers) {
+ onnx::GraphProto* graph_proto,
+ const std::shared_ptr<Graph>& graph,
+ const std::vector<at::Tensor>& initializers) {
EncodeBlock(graph_proto, graph->block(), initializers);
}
void EncoderBase::EncodeBlock(
- onnx::GraphProto *graph_proto, const Block *block,
- const std::vector<at::Tensor> &initializers) {
+ onnx::GraphProto* graph_proto,
+ const Block* block,
+ const std::vector<at::Tensor>& initializers) {
JIT_ASSERT(graph_proto != nullptr);
std::string block_name = "torch-jit-export";
if (num_blocks_) {
EncodeValueInfo(graph_proto, v, output);
}
for (auto node : block->nodes()) {
- bool is_raw_export = operator_export_type_ == onnx_torch::OperatorExportTypes::RAW;
+ bool is_raw_export =
+ operator_export_type_ == onnx_torch::OperatorExportTypes::RAW;
if (node->kind() == prim::Undefined && !is_raw_export) {
// Undefined nodes are used to implement optional inputs. One
// way to "not provide" an optional input is to create an
node->getSourceLocation()->highlight(ss);
p_n->set_doc_string(ss.str());
}
- for(auto input : node->inputs()) {
+ for (auto input : node->inputs()) {
if (input->node()->kind() == prim::Undefined && !is_raw_export) {
p_n->add_input("");
} else {
p_n->add_input(input->uniqueName());
}
}
- for(auto output : node->outputs()) {
+ for (auto output : node->outputs()) {
p_n->add_output(output->uniqueName());
EncodeIntermediateValueInfo(graph_proto, output);
}
if (is_raw_export) {
JIT_ASSERT(!node->kind().is_onnx());
p_n->set_domain(node->kind().domainString());
- }
- else if (operator_export_type_ == onnx_torch::OperatorExportTypes::ONNX) {
+ } else if (operator_export_type_ == onnx_torch::OperatorExportTypes::ONNX) {
JIT_ASSERT(node->kind().is_onnx());
}
p_n->set_op_type(node->kind().toUnqualString());
- for(auto attr_name : node->attributeNames()) {
+ for (auto attr_name : node->attributeNames()) {
AddAttribute(p_n, node, attr_name);
}
if (is_raw_export && node->blocks().size() > 0) {
auto num_initializers = initializers.size();
JIT_ASSERT(block->inputs().size() >= num_initializers);
size_t inputs_count = block->inputs().size() - num_initializers;
- for (auto & tensor : initializers) {
+ for (auto& tensor : initializers) {
// TODO: stop using positions to determine which initializers
// match to which inputs
std::string name = graph_proto->input(inputs_count++).name();
}
}
-void EncoderBase::AddAttribute(onnx::NodeProto *node_proto, const jit::Node *node, const jit::Symbol name) {
+void EncoderBase::AddAttribute(
+ onnx::NodeProto* node_proto,
+ const jit::Node* node,
+ const jit::Symbol name) {
auto attr = node_proto->add_attribute();
JIT_ASSERT(name.is_attr());
attr->set_name(name.toUnqualString());
- switch(node->kindOf(name)) {
+ switch (node->kindOf(name)) {
case AttributeKind::f:
attr->set_f(node->f(name));
attr->set_type(onnx::AttributeProto_AttributeType_FLOAT);
break;
case AttributeKind::fs:
attr->set_type(onnx::AttributeProto_AttributeType_FLOATS);
- for(auto & v : node->fs(name))
+ for (auto& v : node->fs(name))
attr->add_floats(v);
break;
case AttributeKind::i:
break;
case AttributeKind::is:
attr->set_type(onnx::AttributeProto_AttributeType_INTS);
- for(auto & v : node->is(name))
+ for (auto& v : node->is(name))
attr->add_ints(v);
break;
case AttributeKind::s:
break;
case AttributeKind::ss:
attr->set_type(onnx::AttributeProto_AttributeType_STRINGS);
- for(auto & v : node->ss(name))
+ for (auto& v : node->ss(name))
attr->add_strings(v);
break;
case AttributeKind::t: {
} break;
case AttributeKind::ts:
attr->set_type(onnx::AttributeProto_AttributeType_TENSORS);
- for(auto & v : node->ts(name)) {
+ for (auto& v : node->ts(name)) {
auto t = attr->add_tensors();
EncodeTensor(t, v);
}
} break;
case AttributeKind::gs:
attr->set_type(onnx::AttributeProto_AttributeType_GRAPHS);
- for(auto & v : node->gs(name)) {
+ for (auto& v : node->gs(name)) {
auto g = attr->add_graphs();
EncodeGraph(g, v);
}
}
}
-class GraphEncoder: public EncoderBase {
+class GraphEncoder : public EncoderBase {
public:
- GraphEncoder(const std::shared_ptr<Graph> &graph,
- int64_t onnx_opset_version,
- onnx_torch::OperatorExportTypes operator_export_type,
- const std::vector<at::Tensor> &initializers,
- bool defer_weight_export,
- bool strip_doc);
+ GraphEncoder(
+ const std::shared_ptr<Graph>& graph,
+ int64_t onnx_opset_version,
+ onnx_torch::OperatorExportTypes operator_export_type,
+ const std::vector<at::Tensor>& initializers,
+ bool defer_weight_export,
+ bool strip_doc);
RawDataExportMap get_raw_data_export_map() {
return raw_data_export_map_;
};
GraphEncoder::GraphEncoder(
- const std::shared_ptr<Graph> &graph,
+ const std::shared_ptr<Graph>& graph,
int64_t onnx_opset_version,
onnx_torch::OperatorExportTypes operator_export_type,
- const std::vector<at::Tensor> &initializers,
+ const std::vector<at::Tensor>& initializers,
bool defer_weight_export,
bool strip_doc)
: EncoderBase(operator_export_type, strip_doc),
onnx::TensorProto* tensor_proto,
const at::Tensor& tensor,
const c10::optional<std::string> external_ref) {
- for(auto d : tensor.sizes()) {
+ for (auto d : tensor.sizes()) {
tensor_proto->add_dims(d);
}
tensor_proto->set_data_type(ATenTypeToOnnxType(tensor.type().scalarType()));
tensor_proto->set_raw_data("__EXTERNAL");
} else {
JIT_ASSERT(t.is_contiguous());
- tensor_proto->set_raw_data(std::string(static_cast<char*>(t.data_ptr()), t.type().elementSizeInBytes() * t.numel()));
+ tensor_proto->set_raw_data(std::string(
+ static_cast<char*>(t.data_ptr()),
+ t.type().elementSizeInBytes() * t.numel()));
}
}
model_def->set_producer_version("1.0"); // TODO: set the producer version
// using appropriate function call
model_def->set_proto_version(torch::ProtoVersion::PROTO_VERSION_NEWEST);
- convertModule(module, "", writer_.archiveName(), model_def->mutable_main_module());
+ convertModule(
+ module, "", writer_.archiveName(), model_def->mutable_main_module());
writeTensorTable(model_def);
}
/* stride = */ {1})
.cpu();
AT_ASSERT(
- storage_tensor.type().elementSizeInBytes() * storage_tensor.storage().size() ==
+ storage_tensor.type().elementSizeInBytes() *
+ storage_tensor.storage().size() ==
record_size);
}
std::string name = "tensors/" + std::to_string(tensor_id);
std::stringstream filename;
filename << "code/" << module_name.str() << ".py";
std::string methods_str = methods.str();
- writer_.writeRecord(filename.str(), methods_str.c_str(), methods_str.size());
+ writer_.writeRecord(
+ filename.str(), methods_str.c_str(), methods_str.size());
record->set_key(filename.str());
}
void dump(const onnx::TensorShapeProto& shape, std::ostream& stream) {
for (int i = 0; i < shape.dim_size(); ++i) {
- auto &dim = shape.dim(i);
+ auto& dim = shape.dim(i);
if (dim.has_dim_value()) {
stream << dim.dim_value();
} else {
}
void dump(const onnx::ValueInfoProto& value_info, std::ostream& stream) {
- stream << "{name: \"" << value_info.name()
- << "\", type:";
+ stream << "{name: \"" << value_info.name() << "\", type:";
dump(value_info.type(), stream);
stream << "}";
}
void dump(const onnx::GraphProto& graph, std::ostream& stream, size_t indent);
-void dump(const onnx::AttributeProto& attr, std::ostream& stream, size_t indent) {
+void dump(
+ const onnx::AttributeProto& attr,
+ std::ostream& stream,
+ size_t indent) {
stream << "{ name: '" << attr.name() << "', type: ";
if (attr.has_f()) {
stream << "float, value: " << attr.f();
stream << "string, value: '" << attr.s() << "'";
} else if (attr.has_g()) {
stream << "graph, value:\n";
- dump(attr.g(), stream, indent+1);
+ dump(attr.g(), stream, indent + 1);
stream << nlidt(indent);
} else if (attr.has_t()) {
stream << "tensor, value:";
} else if (attr.strings_size()) {
stream << "strings, values: [";
for (int i = 0; i < attr.strings_size(); ++i)
- stream << "'" << attr.strings(i) << "'" << (i == attr.strings_size() - 1 ? "" : " ");
+ stream << "'" << attr.strings(i) << "'"
+ << (i == attr.strings_size() - 1 ? "" : " ");
stream << "]";
} else if (attr.tensors_size()) {
stream << "tensors, values: [";
} else if (attr.graphs_size()) {
stream << "graphs, values: [";
for (auto& g : attr.graphs()) {
- dump(g, stream, indent+1);
+ dump(g, stream, indent + 1);
}
stream << "]";
} else {
}
stream << "], attributes: [";
for (int i = 0; i < node.attribute_size(); ++i) {
- dump(node.attribute(i), stream, indent+1);
+ dump(node.attribute(i), stream, indent + 1);
stream << (i == node.attribute_size() - 1 ? "" : ",");
}
stream << "]}";
}
void dump(const onnx::GraphProto& graph, std::ostream& stream, size_t indent) {
- stream << idt(indent) << "GraphProto {" << nlidt(indent+1)
- << "name: \"" << graph.name() << "\"" << nlidt(indent+1)
- << "inputs: [";
+ stream << idt(indent) << "GraphProto {" << nlidt(indent + 1) << "name: \""
+ << graph.name() << "\"" << nlidt(indent + 1) << "inputs: [";
for (int i = 0; i < graph.input_size(); ++i) {
dump(graph.input(i), stream);
stream << (i == graph.input_size() - 1 ? "" : ",");
}
- stream << "]" << nlidt(indent+1)
- << "outputs: [";
+ stream << "]" << nlidt(indent + 1) << "outputs: [";
for (int i = 0; i < graph.output_size(); ++i) {
dump(graph.output(i), stream);
stream << (i == graph.output_size() - 1 ? "" : ",");
}
- stream << "]" << nlidt(indent+1)
- << "initializers: [";
+ stream << "]" << nlidt(indent + 1) << "initializers: [";
for (int i = 0; i < graph.initializer_size(); ++i) {
dump(graph.initializer(i), stream);
stream << (i == graph.initializer_size() - 1 ? "" : ",");
}
- stream << "]" << nlidt(indent+1)
- << "nodes: [" << nlidt(indent+2);
+ stream << "]" << nlidt(indent + 1) << "nodes: [" << nlidt(indent + 2);
for (int i = 0; i < graph.node_size(); ++i) {
- dump(graph.node(i), stream, indent+2);
- if (i != graph.node_size() - 1) stream << "," << nlidt(indent+2);
+ dump(graph.node(i), stream, indent + 2);
+ if (i != graph.node_size() - 1)
+ stream << "," << nlidt(indent + 2);
}
- stream << nlidt(indent+1) << "]\n" << idt(indent) << "}\n";
+ stream << nlidt(indent + 1) << "]\n" << idt(indent) << "}\n";
}
-void dump(const onnx::OperatorSetIdProto& operator_set_id, std::ostream& stream) {
+void dump(
+ const onnx::OperatorSetIdProto& operator_set_id,
+ std::ostream& stream) {
stream << "OperatorSetIdProto { domain: " << operator_set_id.domain() << "}";
}
void dump(const onnx::ModelProto& model, std::ostream& stream, size_t indent) {
- stream << idt(indent)
- << "ModelProto {" << nlidt(indent+1)
- << "producer_name: \"" << model.producer_name() << "\"" << nlidt(indent+1)
- << "domain: \"" << model.domain() << "\"" << nlidt(indent+1)
- << "doc_string: \"" << model.doc_string() << "\"";
+ stream << idt(indent) << "ModelProto {" << nlidt(indent + 1)
+ << "producer_name: \"" << model.producer_name() << "\""
+ << nlidt(indent + 1) << "domain: \"" << model.domain() << "\""
+ << nlidt(indent + 1) << "doc_string: \"" << model.doc_string() << "\"";
if (model.has_graph()) {
- stream << nlidt(indent+1) << "graph:\n";
- dump(model.graph(), stream, indent+2);
+ stream << nlidt(indent + 1) << "graph:\n";
+ dump(model.graph(), stream, indent + 2);
}
if (model.opset_import_size()) {
- stream << idt(indent+1) << "opset_import: [";
- for (auto &opset_imp : model.opset_import()) {
+ stream << idt(indent + 1) << "opset_import: [";
+ for (auto& opset_imp : model.opset_import()) {
dump(opset_imp, stream);
}
stream << "],\n";
} // namespace
std::string pretty_print_onnx(
- const std::shared_ptr<Graph> &graph,
- const std::vector<at::Tensor> &initializers,
- int64_t onnx_opset_version,
- bool defer_weight_export,
- ::torch::onnx::OperatorExportTypes operator_export_type,
- bool google_printer) {
+ const std::shared_ptr<Graph>& graph,
+ const std::vector<at::Tensor>& initializers,
+ int64_t onnx_opset_version,
+ bool defer_weight_export,
+ ::torch::onnx::OperatorExportTypes operator_export_type,
+ bool google_printer) {
auto graph_encoder = GraphEncoder(
- graph, onnx_opset_version, operator_export_type, initializers, defer_weight_export, true);
+ graph,
+ onnx_opset_version,
+ operator_export_type,
+ initializers,
+ defer_weight_export,
+ true);
if (google_printer) {
return graph_encoder.get_model_proto().DebugString();
}
// be interpretable by a ONNX-compatible framework. However, PyTorch or
// libtorch will be able to import the IR and play it back.
std::tuple<std::string, RawDataExportMap> export_onnx(
- const std::shared_ptr<Graph> &graph,
- const std::vector<at::Tensor> &initializers,
- int64_t onnx_opset_version,
- bool defer_weight_export,
- ::torch::onnx::OperatorExportTypes operator_export_type) {
+ const std::shared_ptr<Graph>& graph,
+ const std::vector<at::Tensor>& initializers,
+ int64_t onnx_opset_version,
+ bool defer_weight_export,
+ ::torch::onnx::OperatorExportTypes operator_export_type) {
auto graph_encoder = GraphEncoder(
- graph, onnx_opset_version, operator_export_type, initializers, defer_weight_export, false);
- return std::make_tuple(graph_encoder.get_model_proto().SerializeAsString(),
- graph_encoder.get_raw_data_export_map());
+ graph,
+ onnx_opset_version,
+ operator_export_type,
+ initializers,
+ defer_weight_export,
+ false);
+ return std::make_tuple(
+ graph_encoder.get_model_proto().SerializeAsString(),
+ graph_encoder.get_raw_data_export_map());
}
void ExportModule(const script::Module& module, std::ostream& out) {
serializer.serialize(module);
}
-void ExportModule(const script::Module& module, const std::string &filename) {
+void ExportModule(const script::Module& module, const std::string& filename) {
ScriptModuleSerializer serializer(filename);
serializer.serialize(module);
}
-}}
+} // namespace jit
+} // namespace torch
#include <ostream>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
// This map is used to keep track of parameters that should be exported
// externally. When `defer_weight_export` is true, the returned map contains
const std::vector<at::Tensor>& initializers,
int64_t onnx_opset_version,
bool defer_weight_export = false,
- ::torch::onnx::OperatorExportTypes operator_export_type
- = ::torch::onnx::OperatorExportTypes::ONNX);
+ ::torch::onnx::OperatorExportTypes operator_export_type =
+ ::torch::onnx::OperatorExportTypes::ONNX);
// For testing purposes
TORCH_API std::string pretty_print_onnx(
const std::shared_ptr<Graph>& graph,
- const std::vector<at::Tensor> & initializers,
+ const std::vector<at::Tensor>& initializers,
int64_t onnx_opset_version,
bool defer_weight_export,
- ::torch::onnx::OperatorExportTypes operator_export_type
- = ::torch::onnx::OperatorExportTypes::ONNX,
+ ::torch::onnx::OperatorExportTypes operator_export_type =
+ ::torch::onnx::OperatorExportTypes::ONNX,
bool google_printer = false);
-TORCH_API void ExportModule(
- const script::Module& module,
- std::ostream& out);
+TORCH_API void ExportModule(const script::Module& module, std::ostream& out);
TORCH_API void ExportModule(
const script::Module& module,
const std::string& filename);
-}}
+} // namespace jit
+} // namespace torch
#include <ATen/core/function_schema.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
-using ::c10::FunctionSchema;
using ::c10::Argument;
+using ::c10::FunctionSchema;
-}} // namespace torch::jit
+} // namespace jit
+} // namespace torch
#include <ATen/ATen.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/fuser/tensor_desc.h>
#include <torch/csrc/utils/functional.h> // fmap
#include <torch/csrc/utils/hash.h>
-#include <torch/csrc/jit/fuser/tensor_desc.h>
-#include <vector>
#include <cstdint>
+#include <vector>
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
// Describes the (runtime) arguments to a kernel.
// ArgSpecs are also used as keys to lookup instantiated kernels, so
// Note: the device to run on is included in the arg spec because kernels
// are compiled per-device.
struct TORCH_API ArgSpec {
- ArgSpec(
- at::TensorList inputs
- , const int _device)
- : descs_{fmap<TensorDesc>(inputs)}
- , hash_code_{torch::get_hash(_device, inputs.size(), descs_)}
- , device_{_device}
- { }
+ ArgSpec(at::TensorList inputs, const int _device)
+ : descs_{fmap<TensorDesc>(inputs)},
+ hash_code_{torch::get_hash(_device, inputs.size(), descs_)},
+ device_{_device} {}
// (Common) hash function
- static size_t hash(const ArgSpec& spec) { return spec.hash_code_; }
+ static size_t hash(const ArgSpec& spec) {
+ return spec.hash_code_;
+ }
// Comparators
bool operator==(const ArgSpec& other) const {
- return (
- descs_ == other.descs_
- && device_ == other.device_);
+ return (descs_ == other.descs_ && device_ == other.device_);
}
bool operator!=(const ArgSpec& spec) const {
}
// Getters
- size_t hashCode() const { return hash_code_; }
- const std::vector<TensorDesc>& descs() const { return descs_; }
- int device() const { return device_; }
+ size_t hashCode() const {
+ return hash_code_;
+ }
+ const std::vector<TensorDesc>& descs() const {
+ return descs_;
+ }
+ int device() const {
+ return device_;
+ }
-private:
+ private:
std::vector<TensorDesc> descs_;
size_t hash_code_;
int device_;
};
} // namespace fuser
-} // namespace jit
+} // namespace jit
} // namespace torch
#endif // USE_CUDA_FUSER || USE_CPU_FUSER
#include <torch/csrc/jit/fuser/codegen.h>
#include <ATen/ATen.h>
-#include <torch/csrc/jit/code_template.h>
-#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/code_template.h>
#include <torch/csrc/jit/fuser/compiler.h>
#include <torch/csrc/jit/fuser/config.h>
#include <torch/csrc/jit/fuser/interface.h>
#include <torch/csrc/jit/fuser/tensor_info.h>
+#include <torch/csrc/jit/ir.h>
#if USE_CUDA_FUSER
- #include <torch/csrc/jit/fuser/cuda/resource_strings.h>
+#include <torch/csrc/jit/fuser/cuda/resource_strings.h>
#endif
#if USE_CPU_FUSER
- #include <torch/csrc/jit/fuser/cpu/resource_strings.h>
+#include <torch/csrc/jit/fuser/cpu/resource_strings.h>
#endif
-#include <tuple>
+#include <cmath>
+#include <cstdint>
#include <iostream>
#include <sstream>
-#include <cstdint>
+#include <tuple>
#include <vector>
-#include <cmath>
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
// Template for computing the offset into the tensor to access a value
static auto dim_calc = CodeTemplate(R"(
${tensor}_offset += ${tensor}_dimIndex${d} ${times_stride};
)");
-
static std::string valueName(const Value* n) {
return "n" + std::to_string(n->unique());
}
return "half";
}
- switch(type) {
- #define DEFINE_CASE(ctype,name,_) \
- case at::ScalarType::name: return #ctype;
+ switch (type) {
+#define DEFINE_CASE(ctype, name, _) \
+ case at::ScalarType::name: \
+ return #ctype;
AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(DEFINE_CASE)
- #undef DEFINE_CASE
+#undef DEFINE_CASE
default:
throw std::runtime_error("unknown scalar type");
}
return scalarTypeName(type);
}
-
static std::string variableType(const std::shared_ptr<c10::Type>& t) {
if (t->kind() == TypeKind::IntType) {
return "int";
return calcScalarTypeName(tt->scalarType());
}
// something went wrong with the type analysis during shape propagation
- throw std::runtime_error("unknown scalar type during JIT fusion code generation");
+ throw std::runtime_error(
+ "unknown scalar type during JIT fusion code generation");
}
-static std::string typeCastedValueName(const std::shared_ptr<c10::Type>& t, const at::ScalarType outtype, const std::string& vn) {
+static std::string typeCastedValueName(
+ const std::shared_ptr<c10::Type>& t,
+ const at::ScalarType outtype,
+ const std::string& vn) {
if (t->kind() == TypeKind::IntType || t->kind() == TypeKind::BoolType) {
- if (! isIntegralType(outtype)) {
+ if (!isIntegralType(outtype)) {
return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
}
return vn;
} else if (t->kind() == TypeKind::FloatType) {
- if (! isFloatingType(outtype)) {
+ if (!isFloatingType(outtype)) {
return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
}
return vn;
return vn;
}
// something went wrong with the type analysis during shape propagation
- throw std::runtime_error("unknown scalar type during JIT fusion code generation");
+ throw std::runtime_error(
+ "unknown scalar type during JIT fusion code generation");
}
// Writes "simple mappable" ops
static std::string encodeRHS(const Node* n) {
static std::unordered_map<NodeKind, std::string> simple_map_ops = {
- // unary
- {aten::_cast_Float, "static_cast<float>(${0})"},
- {aten::abs, "fabs(${0})"},
- {aten::sigmoid, "1.f / (1.f + expf(-${0}))"},
- {aten::relu, "${0} < 0 ? 0.f : ${0} "},
- {aten::log, "logf(${0})"},
- {aten::log10, "log10f(${0})"},
- {aten::log1p, "log1pf(${0})"},
- {aten::log2, "log2f(${0})"},
- {aten::lgamma, "lgammaf(${0})"},
- {aten::exp, "expf(${0})"},
- {aten::expm1, "expm1f(${0})"},
- {aten::erf, "erff(${0})"},
- {aten::erfc, "erfcf(${0})"},
- {aten::cos, "cosf(${0})"},
- {aten::acos, "acosf(${0})"},
- {aten::cosh, "coshf(${0})"},
- {aten::sin, "sinf(${0})"},
- {aten::asin, "asinf(${0})"},
- {aten::sinh, "sinhf(${0})"},
- {aten::tan, "tanf(${0})"},
- {aten::atan, "atanf(${0})"},
- {aten::tanh, "tanhf(${0})"},
- {aten::sqrt, "sqrtf(${0})"},
- {aten::rsqrt, "rsqrtf(${0})"},
- {aten::ceil, "ceilf(${0})"},
- {aten::floor, "floorf(${0})"},
- {aten::round, "roundf(${0})"},
- {aten::trunc, "truncf(${0})"},
- {aten::frac, "fracf(${0})"},
- {aten::reciprocal, "1.f/(${0})"},
- {aten::neg, "-${0}"},
- //simple binary
- {aten::atan2, "atan2(${0}, ${1})"},
- {aten::min, "fminf(${0}, ${1})"},
- {aten::max, "fmaxf(${0}, ${1})"},
-
- //binary with other
- // TODO: some of these ops will not get generated because
- // we only work on float inputs/outputs, but they are here to record
- // that they are valid mappable ops once we handle more type
-
- {aten::__and__, "${0} && ${1}"},
- {aten::__lshift__, "${0} << ${1}"},
- {aten::__or__, "${0} || ${1}"},
- {aten::__rshift__, "${0} >> ${1}"},
- {aten::__xor__, "${0} ^ ${1}"},
- {aten::div, "${cast_0} / ${cast_1}"},
- {aten::eq, "${0} == ${1}"},
- {aten::fmod, "fmodf(${cast_0}, ${cast_1})"},
- {aten::ge, "(${0} >= ${1})"},
- {aten::gt, "${0} > ${1}"},
- {aten::le, "(${0} <= ${1})"},
- {aten::lt, "${0} < ${1}"},
- {aten::type_as, "(${cast_0})"},
- {aten::mul, "${cast_0} * ${cast_1}"},
- {aten::ne, "${0} != ${1}"},
- {aten::remainder, "remainderf(${0}, ${1})"},
- {aten::pow, "powf(${cast_0}, ${cast_1})"},
-
- //alpha
- {aten::add, "${cast_0} + ${cast_2}*${cast_1}"},
- {aten::sub, "(${cast_0} - ${cast_2}*${cast_1})"},
- {aten::rand_like, "uniform(rnd())"},
-
- // min, max
- // It may seem unusual to have the bounds as the first case below,
- // this is so that if min or max is NaN, they are "ignored"
- // and when the input is NaN, the output is, too
- {aten::clamp, "(${0}<${1}?${1}:(${0}>${2}?${2}:${0}))"},
-
- //where
- {aten::where, "(${0} ? ${1} : ${2})"},
-
- // simple derivatives
- {aten::_sigmoid_backward, "${0} * ${1} * (1.f - ${1})"},
- {aten::_tanh_backward, "${0} * (1.f - ${1} * ${1})"},
+ // unary
+ {aten::_cast_Float, "static_cast<float>(${0})"},
+ {aten::abs, "fabs(${0})"},
+ {aten::sigmoid, "1.f / (1.f + expf(-${0}))"},
+ {aten::relu, "${0} < 0 ? 0.f : ${0} "},
+ {aten::log, "logf(${0})"},
+ {aten::log10, "log10f(${0})"},
+ {aten::log1p, "log1pf(${0})"},
+ {aten::log2, "log2f(${0})"},
+ {aten::lgamma, "lgammaf(${0})"},
+ {aten::exp, "expf(${0})"},
+ {aten::expm1, "expm1f(${0})"},
+ {aten::erf, "erff(${0})"},
+ {aten::erfc, "erfcf(${0})"},
+ {aten::cos, "cosf(${0})"},
+ {aten::acos, "acosf(${0})"},
+ {aten::cosh, "coshf(${0})"},
+ {aten::sin, "sinf(${0})"},
+ {aten::asin, "asinf(${0})"},
+ {aten::sinh, "sinhf(${0})"},
+ {aten::tan, "tanf(${0})"},
+ {aten::atan, "atanf(${0})"},
+ {aten::tanh, "tanhf(${0})"},
+ {aten::sqrt, "sqrtf(${0})"},
+ {aten::rsqrt, "rsqrtf(${0})"},
+ {aten::ceil, "ceilf(${0})"},
+ {aten::floor, "floorf(${0})"},
+ {aten::round, "roundf(${0})"},
+ {aten::trunc, "truncf(${0})"},
+ {aten::frac, "fracf(${0})"},
+ {aten::reciprocal, "1.f/(${0})"},
+ {aten::neg, "-${0}"},
+ // simple binary
+ {aten::atan2, "atan2(${0}, ${1})"},
+ {aten::min, "fminf(${0}, ${1})"},
+ {aten::max, "fmaxf(${0}, ${1})"},
+
+ // binary with other
+ // TODO: some of these ops will not get generated because
+ // we only work on float inputs/outputs, but they are here to record
+ // that they are valid mappable ops once we handle more type
+
+ {aten::__and__, "${0} && ${1}"},
+ {aten::__lshift__, "${0} << ${1}"},
+ {aten::__or__, "${0} || ${1}"},
+ {aten::__rshift__, "${0} >> ${1}"},
+ {aten::__xor__, "${0} ^ ${1}"},
+ {aten::div, "${cast_0} / ${cast_1}"},
+ {aten::eq, "${0} == ${1}"},
+ {aten::fmod, "fmodf(${cast_0}, ${cast_1})"},
+ {aten::ge, "(${0} >= ${1})"},
+ {aten::gt, "${0} > ${1}"},
+ {aten::le, "(${0} <= ${1})"},
+ {aten::lt, "${0} < ${1}"},
+ {aten::type_as, "(${cast_0})"},
+ {aten::mul, "${cast_0} * ${cast_1}"},
+ {aten::ne, "${0} != ${1}"},
+ {aten::remainder, "remainderf(${0}, ${1})"},
+ {aten::pow, "powf(${cast_0}, ${cast_1})"},
+
+ // alpha
+ {aten::add, "${cast_0} + ${cast_2}*${cast_1}"},
+ {aten::sub, "(${cast_0} - ${cast_2}*${cast_1})"},
+ {aten::rand_like, "uniform(rnd())"},
+
+ // min, max
+ // It may seem unusual to have the bounds as the first case below,
+ // this is so that if min or max is NaN, they are "ignored"
+ // and when the input is NaN, the output is, too
+ {aten::clamp, "(${0}<${1}?${1}:(${0}>${2}?${2}:${0}))"},
+
+ // where
+ {aten::where, "(${0} ? ${1} : ${2})"},
+
+ // simple derivatives
+ {aten::_sigmoid_backward, "${0} * ${1} * (1.f - ${1})"},
+ {aten::_tanh_backward, "${0} * (1.f - ${1} * ${1})"},
};
if (n->kind() == prim::Constant) {
TemplateEnv env;
size_t i = 0;
- auto outtype = n->output()->type()->expect<c10::TensorType const>()->scalarType();
- for(auto in : n->inputs()) {
- // PyTorch converts (scalar) argument types to result before applying the operator
- // e.g. 1.4-torch.tensor(3) = -2
+ auto outtype =
+ n->output()->type()->expect<c10::TensorType const>()->scalarType();
+ for (auto in : n->inputs()) {
+ // PyTorch converts (scalar) argument types to result before applying the
+ // operator e.g. 1.4-torch.tensor(3) = -2
env.s(std::to_string(i), valueName(in));
- env.s(std::string("cast_")+std::to_string(i), typeCastedValueName(in->type(), outtype, valueName(in)));
+ env.s(
+ std::string("cast_") + std::to_string(i),
+ typeCastedValueName(in->type(), outtype, valueName(in)));
i++;
}
- const auto & str = simple_map_ops.at(n->kind());
+ const auto& str = simple_map_ops.at(n->kind());
return format(str, env);
}
static Node* usedInFusedChunk(const Value* input) {
auto uses = input->uses();
if (uses.size() == 1) {
- Node *user = uses[0].user;
+ Node* user = uses[0].user;
if (user->kind() == prim::ConstantChunk) {
return user;
}
}
static void emitIndexingFor(
- std::ostream& out
-, const std::string& tensor
-, const int ndim
-, const bool last_is_cont) {
+ std::ostream& out,
+ const std::string& tensor,
+ const int ndim,
+ const bool last_is_cont) {
TemplateEnv env;
- env.s("tensor",tensor);
- out << format("IndexType ${tensor}_offset = 0;\n",env);
- out << format("IndexType ${tensor}_linearIndex = linearIndex;\n",env);
+ env.s("tensor", tensor);
+ out << format("IndexType ${tensor}_offset = 0;\n", env);
+ out << format("IndexType ${tensor}_linearIndex = linearIndex;\n", env);
for (int d = ndim - 1; d >= 0; --d) {
- env.d("d",d);
- env.s("mod_sizes", d > 0 ? format("% ${tensor}.sizes[${d}]",env) : "");
- env.s("times_stride",(d < ndim - 1 || !last_is_cont) ?
- format("* ${tensor}.strides[${d}]",env) : "");
+ env.d("d", d);
+ env.s("mod_sizes", d > 0 ? format("% ${tensor}.sizes[${d}]", env) : "");
+ env.s(
+ "times_stride",
+ (d < ndim - 1 || !last_is_cont)
+ ? format("* ${tensor}.strides[${d}]", env)
+ : "");
out << dim_calc.format(env);
if (d > 0) {
- out << format("${tensor}_linearIndex /= ${tensor}.sizes[${d}];\n",env);
+ out << format("${tensor}_linearIndex /= ${tensor}.sizes[${d}];\n", env);
}
}
}
// TODO: handle cases where we need to generate > 2^32 element tensors
std::tuple<
- std::string
-, std::vector<PartitionDesc>
-, std::vector<PartitionDesc>
-, bool>
+ std::string,
+ std::vector<PartitionDesc>,
+ std::vector<PartitionDesc>,
+ bool>
generateKernel(
- const std::string& name
-, const Graph& graph
-, const std::vector<TensorDesc>& input_desc
-, const std::vector<TensorDesc>& output_desc
-, const bool use_cuda) {
+ const std::string& name,
+ const Graph& graph,
+ const std::vector<TensorDesc>& input_desc,
+ const std::vector<TensorDesc>& output_desc,
+ const bool use_cuda) {
TemplateEnv env;
env.s("kernelName", name);
- env.s("IndexType","unsigned int"); // Note: not uint32_t to avoid including cstdint
+ env.s(
+ "IndexType",
+ "unsigned int"); // Note: not uint32_t to avoid including cstdint
std::stringstream body;
std::stringstream tensorOffsets;
// Lambda for writing arguments
auto emitFormal = [&](const Value* n, const TensorDesc& desc) {
- std::string tensor = "t" + std::to_string(formals.size()); //can't be unique() because Param may be an output
+ std::string tensor =
+ "t" +
+ std::to_string(
+ formals.size()); // can't be unique() because Param may be an output
const auto nDim = desc.nDim();
- emitIndexingFor(tensorOffsets, tensor, nDim, desc.lastIsContiguous());
+ emitIndexingFor(tensorOffsets, tensor, nDim, desc.lastIsContiguous());
env.s("tensor", tensor);
- env.d("formal_index", formals.size() + 1); // + 1 because the first argument is the linearIndex
+ env.d(
+ "formal_index",
+ formals.size() +
+ 1); // + 1 because the first argument is the linearIndex
env.d("nDim", nDim);
env.s("scalar_type", scalarTypeName(desc.scalar_type));
- formals.push_back(format("TensorInfo<${scalar_type},${nDim}> ${tensor}", env));
- argument_loads.push_back(format("*static_cast<TensorInfo<${scalar_type},${nDim}>*>(args[${formal_index}])", env));
+ formals.push_back(
+ format("TensorInfo<${scalar_type},${nDim}> ${tensor}", env));
+ argument_loads.push_back(format(
+ "*static_cast<TensorInfo<${scalar_type},${nDim}>*>(args[${formal_index}])",
+ env));
};
// Writes input parameters and creates flattened inputs
std::vector<std::pair<const Value*, const TensorDesc&>> flat_inputs;
{
size_t input_index = 0;
- for(const auto& p : graph.inputs()) {
+ for (const auto& p : graph.inputs()) {
if (const Node* chunk = usedInFusedChunk(p)) {
int64_t dim = chunk->i(attr::dim);
int64_t chunks = chunk->i(attr::chunks);
} else {
const auto cat = o->node();
concat_desc.emplace_back(desc, cat->inputs().size(), cat->i(attr::dim));
- for(const auto& c : cat->inputs()) {
+ for (const auto& c : cat->inputs()) {
emitFormal(c, *concat_desc.back().subTensorDesc());
flat_output_nodes.emplace_back(c, desc);
}
if (is_half) {
JIT_ASSERT(use_cuda);
env.s(
- "access"
- , format("__half2float(t${formal}.data[t${formal}_offset])", env));
+ "access",
+ format("__half2float(t${formal}.data[t${formal}_offset])", env));
has_half_tensor = true;
} else {
env.s("access", format("t${formal}.data[t${formal}_offset]", env));
// Note: Concat and Chunk are implicitly generated
// Note: Random number generation is only supported for CUDA kernels.
for (const auto& n : graph.nodes()) {
- // Note: FusedConcat nodes work by narrowing the output Tensors before the kernel runs
- if (n->kind() == prim::FusedConcat) continue;
- if (n->kind() == prim::ConstantChunk) continue;
+ // Note: FusedConcat nodes work by narrowing the output Tensors before the
+ // kernel runs
+ if (n->kind() == prim::FusedConcat)
+ continue;
+ if (n->kind() == prim::ConstantChunk)
+ continue;
if (n->kind() == aten::rand_like) {
JIT_ASSERT(use_cuda);
has_random = true;
env.s("node", valueName(n->output()));
env.s("rhs", encodeRHS(n));
env.s("lhs_type", variableType(n->output()->type()));
- body << format("${lhs_type} ${node} = ${rhs};\n",env);
+ body << format("${lhs_type} ${node} = ${rhs};\n", env);
}
// Generates writes to output tensors
for (const auto& output : flat_output_nodes) {
const auto& o = output.first;
env.d("formal", formal_count++);
- env.s("access", format("t${formal}.data[t${formal}_offset]",env));
+ env.s("access", format("t${formal}.data[t${formal}_offset]", env));
env.s("node", valueName(o));
// Acquires and converts (if needed) outputs
const auto is_half = (output.second.scalar_type == at::ScalarType::Half);
if (is_half) {
JIT_ASSERT(use_cuda);
- body << format("${access} = __float2half(${node});\n",env);
+ body << format("${access} = __float2half(${node});\n", env);
has_half_tensor = true;
} else {
- body << format("${access} = ${node};\n",env);
+ body << format("${access} = ${node};\n", env);
}
}
- // Includes headers
- // Note: CUDA kernels support halfs and random generation, CPU kernels do not
- #if USE_CUDA_FUSER
- if (has_half_tensor) {
- env.s("HalfHeader", cuda::half_support_literal);
- } else {
- env.s("HalfHeader", "");
- }
+// Includes headers
+// Note: CUDA kernels support halfs and random generation, CPU kernels do not
+#if USE_CUDA_FUSER
+ if (has_half_tensor) {
+ env.s("HalfHeader", cuda::half_support_literal);
+ } else {
+ env.s("HalfHeader", "");
+ }
- if (has_random) {
- env.s("RandHeader", cuda::rand_support_literal);
- env.s("RandParam", cuda::rand_param);
- env.s("RandInit", cuda::rand_init);
- } else {
- env.s("RandHeader", "");
- env.s("RandParam", "");
- env.s("RandInit", "");
- }
- #endif // USE_CUDA_FUSER
+ if (has_random) {
+ env.s("RandHeader", cuda::rand_support_literal);
+ env.s("RandParam", cuda::rand_param);
+ env.s("RandInit", cuda::rand_init);
+ } else {
+ env.s("RandHeader", "");
+ env.s("RandParam", "");
+ env.s("RandInit", "");
+ }
+#endif // USE_CUDA_FUSER
// Insantiates the CUDA or CPU-specific templates
env.s("tensorOffsets", tensorOffsets.str());
env.v("argument_loads", argument_loads);
std::string code_string;
if (use_cuda) {
- #if USE_CUDA_FUSER
- env.s("type_declarations", cuda::type_declarations_template.format(env));
- code_string = cuda::cuda_compilation_unit_template.format(env);
- #else
- throw std::runtime_error("CUDA Fusion requested but not supported.");
- #endif // USE_CUDA_FUSER
+#if USE_CUDA_FUSER
+ env.s("type_declarations", cuda::type_declarations_template.format(env));
+ code_string = cuda::cuda_compilation_unit_template.format(env);
+#else
+ throw std::runtime_error("CUDA Fusion requested but not supported.");
+#endif // USE_CUDA_FUSER
} else {
- #if USE_CPU_FUSER
- env.s("type_declarations", cpu::type_declarations_template.format(env));
- code_string = cpu::cpu_compilation_unit_template.format(env);
- #else
- throw std::runtime_error("CPU Fusion requested but not supported");
- #endif // USE_CPU_FUSER
+#if USE_CPU_FUSER
+ env.s("type_declarations", cpu::type_declarations_template.format(env));
+ code_string = cpu::cpu_compilation_unit_template.format(env);
+#else
+ throw std::runtime_error("CPU Fusion requested but not supported");
+#endif // USE_CPU_FUSER
}
if (debugFuser()) {
std::cerr << "fusion code:" << code_string << std::endl;
}
- return std::make_tuple(code_string, std::move(chunk_desc), std::move(concat_desc), has_random);
+ return std::make_tuple(
+ code_string, std::move(chunk_desc), std::move(concat_desc), has_random);
}
} // namespace fuser
#if USE_CUDA_FUSER || USE_CPU_FUSER
#include <torch/csrc/WindowsTorchApiMacro.h>
-#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/fuser/arg_spec.h>
#include <torch/csrc/jit/fuser/partition_desc.h>
#include <torch/csrc/jit/fuser/tensor_desc.h>
+#include <torch/csrc/jit/ir.h>
-#include <tuple>
-#include <vector>
#include <iostream>
#include <string>
+#include <tuple>
+#include <vector>
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
// Creates a CPU or CUDA kernel for the given graph.
// Returns a tuple consisting of the generated code (as a string),
-// two vectors of PartitionDescs, the chunk and concat descriptions,
-// respectively, and a bool indicating whether the generated code
+// two vectors of PartitionDescs, the chunk and concat descriptions,
+// respectively, and a bool indicating whether the generated code
// generates random numbers.
// TODO: the partition descriptions should be generated by the executor.
TORCH_API std::tuple<
- std::string
-, std::vector<PartitionDesc>
-, std::vector<PartitionDesc>
-, bool>
+ std::string,
+ std::vector<PartitionDesc>,
+ std::vector<PartitionDesc>,
+ bool>
generateKernel(
- const std::string& name
-, const Graph& graph
-, const std::vector<TensorDesc>& input_desc
-, const std::vector<TensorDesc>& output_desc
-, const bool use_cuda);
+ const std::string& name,
+ const Graph& graph,
+ const std::vector<TensorDesc>& input_desc,
+ const std::vector<TensorDesc>& output_desc,
+ const bool use_cuda);
} // namespace fuser
} // namespace jit
#include <torch/csrc/jit/fuser/compiler.h>
#include <ATen/ATen.h>
-#include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/type.h>
-#include <torch/csrc/jit/code_template.h>
#include <torch/csrc/jit/assertions.h>
-#include <torch/csrc/jit/passes/canonicalize.h>
-#include <torch/csrc/jit/passes/shape_analysis.h>
+#include <torch/csrc/jit/code_template.h>
+#include <torch/csrc/jit/fuser/codegen.h>
#include <torch/csrc/jit/fuser/interface.h>
#include <torch/csrc/jit/fuser/kernel_cache.h>
-#include <torch/csrc/jit/fuser/codegen.h>
#include <torch/csrc/jit/fuser/tensor_desc.h>
+#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/passes/canonicalize.h>
+#include <torch/csrc/jit/passes/shape_analysis.h>
+#include <torch/csrc/jit/type.h>
#include "torch/csrc/jit/fuser/interface.h"
#if USE_CUDA_FUSER
- #include <torch/csrc/jit/fuser/cuda/fused_kernel.h>
+#include <torch/csrc/jit/fuser/cuda/fused_kernel.h>
#endif // USE_CUDA_FUSER
#if USE_CPU_FUSER
- #include <torch/csrc/jit/fuser/cpu/fused_kernel.h>
+#include <torch/csrc/jit/fuser/cpu/fused_kernel.h>
#endif // USE_CUDA_FUSER
+#include <atomic>
#include <iostream>
#include <memory>
-#include <unordered_set>
-#include <utility>
-#include <string>
-#include <atomic>
#include <sstream>
#include <stdexcept>
+#include <string>
#include <tuple>
+#include <unordered_set>
+#include <utility>
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
// Counter for number of kernels compiled, used for debugging and
// creating arbitrary kernel names.
static std::atomic<size_t> next_kernel_id{0};
static int debug_fusion{-1};
-size_t nCompiledKernels() { return next_kernel_id.load(); }
+size_t nCompiledKernels() {
+ return next_kernel_id.load();
+}
int debugFuser() {
if (debug_fusion < 0) {
static const Node* usedInFusedChunk(const Value* input) {
const auto& uses = input->uses();
if (uses.size() == 1) {
- const Node *user = uses[0].user;
+ const Node* user = uses[0].user;
if (user->kind() == prim::ConstantChunk) {
return user;
}
spec.inputChunks().reserve((spec.graph())->inputs().size());
for (const Value* input : (spec.graph())->inputs()) {
if (const Node* chunk = usedInFusedChunk(input)) {
- spec.inputChunks().emplace_back(chunk->i(attr::chunks), chunk->i(attr::dim));
+ spec.inputChunks().emplace_back(
+ chunk->i(attr::chunks), chunk->i(attr::dim));
} else {
spec.inputChunks().emplace_back(1, 0);
}
std::unordered_set<const Value*> inputs;
std::unordered_set<const Value*> seen;
while (!queue.empty()) {
- const Value* val = queue.back(); queue.pop_back();
+ const Value* val = queue.back();
+ queue.pop_back();
const Node* producer = val->node();
if (producer->kind() == prim::Param) {
inputs.insert(val);
continue;
}
for (const Value* input : producer->inputs()) {
- if (/*bool inserted = */seen.insert(input).second) {
+ if (/*bool inserted = */ seen.insert(input).second) {
queue.push_back(input);
}
}
}
static void setInputBroadcastGroups(KernelSpec& spec) {
- std::unordered_set<std::vector<int64_t>, torch::hash<std::vector<int64_t>>> broadcast_groups;
+ std::unordered_set<std::vector<int64_t>, torch::hash<std::vector<int64_t>>>
+ broadcast_groups;
for (const Value* output : (spec.graph())->outputs()) {
if (output->node()->kind() == prim::FusedConcat) {
for (const Value* concat_input : output->node()->inputs()) {
}
}
std::copy(
- broadcast_groups.begin()
- , broadcast_groups.end()
- , std::back_inserter(spec.inputBroadcastGroups()));
+ broadcast_groups.begin(),
+ broadcast_groups.end(),
+ std::back_inserter(spec.inputBroadcastGroups()));
}
// Performs "upfront" compilation where storage is known but shapes are not.
}
std::shared_ptr<FusedKernel> compileKernel(
- const KernelSpec& spec
-, const ArgSpec& arg_spec
-, const std::vector<int64_t>& map_size
-, const at::Device device) {
+ const KernelSpec& spec,
+ const ArgSpec& arg_spec,
+ const std::vector<int64_t>& map_size,
+ const at::Device device) {
const std::vector<TensorDesc>& input_desc = arg_spec.descs();
auto graph = spec.graph()->copy();
c10::optional<at::ScalarType> scalar_type;
for (size_t i = 0; i < input_desc.size(); i++) {
const auto& desc = input_desc[i];
- graph->inputs()[i]->setType(TensorType::create(desc.scalar_type, device, desc.nDim())); // TODO: nDim is bad, as it is collapsed
+ graph->inputs()[i]->setType(TensorType::create(
+ desc.scalar_type,
+ device,
+ desc.nDim())); // TODO: nDim is bad, as it is collapsed
}
PropagateInputShapes(graph);
if (output->node()->kind() == prim::FusedConcat) {
sizes.at(output->node()->i(attr::dim)) *= output->node()->inputs().size();
}
- auto scalar_type = output->type()->expect<c10::TensorType const>()->scalarType();
+ auto scalar_type =
+ output->type()->expect<c10::TensorType const>()->scalarType();
auto type = CompleteTensorType::create(scalar_type, device, sizes);
output_desc.emplace_back(std::move(type));
}
std::vector<PartitionDesc> chunk_desc;
std::vector<PartitionDesc> concat_desc;
bool has_random;
- std::tie(code, chunk_desc, concat_desc, has_random)
- = generateKernel(
- name
- , *graph
- , input_desc
- , output_desc
- , use_cuda);
+ std::tie(code, chunk_desc, concat_desc, has_random) =
+ generateKernel(name, *graph, input_desc, output_desc, use_cuda);
std::shared_ptr<FusedKernel> fused_kernel;
if (use_cuda) {
- #if USE_CUDA_FUSER
- fused_kernel = std::make_shared<cuda::FusedKernelCUDA>(
- device.index()
- , name
- , code
- , input_desc
- , output_desc
- , chunk_desc
- , concat_desc
- , has_random);
- #else
- throw std::runtime_error("CUDA Fusion is not supported on this build.");
- #endif // USE_CUDA_FUSER
+#if USE_CUDA_FUSER
+ fused_kernel = std::make_shared<cuda::FusedKernelCUDA>(
+ device.index(),
+ name,
+ code,
+ input_desc,
+ output_desc,
+ chunk_desc,
+ concat_desc,
+ has_random);
+#else
+ throw std::runtime_error("CUDA Fusion is not supported on this build.");
+#endif // USE_CUDA_FUSER
} else {
- #if USE_CPU_FUSER
- fused_kernel = std::make_shared<cpu::FusedKernelCPU>(
- name
- , code
- , input_desc
- , output_desc
- , chunk_desc
- , concat_desc
- , has_random);
- #else
- throw std::runtime_error("CPU Fusion is not supported on this build.");
- #endif // USE_CPU_FUSER
+#if USE_CPU_FUSER
+ fused_kernel = std::make_shared<cpu::FusedKernelCPU>(
+ name,
+ code,
+ input_desc,
+ output_desc,
+ chunk_desc,
+ concat_desc,
+ has_random);
+#else
+ throw std::runtime_error("CPU Fusion is not supported on this build.");
+#endif // USE_CPU_FUSER
}
return fused_kernel;
#if USE_CUDA_FUSER || USE_CPU_FUSER
#include <torch/csrc/WindowsTorchApiMacro.h>
-#include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/stack.h>
+#include <torch/csrc/jit/fuser/arg_spec.h>
#include <torch/csrc/jit/fuser/config.h>
+#include <torch/csrc/jit/fuser/fused_kernel.h>
#include <torch/csrc/jit/fuser/interface.h>
#include <torch/csrc/jit/fuser/kernel_spec.h>
-#include <torch/csrc/jit/fuser/arg_spec.h>
-#include <torch/csrc/jit/fuser/fused_kernel.h>
+#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/stack.h>
#include <cstdint>
#include <vector>
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
// Performs device-independent "upfront" compilation of the given fusion_group,
// if it has not been registered already.
// with the runtime arguments specified in ArgSpec.
// Outputs are allocated using map_size on the specified device.
TORCH_API std::shared_ptr<FusedKernel> compileKernel(
- const KernelSpec& spec
-, const ArgSpec& arg_spec
-, const std::vector<int64_t>& map_size
-, const at::Device device);
+ const KernelSpec& spec,
+ const ArgSpec& arg_spec,
+ const std::vector<int64_t>& map_size,
+ const at::Device device);
TORCH_API size_t nCompiledKernels();
#pragma once
+// clang-format off
#define USE_CPU_FUSER @USE_CPU_FUSER@
#define USE_CUDA_FUSER @USE_CUDA_FUSER@
+// clang-format on
#if USE_CPU_FUSER
#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/utils/disallow_copy.h>
#include <dlfcn.h>
-namespace torch { namespace jit { namespace fuser { namespace cpu {
+namespace torch {
+namespace jit {
+namespace fuser {
+namespace cpu {
static void* checkDL(void* x) {
if (!x) {
}
~DynamicLibrary() {
- if (!handle) return;
+ if (!handle)
+ return;
dlclose(handle);
}
-private:
+ private:
void* handle = nullptr;
};
#include <torch/csrc/jit/assertions.h>
#include <torch/csrc/jit/code_template.h>
#include <torch/csrc/jit/fuser/compiler.h>
-#include <torch/csrc/jit/fuser/cpu/temp_file.h>
#include <torch/csrc/jit/fuser/cpu/dynamic_library.h>
+#include <torch/csrc/jit/fuser/cpu/temp_file.h>
#include <torch/csrc/utils/memory.h>
-#include <sstream>
#include <cstdlib>
#include <iostream>
-#include <string>
+#include <sstream>
#include <stdexcept>
+#include <string>
-namespace torch { namespace jit { namespace fuser { namespace cpu {
+namespace torch {
+namespace jit {
+namespace fuser {
+namespace cpu {
static const std::string so_template = "/tmp/pytorch_fuserXXXXXX.so";
static const std::string cpp_template = "/tmp/pytorch_fuserXXXXXX.cpp";
// optimization can be re-enabled by tracking down the platforms where
// this error occurs and only selectively disabling it.
static const std::string compile_string =
- "\"${cxx}\" -O3 -g "
+ "\"${cxx}\" -O3 -g "
#ifndef __PPC64__
// "-march=native "
#endif
- "-std=c++11 -fPIC ${fopenmp} -shared \"${cpp_file}\" -o \"${so_file}\" -lm";
+ "-std=c++11 -fPIC ${fopenmp} -shared \"${cpp_file}\" -o \"${so_file}\" -lm";
static void runCompiler(
- const std::string& cpp_file
-, const std::string& so_file) {
+ const std::string& cpp_file,
+ const std::string& so_file) {
auto& config = getConfig();
TemplateEnv env;
env.s("cxx", config.cxx);
std::string result = format(compile_string, env);
int r = system(result.c_str());
if (config.openmp && r != 0) {
- std::cerr << "warning: pytorch jit fuser failed to compile with openmp, trying without it...\n";
+ std::cerr
+ << "warning: pytorch jit fuser failed to compile with openmp, trying without it...\n";
config.openmp = false; // disable for future compiles
return runCompiler(cpp_file, so_file);
}
JIT_ASSERTM(r == 0, "Failed to compile a fused CPU kernel");
}
-static const std::string disas_string =
- "objdump -M intel -d \"${so_file}\"";
+static const std::string disas_string = "objdump -M intel -d \"${so_file}\"";
static void disas(const std::string& so_file) {
TemplateEnv env;
env.s("so_file", so_file);
cpp_file.write(code_);
cpp_file.sync();
runCompiler(cpp_file.name(), so_file.name());
- if (debugFuser() >= 2) disas(so_file.name());
+ if (debugFuser() >= 2)
+ disas(so_file.name());
so_lib = make_unique<DynamicLibrary>(so_file.name().c_str());
- #pragma GCC diagnostic ignored "-Wpedantic"
- kernel = reinterpret_cast<void(*)(uint32_t, void**)>(so_lib->sym(name_.c_str()));
- #pragma GCC diagnostic pop
+#pragma GCC diagnostic ignored "-Wpedantic"
+ kernel =
+ reinterpret_cast<void (*)(uint32_t, void**)>(so_lib->sym(name_.c_str()));
+#pragma GCC diagnostic pop
}
} // namespace cpu
#include <ATen/ATen.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
-#include <torch/csrc/utils/disallow_copy.h>
#include <torch/csrc/jit/fuser/cpu/dynamic_library.h>
#include <torch/csrc/jit/fuser/fused_kernel.h>
+#include <torch/csrc/utils/disallow_copy.h>
-#include <string>
#include <cstdint>
#include <memory>
+#include <string>
-namespace torch { namespace jit { namespace fuser { namespace cpu {
+namespace torch {
+namespace jit {
+namespace fuser {
+namespace cpu {
// Represents a compiled CPU kernel and the metadata necessary to run it
struct TORCH_API FusedKernelCPU : public ::torch::jit::fuser::FusedKernel {
kernel(numel, arguments.data());
}
-private:
+ private:
std::unique_ptr<DynamicLibrary> so_lib;
void (*kernel)(uint32_t, void**) = nullptr;
};
#include <torch/csrc/jit/code_template.h>
-namespace torch { namespace jit { namespace fuser { namespace cpu {
+namespace torch {
+namespace jit {
+namespace fuser {
+namespace cpu {
-/*with type_as not checking type of its input, a fusion group can have non-fp32 tensor as input.
-Correct code for this case is generated, however, nvrtc does not know how to handle int*_t integer types,
-so typedefs help it handle those cases*/
+/*with type_as not checking type of its input, a fusion group can have non-fp32
+tensor as input. Correct code for this case is generated, however, nvrtc does
+not know how to handle int*_t integer types, so typedefs help it handle those
+cases*/
static auto type_declarations_template = CodeTemplate(R"(
)");
static auto cpu_compilation_unit_template = CodeTemplate(R"(
+#include <math.h>
#include <cstddef>
#include <cstdint>
-#include <math.h>
template <typename scalar_t>
scalar_t rsqrtf(scalar_t x) {
} // namespace cpu
} // namespace fuser
-} // namespace jit
+} // namespace jit
} // namespace torch
#endif // USE_CPU_FUSER
#include <ATen/ATen.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
-#include <torch/csrc/utils/disallow_copy.h>
#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/utils/disallow_copy.h>
#include <unistd.h>
#include <string>
#include <vector>
-namespace torch { namespace jit { namespace fuser { namespace cpu {
+namespace torch {
+namespace jit {
+namespace fuser {
+namespace cpu {
struct TempFile {
TH_DISALLOW_COPY_AND_ASSIGN(TempFile);
fflush(file_);
}
- void write(const std::string & str) {
+ void write(const std::string& str) {
size_t result = fwrite(str.c_str(), 1, str.size(), file_);
JIT_ASSERT(str.size() == result);
}
- FILE* file() {
+ FILE* file() {
return file_;
}
fclose(file_);
}
}
-private:
+
+ private:
FILE* file_ = nullptr;
std::string name_;
};
} // namespace cpu
} // namespace fuser
-} // namespace jit
+} // namespace jit
} // namespace torch
#endif // USE_CPU_FUSER
#include <torch/csrc/jit/fuser/cuda/fused_kernel.h>
#include <ATen/cuda/CUDAContext.h>
-#include <c10/cuda/CUDAGuard.h>
#include <THC/THC.h>
+#include <c10/cuda/CUDAGuard.h>
#include <torch/csrc/cuda/cuda_check.h>
#include <torch/csrc/jit/resource_guard.h>
// Note: unclear why this forward declaration is necessary
-#include <THC/THCTensorRandom.h>
#include <THC/THCGenerator.hpp>
+#include <THC/THCTensorRandom.h>
THCGenerator* THCRandom_getGenerator(THCState* state);
-#include <nvrtc.h>
#include <cuda.h>
#include <cuda_runtime.h>
+#include <nvrtc.h>
-#include <stdexcept>
+#include <algorithm>
+#include <cmath>
#include <sstream>
+#include <stdexcept>
#include <tuple>
#include <vector>
-#include <algorithm>
-#include <cmath>
-namespace torch { namespace jit { namespace fuser { namespace cuda {
+namespace torch {
+namespace jit {
+namespace fuser {
+namespace cuda {
-void checkCUDAVersion(
- const cudaDeviceProp& prop) {
+void checkCUDAVersion(const cudaDeviceProp& prop) {
if ((prop.major >= 6 && CUDA_VERSION < 8000) ||
(prop.major >= 7 && CUDA_VERSION < 9000)) {
std::stringstream err_string;
- err_string << "In CUDAFusedKernel, PyTorch compiled with insufficient CUDA version: "
- << CUDA_VERSION << " for the current GPU device " << prop.name
- << " with device capability " << prop.major << "." << prop.minor;
+ err_string
+ << "In CUDAFusedKernel, PyTorch compiled with insufficient CUDA version: "
+ << CUDA_VERSION << " for the current GPU device " << prop.name
+ << " with device capability " << prop.major << "." << prop.minor;
throw std::runtime_error(err_string.str());
}
}
-static void getMajorMinor(const cudaDeviceProp* const prop, int& major, int& minor) {
+static void getMajorMinor(
+ const cudaDeviceProp* const prop,
+ int& major,
+ int& minor) {
int nvrtc_major, nvrtc_minor;
TORCH_NVRTC_CHECK(nvrtcVersion(&nvrtc_major, &nvrtc_minor));
minor = 0;
} else if (nvrtc_major <= 9 && prop->major >= 7) { // 9 supports 3-7.2
major = 7;
- if (prop->major == 7 && prop->minor <= 2) minor = prop->minor;
- else minor = 0;
+ if (prop->major == 7 && prop->minor <= 2)
+ minor = prop->minor;
+ else
+ minor = 0;
} else if (nvrtc_major <= 10 && prop->major >= 7) { // 10 supports 3-7.5
major = 7;
- if (prop->major == 7 && prop->minor <= 5) minor = prop->minor;
- else minor = 0;
+ if (prop->major == 7 && prop->minor <= 5)
+ minor = prop->minor;
+ else
+ minor = 0;
}
}
CUcontext pctx = 0;
TORCH_CU_CHECK(cuCtxGetCurrent(&pctx));
if (!pctx) {
- std::unique_lock<std::mutex> cudaFreeMutexLock(
- *(THCCachingAllocator_getCudaFreeMutex()));
- cudaFree(0);
+ std::unique_lock<std::mutex> cudaFreeMutexLock(
+ *(THCCachingAllocator_getCudaFreeMutex()));
+ cudaFree(0);
}
// Note: hacked at::DeviceGuard since at::DeviceGuard was failing to work
const auto prior_device = at::cuda::current_device();
at::cuda::set_device(device_);
- // Acquires device and NVRTC properties (for compile arch and occupancy calculations)
+ // Acquires device and NVRTC properties (for compile arch and occupancy
+ // calculations)
prop_ = at::cuda::getCurrentDeviceProperties();
int major, minor;
getMajorMinor(prop_, major, minor);
// Creates the NVRTC program
nvrtcProgram program;
TORCH_NVRTC_CHECK(nvrtcCreateProgram(
- &program
- , code_.c_str()
- , nullptr
- , 0
- , nullptr
- , nullptr));
-
- const std::string compute = "--gpu-architecture=compute_" + std::to_string(major) + std::to_string(minor);
- const std::vector<const char *> args = {"--std=c++11", compute.c_str(), "-default-device"};
+ &program, code_.c_str(), nullptr, 0, nullptr, nullptr));
+
+ const std::string compute = "--gpu-architecture=compute_" +
+ std::to_string(major) + std::to_string(minor);
+ const std::vector<const char*> args = {
+ "--std=c++11", compute.c_str(), "-default-device"};
const auto result = nvrtcCompileProgram(program, args.size(), args.data());
if (result == NVRTC_ERROR_COMPILATION) {
size_t logsize;
cu << log.data();
throw std::runtime_error(cu.str());
}
- ResourceGuard holdProgram([&] {
- TORCH_NVRTC_CHECK(nvrtcDestroyProgram(&program));
- });
+ ResourceGuard holdProgram(
+ [&] { TORCH_NVRTC_CHECK(nvrtcDestroyProgram(&program)); });
TORCH_NVRTC_CHECK(result);
size_t ptx_size;
TORCH_NVRTC_CHECK(nvrtcGetPTXSize(program, &ptx_size));
// Computes max blocks
TORCH_CU_CHECK(cuOccupancyMaxActiveBlocksPerMultiprocessor(
- &maxBlocks_, function_, 128, 0));
+ &maxBlocks_, function_, 128, 0));
maxBlocks_ *= prop_->multiProcessorCount;
// Resets device (end of hacked at::DeviceGuard)
}
void FusedKernelCUDA::launch_raw(
- const uint32_t numel
-, std::vector<void*>& arguments) const {
+ const uint32_t numel,
+ std::vector<void*>& arguments) const {
at::cuda::CUDAGuard{device_};
// Hacked at::DeviceGuard (see note above)
const auto prior_device = at::cuda::current_device();
// Note: offset defined here so its lifetime extends to the launch
uint64_t offset;
if (has_random_) {
- const auto rand_offset = 4 * (std::ceil(numel / (4.0 * kBlockSize * nBlocks)) + 1);
+ const auto rand_offset =
+ 4 * (std::ceil(numel / (4.0 * kBlockSize * nBlocks)) + 1);
auto gen = THCRandom_getGenerator(at::globalContext().getTHCState());
offset = gen->state.philox_seed_offset.fetch_add(rand_offset);
arguments.push_back(&gen->state.initial_seed);
// Launches kernel on current stream (device was set by executor)
auto stream = at::cuda::getCurrentCUDAStream();
TORCH_CU_CHECK(cuLaunchKernel(
- function_,
- nBlocks, 1, 1,
- kBlockSize, 1, 1,
- 0, stream,
- arguments.data(),
- nullptr));
+ function_,
+ nBlocks,
+ 1,
+ 1,
+ kBlockSize,
+ 1,
+ 1,
+ 0,
+ stream,
+ arguments.data(),
+ nullptr));
// Resets device (see at::DeviceGuard notes above)
at::cuda::set_device(prior_device);
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/fuser/fused_kernel.h>
-#include <nvrtc.h>
#include <cuda.h>
#include <cuda_runtime.h>
+#include <nvrtc.h>
#include <cstdint>
-#include <vector>
#include <string>
+#include <vector>
-namespace torch { namespace jit { namespace fuser { namespace cuda {
+namespace torch {
+namespace jit {
+namespace fuser {
+namespace cuda {
// A class holding metadata for an actual CUDA function.
// Note: CUDA functions are per device.
return at::Backend::CUDA;
}
-private:
+ private:
static constexpr auto kBlockSize = 128;
// Note: per device to store device properties and compute launch heuristics
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/code_template.h>
-namespace torch { namespace jit { namespace fuser { namespace cuda {
+namespace torch {
+namespace jit {
+namespace fuser {
+namespace cuda {
-/*with type_as not checking type of its input, a fusion group can have non-fp32 tensor as input.
-Correct code for this case is generated, however, nvrtc does not know how to handle int*_t integer types,
-so typedefs help it handle those cases*/
+/*with type_as not checking type of its input, a fusion group can have non-fp32
+tensor as input. Correct code for this case is generated, however, nvrtc does
+not know how to handle int*_t integer types, so typedefs help it handle those
+cases*/
static auto type_declarations_template = CodeTemplate(R"(
typedef unsigned char uint8_t;
}
)";
-constexpr auto rand_param = ",unsigned long long seed, unsigned long long offset";
+constexpr auto rand_param =
+ ",unsigned long long seed, unsigned long long offset";
constexpr auto rand_init = R"(
int idx = blockIdx.x*blockDim.x + threadIdx.x;
}
)");
-
// This snippet enables half support in the jit. Following the pattern for
// reductions, fp16 input data is immediately upconverted to float
// with __half2float(). All mathematical operations are done on float
// values, and if needed the intermediate float representation is
// converted to half with __float2half() when writing to a half tensor.
-constexpr auto half_support_literal = R"(
+constexpr auto half_support_literal = R"(
#define __HALF_TO_US(var) *(reinterpret_cast<unsigned short *>(&(var)))
#define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short *>(&(var)))
#if defined(__cplusplus)
} // namespace cuda
} // namespace fuser
-} // namespace jit
+} // namespace jit
} // namespace torch
#endif // USE_CUDA_FUSER
#include <ATen/ATen.h>
#include <ATen/ExpandUtils.h>
#include <c10/util/Optional.h>
-#include <torch/csrc/utils/functional.h>
-#include <torch/csrc/jit/stack.h>
+#include <torch/csrc/jit/fuser/compiler.h>
#include <torch/csrc/jit/fuser/config.h>
#include <torch/csrc/jit/fuser/interface.h>
#include <torch/csrc/jit/fuser/kernel_cache.h>
#include <torch/csrc/jit/fuser/kernel_spec.h>
-#include <torch/csrc/jit/fuser/compiler.h>
#include <torch/csrc/jit/fuser/tensor_info.h>
+#include <torch/csrc/jit/stack.h>
+#include <torch/csrc/utils/functional.h>
-#include <vector>
-#include <tuple>
-#include <stdexcept>
#include <algorithm>
-#include <map>
#include <iostream> // TODO: remove, debugging only
+#include <map>
+#include <stdexcept>
+#include <tuple>
+#include <vector>
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
// Returns the "map size" for this run, which is the common size for all
// intermediate tensors.
static c10::optional<std::vector<int64_t>> getMapSize(
- const KernelSpec& spec
-, at::TensorList args
-, at::IntList arg_subset) {
-
+ const KernelSpec& spec,
+ at::TensorList args,
+ at::IntList arg_subset) {
// TODO: this keeps reallocating map_size at every iteration, but we know
// exactly how much storage do we need, so this could be fixed in-place at
// every step. We're just missing a few functions for ATen, but the fix
} else {
auto tensor_sizes = arg.sizes().vec();
const auto num_chunks = chunk_desc.nSubTensors();
- const auto dim = at::maybe_wrap_dim(chunk_desc.dim(), tensor_sizes.size());
+ const auto dim =
+ at::maybe_wrap_dim(chunk_desc.dim(), tensor_sizes.size());
if (tensor_sizes[dim] % num_chunks != 0) {
return c10::nullopt;
}
// Tries to determine a map size for the instantiated kernel (see above)
static c10::optional<std::vector<int64_t>> canRunKernel(
- const KernelSpec& spec
-, at::TensorList args) {
+ const KernelSpec& spec,
+ at::TensorList args) {
// Short-circuits on size mismatch
AT_CHECK(
- args.size() == spec.inputChunks().size()
- , "Expected ", spec.inputChunks().size(), " arguments, but got ", args.size());
+ args.size() == spec.inputChunks().size(),
+ "Expected ",
+ spec.inputChunks().size(),
+ " arguments, but got ",
+ args.size());
c10::optional<std::vector<int64_t>> map_size;
for (const auto& broadcast_group : spec.inputBroadcastGroups()) {
if (!map_size) {
map_size = getMapSize(spec, args, broadcast_group);
- if (!map_size) return c10::nullopt;
+ if (!map_size)
+ return c10::nullopt;
} else {
const auto group_map_size = getMapSize(spec, args, broadcast_group);
// Note: this checks that group_map_size is defined AND equal to map_size
- if (map_size != group_map_size) return c10::nullopt;
+ if (map_size != group_map_size)
+ return c10::nullopt;
}
}
// Note: Arguments are mutated by this call, although map_size is restored
// to its original value.
static void expandArgs(
- const KernelSpec& spec
-, std::vector<at::Tensor>& args
-, std::vector<int64_t>& map_size) {
+ const KernelSpec& spec,
+ std::vector<at::Tensor>& args,
+ std::vector<int64_t>& map_size) {
for (size_t i = 0; i < args.size(); ++i) {
auto& arg = args[i];
const auto& pdesc = spec.inputChunks()[i];
if (pdesc.nSubTensors() == 1) {
- if (arg.sizes().equals(map_size)) continue;
+ if (arg.sizes().equals(map_size))
+ continue;
arg = arg.expand(map_size);
} else {
map_size.at(pdesc.dim()) *= pdesc.nSubTensors();
// Note: Assumes that after at::chunk, all inputs are the same size
static std::vector<int64_t> computeMapSize(
- const at::Tensor& tensor
-, const PartitionDesc& chunkDesc) {
+ const at::Tensor& tensor,
+ const PartitionDesc& chunkDesc) {
std::vector<int64_t> sizes(tensor.sizes().begin(), tensor.sizes().end());
JIT_ASSERT(sizes[chunkDesc.dim()] % chunkDesc.nSubTensors() == 0);
sizes[chunkDesc.dim()] /= chunkDesc.nSubTensors();
// Tries to compress sizes and strides according to cont. Emits the result t
// c_sizes, c_strides and throws an error on failure (if can't compress)
static void compressContiguous(
- const at::IntList& sizes
-, const at::IntList& strides
-, const std::vector<bool>& cont
-, uint32_t* c_sizes
-, uint32_t* c_strides) {
+ const at::IntList& sizes,
+ const at::IntList& strides,
+ const std::vector<bool>& cont,
+ uint32_t* c_sizes,
+ uint32_t* c_strides) {
size_t compressed_dims = 0;
size_t cur = 0;
size_t ndim = sizes.size();
while (cur < ndim) {
size_t total_size = sizes[cur];
cur++;
- while (cont[cur-1] && cur < ndim) {
- JIT_ASSERT(strides[cur-1] == sizes[cur]*strides[cur]);
+ while (cont[cur - 1] && cur < ndim) {
+ JIT_ASSERT(strides[cur - 1] == sizes[cur] * strides[cur]);
total_size *= sizes[cur];
cur++;
}
c_sizes[compressed_dims] = total_size;
- c_strides[compressed_dims] = strides[cur-1];
+ c_strides[compressed_dims] = strides[cur - 1];
compressed_dims++;
}
- if (ndim > 0) JIT_ASSERT(!cont.back() || strides.back() == 1);
+ if (ndim > 0)
+ JIT_ASSERT(!cont.back() || strides.back() == 1);
}
// Launches the requested fusion on the given device with the given inputs.
// Output pointers are stored in outputs (to be put on the stack later).
void launchFusion(
- const FusedKernel& fusion
-, const at::Device device
-, const at::ArrayRef<at::Tensor>& inputs
-, std::vector<at::Tensor>& outputs) {
+ const FusedKernel& fusion,
+ const at::Device device,
+ const at::ArrayRef<at::Tensor>& inputs,
+ std::vector<at::Tensor>& outputs) {
// Fails if fusion and given inputs disagree
JIT_ASSERT(inputs.size() == fusion.inputDesc().size());
numel = computeNumel(map_size);
}
- // Computes the storage needed to store TensorInfo structs for inputs and outputs.
+ // Computes the storage needed to store TensorInfo structs for inputs and
+ // outputs.
size_t uncompressedDim = fusion.inputDesc().at(0).contiguity.size();
- size_t maxPossibleTensorInfoSize = sizeof(TensorInfo) + 2 * sizeof(uint32_t) * uncompressedDim;
- size_t maxPossibleBufferSize = maxPossibleTensorInfoSize * (flat_inputs_size + flat_outputs_size);
+ size_t maxPossibleTensorInfoSize =
+ sizeof(TensorInfo) + 2 * sizeof(uint32_t) * uncompressedDim;
+ size_t maxPossibleBufferSize =
+ maxPossibleTensorInfoSize * (flat_inputs_size + flat_outputs_size);
std::vector<char> buffer(maxPossibleBufferSize);
char* buffer_next = buffer.data();
arguments.reserve(3 + flat_inputs_size + flat_outputs_size);
arguments.push_back(&numel);
- auto addTensorInfoRaw = [&](
- const TensorDesc& desc
- , void* data_ptr
- , at::IntList sizes
- , at::IntList strides) {
+ auto addTensorInfoRaw = [&](const TensorDesc& desc,
+ void* data_ptr,
+ at::IntList sizes,
+ at::IntList strides) {
const auto nDim = desc.nDim(); // NOTE: this is the compressed dim
JIT_ASSERT(nDim <= uncompressedDim); // We'd overflow the space otherwise
auto ti = reinterpret_cast<TensorInfo*>(buffer_next);
ti->data = data_ptr;
compressContiguous(
- sizes
- , strides
- , desc.contiguity
- , ti->sizes(nDim)
- , ti->strides(nDim));
+ sizes, strides, desc.contiguity, ti->sizes(nDim), ti->strides(nDim));
buffer_next += maxPossibleTensorInfoSize;
arguments.push_back(ti);
};
if (chunk.isNoop()) {
addTensorInfo(fusion.inputDesc()[i], tensor);
} else {
- size_t chunk_offset = map_size[chunk.dim()] * tensor.stride(chunk.dim()) * elementSize(tensor.type().scalarType());
+ size_t chunk_offset = map_size[chunk.dim()] * tensor.stride(chunk.dim()) *
+ elementSize(tensor.type().scalarType());
char* data_ptr = reinterpret_cast<char*>(tensor.data_ptr());
for (size_t chunks = 0; chunks < chunk.nSubTensors(); ++chunks) {
- addTensorInfoRaw(*chunk.subTensorDesc(), data_ptr, map_size, tensor.strides());
+ addTensorInfoRaw(
+ *chunk.subTensorDesc(), data_ptr, map_size, tensor.strides());
data_ptr += chunk_offset;
}
}
for (size_t i = 0; i < fusion.outputDesc().size(); ++i) {
const auto& c = fusion.concatDesc()[i];
if (c.isNoop()) {
- outputs.push_back(at::empty(map_size, ref_options.dtype(fusion.outputDesc()[i].scalar_type)));
+ outputs.push_back(at::empty(
+ map_size, ref_options.dtype(fusion.outputDesc()[i].scalar_type)));
addTensorInfo(fusion.outputDesc()[i], outputs[i]);
} else {
size_t small_size = map_size[c.dim()];
fusion.launch_raw(numel, arguments);
}
-
-bool runFusion(
- const int64_t key
-, Stack& stack) {
+bool runFusion(const int64_t key, Stack& stack) {
// Short-circuits if fusion isn't enabled
- if (!canFuseOnCPU() && !canFuseOnGPU()) return false;
+ if (!canFuseOnCPU() && !canFuseOnGPU())
+ return false;
// Acquires the FusionSpec
auto maybe_spec = retrieve(key);
return i.toTensor();
});
- // Determines device to dispatch to. If there's a device mismatch in the inputs,
- // we use the fallback (which should give a nice error message).
+ // Determines device to dispatch to. If there's a device mismatch in the
+ // inputs, we use the fallback (which should give a nice error message).
at::Device device = inputs.at(0).device();
at::ScalarType dtype = inputs[0].type().scalarType();
for (const auto& t : at::TensorList(inputs).slice(1)) {
}
// Attempts to run fallback if device fusion is disabled
- if (device.is_cuda() && !canFuseOnGPU()) return false;
- if (device.is_cpu() && !canFuseOnCPU()) return false;
+ if (device.is_cuda() && !canFuseOnGPU())
+ return false;
+ if (device.is_cpu() && !canFuseOnCPU())
+ return false;
// Validates sizes and expands inputs as needed
auto maybe_map_size = canRunKernel(spec, inputs);
// Tries to run fallback if map size can't be computed
- if (!maybe_map_size) return false;
+ if (!maybe_map_size)
+ return false;
expandArgs(spec, inputs, *maybe_map_size);
// Retrieves the kernel, compiling (and caching) if necessary
// Updates stack
drop(stack, spec.nInputs());
stack.insert(
- stack.end()
- , std::make_move_iterator(outputs.begin())
- , std::make_move_iterator(outputs.end()));
+ stack.end(),
+ std::make_move_iterator(outputs.begin()),
+ std::make_move_iterator(outputs.end()));
return true;
}
#include <cstdint>
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
// Runs the fusion associated with the key (see registerFusion() in interface.h)
// on the inputs taken from the given Stack.
-TORCH_API bool runFusion(
- const int64_t key
-, Stack& stack);
+TORCH_API bool runFusion(const int64_t key, Stack& stack);
} // namespace fuser
} // namespace jit
#include <torch/csrc/jit/fuser/fallback.h>
-#include <torch/csrc/utils/functional.h> //fmap
+#include <torch/csrc/jit/custom_operator.h>
+#include <torch/csrc/jit/fuser/kernel_cache.h>
#include <torch/csrc/jit/interpreter.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/stack.h>
-#include <torch/csrc/jit/custom_operator.h>
-#include <torch/csrc/jit/fuser/kernel_cache.h>
+#include <torch/csrc/utils/functional.h> //fmap
#include <stdexcept>
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
-// Registers fused operators so that fused graphs can properly generate fallback code.
-RegisterOperators reg_fused_operators({
- Operator(
- prim::FusedConcat
- , [](const Node* node) {
- int64_t dim = node->i(attr::dim);
- int64_t num_inputs = node->inputs().size();
- return [dim, num_inputs](Stack& stack) {
- auto result = at::cat(
- fmap(last(stack, num_inputs), [](const IValue& i) { return i.toTensor(); })
- , dim
- );
- drop(stack, num_inputs);
- pack(stack, std::move(result));
- return 0;
- };
- })
-});
+// Registers fused operators so that fused graphs can properly generate fallback
+// code.
+RegisterOperators reg_fused_operators(
+ {Operator(prim::FusedConcat, [](const Node* node) {
+ int64_t dim = node->i(attr::dim);
+ int64_t num_inputs = node->inputs().size();
+ return [dim, num_inputs](Stack& stack) {
+ auto result = at::cat(
+ fmap(
+ last(stack, num_inputs),
+ [](const IValue& i) { return i.toTensor(); }),
+ dim);
+ drop(stack, num_inputs);
+ pack(stack, std::move(result));
+ return 0;
+ };
+ })});
void runFallback(int64_t key, Stack& stack) {
auto maybe_spec = retrieve(key);
if (!maybe_spec)
throw std::runtime_error("Failed to find fusion spec to run fallback.");
-
+
InterpreterState{(*maybe_spec)->code()}.run(stack);
}
#include <cstdlib>
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
void runFallback(int64_t key, Stack& stack);
#if USE_CUDA_FUSER || USE_CPU_FUSER
#include <ATen/ATen.h>
-#include <torch/csrc/utils/disallow_copy.h>
-#include <torch/csrc/jit/fuser/tensor_desc.h>
#include <torch/csrc/jit/fuser/partition_desc.h>
+#include <torch/csrc/jit/fuser/tensor_desc.h>
+#include <torch/csrc/utils/disallow_copy.h>
-#include <string>
#include <cstdint>
+#include <string>
#include <vector>
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
struct FusedKernel {
TH_DISALLOW_COPY_AND_ASSIGN(FusedKernel);
virtual ~FusedKernel() = default;
-
// arguments is a list of pointers to the arguments for the compiled CUDA/CPU
// code.
// The format of arguments is suitable for directly passing to a call to
// CUDA code), and the remainder are pointers to the TensorInfo<T> structs
// that compiled code uses to load Tensor data.
// launch_with_tensors handles packing at::Tensors into this arguments array.
- // CPU code uses the same convension so that launch_with_tensors can be shared.
- virtual void launch_raw(
- const uint32_t numel
- , std::vector<void*>& arguments) const = 0;
+ // CPU code uses the same convension so that launch_with_tensors can be
+ // shared.
+ virtual void launch_raw(const uint32_t numel, std::vector<void*>& arguments)
+ const = 0;
virtual at::Backend backend() const = 0;
// Getters
- const std::string& name() const { return name_; }
- const std::string& code() const { return code_; }
- const std::vector<TensorDesc>& inputDesc() const { return input_desc_; }
- const std::vector<TensorDesc>& outputDesc() const { return output_desc_; }
- const std::vector<PartitionDesc>& chunkDesc() const { return chunk_desc_; }
- const std::vector<PartitionDesc>& concatDesc() const { return concat_desc_; }
- bool hasRandom() const { return has_random_; }
-
-
-protected:
+ const std::string& name() const {
+ return name_;
+ }
+ const std::string& code() const {
+ return code_;
+ }
+ const std::vector<TensorDesc>& inputDesc() const {
+ return input_desc_;
+ }
+ const std::vector<TensorDesc>& outputDesc() const {
+ return output_desc_;
+ }
+ const std::vector<PartitionDesc>& chunkDesc() const {
+ return chunk_desc_;
+ }
+ const std::vector<PartitionDesc>& concatDesc() const {
+ return concat_desc_;
+ }
+ bool hasRandom() const {
+ return has_random_;
+ }
+
+ protected:
const std::string name_;
const std::string code_;
const std::vector<TensorDesc> input_desc_;
#include <torch/csrc/jit/fuser/config.h>
#if USE_CUDA_FUSER || USE_CPU_FUSER
- #include <torch/csrc/jit/fuser/compiler.h>
- #include <torch/csrc/jit/fuser/executor.h>
- #include <torch/csrc/jit/fuser/fallback.h>
+#include <torch/csrc/jit/fuser/compiler.h>
+#include <torch/csrc/jit/fuser/executor.h>
+#include <torch/csrc/jit/fuser/fallback.h>
#endif // USE_CUDA_FUSER || USE_CPU_FUSER
#include <stdexcept>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
namespace detail {
} // namespace detail
int64_t registerFusion(const Node* fusion_group) {
- #if USE_CUDA_FUSER || USE_CPU_FUSER
- return fuser::registerFusion(fusion_group);
- #else
- throw std::runtime_error("Fusion not supported for this build.");
- #endif // USE_CUDA_FUSER || USE_CPU_FUSER
+#if USE_CUDA_FUSER || USE_CPU_FUSER
+ return fuser::registerFusion(fusion_group);
+#else
+ throw std::runtime_error("Fusion not supported for this build.");
+#endif // USE_CUDA_FUSER || USE_CPU_FUSER
}
void runFusion(const int64_t key, Stack& stack) {
- #if USE_CUDA_FUSER || USE_CPU_FUSER
- const auto result = fuser::runFusion(key, stack);
- if (!result) fuser::runFallback(key, stack);
- #else
- throw std::runtime_error("Fusion not supported for this build.");
- #endif // USE_CUDA_FUSER || USE_CPU_FUSER
+#if USE_CUDA_FUSER || USE_CPU_FUSER
+ const auto result = fuser::runFusion(key, stack);
+ if (!result)
+ fuser::runFallback(key, stack);
+#else
+ throw std::runtime_error("Fusion not supported for this build.");
+#endif // USE_CUDA_FUSER || USE_CPU_FUSER
}
bool canFuseOnCPU() {
- #if USE_CPU_FUSER
- return detail::cpu_fuser_enabled;
- #endif // USE_CPU_FUSER
+#if USE_CPU_FUSER
+ return detail::cpu_fuser_enabled;
+#endif // USE_CPU_FUSER
return false;
}
bool canFuseOnGPU() {
- #if USE_CUDA_FUSER
- return true;
- #endif // USE_CUDA_FUSER
+#if USE_CUDA_FUSER
+ return true;
+#endif // USE_CUDA_FUSER
return false;
}
// Uses the above interface by stuffing the graph into a node and treating that
// node as a fusion group.
std::vector<at::Tensor> debugLaunchGraph(
- Graph& graph
-, at::ArrayRef<at::Tensor> inputs) {
- #if USE_CUDA_FUSER || USE_CPU_FUSER
- // Creates a fusion group node
- auto wrapper_graph = std::make_shared<Graph>();
- Node* fusion_group = wrapper_graph->insertNode(wrapper_graph->createFusionGroup());
- fusion_group->g_(attr::Subgraph, graph.copy());
- for (size_t i = 0; i < graph.inputs().size(); ++i) {
- fusion_group->addInput(wrapper_graph->addInput());
- }
- for (size_t i = 0; i < graph.outputs().size(); ++i) {
- wrapper_graph->registerOutput(fusion_group->addOutput());
- }
-
- // Creates the stack, registers and runs the fusion
- Stack stack = fmap<IValue>(inputs);
- const auto key = fuser::registerFusion(fusion_group);
- fuser::runFusion(key, stack);
- return fmap(stack, [](const IValue& iv) { return iv.toTensor(); });
- #else
- throw std::runtime_error("Fusion not supported for this build.");
- #endif // USE_CUDA_FUSER || USE_CPU_FUSER
+ Graph& graph,
+ at::ArrayRef<at::Tensor> inputs) {
+#if USE_CUDA_FUSER || USE_CPU_FUSER
+ // Creates a fusion group node
+ auto wrapper_graph = std::make_shared<Graph>();
+ Node* fusion_group =
+ wrapper_graph->insertNode(wrapper_graph->createFusionGroup());
+ fusion_group->g_(attr::Subgraph, graph.copy());
+ for (size_t i = 0; i < graph.inputs().size(); ++i) {
+ fusion_group->addInput(wrapper_graph->addInput());
+ }
+ for (size_t i = 0; i < graph.outputs().size(); ++i) {
+ wrapper_graph->registerOutput(fusion_group->addOutput());
+ }
+
+ // Creates the stack, registers and runs the fusion
+ Stack stack = fmap<IValue>(inputs);
+ const auto key = fuser::registerFusion(fusion_group);
+ fuser::runFusion(key, stack);
+ return fmap(stack, [](const IValue& iv) { return iv.toTensor(); });
+#else
+ throw std::runtime_error("Fusion not supported for this build.");
+#endif // USE_CUDA_FUSER || USE_CPU_FUSER
}
-size_t nCompiledKernels() {
- #if USE_CUDA_FUSER || USE_CPU_FUSER
- return fuser::nCompiledKernels();
- #else
- return 0;
- #endif // USE_CUDA_FUSER || USE_CPU_FUSER
+size_t nCompiledKernels() {
+#if USE_CUDA_FUSER || USE_CPU_FUSER
+ return fuser::nCompiledKernels();
+#else
+ return 0;
+#endif // USE_CUDA_FUSER || USE_CPU_FUSER
}
} // namespace jit
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/stack.h>
+#include <cstdint>
#include <memory>
#include <vector>
-#include <cstdint>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
constexpr int kCPUDevice = -1;
TORCH_API bool canFuseOnCPU();
TORCH_API bool canFuseOnGPU();
-// Sets whether fusion on the CPU is allowed (disabled by default due to flakiness)
+// Sets whether fusion on the CPU is allowed (disabled by default due to
+// flakiness)
TORCH_API void overrideCanFuseOnCPU(bool value);
// Treats the given graph as a fusion group and launches it on the
// specified device with the given inputs.
// Returns the outputs.
TORCH_API std::vector<at::Tensor> debugLaunchGraph(
- Graph& graph
-, at::ArrayRef<at::Tensor> inputs);
+ Graph& graph,
+ at::ArrayRef<at::Tensor> inputs);
TORCH_API size_t nCompiledKernels();
#include <torch/csrc/jit/passes/canonicalize.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
-#include <unordered_map>
-#include <mutex>
#include <cstdint>
+#include <mutex>
+#include <unordered_map>
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
struct KernelCacheImpl {
// Note: std::unordered_map does not invalidate references even if rehashing
- // occurs. This is a critical property for thread-safety.
+ // occurs. This is a critical property for thread-safety.
std::mutex mutex_;
int64_t kernel_counter{0};
return cache.specMap_.size();
}
-std::shared_ptr<Graph> normalizeGraphForCache(const std::shared_ptr<Graph>& graph) {
+std::shared_ptr<Graph> normalizeGraphForCache(
+ const std::shared_ptr<Graph>& graph) {
auto result = Canonicalize(graph, /*keep_unique_names=*/false);
EraseShapeInformation(result);
return result;
std::lock_guard<std::mutex> guard{cache.mutex_};
const auto key = cache.kernel_counter++;
cache.specMap_.emplace(
- std::piecewise_construct
- , std::forward_as_tuple(key)
- , std::forward_as_tuple(key, graph));
+ std::piecewise_construct,
+ std::forward_as_tuple(key),
+ std::forward_as_tuple(key, graph));
cache.graphToKey_.emplace(std::make_pair(std::move(repr), key));
return key;
}
// XXX: Does not grab mutex
static at::optional<KernelSpec*> nolock_retrieve(
- KernelCacheImpl& cache, const int64_t key) {
+ KernelCacheImpl& cache,
+ const int64_t key) {
auto it = cache.specMap_.find(key);
- if (it == cache.specMap_.end()) return at::nullopt;
+ if (it == cache.specMap_.end())
+ return at::nullopt;
return &(it->second);
}
-at::optional<KernelSpec*> retrieve(const int64_t key) {
+at::optional<KernelSpec*> retrieve(const int64_t key) {
auto& cache = getKernelCache();
std::lock_guard<std::mutex> guard{cache.mutex_};
return nolock_retrieve(cache, key);
std::lock_guard<std::mutex> guard{cache.mutex_};
auto it = cache.graphToKey_.find(repr);
- if (it == cache.graphToKey_.end()) return at::nullopt;
+ if (it == cache.graphToKey_.end())
+ return at::nullopt;
return nolock_retrieve(cache, it->second);
}
#include <c10/util/Optional.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
-#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/fuser/kernel_spec.h>
+#include <torch/csrc/jit/ir.h>
-#include <cstdint>
+#include <cstdint>
#include <functional>
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
// A thread-safe cache interface.
// Normalizes the graph by canonicalizing and erasing shape information
-TORCH_API std::shared_ptr<Graph> normalizeGraphForCache(const std::shared_ptr<Graph>& graph);
+TORCH_API std::shared_ptr<Graph> normalizeGraphForCache(
+ const std::shared_ptr<Graph>& graph);
// Stores the given graph, returning the key used to access it
TORCH_API int64_t store(std::shared_ptr<Graph> graph);
#if USE_CUDA_FUSER || USE_CPU_FUSER
#include <ATen/ATen.h>
-#include <torch/csrc/WindowsTorchApiMacro.h>
#include <c10/util/Optional.h>
-#include <torch/csrc/jit/stack.h>
-#include <torch/csrc/jit/interpreter.h>
-#include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/fuser/interface.h>
+#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/fuser/arg_spec.h>
#include <torch/csrc/jit/fuser/fused_kernel.h>
+#include <torch/csrc/jit/fuser/interface.h>
+#include <torch/csrc/jit/interpreter.h>
+#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/stack.h>
-#include <memory>
#include <cstdint>
-#include <vector>
-#include <unordered_map>
+#include <memory>
#include <mutex>
+#include <unordered_map>
+#include <vector>
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
// Helper struct containing partition information: the number of tensors
// created and the dimension the partitioning is performed on.
// at runtime the partition info is logically combined with the tensor
// descriptions to create PartitionDesc objects.
struct TORCH_API PartitionInfo {
- PartitionInfo(
- const int64_t _nSubTensors
- , const int64_t _dim)
- : nSubTensors_{_nSubTensors}
- , dim_{_dim}
- { };
+ PartitionInfo(const int64_t _nSubTensors, const int64_t _dim)
+ : nSubTensors_{_nSubTensors}, dim_{_dim} {};
- int64_t nSubTensors() const { return nSubTensors_; }
- int64_t dim() const { return dim_; }
+ int64_t nSubTensors() const {
+ return nSubTensors_;
+ }
+ int64_t dim() const {
+ return dim_;
+ }
-private:
+ private:
int64_t nSubTensors_;
int64_t dim_;
};
- // "Kernel Specification." - Contains device-independent fusion information.
- // Each kernel specification contains a map of instantiated generated functions
- // that implement some or most of its functionality. Multiple generated
- // functions are needed by each abstract specification because of different
- // devices (cpu vs gpu, different gpus) and different inputs (int vs float,
- // contiguous vs discontiguous).
- // Note: uses a mutex to control access to its kernel store
- // Note: unordered containers do not invalidate references/pointers on
- // rehashing, which is critical for thread-safety.
- // TODO: allow abstract kernels to use multiple generated kernels
- // TODO: allow abstract kernels to reuse generated kernels from common pool
+// "Kernel Specification." - Contains device-independent fusion information.
+// Each kernel specification contains a map of instantiated generated functions
+// that implement some or most of its functionality. Multiple generated
+// functions are needed by each abstract specification because of different
+// devices (cpu vs gpu, different gpus) and different inputs (int vs float,
+// contiguous vs discontiguous).
+// Note: uses a mutex to control access to its kernel store
+// Note: unordered containers do not invalidate references/pointers on
+// rehashing, which is critical for thread-safety.
+// TODO: allow abstract kernels to use multiple generated kernels
+// TODO: allow abstract kernels to reuse generated kernels from common pool
struct TORCH_API KernelSpec {
KernelSpec(const int64_t _key, const std::shared_ptr<Graph>& _graph)
: key_{_key},
kernels_{} {}
// Getters
- int64_t key() const { return key_; }
- std::shared_ptr<Graph> graph() const { return graph_; }
- const Code& code() const { return code_; }
- int64_t nInputs() const { return nInputs_; }
+ int64_t key() const {
+ return key_;
+ }
+ std::shared_ptr<Graph> graph() const {
+ return graph_;
+ }
+ const Code& code() const {
+ return code_;
+ }
+ int64_t nInputs() const {
+ return nInputs_;
+ }
std::vector<std::vector<int64_t>>& inputBroadcastGroups() {
return inputBroadcastGroups_;
return inputBroadcastGroups_;
}
- std::vector<PartitionInfo>& inputChunks() { return inputChunks_; }
- const std::vector<PartitionInfo>& inputChunks() const { return inputChunks_; }
+ std::vector<PartitionInfo>& inputChunks() {
+ return inputChunks_;
+ }
+ const std::vector<PartitionInfo>& inputChunks() const {
+ return inputChunks_;
+ }
// Cache functions
- c10::optional<std::shared_ptr<FusedKernel>> findKernel(const ArgSpec& arg_spec) const {
+ c10::optional<std::shared_ptr<FusedKernel>> findKernel(
+ const ArgSpec& arg_spec) const {
std::lock_guard<std::mutex> guard{mutex_};
const auto it = kernels_.find(arg_spec);
- if (it == kernels_.end()) return c10::nullopt;
+ if (it == kernels_.end())
+ return c10::nullopt;
return it->second;
}
- void cacheKernel(
- const ArgSpec& arg_spec
- , std::shared_ptr<FusedKernel> kernel) const {
+ void cacheKernel(const ArgSpec& arg_spec, std::shared_ptr<FusedKernel> kernel)
+ const {
std::lock_guard<std::mutex> guard{mutex_};
kernels_.emplace(arg_spec, kernel);
}
-private:
+ private:
int64_t key_;
std::shared_ptr<Graph> graph_;
Code code_;
std::vector<std::vector<int64_t>> inputBroadcastGroups_;
std::vector<PartitionInfo> inputChunks_;
mutable std::mutex mutex_;
- mutable std::unordered_map<
- ArgSpec
- , std::shared_ptr<FusedKernel>
- , torch::hash<ArgSpec>> kernels_;
+ mutable std::
+ unordered_map<ArgSpec, std::shared_ptr<FusedKernel>, torch::hash<ArgSpec>>
+ kernels_;
};
} // namespace fuser
#include <torch/csrc/jit/assertions.h>
#include <torch/csrc/jit/fuser/tensor_desc.h>
-#include <memory>
#include <cstdint>
+#include <memory>
#include <vector>
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
// Descriptor for chunk-ing an input tensor into subtensors
// OR concat-ing an output tensor from subtensors
// Note: default constructed used for tensors that do not participate in
// chunk or cat operations.
struct TORCH_API PartitionDesc {
- PartitionDesc()
- : nSubTensors_{1}
- , dim_{0}
- { }
-
- PartitionDesc(
- const TensorDesc& _desc
- , size_t _nSubTensors
- , size_t _dim)
- : nSubTensors_{_nSubTensors}
- , dim_{_dim} {
+ PartitionDesc() : nSubTensors_{1}, dim_{0} {}
+
+ PartitionDesc(const TensorDesc& _desc, size_t _nSubTensors, size_t _dim)
+ : nSubTensors_{_nSubTensors}, dim_{_dim} {
JIT_ASSERT(nSubTensors_ > 1);
std::vector<bool> cont = _desc.contiguity;
if (dim_ > 0) {
subTensorDesc_.reset(new TensorDesc(_desc.scalar_type, cont));
}
- bool isNoop() const { return (nSubTensors_ == 1);}
- size_t nSubTensors() const { return nSubTensors_; }
- size_t dim() const { return dim_; }
- std::shared_ptr<TensorDesc> subTensorDesc() { return subTensorDesc_; }
- const std::shared_ptr<TensorDesc> subTensorDesc() const { return subTensorDesc_; }
+ bool isNoop() const {
+ return (nSubTensors_ == 1);
+ }
+ size_t nSubTensors() const {
+ return nSubTensors_;
+ }
+ size_t dim() const {
+ return dim_;
+ }
+ std::shared_ptr<TensorDesc> subTensorDesc() {
+ return subTensorDesc_;
+ }
+ const std::shared_ptr<TensorDesc> subTensorDesc() const {
+ return subTensorDesc_;
+ }
-private:
- size_t nSubTensors_; // == 1 for tensors that should not be operated on via chunk/cat
+ private:
+ size_t nSubTensors_; // == 1 for tensors that should not be operated on via
+ // chunk/cat
size_t dim_; // dimension along which the chunk/concat occurs
- std::shared_ptr<TensorDesc> subTensorDesc_; // descriptor for the subtensor, if it exists
+ std::shared_ptr<TensorDesc>
+ subTensorDesc_; // descriptor for the subtensor, if it exists
};
} // namespace fuser
-} // namespace jit
+} // namespace jit
} // namespace torch
#endif // USE_CUDA_FUSER || USE_CPU_FUSER
#include <ATen/ATen.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
-#include <torch/csrc/utils/hash.h>
-#include <torch/csrc/jit/type.h>
#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/type.h>
+#include <torch/csrc/utils/hash.h>
-#include <vector>
-#include <iostream>
#include <algorithm>
+#include <iostream>
+#include <vector>
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
// type information needed by the compiler for input/outputs
// contiguity[i] is true if the dim i is contiguous with dim i + 1.
at::ScalarType scalar_type;
std::vector<bool> contiguity;
- TensorDesc(
- const at::ScalarType& type
- , const std::vector<bool>& contiguity)
- : scalar_type{type}
- , contiguity{contiguity} {
+ TensorDesc(const at::ScalarType& type, const std::vector<bool>& contiguity)
+ : scalar_type{type}, contiguity{contiguity} {
if (contiguity.size() == 0) {
nDim_ = 0;
} else {
- nDim_ = std::count(contiguity.begin(), contiguity.end(), false) + (lastIsContiguous() ? 1 : 0);
+ nDim_ = std::count(contiguity.begin(), contiguity.end(), false) +
+ (lastIsContiguous() ? 1 : 0);
}
}
// Delegating constructors
TensorDesc(
- const at::ScalarType& type
- , const at::IntList& sizes
- , const at::IntList& strides)
- : TensorDesc(type, TensorDesc::findContiguous(sizes, strides)) {}
+ const at::ScalarType& type,
+ const at::IntList& sizes,
+ const at::IntList& strides)
+ : TensorDesc(type, TensorDesc::findContiguous(sizes, strides)) {}
TensorDesc(const at::Tensor& t)
- : TensorDesc(t.type().scalarType(), t.sizes(), t.strides()) {}
+ : TensorDesc(t.type().scalarType(), t.sizes(), t.strides()) {}
TensorDesc(const CompleteTensorTypePtr& type)
- : TensorDesc(type->scalarType(), type->sizes(), type->strides()) {}
+ : TensorDesc(type->scalarType(), type->sizes(), type->strides()) {}
// number of dimensions after contiguity compression
- size_t nDim() const { return nDim_; }
+ size_t nDim() const {
+ return nDim_;
+ }
// True iff innermost stride is 1
bool lastIsContiguous() const {
}
static std::vector<bool> findContiguous(
- const at::IntList& sizes
- , const at::IntList& strides) {
+ const at::IntList& sizes,
+ const at::IntList& strides) {
JIT_ASSERT(sizes.size() == strides.size());
std::vector<bool> cont(sizes.size());
for (size_t i = 0; i < sizes.size(); ++i) {
- const auto expected_stride = (i + 1 < sizes.size()) ? sizes[i+1]*strides[i+1] : 1;
+ const auto expected_stride =
+ (i + 1 < sizes.size()) ? sizes[i + 1] * strides[i + 1] : 1;
cont[i] = (strides[i] == expected_stride);
}
return cont;
}
static size_t hash(const TensorDesc& spec) {
- return torch::get_hash(spec.scalar_type, spec.nDim_, std::hash<std::vector<bool>>{}(spec.contiguity));
+ return torch::get_hash(
+ spec.scalar_type,
+ spec.nDim_,
+ std::hash<std::vector<bool>>{}(spec.contiguity));
}
-private:
+ private:
size_t nDim_;
};
#include <cstdint>
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
// Host-side view of TensorInfo
// Note dims[0] - we need to dynamically allocate the dims.
struct TORCH_API TensorInfo {
-
- uint32_t* sizes(size_t nDim) { return &sizes_strides[0]; }
- uint32_t* strides(size_t nDim) { return &sizes_strides[nDim]; }
+ uint32_t* sizes(size_t nDim) {
+ return &sizes_strides[0];
+ }
+ uint32_t* strides(size_t nDim) {
+ return &sizes_strides[nDim];
+ }
void* data;
- #pragma GCC diagnostic ignored "-Wpedantic"
- uint32_t sizes_strides[0];
- #pragma GCC diagnostic pop
+#pragma GCC diagnostic ignored "-Wpedantic"
+ uint32_t sizes_strides[0];
+#pragma GCC diagnostic pop
};
} // namespace fuser
-} // namespace jit
+} // namespace jit
} // namespace torch
#endif // USE_CUDA_FUSER || USE_CPU_FUSER
// TODO: I'm pretty sure Constness can be done with C++ templates, ala
// std::is_const, but no time to work it out...
-#define GENERIC_IF(Constness, FullKind, x, Kind) \
- auto && __match_key = x; \
- switch(__match_key->kind()) { \
- case FullKind: { \
- auto * value = static_cast<Constness ::torch::jit::Kind*>(__match_key); (void) value;
-#define GENERIC_ELSEIF(Constness, FullKind, Kind) \
- } break; \
- case FullKind: { \
- auto * value = static_cast<Constness ::torch::jit::Kind*>(__match_key); (void) value;
+#define GENERIC_IF(Constness, FullKind, x, Kind) \
+ auto&& __match_key = x; \
+ switch (__match_key->kind()) { \
+ case FullKind: { \
+ auto* value = static_cast<Constness ::torch::jit::Kind*>(__match_key); \
+ (void)value;
+#define GENERIC_ELSEIF(Constness, FullKind, Kind) \
+ } \
+ break; \
+ case FullKind: { \
+ auto* value = static_cast<Constness ::torch::jit::Kind*>(__match_key); \
+ (void)value;
#define GENERIC_ELSE() \
- } break; \
- default: {
+ } \
+ break; \
+ default: {
#define GENERIC_END() \
- } break; \
- };
+ } \
+ break; \
+ } \
+ ;
#include <torch/csrc/jit/graph_executor.h>
-#include <torch/csrc/jit/assertions.h>
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/jit/argument_spec.h>
+#include <torch/csrc/jit/assertions.h>
#include <torch/csrc/jit/autodiff.h>
+#include <torch/csrc/jit/custom_operator.h>
#include <torch/csrc/jit/interpreter.h>
#include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/tracer.h>
+#include <torch/csrc/jit/ivalue.h>
#include <torch/csrc/jit/passes/batch_mm.h>
+#include <torch/csrc/jit/passes/canonicalize_ops.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
+#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/graph_fuser.h>
+#include <torch/csrc/jit/passes/inline_autodiff_subgraphs.h>
#include <torch/csrc/jit/passes/inplace_check.h>
-#include <torch/csrc/jit/passes/peephole.h>
-#include <torch/csrc/jit/passes/shape_analysis.h>
-#include <torch/csrc/jit/passes/remove_expands.h>
-#include <torch/csrc/jit/passes/canonicalize_ops.h>
-#include <torch/csrc/jit/passes/specialize_undef.h>
#include <torch/csrc/jit/passes/loop_unrolling.h>
#include <torch/csrc/jit/passes/lower_grad_of.h>
-#include <torch/csrc/jit/passes/constant_propagation.h>
-#include <torch/csrc/jit/passes/inline_autodiff_subgraphs.h>
+#include <torch/csrc/jit/passes/peephole.h>
+#include <torch/csrc/jit/passes/remove_expands.h>
#include <torch/csrc/jit/passes/requires_grad_analysis.h>
+#include <torch/csrc/jit/passes/shape_analysis.h>
+#include <torch/csrc/jit/passes/specialize_undef.h>
#include <torch/csrc/jit/symbolic_variable.h>
-#include <torch/csrc/jit/ivalue.h>
-#include <torch/csrc/jit/custom_operator.h>
+#include <torch/csrc/jit/tracer.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/jit/script/compiler.h>
#include <cstdint>
+#include <iterator>
#include <memory>
#include <mutex>
#include <unordered_map>
#include <utility>
#include <vector>
-#include <iterator>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
namespace {
struct ExecutionPlan {
ExecutionPlan() = default;
ExecutionPlan(std::shared_ptr<Graph> graph)
- : code(graph)
- , graph(std::move(graph)) {}
+ : code(graph), graph(std::move(graph)) {}
void run(Stack& stack) const {
return InterpreterState(code).run(stack);
struct DifferentiableGraphBackward : public autograd::Function {
DifferentiableGraphBackward(GraphExecutor executor, size_t capture_size)
- : executor(std::move(executor)) {
+ : executor(std::move(executor)) {
is_var_capture.reserve(capture_size);
var_captures.reserve(capture_size);
ivalue_captures.reserve(capture_size);
variable_list apply(variable_list&& inputs) override {
Stack stack;
stack.reserve(is_var_capture.size() + inputs.size());
- stack.insert(stack.end(), std::make_move_iterator(inputs.begin()),
- std::make_move_iterator(inputs.end()));
+ stack.insert(
+ stack.end(),
+ std::make_move_iterator(inputs.begin()),
+ std::make_move_iterator(inputs.end()));
auto var_capture_it = var_captures.begin();
auto ivalue_capture_it = ivalue_captures.begin();
for (bool is_var : is_var_capture) {
for (size_t i = 0; i < num_outputs(); ++i) {
if (should_compute_output(i)) {
auto output = std::move(stack[i]).toTensor();
- const auto & edge = next_edge(i);
+ const auto& edge = next_edge(i);
if (output.defined()) {
outputs.emplace_back(std::move(output));
} else if (edge.is_valid()) {
- outputs.emplace_back(edge.function->input_metadata(edge.input_nr).zeros_like());
+ outputs.emplace_back(
+ edge.function->input_metadata(edge.input_nr).zeros_like());
} else {
outputs.emplace_back();
}
return outputs;
}
- void capture(const IValue & val, bool is_output) {
+ void capture(const IValue& val, bool is_output) {
const bool is_tensor = val.isTensor();
is_var_capture.push_back(is_tensor);
if (is_tensor) {
ivalue_captures.push_back(val);
}
}
-private:
+
+ private:
friend struct ExecutionPlan;
GraphExecutor executor;
- // INVARIANT: is_var_capture.size() == var_captures.size() + ivalue_captures.size()
+ // INVARIANT: is_var_capture.size() == var_captures.size() +
+ // ivalue_captures.size()
std::vector<bool> is_var_capture;
std::vector<autograd::SavedVariable> var_captures;
std::vector<IValue> ivalue_captures;
num_outputs(this->grad.f->outputs().size()) {}
// XXX: keep in mind that stack can be larger than the inputs we need!
- int operator()(Stack & stack) const {
- auto grad_fn = std::make_shared<DifferentiableGraphBackward>(grad_executor,
- grad.df_input_captured_inputs.size() + grad.df_input_captured_outputs.size());
+ int operator()(Stack& stack) const {
+ auto grad_fn = std::make_shared<DifferentiableGraphBackward>(
+ grad_executor,
+ grad.df_input_captured_inputs.size() +
+ grad.df_input_captured_outputs.size());
{
auto inputs = last(stack, num_inputs);
- // hook up the outputs of df to the gradient functions of the inputs that require gradients
- for(auto idx : grad.df_output_vjps) {
+ // hook up the outputs of df to the gradient functions of the inputs that
+ // require gradients
+ for (auto idx : grad.df_output_vjps) {
auto v = Variable(inputs[idx].toTensor());
- grad_fn->add_next_edge(v.defined() ? v.gradient_edge() : autograd::Edge{});
+ grad_fn->add_next_edge(
+ v.defined() ? v.gradient_edge() : autograd::Edge{});
}
captureInputs(*grad_fn, inputs);
}
auto outputs = last(stack, num_outputs);
// hookup the gradients for the output tensors that require gradients
// to the inputs to our gradient function df
- // TODO - XXX - if any output is the same tensor multiple times, views have to be
- // setup here. We need to refactor autograd until it is safe for
- // tensors to be constructed without all the viewing infrastructure.
- // this is currently intentionally not done here so we can get an idea of our
- // perf before introducing overhead for correctness
- for(auto idx : grad.df_input_vjps) {
+ // TODO - XXX - if any output is the same tensor multiple times, views
+ // have to be setup here. We need to refactor autograd until it is safe
+ // for tensors to be constructed without all the viewing infrastructure.
+ // this is currently intentionally not done here so we can get an idea of
+ // our perf before introducing overhead for correctness
+ for (auto idx : grad.df_input_vjps) {
// Note: we have to set this up in place, or we have to throw away and
// reallocate variables that were already created in wrapTensors. We
// should add an API for this.
Variable output = outputs[idx].toTensor();
- // NB: since our requires_grad setting is only a heuristic we might end up
- // wanting to differentiate through integral tensors, which is generally a
- // hard error in autograd.
+ // NB: since our requires_grad setting is only a heuristic we might end
+ // up wanting to differentiate through integral tensors, which is
+ // generally a hard error in autograd.
if (at::isFloatingType(output.type().scalarType())) {
autograd::create_gradient_edge(output, grad_fn);
output.set_requires_grad(true);
return 0;
}
-private:
+ private:
friend GraphExecutor* detail::getGradExecutor(Operation& op);
- void detachVariables(Stack & stack) const {
- // It would be nice to use an ArrayRef here, but unfortunately those can only
- // return const references, so we need to do a bunch of indexing ourselves.
+ void detachVariables(Stack& stack) const {
+ // It would be nice to use an ArrayRef here, but unfortunately those can
+ // only return const references, so we need to do a bunch of indexing
+ // ourselves.
const int64_t stack_size = stack.size();
const int64_t stack_offset = stack_size - num_inputs;
for (int64_t i = stack_offset; i < stack_size; ++i) {
- auto & v = stack[i];
- if (!v.isTensor()) continue;
+ auto& v = stack[i];
+ if (!v.isTensor())
+ continue;
auto t = std::move(v).toTensor();
- v = IValue{t.defined() ? autograd::as_variable_ref(t).detach() : std::move(t)};
+ v = IValue{t.defined() ? autograd::as_variable_ref(t).detach()
+ : std::move(t)};
}
}
// Capture (save) inputs that would be required to subsequently run backwards
- void captureInputs(DifferentiableGraphBackward & grad_fn, at::ArrayRef<IValue> inputs) const {
+ void captureInputs(
+ DifferentiableGraphBackward& grad_fn,
+ at::ArrayRef<IValue> inputs) const {
for (size_t offset : grad.df_input_captured_inputs) {
- grad_fn.capture(inputs[offset], /*is_output*/false);
+ grad_fn.capture(inputs[offset], /*is_output*/ false);
}
}
- void captureOutputs(DifferentiableGraphBackward & grad_fn, at::ArrayRef<IValue> outputs) const {
+ void captureOutputs(
+ DifferentiableGraphBackward& grad_fn,
+ at::ArrayRef<IValue> outputs) const {
for (size_t offset : grad.df_input_captured_outputs) {
- grad_fn.capture(outputs[offset], /*is_output*/true);
+ grad_fn.capture(outputs[offset], /*is_output*/ true);
}
}
const size_t num_outputs;
};
-void packGradient(Gradient gradient, Node *dnode) {
+void packGradient(Gradient gradient, Node* dnode) {
JIT_ASSERT(dnode->kind() == prim::DifferentiableGraph);
dnode->g_(attr::Subgraph, gradient.f)
- ->g_(attr::ReverseSubgraph, gradient.df)
- ->i_(attr::f_real_outputs, gradient.f_real_outputs)
- ->is_(attr::df_input_vjps, fmap<int64_t>(gradient.df_input_vjps))
- ->is_(attr::df_input_captured_inputs, fmap<int64_t>(gradient.df_input_captured_inputs))
- ->is_(attr::df_input_captured_outputs, fmap<int64_t>(gradient.df_input_captured_outputs))
- ->is_(attr::df_output_vjps, fmap<int64_t>(gradient.df_output_vjps));
+ ->g_(attr::ReverseSubgraph, gradient.df)
+ ->i_(attr::f_real_outputs, gradient.f_real_outputs)
+ ->is_(attr::df_input_vjps, fmap<int64_t>(gradient.df_input_vjps))
+ ->is_(
+ attr::df_input_captured_inputs,
+ fmap<int64_t>(gradient.df_input_captured_inputs))
+ ->is_(
+ attr::df_input_captured_outputs,
+ fmap<int64_t>(gradient.df_input_captured_outputs))
+ ->is_(attr::df_output_vjps, fmap<int64_t>(gradient.df_output_vjps));
}
-Gradient getGradient(const Node *n) {
+Gradient getGradient(const Node* n) {
JIT_ASSERT(n->kind() == prim::DifferentiableGraph);
Gradient grad;
grad.f = n->g(attr::Subgraph);
grad.df = n->g(attr::ReverseSubgraph);
grad.f_real_outputs = n->i(attr::f_real_outputs);
grad.df_input_vjps = fmap<size_t>(n->is(attr::df_input_vjps));
- grad.df_input_captured_inputs = fmap<size_t>(n->is(attr::df_input_captured_inputs));
- grad.df_input_captured_outputs = fmap<size_t>(n->is(attr::df_input_captured_outputs));
+ grad.df_input_captured_inputs =
+ fmap<size_t>(n->is(attr::df_input_captured_inputs));
+ grad.df_input_captured_outputs =
+ fmap<size_t>(n->is(attr::df_input_captured_outputs));
grad.df_output_vjps = fmap<size_t>(n->is(attr::df_output_vjps));
return grad;
}
} // anonymous namespace
-RegisterOperators reg_graph_executor_ops({
- Operator(
- prim::DifferentiableGraph,
- [](const Node *n) -> Operation {
+RegisterOperators reg_graph_executor_ops(
+ {Operator(prim::DifferentiableGraph, [](const Node* n) -> Operation {
return DifferentiableGraphOp(getGradient(n));
- })
-});
+ })});
namespace detail {
// a Graph can be created via tracing, or via a language-based frontend
// GraphExecutor runs it. It can run the same graph on many different sizes
-// and different requires_grad states, and handles specializations for each situation.
-// GraphExecutor is completely unaware of tracing or module parameters to keep the
-// tracing concerns separated.
+// and different requires_grad states, and handles specializations for each
+// situation. GraphExecutor is completely unaware of tracing or module
+// parameters to keep the tracing concerns separated.
struct GraphExecutorImpl {
-
static std::shared_ptr<Graph> prepareGraph(std::shared_ptr<Graph>& graph) {
auto copy = graph->copy();
EraseShapeInformation(copy);
}
if (auto tuple_type = ptr->cast<TupleType>()) {
size_t total = 0;
- for (auto & elem : tuple_type->elements()) {
+ for (auto& elem : tuple_type->elements()) {
total += countFlatInputs(elem);
}
return total;
static size_t countFlatInputs(const std::shared_ptr<Graph>& graph) {
size_t total = 0;
- for (Value * input : graph->inputs()) {
+ for (Value* input : graph->inputs()) {
total += countFlatInputs(input->type());
}
return total;
}
inline bool hasMutableOperators(Block* block) {
- for(auto n : block->nodes()) {
- if(n->kind().is_aten() && n->schema().is_mutable())
+ for (auto n : block->nodes()) {
+ if (n->kind().is_aten() && n->schema().is_mutable())
return true;
- for(auto b : n->blocks()) {
- if(hasMutableOperators(b))
+ for (auto b : n->blocks()) {
+ if (hasMutableOperators(b))
return true;
}
}
num_outputs(this->graph->outputs().size()) {}
// entry point where execution begins
- void run(Stack & stack) {
- AT_CHECK(stack.size() >= num_inputs, "expected ", num_inputs, " inputs, but got only ", stack.size());
-
- if(tracer::isTracing()) {
+ void run(Stack& stack) {
+ AT_CHECK(
+ stack.size() >= num_inputs,
+ "expected ",
+ num_inputs,
+ " inputs, but got only ",
+ stack.size());
+
+ if (tracer::isTracing()) {
return runTraced(stack);
}
- auto & execution_plan = optimize ? getOrCompile(stack) : getOrCompileFallback();
+ auto& execution_plan =
+ optimize ? getOrCompile(stack) : getOrCompileFallback();
return execution_plan.run(stack);
}
std::shared_ptr<Graph> graphFor(const Stack& stack) const {
JIT_ASSERT(stack.size() >= num_inputs);
auto inputs = last(stack, num_inputs);
- ArgumentSpec spec(autograd::GradMode::is_enabled(), inputs, num_flat_inputs);
+ ArgumentSpec spec(
+ autograd::GradMode::is_enabled(), inputs, num_flat_inputs);
if (!optimize) {
AT_CHECK(fallback, "No graph found for given inputs");
if (fallback) {
state.fallback = fallback.getDebugState();
}
- for (auto & entry : plan_cache) {
+ for (auto& entry : plan_cache) {
state.execution_plans.emplace(entry.first, entry.second.getDebugState());
}
return state;
autodiffSubgraphInlineThreshold = 1;
}
-private:
+ private:
friend struct GraphExecutor;
- const ExecutionPlan & getOrCompileFallback() {
+ const ExecutionPlan& getOrCompileFallback() {
std::lock_guard<std::mutex> lock(compile_mutex);
- if(!fallback) {
+ if (!fallback) {
auto graph_ = graph->copy();
runRequiredPasses(graph_);
fallback = ExecutionPlan(graph_);
return fallback;
}
- const ExecutionPlan & getOrCompile(const Stack& stack) {
- // outside lock guard, to minimize the time holding the lock on the fast path
- // ArgumentSpec even computes its hashCode here.
- ArgumentSpec spec(autograd::GradMode::is_enabled(), last(stack, num_inputs), num_flat_inputs);
+ const ExecutionPlan& getOrCompile(const Stack& stack) {
+ // outside lock guard, to minimize the time holding the lock on the fast
+ // path ArgumentSpec even computes its hashCode here.
+ ArgumentSpec spec(
+ autograd::GradMode::is_enabled(),
+ last(stack, num_inputs),
+ num_flat_inputs);
{
std::lock_guard<std::mutex> lock(compile_mutex);
auto it = plan_cache.find(spec);
}
}
- ExecutionPlan compileSpec(const ArgumentSpec & spec) {
+ ExecutionPlan compileSpec(const ArgumentSpec& spec) {
auto opt_graph = graph->copy();
setInputTypes(*opt_graph, spec);
// Phase 2. Propagate detailed information about the spec through the
// graph (enabled more specializations in later passes).
// Shape propagation sometimes depends on certain arguments being
- // constants, and constant propagation doesn't need shape information
- // anyway, so it's better to run it first.
+ // constants, and constant propagation doesn't need shape
+ // information anyway, so it's better to run it first.
ConstantPropagation(opt_graph);
PropagateInputShapes(opt_graph);
PropagateRequiresGrad(opt_graph);
- // Phase 3. Run differentiable optimizations (i.e. simple graph rewrites that
+ // Phase 3. Run differentiable optimizations (i.e. simple graph rewrites
+ // that
// we can still execute using autograd).
runOptimization(opt_graph, spec);
// Phase 5. Apply non-differentiable optimizations to the graphs we've found
// (or the whole grpah if we know we won't need its derivative).
if (needsGradient(opt_graph)) {
- auto diff_nodes = CreateAutodiffSubgraphs(opt_graph, autodiffSubgraphNodeThreshold);
- for (Node * dnode : diff_nodes) {
+ auto diff_nodes =
+ CreateAutodiffSubgraphs(opt_graph, autodiffSubgraphNodeThreshold);
+ for (Node* dnode : diff_nodes) {
auto diff_graph = std::move(dnode->g(attr::Subgraph));
Gradient gradient = differentiate(diff_graph);
runNondiffOptimization(gradient.f);
return ExecutionPlan(opt_graph);
}
- void runOptimization(std::shared_ptr<Graph>& graph, const ArgumentSpec& spec) {
+ void runOptimization(
+ std::shared_ptr<Graph>& graph,
+ const ArgumentSpec& spec) {
// Basic graph preprocessing to eliminate noise.
EliminateDeadCode(graph);
EliminateCommonSubexpression(graph);
return false;
if (mayIntroduceGradient(graph->block()))
return true;
- for (const Value * input : graph->inputs()) {
+ for (const Value* input : graph->inputs()) {
if (input->type()->requires_grad())
return true;
}
return false;
}
- void runTraced(Stack & stack) {
+ void runTraced(Stack& stack) {
const auto& state = tracer::getTracingState();
auto inputs = last(stack, num_inputs);
- auto input_values = fmap(inputs, [](const IValue & v) {
- return tracer::getNestedValueTrace(v);
- });
-
- ArgumentSpec spec(autograd::GradMode::is_enabled(), inputs, num_flat_inputs);
- // NB: we could just run the fallback in here and call it a day, but that would loose all
- // the control flow information we have in the graph. Thus, we run the fallback to
- // get the correct output values, but we will override the tracing states later.
+ auto input_values = fmap(
+ inputs, [](const IValue& v) { return tracer::getNestedValueTrace(v); });
+
+ ArgumentSpec spec(
+ autograd::GradMode::is_enabled(), inputs, num_flat_inputs);
+ // NB: we could just run the fallback in here and call it a day, but that
+ // would loose all the control flow information we have in the graph. Thus,
+ // we run the fallback to get the correct output values, but we will
+ // override the tracing states later.
{
// No need to trace a script module.
ResourceGuard guard(tracer::pauseTracing());
auto local_graph = this->graph->copy();
setInputTypes(*local_graph, spec);
PropagateInputShapes(local_graph);
- auto output_values = inlineCallTo(*state->graph, *local_graph, input_values);
+ auto output_values =
+ inlineCallTo(*state->graph, *local_graph, input_values);
auto outputs = last(stack, num_outputs);
for (size_t i = 0; i < outputs.size(); ++i) {
}
}
- // The unoptimized starting graph. This field is effectively const, but we can't make it so
- // because Graph::copy() is not const (and making it const is not that easy at this point).
+ // The unoptimized starting graph. This field is effectively const, but we
+ // can't make it so because Graph::copy() is not const (and making it const is
+ // not that easy at this point).
std::shared_ptr<Graph> graph;
- // If false, we'll run the graph as we get it, without any optimizations. Useful
- // for debugging.
+ // If false, we'll run the graph as we get it, without any optimizations.
+ // Useful for debugging.
const bool optimize;
const size_t num_inputs;
- const size_t num_flat_inputs; // Number of inputs, assuming all tuples would be flattened.
+ const size_t num_flat_inputs; // Number of inputs, assuming all tuples would
+ // be flattened.
const size_t num_outputs;
- // Populated only when optimize is false (and in that case plan_cache will be unused).
- // The compiled version of graph.
+ // Populated only when optimize is false (and in that case plan_cache will be
+ // unused). The compiled version of graph.
ExecutionPlan fallback;
- // Mapping from argument configurations to optimized versions of the graph that are
- // specialized to the spec.
+ // Mapping from argument configurations to optimized versions of the graph
+ // that are specialized to the spec.
std::unordered_map<ArgumentSpec, ExecutionPlan> plan_cache;
- // GraphExecutors can be accessed from multiple threads, so this thread needs to be
- // held every time we access the fallback or plan_cache.
+ // GraphExecutors can be accessed from multiple threads, so this thread needs
+ // to be held every time we access the fallback or plan_cache.
std::mutex compile_mutex;
// Some tunable parameters
};
GraphExecutor::GraphExecutor(std::shared_ptr<Graph> graph, bool optimize)
-: pImpl(new GraphExecutorImpl(std::move(graph), optimize)) {}
+ : pImpl(new GraphExecutorImpl(std::move(graph), optimize)) {}
-void GraphExecutor::run(Stack & inputs) {
+void GraphExecutor::run(Stack& inputs) {
return pImpl->run(inputs);
}
return pImpl->debugDisableAutodiffSubgraphInlining();
}
-
-void runRequiredPasses(const std::shared_ptr<Graph>& g) {
+void runRequiredPasses(const std::shared_ptr<Graph>& g) {
specializeUndef(*g);
LowerGradOf(*g);
// implicit inserted expand nodes are not necessarily always valid
EliminateDeadCode(g);
}
-}}
+} // namespace jit
+} // namespace torch
#pragma once
-#include <memory>
+#include <torch/csrc/jit/argument_spec.h>
+#include <torch/csrc/jit/autodiff.h>
+#include <torch/csrc/jit/interpreter.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/variable_tensor_list.h>
-#include <torch/csrc/jit/interpreter.h>
-#include <torch/csrc/jit/autodiff.h>
-#include <torch/csrc/jit/argument_spec.h>
+#include <memory>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
struct GraphExecutorState;
struct TORCH_API GraphExecutor {
GraphExecutor() = default;
GraphExecutor(std::shared_ptr<Graph> graph, bool optimize = true);
- void run(Stack & inputs);
+ void run(Stack& inputs);
explicit operator bool() const {
return pImpl != nullptr;
}
std::shared_ptr<Graph> graphFor(const Stack& inputs) const;
GraphExecutorState getDebugState();
void debugDisableAutodiffSubgraphInlining();
-private:
+
+ private:
std::shared_ptr<GraphExecutorImpl> pImpl;
};
} // namespace detail
-
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/assertions.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
// Intrusive doubly linked lists with sane reverse iterators.
// The header file is named generic_graph_node_list.h because it is ONLY
using graph_node_list = generic_graph_node_list<Node>;
using const_graph_node_list = generic_graph_node_list<const Node>;
using graph_node_list_iterator = generic_graph_node_list_iterator<Node>;
-using const_graph_node_list_iterator = generic_graph_node_list_iterator<const Node>;
+using const_graph_node_list_iterator =
+ generic_graph_node_list_iterator<const Node>;
template <typename T>
struct generic_graph_node_list_iterator {
- generic_graph_node_list_iterator()
- : cur(nullptr), d(kNextDirection) {}
- generic_graph_node_list_iterator(T * cur, int d)
- : cur(cur), d(d) {}
- generic_graph_node_list_iterator(const generic_graph_node_list_iterator & rhs) = default;
- generic_graph_node_list_iterator(generic_graph_node_list_iterator && rhs) = default;
- generic_graph_node_list_iterator& operator=(const generic_graph_node_list_iterator & rhs) = default;
- generic_graph_node_list_iterator& operator=(generic_graph_node_list_iterator && rhs) = default;
- T * operator*() const { return cur; }
- T * operator->() const { return cur; }
- generic_graph_node_list_iterator & operator++() {
+ generic_graph_node_list_iterator() : cur(nullptr), d(kNextDirection) {}
+ generic_graph_node_list_iterator(T* cur, int d) : cur(cur), d(d) {}
+ generic_graph_node_list_iterator(
+ const generic_graph_node_list_iterator& rhs) = default;
+ generic_graph_node_list_iterator(generic_graph_node_list_iterator&& rhs) =
+ default;
+ generic_graph_node_list_iterator& operator=(
+ const generic_graph_node_list_iterator& rhs) = default;
+ generic_graph_node_list_iterator& operator=(
+ generic_graph_node_list_iterator&& rhs) = default;
+ T* operator*() const {
+ return cur;
+ }
+ T* operator->() const {
+ return cur;
+ }
+ generic_graph_node_list_iterator& operator++() {
JIT_ASSERT(cur);
cur = cur->next_in_graph[d];
return *this;
++(*this);
return old;
}
- generic_graph_node_list_iterator & operator--() {
+ generic_graph_node_list_iterator& operator--() {
JIT_ASSERT(cur);
cur = cur->next_in_graph[reverseDir()];
return *this;
// silently cause the wrong one to be called.
// iterator will point to the previous entry after call
void destroyCurrent() {
- T * n = cur;
+ T* n = cur;
cur = cur->next_in_graph[reverseDir()];
n->destroy();
}
generic_graph_node_list_iterator reverse() {
return generic_graph_node_list_iterator(cur, reverseDir());
}
-private:
+
+ private:
int reverseDir() {
return d == kNextDirection ? kPrevDirection : kNextDirection;
}
- T * cur;
- int d; //direction 0 is forward 1 is reverse, see next_in_graph
+ T* cur;
+ int d; // direction 0 is forward 1 is reverse, see next_in_graph
};
template <typename T>
return generic_graph_node_list_iterator<const T>(head->next_in_graph[d], d);
}
generic_graph_node_list_iterator<T> end() {
- return generic_graph_node_list_iterator<T>(head,d);
+ return generic_graph_node_list_iterator<T>(head, d);
}
generic_graph_node_list_iterator<const T> end() const {
- return generic_graph_node_list_iterator<const T>(head,d);
+ return generic_graph_node_list_iterator<const T>(head, d);
}
generic_graph_node_list_iterator<T> rbegin() {
return reverse().begin();
return reverse().end();
}
generic_graph_node_list reverse() {
- return generic_graph_node_list(head, d == kNextDirection ? kPrevDirection : kNextDirection);
+ return generic_graph_node_list(
+ head, d == kNextDirection ? kPrevDirection : kNextDirection);
}
const generic_graph_node_list reverse() const {
- return generic_graph_node_list(head, d == kNextDirection ? kPrevDirection : kNextDirection);
- }
- T* front() { return head->next_in_graph[d]; }
- const T* front() const { return head->next_in_graph[d]; }
- T* back() { return head->next_in_graph[!d]; }
- const T* back() const { return head->next_in_graph[!d]; }
- generic_graph_node_list(T * head, int d)
- : head(head), d(d) {}
-private:
- T * head;
+ return generic_graph_node_list(
+ head, d == kNextDirection ? kPrevDirection : kNextDirection);
+ }
+ T* front() {
+ return head->next_in_graph[d];
+ }
+ const T* front() const {
+ return head->next_in_graph[d];
+ }
+ T* back() {
+ return head->next_in_graph[!d];
+ }
+ const T* back() const {
+ return head->next_in_graph[!d];
+ }
+ generic_graph_node_list(T* head, int d) : head(head), d(d) {}
+
+ private:
+ T* head;
int d;
};
template <typename T>
-static inline bool operator==(generic_graph_node_list_iterator<T> a, generic_graph_node_list_iterator<T> b) {
+static inline bool operator==(
+ generic_graph_node_list_iterator<T> a,
+ generic_graph_node_list_iterator<T> b) {
return *a == *b;
}
template <typename T>
-static inline bool operator!=(generic_graph_node_list_iterator<T> a, generic_graph_node_list_iterator<T> b) {
+static inline bool operator!=(
+ generic_graph_node_list_iterator<T> a,
+ generic_graph_node_list_iterator<T> b) {
return *a != *b;
}
-}}
+} // namespace jit
+} // namespace torch
namespace std {
-template<typename T>
+template <typename T>
struct iterator_traits<torch::jit::generic_graph_node_list_iterator<T>> {
using difference_type = int64_t;
using value_type = T*;
using iterator_category = bidirectional_iterator_tag;
};
-}
+} // namespace std
namespace torch {
namespace jit {
-static std::function<void(std::shared_ptr<script::Module> module)> emit_module_callback;
+static std::function<void(std::shared_ptr<script::Module> module)>
+ emit_module_callback;
TORCH_API void didFinishEmitModule(std::shared_ptr<script::Module> module) {
- if(emit_module_callback) {
+ if (emit_module_callback) {
emit_module_callback(std::move(module));
}
}
-TORCH_API void setEmitModuleHook(std::function<void(std::shared_ptr<script::Module> module)> cb) {
+TORCH_API void setEmitModuleHook(
+ std::function<void(std::shared_ptr<script::Module> module)> cb) {
emit_module_callback = std::move(cb);
}
} // namespace jit
#pragma once
-#include <functional>
#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <functional>
#include <memory>
namespace torch {
struct Module;
}
TORCH_API void didFinishEmitModule(std::shared_ptr<script::Module> module);
-TORCH_API void setEmitModuleHook(std::function<void(std::shared_ptr<script::Module> module)> cb);
+TORCH_API void setEmitModuleHook(
+ std::function<void(std::shared_ptr<script::Module> module)> cb);
} // namespace jit
} // namespace torch
#include <google/protobuf/util/json_util.h>
#include <google/protobuf/util/type_resolver_util.h>
+#include <torch/csrc/jit/assertions.h>
#include <torch/csrc/jit/import.h>
+#include <torch/csrc/jit/import_method.h>
#include <torch/csrc/jit/ir.h>
-#include <torch/csrc/utils/functional.h>
-#include <torch/csrc/jit/assertions.h>
#include <torch/csrc/jit/operator.h>
-#include <torch/csrc/jit/import_method.h>
-
+#include <torch/csrc/utils/functional.h>
#include <caffe2/core/types.h>
#include <caffe2/proto/caffe2_pb.h>
#include <ATen/ATen.h>
+#include <fstream>
+#include <string>
#include <unordered_map>
#include <vector>
-#include <string>
-#include <fstream>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
namespace {
ScriptModuleDeserializer(std::istream* is);
- void deserialize(ModuleLookup module_lookup,
+ void deserialize(
+ ModuleLookup module_lookup,
c10::optional<at::Device> device);
-private:
- at::Tensor loadTensor(
- const torch::TensorDef& tensor_proto,
- std::unordered_map<std::string, at::Storage>& storageMap);
+ private:
+ at::Tensor loadTensor(
+ const torch::TensorDef& tensor_proto,
+ std::unordered_map<std::string, at::Storage>& storageMap);
- void convertModule(const torch::ModuleDef& module_def);
+ void convertModule(const torch::ModuleDef& module_def);
- void loadTensorTable(torch::ModelDef* model_def);
+ void loadTensorTable(torch::ModelDef* model_def);
- PyTorchStreamReader reader_;
- // this is a hack to make sure the script module created in C++ is the
- // same as created in Python
- ModuleLookup moduleLookup_;
- c10::optional<at::Device> device_;
- std::vector<std::string> moduleStack_;
+ PyTorchStreamReader reader_;
+ // this is a hack to make sure the script module created in C++ is the
+ // same as created in Python
+ ModuleLookup moduleLookup_;
+ c10::optional<at::Device> device_;
+ std::vector<std::string> moduleStack_;
- std::vector<at::Tensor> tensor_table_;
+ std::vector<at::Tensor> tensor_table_;
};
ScriptModuleDeserializer::ScriptModuleDeserializer(const std::string& filename)
ScriptModuleDeserializer::ScriptModuleDeserializer(std::istream* is)
: reader_(is) {}
-void ScriptModuleDeserializer::deserialize(ModuleLookup module_lookup,
+void ScriptModuleDeserializer::deserialize(
+ ModuleLookup module_lookup,
c10::optional<at::Device> device) {
torch::ModelDef model_def;
at::DataPtr data_ptr;
void ScriptModuleDeserializer::loadTensorTable(torch::ModelDef* model_def) {
std::unordered_map<std::string, at::Storage> storageMap;
- for(const torch::TensorDef& tensor : model_def->tensors()) {
+ for (const torch::TensorDef& tensor : model_def->tensors()) {
tensor_table_.emplace_back(loadTensor(tensor, storageMap));
}
}
-at::Tensor ScriptModuleDeserializer::loadTensor(const torch::TensorDef& tensor_proto,
- std::unordered_map<std::string, at::Storage>& storageMap) {
- std::vector<int64_t> dims(tensor_proto.dims().begin(), tensor_proto.dims().end());
- std::vector<int64_t> strides(tensor_proto.strides().begin(), tensor_proto.strides().end());
+at::Tensor ScriptModuleDeserializer::loadTensor(
+ const torch::TensorDef& tensor_proto,
+ std::unordered_map<std::string, at::Storage>& storageMap) {
+ std::vector<int64_t> dims(
+ tensor_proto.dims().begin(), tensor_proto.dims().end());
+ std::vector<int64_t> strides(
+ tensor_proto.strides().begin(), tensor_proto.strides().end());
auto type = at::typeMetaToScalarType(
caffe2::DataTypeToTypeMeta(tensor_proto.data_type()));
const std::string& record_key = tensor_proto.data().key();
record_size / at::CPU(type).typeMeta().itemsize(),
nullptr); // NB: we didn't set any allocator for the tensor
if (device.type() == at::DeviceType::CPU) {
- storage_it = storageMap.insert(std::make_pair(
- record_key, cpu_storage)).first;
+ storage_it =
+ storageMap.insert(std::make_pair(record_key, cpu_storage)).first;
} else if (device.type() == at::DeviceType::CUDA) {
- at::Tensor cpu_tensor = at::empty({0}, at::CPU(type).options()).set_(
- cpu_storage, tensor_proto.offset(), dims, strides);
- at::Storage cuda_storage = cpu_tensor.to(device,
- cpu_tensor.scalar_type()).storage();
- storage_it = storageMap.insert(std::make_pair(
- record_key, cuda_storage)).first;
+ at::Tensor cpu_tensor =
+ at::empty({0}, at::CPU(type).options())
+ .set_(cpu_storage, tensor_proto.offset(), dims, strides);
+ at::Storage cuda_storage =
+ cpu_tensor.to(device, cpu_tensor.scalar_type()).storage();
+ storage_it =
+ storageMap.insert(std::make_pair(record_key, cuda_storage)).first;
} else {
- AT_ERROR("supported devices include CPU and CUDA, however got ",
+ AT_ERROR(
+ "supported devices include CPU and CUDA, however got ",
at::DeviceTypeName(device.type(), false));
}
}
storage_it->second.device().index() != device.index())) {
std::stringstream oss;
oss << "storage previously was specified with device "
- << storage_it->second.device()
- << "but now is specified with device "
- << device << std::endl;
+ << storage_it->second.device() << "but now is specified with device "
+ << device << std::endl;
AT_ERROR(oss.str());
}
at::Tensor result;
if (device.type() == at::DeviceType::CPU) {
- result = at::empty({0}, at::CPU(type).options()).set_(
- storage_it->second, tensor_proto.offset(), dims, strides);
+ result =
+ at::empty({0}, at::CPU(type).options())
+ .set_(storage_it->second, tensor_proto.offset(), dims, strides);
} else if (device.type() == at::DeviceType::CUDA) {
- result = at::empty({0}, at::CUDA(type).options()).set_(
- storage_it->second, tensor_proto.offset(), dims, strides);
+ result =
+ at::empty({0}, at::CUDA(type).options())
+ .set_(storage_it->second, tensor_proto.offset(), dims, strides);
}
AT_ASSERT(result.defined());
for (int i = 0; i < module_def.parameters_size(); ++i) {
const torch::ParameterDef& param_def = module_def.parameters(i);
at::Tensor tensor = tensor_table_.at(param_def.tensor_id());
- module->register_parameter(
- param_def.name(), tensor, param_def.is_buffer());
+ module->register_parameter(param_def.name(), tensor, param_def.is_buffer());
}
if (module_def.has_torchscript_arena()) {
at::DataPtr data;
size_t size;
- std::tie(data, size) = reader_.getRecord(module_def.torchscript_arena().key());
+ std::tie(data, size) =
+ reader_.getRecord(module_def.torchscript_arena().key());
std::string data_str(static_cast<const char*>(data.get()), size);
import_methods(module, data_str, tensor_table_);
}
}
-} // namespace
+} // namespace
void import_ir_module(
ModuleLookup module_lookup,
deserializer.deserialize(module_lookup, device);
}
-std::shared_ptr<script::Module> load(std::istream& in,
+std::shared_ptr<script::Module> load(
+ std::istream& in,
c10::optional<at::Device> device) {
auto module = std::make_shared<script::Module>();
return module;
}
-std::shared_ptr<script::Module> load(const std::string& filename,
+std::shared_ptr<script::Module> load(
+ const std::string& filename,
c10::optional<at::Device> device) {
std::ifstream in(filename, std::ios_base::binary);
- AT_CHECK(! in.fail(), "load: could not open file ", filename);
+ AT_CHECK(!in.fail(), "load: could not open file ", filename);
auto module = load(in, device);
return module;
}
-}}
+} // namespace jit
+} // namespace torch
///
/// The istream must contain a serialized `script::Module`, exported via
/// `torch::jit::ExportModule` in C++.
-TORCH_API std::shared_ptr<script::Module> load(std::istream& in,
+TORCH_API std::shared_ptr<script::Module> load(
+ std::istream& in,
c10::optional<c10::Device> device = c10::nullopt);
/// Loads a serialized `script::Module` from the given `filename`.
/// The file stored at the location given in `filename` must contain a
/// serialized `script::Module`, exported either via `ScriptModule.save()` in
/// Python or `torch::jit::ExportModule` in C++.
-TORCH_API std::shared_ptr<script::Module> load(const std::string& filename,
+TORCH_API std::shared_ptr<script::Module> load(
+ const std::string& filename,
c10::optional<c10::Device> device = c10::nullopt);
} // namespace jit
#include <torch/csrc/jit/import_method.h>
#include <torch/csrc/jit/script/parser.h>
-namespace torch { namespace jit {
-
+namespace torch {
+namespace jit {
// this is a much simpler accessor that only handles modules, parameters, and
// and methods. It does not depend on python to work.
struct ModuleAccessorValue : public script::SugaredValue {
ModuleAccessorValue(std::shared_ptr<script::Module> module)
- : module(std::move(module)) {}
+ : module(std::move(module)) {}
std::string kind() const override {
return "module";
}
// select an attribute on it, e.g. `this.field`
- std::shared_ptr<SugaredValue> attr(const SourceRange& loc, script::Method & m, const std::string& field) override {
- if(script::NamedModule* v = module->find_module(field)) {
+ std::shared_ptr<SugaredValue> attr(
+ const SourceRange& loc,
+ script::Method& m,
+ const std::string& field) override {
+ if (script::NamedModule* v = module->find_module(field)) {
return std::make_shared<ModuleAccessorValue>(v->module);
- } else if(script::NamedParameter* v = module->find_parameter(field)) {
- return std::make_shared<script::SimpleValue>(m.get_or_add_parameter(v->slot()));
- } else if(script::Method* m = module->find_method(field)) {
+ } else if (script::NamedParameter* v = module->find_parameter(field)) {
+ return std::make_shared<script::SimpleValue>(
+ m.get_or_add_parameter(v->slot()));
+ } else if (script::Method* m = module->find_method(field)) {
return std::make_shared<script::MethodValue>(module, *m);
} else {
throw script::ErrorReport(loc) << "unknown attr: " << field;
}
}
-private:
+
+ private:
std::shared_ptr<script::Module> module;
};
struct OpsValue : public script::SugaredValue {
- OpsValue(size_t version)
- : version_(version) {}
+ OpsValue(size_t version) : version_(version) {}
std::string kind() const override {
return "ops";
}
- std::shared_ptr<SugaredValue> attr(const SourceRange& loc, script::Method & m, const std::string& field) override {
+ std::shared_ptr<SugaredValue> attr(
+ const SourceRange& loc,
+ script::Method& m,
+ const std::string& field) override {
return std::make_shared<script::BuiltinModule>(field, version_);
}
size_t version_;
};
struct ConstantValue : public script::SugaredValue {
- ConstantValue(IValue value)
- : value_(std::move(value)) {}
+ ConstantValue(IValue value) : value_(std::move(value)) {}
IValue value_;
- std::string kind() const override { return "constant"; }
- Value * asValue(const SourceRange& loc, script::Method & m) override {
+ std::string kind() const override {
+ return "constant";
+ }
+ Value* asValue(const SourceRange& loc, script::Method& m) override {
return m.graph()->insertConstant(value_);
}
};
// in the 'constants' vector. This table is will be stored in a container format
// and given to the import_method when restoring the code.
struct ConstantTableValue : public script::SugaredValue {
- ConstantTableValue(ArrayRef<at::Tensor> constants)
- : constants_(constants) {}
+ ConstantTableValue(ArrayRef<at::Tensor> constants) : constants_(constants) {}
std::string kind() const override {
return "CONSTANTS";
}
// select an attribute on it, e.g. `this.field`
- std::shared_ptr<SugaredValue> attr(const SourceRange& loc, script::Method & m, const std::string& field) override {
+ std::shared_ptr<SugaredValue> attr(
+ const SourceRange& loc,
+ script::Method& m,
+ const std::string& field) override {
const char* field_s = field.c_str();
char* end;
int64_t offset = std::strtoll(field_s + 1, &end, 10);
- if(field.size() < 2 || *end != 0)
+ if (field.size() < 2 || *end != 0)
throw script::ErrorReport(loc) << "invalid constant specifier: " << field;
if (offset < 0 || size_t(offset) >= constants_.size()) {
throw script::ErrorReport(loc) << "constant index " << offset
}
private:
- ArrayRef<at::Tensor> constants_;
+ ArrayRef<at::Tensor> constants_;
};
static size_t parseVersionNumber(script::Lexer& L) {
L.expect(script::TK_NEWLINE);
auto version = script::Const::create(L.cur().range, version_text);
if (name != "op_version_set")
- throw script::ErrorReport(range) << "expected an assignment to op_version_set";
+ throw script::ErrorReport(range)
+ << "expected an assignment to op_version_set";
if (!version.isIntegral())
- throw script::ErrorReport(range) << "expected an integral version but found " << version.text();
- return size_t(version.asIntegral());
+ throw script::ErrorReport(range)
+ << "expected an integral version but found " << version.text();
+ return size_t(version.asIntegral());
}
-void import_methods(const std::shared_ptr<script::Module>& mod, const std::string& src, const std::vector<at::Tensor>& constant_table) {
+void import_methods(
+ const std::shared_ptr<script::Module>& mod,
+ const std::string& src,
+ const std::vector<at::Tensor>& constant_table) {
script::Parser p(src);
size_t version = parseVersionNumber(p.lexer());
std::unordered_map<std::string, std::shared_ptr<script::SugaredValue>> env = {
- {"torch", std::make_shared<script::BuiltinModule>("aten", version)},
- {"ops", std::make_shared<OpsValue>(version)},
- {"CONSTANTS", std::make_shared<ConstantTableValue>(constant_table)},
- {"fork", std::make_shared<script::ForkValue>()},
- {"annotate", std::make_shared<script::AnnotateValue>()},
- {"inf", std::make_shared<ConstantValue>(std::numeric_limits<double>::infinity())},
- {"nan", std::make_shared<ConstantValue>(std::numeric_limits<double>::quiet_NaN())},
+ {"torch", std::make_shared<script::BuiltinModule>("aten", version)},
+ {"ops", std::make_shared<OpsValue>(version)},
+ {"CONSTANTS", std::make_shared<ConstantTableValue>(constant_table)},
+ {"fork", std::make_shared<script::ForkValue>()},
+ {"annotate", std::make_shared<script::AnnotateValue>()},
+ {"inf",
+ std::make_shared<ConstantValue>(
+ std::numeric_limits<double>::infinity())},
+ {"nan",
+ std::make_shared<ConstantValue>(
+ std::numeric_limits<double>::quiet_NaN())},
};
- auto resolver = [&](const std::string& name, script::Method& m, const SourceRange& loc)
- -> std::shared_ptr<script::SugaredValue> {
+ auto resolver =
+ [&](const std::string& name,
+ script::Method& m,
+ const SourceRange& loc) -> std::shared_ptr<script::SugaredValue> {
auto it = env.find(name);
if (it == env.end())
return nullptr;
script::defineMethodsInModule(mod, definitions, resolvers, self);
}
-}}
+} // namespace jit
+} // namespace torch
#pragma once
#include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/script/module.h>
#include <torch/csrc/jit/script/compiler.h>
+#include <torch/csrc/jit/script/module.h>
namespace torch {
namespace jit {
-TORCH_API void import_methods(const std::shared_ptr<script::Module>& mod, const std::string& src, const std::vector<at::Tensor>& constant_table);
+TORCH_API void import_methods(
+ const std::shared_ptr<script::Module>& mod,
+ const std::string& src,
+ const std::vector<at::Tensor>& constant_table);
} // namespace jit
} // namespace torch
-#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/auto_gil.h>
+#include <torch/csrc/utils/pybind.h>
-#include <torch/csrc/jit/python_tracer.h>
-#include <torch/csrc/jit/tracer.h>
-#include <torch/csrc/jit/python_ir.h>
-#include <torch/csrc/jit/python_arg_flatten.h>
-#include <torch/csrc/jit/export.h>
-#include <torch/csrc/jit/import.h>
#include <torch/csrc/jit/argument_spec.h>
+#include <torch/csrc/jit/batched/BatchTensor.h>
+#include <torch/csrc/jit/export.h>
+#include <torch/csrc/jit/function_schema.h>
+#include <torch/csrc/jit/fuser/interface.h>
#include <torch/csrc/jit/fuser/kernel_cache.h>
-#include <torch/csrc/jit/passes/remove_expands.h>
-#include <torch/csrc/jit/passes/graph_fuser.h>
-#include <torch/csrc/jit/passes/onnx.h>
-#include <torch/csrc/jit/passes/dead_code_elimination.h>
-#include <torch/csrc/jit/passes/erase_number_types.h>
-#include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h>
-#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
-#include <torch/csrc/jit/passes/constant_pooling.h>
-#include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
-#include <torch/csrc/jit/passes/peephole.h>
+#include <torch/csrc/jit/graph_executor.h>
+#include <torch/csrc/jit/import.h>
+#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/passes/canonicalize.h>
-#include <torch/csrc/jit/passes/onnx/peephole.h>
-#include <torch/csrc/jit/passes/onnx/fixup_onnx_loop.h>
-#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/passes/canonicalize_ops.h>
-#include <torch/csrc/jit/passes/remove_inplace_ops.h>
+#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
+#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
+#include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
+#include <torch/csrc/jit/passes/dead_code_elimination.h>
+#include <torch/csrc/jit/passes/erase_number_types.h>
+#include <torch/csrc/jit/passes/graph_fuser.h>
#include <torch/csrc/jit/passes/loop_unrolling.h>
-#include <torch/csrc/jit/passes/to_batch.h>
#include <torch/csrc/jit/passes/lower_tuples.h>
+#include <torch/csrc/jit/passes/onnx.h>
+#include <torch/csrc/jit/passes/onnx/fixup_onnx_loop.h>
+#include <torch/csrc/jit/passes/onnx/peephole.h>
+#include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h>
+#include <torch/csrc/jit/passes/peephole.h>
+#include <torch/csrc/jit/passes/remove_expands.h>
+#include <torch/csrc/jit/passes/remove_inplace_ops.h>
+#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/passes/specialize_undef.h>
+#include <torch/csrc/jit/passes/to_batch.h>
#include <torch/csrc/jit/passes/utils/check_alias_annotation.h>
-#include <torch/csrc/jit/graph_executor.h>
-#include <torch/csrc/jit/script/init.h>
-#include <torch/csrc/jit/script/python_tree_views.h>
-#include <torch/csrc/jit/batched/BatchTensor.h>
#include <torch/csrc/jit/pybind_utils.h>
-#include <torch/csrc/jit/function_schema.h>
-#include <torch/csrc/jit/operator.h>
-#include <torch/csrc/jit/fuser/interface.h>
-#include <torch/csrc/jit/script/jit_exception.h>
+#include <torch/csrc/jit/python_arg_flatten.h>
+#include <torch/csrc/jit/python_ir.h>
+#include <torch/csrc/jit/python_tracer.h>
+#include <torch/csrc/jit/script/init.h>
#include <torch/csrc/jit/script/jit_exception.h>
+#include <torch/csrc/jit/script/python_tree_views.h>
+#include <torch/csrc/jit/tracer.h>
#include <caffe2/serialize/inline_container.h>
#include <tuple>
#include <utility>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
// TODO: make a fake future for python
namespace detail {
-class Future {
-
-};
-}
+class Future {};
+} // namespace detail
namespace {
bool loadPythonClasses() {
// Leaving this code here, because it will likely be useful at some point
- //PyObject *jit_module = PyImport_ImportModule("torch.jit");
- //THPUtils_assert(jit_module, "class loader couldn't access "
- //"torch.jit module");
- //PyObject *jit_dict = PyModule_GetDict(jit_module);
+ // PyObject *jit_module = PyImport_ImportModule("torch.jit");
+ // THPUtils_assert(jit_module, "class loader couldn't access "
+ //"torch.jit module");
+ // PyObject *jit_dict = PyModule_GetDict(jit_module);
return true;
}
std::string runJITCPPTests();
#endif
-void initJITBindings(PyObject *module) {
+void initJITBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
py::register_exception<JITException>(m, "JITException");
- py::class_<python::IODescriptor>(m, "IODescriptor"); // NOLINT(bugprone-unused-raii)
+ py::class_<python::IODescriptor>(
+ m, "IODescriptor"); // NOLINT(bugprone-unused-raii)
m.def("_jit_init", loadPythonClasses)
#if USE_CUDA_FUSER || USE_CPU_FUSER
- .def("_jit_debug_fuser_num_cached_kernel_specs",
- torch::jit::fuser::debugNumCachedKernelSpecs)
+ .def(
+ "_jit_debug_fuser_num_cached_kernel_specs",
+ torch::jit::fuser::debugNumCachedKernelSpecs)
#endif
- .def("_jit_pass_onnx", ToONNX)
- .def("_jit_pass_lower_all_tuples", LowerAllTuples)
- .def("_jit_pass_onnx_peephole", PeepholeOptimizeONNX)
- .def("_jit_pass_fuse", FuseGraph)
- .def("_jit_pass_dce", [](std::shared_ptr<Graph>& g) {
- return EliminateDeadCode(g->block()); // overload resolution
- })
- .def("_jit_pass_cse", [](std::shared_ptr<Graph>& g) {
- return EliminateCommonSubexpression(g); // overload resolution
- })
- .def("_jit_pass_remove_inplace_ops", [](std::shared_ptr<Graph> g) {
- return RemoveInplaceOps(g);
- })
- .def("_jit_pass_constant_pooling", ConstantPooling)
- .def("_jit_pass_peephole", [](const std::shared_ptr<Graph>& g, bool addmm_fusion_enabled) {
- return PeepholeOptimize(g, addmm_fusion_enabled);
- }, py::arg("graph"), py::arg("addmm_fusion_enabled") = false)
- .def("_jit_pass_canonicalize", [](const std::shared_ptr<Graph>& g) {
- return Canonicalize(g);
- })
- .def("_jit_pass_lint", LintGraph)
- .def("_jit_pass_shape_analysis", [](std::shared_ptr<Graph> graph, std::vector<at::Tensor> inputs, bool with_grad) {
- setInputTypes(*graph, ArgumentSpec(with_grad, fmap<IValue>(inputs), inputs.size()));
- PropagateInputShapes(graph);
- })
- .def("_jit_pass_complete_shape_analysis", [](std::shared_ptr<Graph> graph, py::tuple inputs, bool with_grad) {
- CompleteArgumentSpec spec(with_grad, evilDeprecatedBadCreateStackDoNotUse(inputs, graph->inputs()));
- auto graph_inputs = graph->inputs();
- JIT_ASSERT(spec.size() == graph_inputs.size());
- for (size_t i = 0; i < graph_inputs.size(); ++i) {
- graph_inputs[i]->setType(spec.at(i));
- }
- PropagateInputShapes(graph);
- })
- .def("_jit_pass_remove_expands", RemoveExpands)
- .def("_jit_pass_erase_number_types", EraseNumberTypes)
- .def("_jit_pass_prepare_division_for_onnx", PrepareDivisionForONNX)
- .def("_jit_pass_loop_unrolling", UnrollLoops)
- .def("_jit_pass_constant_propagation", [](std::shared_ptr<Graph>& g) {
- return ConstantPropagation(g);
- })
- .def("_jit_pass_erase_shape_information", EraseShapeInformation)
- .def("_jit_pass_create_autodiff_subgraphs", [](std::shared_ptr<Graph> graph) {
- CreateAutodiffSubgraphs(graph);
- })
- .def("_jit_run_cpp_tests", [] {
- // We have to release the GIL inside this method, because if we happen to
- // initialize the autograd engine in these tests, the newly spawned worker threads will
- // try to initialize their PyThreadState*, and they need the GIL for this.
- AutoNoGIL _no_gil;
- return runJITCPPTests();
- })
- .def("_jit_flatten", [](py::handle& obj) {
- auto res = python::flatten(obj);
- return std::make_pair(res.vars, res.desc);
- })
- .def("_jit_unflatten", [](autograd::variable_list vars, python::IODescriptor& desc) {
- return py::reinterpret_steal<py::object>(python::unflatten(vars, desc));
- })
- .def("_jit_pass_onnx_block", BlockToONNX)
- .def("_jit_pass_fixup_onnx_loops", FixupONNXLoops)
- .def("_jit_pass_canonicalize_ops", CanonicalizeOps)
- .def("_jit_pass_specialize_undef", specializeUndef)
- .def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU)
- .def("_jit_differentiate", [](Graph &g) {
- // the python binding slightly differs in semantics
- // it makes a copy of the input Graph, and works on that
- // jit::differentiate mutates the input Graph
- auto g_clone = g.copy();
- return differentiate(g_clone);
- })
- .def("_jit_check_alias_annotation", [](
- std::shared_ptr<Graph> g,
- py::tuple args,
- const std::string& unqualified_op_name) {
- auto stack = toStack(args);
- checkAliasAnnotation(g, std::move(stack), unqualified_op_name);
- });
+ .def("_jit_pass_onnx", ToONNX)
+ .def("_jit_pass_lower_all_tuples", LowerAllTuples)
+ .def("_jit_pass_onnx_peephole", PeepholeOptimizeONNX)
+ .def("_jit_pass_fuse", FuseGraph)
+ .def(
+ "_jit_pass_dce",
+ [](std::shared_ptr<Graph>& g) {
+ return EliminateDeadCode(g->block()); // overload resolution
+ })
+ .def(
+ "_jit_pass_cse",
+ [](std::shared_ptr<Graph>& g) {
+ return EliminateCommonSubexpression(g); // overload resolution
+ })
+ .def(
+ "_jit_pass_remove_inplace_ops",
+ [](std::shared_ptr<Graph> g) { return RemoveInplaceOps(g); })
+ .def("_jit_pass_constant_pooling", ConstantPooling)
+ .def(
+ "_jit_pass_peephole",
+ [](const std::shared_ptr<Graph>& g, bool addmm_fusion_enabled) {
+ return PeepholeOptimize(g, addmm_fusion_enabled);
+ },
+ py::arg("graph"),
+ py::arg("addmm_fusion_enabled") = false)
+ .def(
+ "_jit_pass_canonicalize",
+ [](const std::shared_ptr<Graph>& g) { return Canonicalize(g); })
+ .def("_jit_pass_lint", LintGraph)
+ .def(
+ "_jit_pass_shape_analysis",
+ [](std::shared_ptr<Graph> graph,
+ std::vector<at::Tensor> inputs,
+ bool with_grad) {
+ setInputTypes(
+ *graph,
+ ArgumentSpec(with_grad, fmap<IValue>(inputs), inputs.size()));
+ PropagateInputShapes(graph);
+ })
+ .def(
+ "_jit_pass_complete_shape_analysis",
+ [](std::shared_ptr<Graph> graph, py::tuple inputs, bool with_grad) {
+ CompleteArgumentSpec spec(
+ with_grad,
+ evilDeprecatedBadCreateStackDoNotUse(inputs, graph->inputs()));
+ auto graph_inputs = graph->inputs();
+ JIT_ASSERT(spec.size() == graph_inputs.size());
+ for (size_t i = 0; i < graph_inputs.size(); ++i) {
+ graph_inputs[i]->setType(spec.at(i));
+ }
+ PropagateInputShapes(graph);
+ })
+ .def("_jit_pass_remove_expands", RemoveExpands)
+ .def("_jit_pass_erase_number_types", EraseNumberTypes)
+ .def("_jit_pass_prepare_division_for_onnx", PrepareDivisionForONNX)
+ .def("_jit_pass_loop_unrolling", UnrollLoops)
+ .def(
+ "_jit_pass_constant_propagation",
+ [](std::shared_ptr<Graph>& g) { return ConstantPropagation(g); })
+ .def("_jit_pass_erase_shape_information", EraseShapeInformation)
+ .def(
+ "_jit_pass_create_autodiff_subgraphs",
+ [](std::shared_ptr<Graph> graph) { CreateAutodiffSubgraphs(graph); })
+ .def(
+ "_jit_run_cpp_tests",
+ [] {
+ // We have to release the GIL inside this method, because if we
+ // happen to initialize the autograd engine in these tests, the
+ // newly spawned worker threads will try to initialize their
+ // PyThreadState*, and they need the GIL for this.
+ AutoNoGIL _no_gil;
+ return runJITCPPTests();
+ })
+ .def(
+ "_jit_flatten",
+ [](py::handle& obj) {
+ auto res = python::flatten(obj);
+ return std::make_pair(res.vars, res.desc);
+ })
+ .def(
+ "_jit_unflatten",
+ [](autograd::variable_list vars, python::IODescriptor& desc) {
+ return py::reinterpret_steal<py::object>(
+ python::unflatten(vars, desc));
+ })
+ .def("_jit_pass_onnx_block", BlockToONNX)
+ .def("_jit_pass_fixup_onnx_loops", FixupONNXLoops)
+ .def("_jit_pass_canonicalize_ops", CanonicalizeOps)
+ .def("_jit_pass_specialize_undef", specializeUndef)
+ .def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU)
+ .def(
+ "_jit_differentiate",
+ [](Graph& g) {
+ // the python binding slightly differs in semantics
+ // it makes a copy of the input Graph, and works on that
+ // jit::differentiate mutates the input Graph
+ auto g_clone = g.copy();
+ return differentiate(g_clone);
+ })
+ .def(
+ "_jit_check_alias_annotation",
+ [](std::shared_ptr<Graph> g,
+ py::tuple args,
+ const std::string& unqualified_op_name) {
+ auto stack = toStack(args);
+ checkAliasAnnotation(g, std::move(stack), unqualified_op_name);
+ });
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<CompleteArgumentSpec>(m, "CompleteArgumentSpec")
});
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<ArgumentSpec>(m, "ArgumentSpec");
- py::class_<Code>(m, "Code")
- .def("grad_executors", [](Code& c) {
- return py::make_iterator(c.grad_executors().begin(), c.grad_executors().end());
- });
+ py::class_<Code>(m, "Code").def("grad_executors", [](Code& c) {
+ return py::make_iterator(
+ c.grad_executors().begin(), c.grad_executors().end());
+ });
py::class_<ExecutionPlanState>(m, "ExecutionPlanState")
- .def_property_readonly("graph", [](ExecutionPlanState& s) {
- return s.graph;
- })
- .def_property_readonly("code", [](ExecutionPlanState& s) {
- return s.code;
- });
+ .def_property_readonly(
+ "graph", [](ExecutionPlanState& s) { return s.graph; })
+ .def_property_readonly(
+ "code", [](ExecutionPlanState& s) { return s.code; });
py::class_<Gradient>(m, "Gradient")
- .def_property_readonly("f", [](Gradient& m) {
- return m.f;
- })
- .def_property_readonly("df", [](Gradient& m) {
- return m.df;
- })
- .def_property_readonly("f_real_outputs", [](Gradient& m) {
- return m.f_real_outputs;
- })
- .def_property_readonly("df_input_vjps", [](Gradient& m) {
- return m.df_input_vjps;
- })
- .def_property_readonly("df_input_captured_inputs", [](Gradient& m) {
- return m.df_input_captured_inputs;
- })
- .def_property_readonly("df_input_captured_outputs", [](Gradient& m) {
- return m.df_input_captured_outputs;
- })
- .def_property_readonly("df_output_vjps", [](Gradient& m) {
- return m.df_output_vjps;
- });
+ .def_property_readonly("f", [](Gradient& m) { return m.f; })
+ .def_property_readonly("df", [](Gradient& m) { return m.df; })
+ .def_property_readonly(
+ "f_real_outputs", [](Gradient& m) { return m.f_real_outputs; })
+ .def_property_readonly(
+ "df_input_vjps", [](Gradient& m) { return m.df_input_vjps; })
+ .def_property_readonly(
+ "df_input_captured_inputs",
+ [](Gradient& m) { return m.df_input_captured_inputs; })
+ .def_property_readonly(
+ "df_input_captured_outputs",
+ [](Gradient& m) { return m.df_input_captured_outputs; })
+ .def_property_readonly(
+ "df_output_vjps", [](Gradient& m) { return m.df_output_vjps; });
py::class_<GraphExecutorState>(m, "GraphExecutorState")
- .def_property_readonly("graph", [](GraphExecutorState& s) {
- return s.graph;
- })
- .def_property_readonly("execution_plans", [](GraphExecutorState& s) {
- return s.execution_plans;
- })
- .def_property_readonly("fallback", [](GraphExecutorState& s) {
- return s.fallback;
- });
+ .def_property_readonly(
+ "graph", [](GraphExecutorState& s) { return s.graph; })
+ .def_property_readonly(
+ "execution_plans",
+ [](GraphExecutorState& s) { return s.execution_plans; })
+ .def_property_readonly(
+ "fallback", [](GraphExecutorState& s) { return s.fallback; });
py::class_<GraphExecutor>(m, "GraphExecutor", py::dynamic_attr())
.def(
"get_debug_state",
[](GraphExecutor& ge) { return ge.getDebugState(); })
.def("__call__", [](GraphExecutor& ge, py::args args) -> py::object {
- const auto & graph = ge.graph();
- auto stack = evilDeprecatedBadCreateStackDoNotUse(args, graph->inputs());
+ const auto& graph = ge.graph();
+ auto stack =
+ evilDeprecatedBadCreateStackDoNotUse(args, graph->inputs());
{
AutoNoGIL no_gil_guard;
ge.run(stack);
.def(py::init<std::string>())
.def(
"write_record",
- [](PyTorchStreamWriter& self, const std::string& name, const char* data, size_t size) {
- return self.writeRecord(name, data, size);
- })
+ [](PyTorchStreamWriter& self,
+ const std::string& name,
+ const char* data,
+ size_t size) { return self.writeRecord(name, data, size); })
.def("write_end_of_file", &PyTorchStreamWriter::writeEndOfFile);
py::class_<PyTorchStreamReader>(m, "PyTorchFileReader")
.def(py::init<std::string>())
- .def(
- "get_record",
- [](PyTorchStreamReader& self, const std::string& key) {
- at::DataPtr data;
- size_t size;
- std::tie(data, size) = self.getRecord(key);
- return py::bytes(reinterpret_cast<const char*>(data.get()), size);
- });
-
+ .def("get_record", [](PyTorchStreamReader& self, const std::string& key) {
+ at::DataPtr data;
+ size_t size;
+ std::tie(data, size) = self.getRecord(key);
+ return py::bytes(reinterpret_cast<const char*>(data.get()), size);
+ });
- m.def("_jit_get_operation", [](const std::string& qualified_name) {
- try {
- auto symbol = Symbol::fromQualString(qualified_name);
- auto operations = getAllOperatorsFor(symbol);
- AT_CHECK(!operations.empty(), "No such operator ", qualified_name);
- AT_CHECK(
- operations.size() == 1,
- "Found ", operations.size(), " overloads for operator ",
- qualified_name, "! Overloads are not supported from Python.");
- std::shared_ptr<Operator> op = operations[0];
- AT_ASSERT(op != nullptr);
- std::ostringstream docstring;
- docstring << "Automatically bound operator '" << qualified_name
- << "' with schema: " << op->schema();
- return py::cpp_function([op](py::args args, py::kwargs kwargs) {
- return invokeOperatorFromPython(
- *op, std::move(args), std::move(kwargs));
- }, py::name(qualified_name.c_str()), py::doc(docstring.str().c_str()));
- } catch (const c10::Error& error) {
- throw std::runtime_error(error.what_without_backtrace());
- }
- }, py::arg("qualified_name"));
+ m.def(
+ "_jit_get_operation",
+ [](const std::string& qualified_name) {
+ try {
+ auto symbol = Symbol::fromQualString(qualified_name);
+ auto operations = getAllOperatorsFor(symbol);
+ AT_CHECK(!operations.empty(), "No such operator ", qualified_name);
+ AT_CHECK(
+ operations.size() == 1,
+ "Found ",
+ operations.size(),
+ " overloads for operator ",
+ qualified_name,
+ "! Overloads are not supported from Python.");
+ std::shared_ptr<Operator> op = operations[0];
+ AT_ASSERT(op != nullptr);
+ std::ostringstream docstring;
+ docstring << "Automatically bound operator '" << qualified_name
+ << "' with schema: " << op->schema();
+ return py::cpp_function(
+ [op](py::args args, py::kwargs kwargs) {
+ return invokeOperatorFromPython(
+ *op, std::move(args), std::move(kwargs));
+ },
+ py::name(qualified_name.c_str()),
+ py::doc(docstring.str().c_str()));
+ } catch (const c10::Error& error) {
+ throw std::runtime_error(error.what_without_backtrace());
+ }
+ },
+ py::arg("qualified_name"));
py::class_<FunctionSchema>(m, "FunctionSchema")
- .def_property_readonly("name", [](FunctionSchema& self) { return self.name(); })
- .def_property_readonly("arguments", [](FunctionSchema& self) { return self.arguments(); })
- .def_property_readonly("returns", [](FunctionSchema& self) { return self.returns(); });
+ .def_property_readonly(
+ "name", [](FunctionSchema& self) { return self.name(); })
+ .def_property_readonly(
+ "arguments", [](FunctionSchema& self) { return self.arguments(); })
+ .def_property_readonly(
+ "returns", [](FunctionSchema& self) { return self.returns(); });
py::class_<Argument>(m, "Argument")
- .def_property_readonly("name", [](Argument& self) { return self.name(); })
- .def_property_readonly("type", [](Argument& self) { return self.type(); })
- .def_property_readonly("N", [](Argument& self) -> py::object {
- return (self.N()) ? py::cast(*self.N()) : py::none();
- })
- .def_property_readonly("default_value", [](Argument& self) -> py::object {
- if(!self.default_value())
- return py::none();
- IValue v = *self.default_value();
- return toPyObject(std::move(v));
- });
+ .def_property_readonly("name", [](Argument& self) { return self.name(); })
+ .def_property_readonly("type", [](Argument& self) { return self.type(); })
+ .def_property_readonly(
+ "N",
+ [](Argument& self) -> py::object {
+ return (self.N()) ? py::cast(*self.N()) : py::none();
+ })
+ .def_property_readonly("default_value", [](Argument& self) -> py::object {
+ if (!self.default_value())
+ return py::none();
+ IValue v = *self.default_value();
+ return toPyObject(std::move(v));
+ });
m.def("_jit_get_schemas_for_operator", [](const std::string& qualified_name) {
auto symbol = Symbol::fromQualString(qualified_name);
auto operations = getAllOperatorsFor(symbol);
return fmap(operations, [](const std::shared_ptr<Operator>& op) {
- return op->schema();
- });
+ return op->schema();
+ });
});
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<detail::Future>(m, "Future");
- m.def("fork", [](script::Module &sm, py::args args) {
+ m.def("fork", [](script::Module& sm, py::args args) {
// TODO: this is a fake stub
return detail::Future();
});
- m.def("wait", [](detail::Future &fut) {
+ m.def("wait", [](detail::Future& fut) {
// TODO: this is a fake stub
});
initRegisterBatchOpsBindings(module);
}
-}}
+} // namespace jit
+} // namespace torch
#pragma once
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
-void initJITBindings(PyObject *module);
+void initJITBindings(PyObject* module);
-}}
+}
+} // namespace torch
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/constants.h>
#include <torch/csrc/jit/graph_executor.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/ivalue.h>
-#include <torch/csrc/jit/constants.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/script/jit_exception.h>
#include <utility>
#include <vector>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
// Before we translate to intepreter instructions, we do
// some preprocessing of the graph to turn it into a form that is closer
// *. computes move_flags (see Outputs), and inserts
// * Drop nodes are inserted for any node that is unused to create a dummy use
// that will cause the interpreter to free the node.
-// A drop node is just a node with no outputs that just pops its inputs off the stack,
-// to ensure the interpreter release references to nodes that are never used.
-// Drop nodes are also inserted when the last use of a node is in some conditionally
-// run control flow (e.g. one side of an If) and the interpreter must free
-// the node only after the control flow has reconverged
+// A drop node is just a node with no outputs that just pops its inputs off
+// the stack, to ensure the interpreter release references to nodes that are
+// never used. Drop nodes are also inserted when the last use of a node is in
+// some conditionally run control flow (e.g. one side of an If) and the
+// interpreter must free the node only after the control flow has reconverged
// Outputs are:
// * graph - the post processed copy of g
// * move_flags[n] - a list of booleans, one for each input,
// Emit initial comparison -- initial_trip_count < max_trip_count
Value* initial_comparison_value =
g->insertNode(g->create(aten::lt, {cur_trip_count, max_trip_count}, 1))
- ->output()->setType(BoolType::get());
+ ->output()
+ ->setType(BoolType::get());
// Replace initial condition with logical `and` of trip count and
// initial condition
Value* new_cond =
g->insertNode(
g->create(aten::__and__, {initial_comparison_value, cond}, 1))
- ->output()->setType(BoolType::get());
+ ->output()
+ ->setType(BoolType::get());
return new_cond;
}
// this currently just _removes_ the trip count inputs and checks they are
// unused. In the future they will be desugared into normal arithmetic to
// provide a loop counter
-void desugarTripCounts(Block * b) {
- for(auto n : b->nodes()) {
-
- if(n->kind() == prim::Loop) {
+void desugarTripCounts(Block* b) {
+ for (auto n : b->nodes()) {
+ if (n->kind() == prim::Loop) {
auto g = n->owningGraph();
auto body_block = n->blocks()[0];
Value* const_one = g->insertConstant(1);
Value* inc_trip_count =
- g->insertNode(g->create(
- aten::add, {block_trip_count_input, const_one}, 1))
- ->output()->setType(IntType::get());
+ g->insertNode(
+ g->create(aten::add, {block_trip_count_input, const_one}, 1))
+ ->output()
+ ->setType(IntType::get());
body_block->insertOutput(1, inc_trip_count);
Value* body_cond = createTripCountConjunctiveCondition(
body_block->insertOutput(0, body_cond);
}
}
- for(auto sb : n->blocks()) {
+ for (auto sb : n->blocks()) {
desugarTripCounts(sb);
}
}
}
-// removes all inputs and outputs to a graph, replacing them with Load Store nodes
-static void flattenIO(Graph & graph) {
+// removes all inputs and outputs to a graph, replacing them with Load Store
+// nodes
+static void flattenIO(Graph& graph) {
auto load = graph.prependNode(graph.create(prim::Load, 0));
- for(auto old_input : graph.inputs()) {
+ for (auto old_input : graph.inputs()) {
auto nv = load->addOutput();
nv->setType(old_input->type());
old_input->replaceAllUsesWith(nv);
graph.eraseOutput(graph.outputs().size() - 1);
}
-
// insert Drop nodes to kill references for anything unused:
// this can happen in a few places, e.g. when a node returns
// many values but only one is used
// a, b = foo()
// return a
-void dropUnused(Block *b) {
+void dropUnused(Block* b) {
auto createDropIfUnused = [&](ArrayRef<Value*> values) -> Node* {
std::vector<Value*> to_drop;
- for(auto v : values) {
- if(v->uses().size() == 0)
+ for (auto v : values) {
+ if (v->uses().size() == 0)
to_drop.push_back(v);
}
- if(to_drop.size() == 0)
+ if (to_drop.size() == 0)
return nullptr;
return b->owningGraph()->create(prim::Drop, to_drop, 0);
};
- if(auto d = createDropIfUnused(b->inputs())) {
+ if (auto d = createDropIfUnused(b->inputs())) {
b->prependNode(d);
}
- for(auto n : b->nodes()) {
- if(auto d = createDropIfUnused(n->outputs())) {
+ for (auto n : b->nodes()) {
+ if (auto d = createDropIfUnused(n->outputs())) {
d->insertAfter(n);
}
- for(auto b : n->blocks())
+ for (auto b : n->blocks())
dropUnused(b);
}
}
-
// for each input, should we move rather than copy the inputs
-std::unordered_map<Node*, std::vector<uint8_t>> findLastUses(Graph & g) {
+std::unordered_map<Node*, std::vector<uint8_t>> findLastUses(Graph& g) {
// struct to share common data structures
struct FindLastUses {
- Graph & graph;
+ Graph& graph;
// have we seen this value, yet, if not, it is the last use of the value
std::unordered_set<Value*> seen;
// when the If/Loop exits. These are created and inserted on demand.
std::unordered_map<Node*, Node*> drop_for_node;
- FindLastUses(Graph & g)
- : graph(g) {
+ FindLastUses(Graph& g) : graph(g) {
scanBlock(graph.block());
}
- void scanBlock(Block * b) {
+ void scanBlock(Block* b) {
scanNode(b->return_node());
- for(auto n : b->nodes().reverse()) {
+ for (auto n : b->nodes().reverse()) {
scanNode(n);
}
}
- void scanNode(Node * n) {
- for(auto b : n->blocks()) {
+ void scanNode(Node* n) {
+ for (auto b : n->blocks()) {
scanBlock(b);
}
move_flags[n].resize(n->inputs().size());
- // scan backwards so if a value is used twice in the list then it is a move
- for(size_t i = n->inputs().size(); i > 0; --i) {
- scanUse(n, i-1);
+ // scan backwards so if a value is used twice in the list then it is a
+ // move
+ for (size_t i = n->inputs().size(); i > 0; --i) {
+ scanUse(n, i - 1);
}
}
- void scanUse(Node * n, size_t i) {
- auto & move_flags_n = move_flags[n];
+ void scanUse(Node* n, size_t i) {
+ auto& move_flags_n = move_flags[n];
auto v = n->inputs()[i];
auto inserted = seen.insert(v).second;
- if(!inserted) {
+ if (!inserted) {
move_flags_n[i] = false;
return;
}
// the last use of v may be in a nested block of an If or Loop statement
- // find the node 'same_depth_node' at the same depth as the definition of v,
- // and consider that node to be the last use of v.
- // This ensures we do not delete nodes in nested scopes
- // that may be executed multiple times
+ // find the node 'same_depth_node' at the same depth as the definition of
+ // v, and consider that node to be the last use of v. This ensures we do
+ // not delete nodes in nested scopes that may be executed multiple times
// and that nodes used on one side of an if
// but not the other get deleted regardless of the branch
// e.g.
// drop(a)
// In other words, we find the first program point for v that
// _reverse_ dominates the definition of v, and add a drop point there.
- Node * same_depth_node = findOwnerInBlock(n, v->node()->owningBlock());
- JIT_ASSERT(same_depth_node); // failure means v is not in scope for n, use lint!
+ Node* same_depth_node = findOwnerInBlock(n, v->node()->owningBlock());
+ JIT_ASSERT(
+ same_depth_node); // failure means v is not in scope for n, use lint!
// In the case where v and n are in the same block, just mark
// its move_flags to be true
- if(same_depth_node == n) {
+ if (same_depth_node == n) {
move_flags_n[i] = true;
return;
}
// in the case where the use is nested in a block
// add a Drop node after that block which will drop 'v'.
move_flags_n[i] = false;
- addToDropIfNotExists(findOrCreateDropInstructionForNode(same_depth_node), v);
+ addToDropIfNotExists(
+ findOrCreateDropInstructionForNode(same_depth_node), v);
}
// finds the node in block 'block' that contains in 'n'
// n1: if <cond>:
// n2: b = a + a
// findOwnerInBlock(n2, n0.block()) == n1
- Node * findOwnerInBlock(Node * n, Block * block) {
- while(n != nullptr && block != n->owningBlock()) {
+ Node* findOwnerInBlock(Node* n, Block* block) {
+ while (n != nullptr && block != n->owningBlock()) {
n = n->owningBlock()->owningNode();
}
return n;
}
- Node * findOrCreateDropInstructionForNode(Node * n) {
+ Node* findOrCreateDropInstructionForNode(Node* n) {
auto it = drop_for_node.find(n);
- if(it == drop_for_node.end()) {
+ if (it == drop_for_node.end()) {
auto drop_node = graph.create(prim::Drop, 0);
drop_node->insertAfter(n);
it = drop_for_node.emplace(n, drop_node).first;
return it->second;
}
- void addToDropIfNotExists(Node * drop, Value * v) {
- for(auto i : drop->inputs()) {
+ void addToDropIfNotExists(Node* drop, Value* v) {
+ for (auto i : drop->inputs()) {
// we already accounted for this use
- if(i == v)
+ if (i == v)
return;
}
drop->addInput(v);
return FindLastUses(g).move_flags;
}
-} //namespace
+} // namespace
// pre-processing that happens once per graph
struct PreprocessGraph {
- PreprocessGraph(Graph & g)
- : graph(g.copy()) {
+ PreprocessGraph(Graph& g) : graph(g.copy()) {
n_outputs = graph->outputs().size();
desugarTripCounts(graph->block());
flattenIO(*graph);
dropUnused(graph->block());
// fill in move_flags by scanning blocks;
move_flags = findLastUses(*graph);
- //TODO: desugar Loop trip counts, for now we drop trip counts
+ // TODO: desugar Loop trip counts, for now we drop trip counts
}
// Outputs of the preprocessing:
std::shared_ptr<Graph> graph;
// Note: this is currently unused but will probably be useful in the future,
// so we keep it around
struct ContainerTensor : public at::TensorImpl {
-public:
+ public:
ContainerTensor()
- : TensorImpl(at::UndefinedTensorId(), caffe2::TypeMeta(), nullptr, /* is_variable */ false) {}
+ : TensorImpl(
+ at::UndefinedTensorId(),
+ caffe2::TypeMeta(),
+ nullptr,
+ /* is_variable */ false) {}
~ContainerTensor() override = default;
at::IntList sizes() const override {
// which are stored in the ListHandle struct
// start is an offset into int_data of Code for ListHandle<int>
// and bool_data of Code for ListHandle<bool>
-template<typename T>
+template <typename T>
struct ListHandle {
int start;
int size;
std::shared_ptr<SourceLocation> debug_location; // for error reporting
};
-
int relativeJump(int from_inst, int to_inst) {
return to_inst - (from_inst + 1);
}
struct CodeImpl {
- CodeImpl(const std::shared_ptr<Graph>& graph_)
- : preprocess(*graph_) {
+ CodeImpl(const std::shared_ptr<Graph>& graph_) : preprocess(*graph_) {
graph = preprocess.graph;
insertNodesFromBlock(graph->block());
}
// jump when input is false
void createJumpFalse(int from_inst, int to_inst) {
- auto & inst = instructions[from_inst];
+ auto& inst = instructions[from_inst];
JIT_ASSERT(inst.debug_name == prim::Placeholder);
auto offset = relativeJump(from_inst, to_inst);
- inst.callback = [offset](Stack & stack) {
+ inst.callback = [offset](Stack& stack) {
auto t = pop(stack).toBool();
return t ? 0 : offset;
};
// jump when input is true
void createJumpTrue(int from_inst, int to_inst) {
- auto & inst = instructions[from_inst];
+ auto& inst = instructions[from_inst];
JIT_ASSERT(inst.debug_name == prim::Placeholder);
auto offset = relativeJump(from_inst, to_inst);
- inst.callback = [offset](Stack & stack) {
+ inst.callback = [offset](Stack& stack) {
auto t = pop(stack).toBool();
return t ? offset : 0;
};
}
void createJump(int from_inst, int to_inst) {
- auto & inst = instructions[from_inst];
+ auto& inst = instructions[from_inst];
JIT_ASSERT(inst.debug_name == prim::Placeholder);
auto offset = relativeJump(from_inst, to_inst);
- inst.callback = [=](Stack & stack) {
- return offset;
- };
+ inst.callback = [=](Stack& stack) { return offset; };
inst.debug_name = prim::Jump;
}
void insertNodesFromBlock(Block* block) {
- for(auto node : block->nodes()) {
- const auto & source_location = node->getSourceLocation();
- switch(node->kind()) {
+ for (auto node : block->nodes()) {
+ const auto& source_location = node->getSourceLocation();
+ switch (node->kind()) {
case prim::If: {
// x = if c:
// <then_block>
// x = vt
// end:
- // prim::Placeholder instructions are replaced with branch instructions
- // when the branch target locations are known
- auto cond_branch = insertInstruction(prim::Placeholder, source_location, node->inputs(), moveFlags(node), {});
+ // prim::Placeholder instructions are replaced with branch
+ // instructions when the branch target locations are known
+ auto cond_branch = insertInstruction(
+ prim::Placeholder,
+ source_location,
+ node->inputs(),
+ moveFlags(node),
+ {});
auto then_block = node->blocks()[0];
auto else_block = node->blocks()[1];
insertNodesFromBlock(else_block);
- insertAssign(source_location,else_block->outputs(), moveFlags(else_block), node->outputs());
- auto jump = insertInstruction(prim::Placeholder, source_location, {}, {}, {});
+ insertAssign(
+ source_location,
+ else_block->outputs(),
+ moveFlags(else_block),
+ node->outputs());
+ auto jump =
+ insertInstruction(prim::Placeholder, source_location, {}, {}, {});
auto then_block_start = instructions.size();
insertNodesFromBlock(then_block);
- insertAssign(source_location, then_block->outputs(), moveFlags(then_block), node->outputs());
+ insertAssign(
+ source_location,
+ then_block->outputs(),
+ moveFlags(then_block),
+ node->outputs());
createJump(jump, instructions.size());
createJumpTrue(cond_branch, then_block_start);
} break;
auto body_block = node->blocks()[0];
// before assign op: stack: ... <cond> <loop-carried-depdencies>
- insertAssign(source_location, node->inputs(), moveFlags(node), body_block->inputs());
+ insertAssign(
+ source_location,
+ node->inputs(),
+ moveFlags(node),
+ body_block->inputs());
// after assign op: stack: ... <cond>
// cond_branch consumes <cond> from top of the stack
- auto cond_branch = insertInstruction(prim::Placeholder, source_location,{}, {}, {});
+ auto cond_branch =
+ insertInstruction(prim::Placeholder, source_location, {}, {}, {});
// after branch: stack: ...
auto entry = instructions.size();
insertNodesFromBlock(body_block);
// before assign op: stack: ... <cond> <loop-carried-depdencies>
- insertAssign(source_location, body_block->outputs(), moveFlags(body_block), body_block->inputs());
+ insertAssign(
+ source_location,
+ body_block->outputs(),
+ moveFlags(body_block),
+ body_block->inputs());
// after assign op: stack: ... <cond>
- auto cond_branch_end = insertInstruction(prim::Placeholder, source_location, {}, {}, {});
+ auto cond_branch_end =
+ insertInstruction(prim::Placeholder, source_location, {}, {}, {});
// after branch: stack: ...
aliasRegistersTo(node->outputs(), body_block->inputs());
createJumpFalse(cond_branch, instructions.size());
createJumpTrue(cond_branch_end, entry);
} break;
- default: {
- insertInstruction(node);
- } break;
+ default: { insertInstruction(node); } break;
}
}
}
- size_t insertInstruction(Node * n) {
- auto inst = insertInstruction(n->kind(), n->getSourceLocation(), n->inputs(), moveFlags(n) , n->outputs());
+ size_t insertInstruction(Node* n) {
+ auto inst = insertInstruction(
+ n->kind(),
+ n->getSourceLocation(),
+ n->inputs(),
+ moveFlags(n),
+ n->outputs());
instructions[inst].callback = getOperation(n);
return inst;
}
- size_t insertInstruction(Symbol sym,
- std::shared_ptr<SourceLocation> debug_location,
- ArrayRef<Value*> inputs,
- ArrayRef<uint8_t> move_flags,
- ArrayRef<Value*> outputs) {
+ size_t insertInstruction(
+ Symbol sym,
+ std::shared_ptr<SourceLocation> debug_location,
+ ArrayRef<Value*> inputs,
+ ArrayRef<uint8_t> move_flags,
+ ArrayRef<Value*> outputs) {
instructions.emplace_back();
- auto & inst = instructions.back();
+ auto& inst = instructions.back();
inst.debug_name = sym;
inst.debug_location = std::move(debug_location);
listBegin(inst.inputs.values);
- for(auto input : inputs) {
+ for (auto input : inputs) {
listInsert(inst.inputs.values, getOrAllocateRegister(input, true));
}
listBegin(inst.inputs.free_flags);
- for(auto flag : move_flags) {
+ for (auto flag : move_flags) {
listInsert(inst.inputs.free_flags, flag);
}
listBegin(inst.outputs);
- for(auto output : outputs) {
+ for (auto output : outputs) {
listInsert(inst.outputs, getOrAllocateRegister(output));
}
return instructions.size() - 1;
}
- ArrayRef<uint8_t> moveFlags(Node * n) {
+ ArrayRef<uint8_t> moveFlags(Node* n) {
return preprocess.move_flags.at(n);
}
- ArrayRef<uint8_t> moveFlags(Block *b) {
+ ArrayRef<uint8_t> moveFlags(Block* b) {
return moveFlags(b->return_node());
}
- size_t insertAssign(std::shared_ptr<SourceLocation> debug_location, ArrayRef<Value*> inputs, ArrayRef<uint8_t> move_flags, ArrayRef<Value*> outputs) {
- auto inst = insertInstruction(prim::Assign, std::move(debug_location),inputs, move_flags, outputs);
- // This node effectively forwards its inputs into different places in a register list.
- // We don't need to manipulate the stack in any way, because all inputs are also outputs,
- // and the interpreter will take care of putting them in correct places.
+ size_t insertAssign(
+ std::shared_ptr<SourceLocation> debug_location,
+ ArrayRef<Value*> inputs,
+ ArrayRef<uint8_t> move_flags,
+ ArrayRef<Value*> outputs) {
+ auto inst = insertInstruction(
+ prim::Assign, std::move(debug_location), inputs, move_flags, outputs);
+ // This node effectively forwards its inputs into different places in a
+ // register list. We don't need to manipulate the stack in any way, because
+ // all inputs are also outputs, and the interpreter will take care of
+ // putting them in correct places.
instructions[inst].callback = [](Stack& stack) { return 0; };
return inst;
}
// helpers to build/access RegList objects
- int get(const ListHandle<int> & list, int i) const {
+ int get(const ListHandle<int>& list, int i) const {
return int_data[list.start + i];
}
- bool get(const ListHandle<bool> & list, int i) const {
+ bool get(const ListHandle<bool>& list, int i) const {
return bool_data[list.start + i];
}
- void listBegin(ListHandle<int> & list) {
+ void listBegin(ListHandle<int>& list) {
list.start = int_data.size();
list.size = 0;
}
- void listInsert(ListHandle<int> & list, int value) {
- JIT_ASSERTM(list.start + list.size == (int)int_data.size(), "another list already started");
+ void listInsert(ListHandle<int>& list, int value) {
+ JIT_ASSERTM(
+ list.start + list.size == (int)int_data.size(),
+ "another list already started");
int_data.push_back(value);
list.size++;
}
- void listBegin(ListHandle<bool> & list) {
+ void listBegin(ListHandle<bool>& list) {
list.start = bool_data.size();
list.size = 0;
}
- void listInsert(ListHandle<bool> & list, int value) {
- JIT_ASSERTM(list.start + list.size == (int)bool_data.size(), "another list already started");
+ void listInsert(ListHandle<bool>& list, int value) {
+ JIT_ASSERTM(
+ list.start + list.size == (int)bool_data.size(),
+ "another list already started");
bool_data.push_back(value);
list.size++;
}
// must be called before any new_allocations are used, otherwise they will
// already have registers assigned
- void aliasRegistersTo(ArrayRef<Value*> new_allocations, ArrayRef<Value*> existing_allocations) {
+ void aliasRegistersTo(
+ ArrayRef<Value*> new_allocations,
+ ArrayRef<Value*> existing_allocations) {
JIT_ASSERT(new_allocations.size() == existing_allocations.size());
- for(size_t i = 0; i < new_allocations.size(); ++i) {
+ for (size_t i = 0; i < new_allocations.size(); ++i) {
auto n = new_allocations[i]->unique();
auto e = existing_allocations[i]->unique();
JIT_ASSERT(unique_to_reg.count(e) > 0 && unique_to_reg.count(n) == 0);
unique_to_reg[n] = unique_to_reg[e];
}
}
- int getOrAllocateRegister(Value * n, bool required = false) {
+ int getOrAllocateRegister(Value* n, bool required = false) {
size_t u = n->unique();
- if(unique_to_reg.count(u) > 0)
+ if (unique_to_reg.count(u) > 0)
return unique_to_reg[u];
JIT_ASSERT(!required);
int r = register_size++;
const std::vector<GraphExecutor*>& grad_executors() {
if (!grad_executors_) {
grad_executors_.emplace();
- for (Instruction & instr : instructions) {
+ for (Instruction& instr : instructions) {
if (auto executor = detail::getGradExecutor(instr.callback)) {
grad_executors_->push_back(executor);
}
return *grad_executors_;
}
- void dumpInstruction(std::ostream & out, size_t pc) const {
- auto writeList = [&](const ListHandle<int> & list) {
- for(int i = 0; i < list.size; i++) {
- if(i > 0)
+ void dumpInstruction(std::ostream& out, size_t pc) const {
+ auto writeList = [&](const ListHandle<int>& list) {
+ for (int i = 0; i < list.size; i++) {
+ if (i > 0)
out << ", ";
out << get(list, i);
}
};
- auto writeUseList = [&](const UseList & list) {
- for(int i = 0; i < list.values.size; i++) {
- if(i > 0)
+ auto writeUseList = [&](const UseList& list) {
+ for (int i = 0; i < list.values.size; i++) {
+ if (i > 0)
out << ", ";
- if(get(list.free_flags, i))
+ if (get(list.free_flags, i))
out << "move(" << get(list.values, i) << ")";
else
out << get(list.values, i);
}
};
- auto & inst = instructions.at(pc);
+ auto& inst = instructions.at(pc);
writeList(inst.outputs);
// NB: debug names are the kind of operator used to select
// dispatch
out << " = " << inst.debug_name.toUnqualString() << " ";
writeUseList(inst.inputs);
}
- void dump(std::ostream & out) const {
- for(size_t i = 0; i < instructions.size(); ++i) {
+ void dump(std::ostream& out) const {
+ for (size_t i = 0; i < instructions.size(); ++i) {
dumpInstruction(out, i);
out << "\n";
}
c10::optional<std::vector<GraphExecutor*>> grad_executors_;
PreprocessGraph preprocess;
- std::unordered_map<size_t, int> unique_to_reg; // map from unique of nodes to register in register table
+ std::unordered_map<size_t, int>
+ unique_to_reg; // map from unique of nodes to register in register table
friend struct InterpreterState;
std::vector<Instruction> instructions;
// InterpreterState state that and used to compute a Code
struct InterpreterStateImpl : c10::intrusive_ptr_target {
- InterpreterStateImpl(const Code & code)
- : function(code.pImpl),
- int_data(function->int_data.data()),
- bool_data(function->bool_data),
- registers(function->register_size) {
- }
+ InterpreterStateImpl(const Code& code)
+ : function(code.pImpl),
+ int_data(function->int_data.data()),
+ bool_data(function->bool_data),
+ registers(function->register_size) {}
private:
c10::intrusive_ptr<InterpreterStateImpl> intrusive_from_this() {
}
bool runImpl(Stack& stack) {
- auto & instructions = function->instructions;
+ auto& instructions = function->instructions;
size_t last = instructions.size();
while (pc < last) {
- // std::cout << "executing " << pc << ": ";
- // function->dumpInstruction(std::cout, pc);
- // std::cout << "\n";
- auto & inst = instructions[pc];
- try {
- loadTensorsFromRegisters(inst.inputs, stack);
- size_t new_pc = pc + 1 + inst.callback(stack);
- for (int i = inst.outputs.size - 1; i >= 0; --i) {
- int reg = get(inst.outputs, i);
- registers[reg] = pop(stack);
- // std::cout << "pop reg[" << reg << "];\n" << registers[reg] << "\n";
- }
- pc = new_pc;
- } catch (Suspend& e) {
- // wait() expects a single input
- JIT_ASSERT(inst.inputs.values.size == 1);
-
- getOrCreateFuture();
-
- if (get(inst.inputs.free_flags, 0)) {
- // make sure the register is not freed once we are waked up
- registers[get(inst.inputs.values, 0)] = e.future;
- }
-
- // Make sure adding callback is the last step.
- // Otherwise if e.future has completed,
- // the current thread will continue running before it suspends.
- InterpreterState state(intrusive_from_this());
- e.future->addCallback([state]() {
- c10::global_work_queue().run(
- InterpreterContinuation(state, Stack()));
- });
-
- return true;
- } catch (Future::FutureError& e) {
- // Error from the forked thread.
- auto msg = e.error_msg; // copy the error for each callback
- handleError(std::move(msg), false);
- return false;
- } catch (std::exception& e) {
- // Error from the current thread
- bool is_jit_exception = dynamic_cast<JITException*>(&e);
- if (instructions[pc].debug_location) {
- handleError(instructions[pc].debug_location->wrapException(
- e, "operation failed in interpreter"), is_jit_exception);
- } else {
- handleError(e.what(), is_jit_exception);
- }
- return false;
+ // std::cout << "executing " << pc << ": ";
+ // function->dumpInstruction(std::cout, pc);
+ // std::cout << "\n";
+ auto& inst = instructions[pc];
+ try {
+ loadTensorsFromRegisters(inst.inputs, stack);
+ size_t new_pc = pc + 1 + inst.callback(stack);
+ for (int i = inst.outputs.size - 1; i >= 0; --i) {
+ int reg = get(inst.outputs, i);
+ registers[reg] = pop(stack);
+ // std::cout << "pop reg[" << reg << "];\n" << registers[reg] << "\n";
}
+ pc = new_pc;
+ } catch (Suspend& e) {
+ // wait() expects a single input
+ JIT_ASSERT(inst.inputs.values.size == 1);
+
+ getOrCreateFuture();
+
+ if (get(inst.inputs.free_flags, 0)) {
+ // make sure the register is not freed once we are waked up
+ registers[get(inst.inputs.values, 0)] = e.future;
+ }
+
+ // Make sure adding callback is the last step.
+ // Otherwise if e.future has completed,
+ // the current thread will continue running before it suspends.
+ InterpreterState state(intrusive_from_this());
+ e.future->addCallback([state]() {
+ c10::global_work_queue().run(InterpreterContinuation(state, Stack()));
+ });
+
+ return true;
+ } catch (Future::FutureError& e) {
+ // Error from the forked thread.
+ auto msg = e.error_msg; // copy the error for each callback
+ handleError(std::move(msg), false);
+ return false;
+ } catch (std::exception& e) {
+ // Error from the current thread
+ bool is_jit_exception = dynamic_cast<JITException*>(&e);
+ if (instructions[pc].debug_location) {
+ handleError(
+ instructions[pc].debug_location->wrapException(
+ e, "operation failed in interpreter"),
+ is_jit_exception);
+ } else {
+ handleError(e.what(), is_jit_exception);
+ }
+ return false;
+ }
}
if (future) {
auto num_outputs = function->preprocess.n_outputs;
}
}
- int get(const ListHandle<int> & list, int i) {
+ int get(const ListHandle<int>& list, int i) {
return int_data[list.start + i];
};
- bool get(const ListHandle<bool> & list, int i) {
+ bool get(const ListHandle<bool>& list, int i) {
return bool_data[list.start + i];
}
- void loadTensorsFromRegisters(const UseList & uses, Stack & stack) {
- for(int i = 0; i < uses.values.size; i++) {
- int reg = get(uses.values,i);
+ void loadTensorsFromRegisters(const UseList& uses, Stack& stack) {
+ for (int i = 0; i < uses.values.size; i++) {
+ int reg = get(uses.values, i);
// std::cout << "push reg[" << reg << "];\n" << registers[reg] << "\n\n";
- if(get(uses.free_flags,i)) {
+ if (get(uses.free_flags, i)) {
stack.push_back(std::move(registers[reg]));
} else {
stack.push_back(registers[reg]);
}
-
}
}
c10::intrusive_ptr<Future> future;
std::shared_ptr<CodeImpl> function; // keep function alive
// these are just copies of function to prevent indirections in interpreter
- int * int_data;
- const std::vector<bool> & bool_data;
-
+ int* int_data;
+ const std::vector<bool>& bool_data;
// this holds all the tensors for this interpreter run
// we don't bother minimizing the size of this vector, since the extra
// to make sure memory management happens efficiently.
// We optimize for the case where derivatives are run with retain_graph=False
- // in the case where it is true, then the interpreter and this array get copied
- // if this every becomes a bottleneck then we _should_ consider minimizing the
- // total number or register
+ // in the case where it is true, then the interpreter and this array get
+ // copied if this every becomes a bottleneck then we _should_ consider
+ // minimizing the total number or register
std::vector<IValue> registers;
- // single buffer for input/output calls to ATen functions, so that we do not reallocate
+ // single buffer for input/output calls to ATen functions, so that we do not
+ // reallocate
Stack stack;
};
-std::ostream & operator<<(std::ostream & out, const Code & code) {
+std::ostream& operator<<(std::ostream& out, const Code& code) {
out << *code.pImpl->graph << "\n";
code.pImpl->dump(out);
return out;
}
-Code::Code(const std::shared_ptr<Graph>& graph)
- : pImpl(new CodeImpl(graph)) {}
+Code::Code(const std::shared_ptr<Graph>& graph) : pImpl(new CodeImpl(graph)) {}
Code::~Code() = default;
const std::vector<GraphExecutor*>& Code::grad_executors() {
return pImpl->grad_executors();
}
-InterpreterState::InterpreterState(const Code & code)
- : pImpl(c10::make_intrusive<InterpreterStateImpl>(code)) {}
+InterpreterState::InterpreterState(const Code& code)
+ : pImpl(c10::make_intrusive<InterpreterStateImpl>(code)) {}
InterpreterState::~InterpreterState() = default;
void InterpreterState::run(Stack& stack) {
return static_cast<InterpreterStateImpl*>(pImpl.get())->getOrCreateFuture();
}
-InterpreterState::InterpreterState(c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl_)
+InterpreterState::InterpreterState(
+ c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl_)
: pImpl(std::move(pImpl_)) {}
-}}
+} // namespace jit
+} // namespace torch
#pragma once
+#include <c10/util/Optional.h>
#include <memory>
#include <vector>
-#include <c10/util/Optional.h>
-#include <torch/csrc/jit/ivalue.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/ivalue.h>
namespace at {
- class Tensor;
+class Tensor;
}
namespace c10 {
struct IValue;
}
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
// The interpreter run Graphs with Tensor inputs and Tensor outputs
// a separate component in the autograd handles unwrapping and wrapping
using Stack = std::vector<c10::IValue>;
struct TORCH_API Code {
- Code()
- : pImpl(nullptr) {}
+ Code() : pImpl(nullptr) {}
explicit Code(const std::shared_ptr<Graph>& graph);
~Code();
return pImpl != nullptr;
}
-private:
+ private:
std::shared_ptr<CodeImpl> pImpl;
friend struct InterpreterStateImpl;
- friend std::ostream & operator<<(std::ostream & out, const Code & code);
+ friend std::ostream& operator<<(std::ostream& out, const Code& code);
};
struct InterpreterState {
- InterpreterState(const Code & code);
+ InterpreterState(const Code& code);
void run(Stack& stack);
c10::intrusive_ptr<Future> runAsync(Stack& stack);
c10::intrusive_ptr<Future> getFuture();
~InterpreterState();
-private:
+
+ private:
InterpreterState(c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl);
// Ideally we should use c10::intrusive_ptr<InterpreterStateImpl> for pImpl;
// but intrusive_ptr requires full definition of InterpreterStateImpl,
InterpreterState state;
Stack stack;
};
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/ir.h>
-
-#include <torch/csrc/jit/operator.h>
#include <torch/csrc/autograd/function.h>
-#include <torch/csrc/jit/constants.h>
#include <torch/csrc/jit/assertions.h>
-#include <torch/csrc/jit/script/schema_matching.h>
-#include <torch/csrc/jit/passes/python_print.h>
+#include <torch/csrc/jit/constants.h>
+#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/passes/alias_analysis.h>
+#include <torch/csrc/jit/passes/python_print.h>
+#include <torch/csrc/jit/script/schema_matching.h>
#include <algorithm>
#include <iostream>
#include <unordered_set>
#include <utility>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
// Constants relating to maintaining the topological index of nodes.
//
// Lower and upper bounds of the index. Inclusive range.
// - 2^(64-n) is the maximum number of appends to the end without reindex
static constexpr topo_position_t kAppendInterval = 1099511627776ULL /* 2^40 */;
-// Sigh, see https://stackoverflow.com/questions/8016780/undefined-reference-to-static-constexpr-char
+// Sigh, see
+// https://stackoverflow.com/questions/8016780/undefined-reference-to-static-constexpr-char
constexpr Symbol PythonOp::Kind;
-void printValueRef(std::ostream & out, const Value * n) {
+void printValueRef(std::ostream& out, const Value* n) {
out << "%" << n->uniqueName();
}
// NB: This overload will become ambiguous with the one Caffe2 provides in its
// logging, if they ever intersect.
template <typename T>
-std::ostream& operator<<(std::ostream & out, const std::vector<T> & nodes) {
+std::ostream& operator<<(std::ostream& out, const std::vector<T>& nodes) {
out << at::ArrayRef<T>{nodes};
return out;
}
template <typename T>
-std::ostream& printValueRefs(std::ostream & out, const at::ArrayRef<T> & nodes) {
+std::ostream& printValueRefs(std::ostream& out, const at::ArrayRef<T>& nodes) {
size_t i = 0;
- for(auto n : nodes) {
- if(i++ > 0)
+ for (auto n : nodes) {
+ if (i++ > 0)
out << ", ";
printValueRef(out, n);
}
// Can't make these two overloads directly a template, it'll be ambiguous with
// the global printer for operator<<.
-std::ostream& operator<<(std::ostream & out, const at::ArrayRef<const Value*> & nodes) {
+std::ostream& operator<<(
+ std::ostream& out,
+ const at::ArrayRef<const Value*>& nodes) {
return printValueRefs(out, nodes);
}
-std::ostream& operator<<(std::ostream & out, const at::ArrayRef<Value*> & nodes) {
+std::ostream& operator<<(std::ostream& out, const at::ArrayRef<Value*>& nodes) {
return printValueRefs(out, nodes);
}
struct const_value_list_with_types {
const ArrayRef<const Value*> values;
bool use_newlines;
- const_value_list_with_types(ArrayRef<const Value*> values, bool use_newlines = false)
- : values(values), use_newlines(use_newlines) {}
+ const_value_list_with_types(
+ ArrayRef<const Value*> values,
+ bool use_newlines = false)
+ : values(values), use_newlines(use_newlines) {}
};
-std::ostream& operator<<(std::ostream & out, const_value_list_with_types l) {
+std::ostream& operator<<(std::ostream& out, const_value_list_with_types l) {
size_t i = 0;
- for(auto n : l.values) {
- if(i++ > 0) {
+ for (auto n : l.values) {
+ if (i++ > 0) {
if (l.use_newlines) {
// TODO: Indent here is hard-coded for "graph(": un-hard-code it
out << "\n ";
return out;
}
-void printAttributes(std::ostream & out, const Node * n, bool ignore_subgraph=false) {
+void printAttributes(
+ std::ostream& out,
+ const Node* n,
+ bool ignore_subgraph = false) {
out << "[";
auto names = n->attributeNames();
int i = 0;
- for(auto name : names) {
+ for (auto name : names) {
if (ignore_subgraph && name == attr::Subgraph)
continue;
- if(i++ > 0)
+ if (i++ > 0)
out << ", ";
// TODO: debugging mode to see the qualifier. We definitely
// don't want to print the qualifier since it should always
out << "]";
}
-static std::ostream & indent(std::ostream & out, size_t level) {
- for(size_t i = 0; i < level; ++i)
+static std::ostream& indent(std::ostream& out, size_t level) {
+ for (size_t i = 0; i < level; ++i)
out << " ";
return out;
}
-std::ostream& printNode(std::ostream & out, size_t level, const Node * n, std::vector<const Node*> * groups) {
+std::ostream& printNode(
+ std::ostream& out,
+ size_t level,
+ const Node* n,
+ std::vector<const Node*>* groups) {
auto outputs = n->outputs();
indent(out, level) << const_value_list_with_types(outputs);
out << " = ";
- IR_IFM_CONST(n,PythonOp)
- out << "^" << value->name();
- value->writeScalars(out);
+ IR_IFM_CONST(n, PythonOp)
+ out << "^" << value->name();
+ value->writeScalars(out);
IR_ELSE()
- if(n->hasAttribute(attr::Subgraph) && groups) {
- out << n->kind().toQualString() << "_" << groups->size();
- if (n->numAttributes() > 1 && n->kind() != prim::DifferentiableGraph) {
- printAttributes(out, n, /*ignore_subgraph=*/true);
- }
- groups->push_back(n);
- } else {
- out << n->kind().toQualString();
- if(n->hasAttributes()) {
- printAttributes(out,n);
- }
+ if (n->hasAttribute(attr::Subgraph) && groups) {
+ out << n->kind().toQualString() << "_" << groups->size();
+ if (n->numAttributes() > 1 && n->kind() != prim::DifferentiableGraph) {
+ printAttributes(out, n, /*ignore_subgraph=*/true);
+ }
+ groups->push_back(n);
+ } else {
+ out << n->kind().toQualString();
+ if (n->hasAttributes()) {
+ printAttributes(out, n);
}
+ }
IR_END()
out << "(" << n->inputs() << ")";
std::string scopeName = n->scopeName();
if (scopeName.empty()) {
out << "\n";
- }
- else {
+ } else {
out << ", ";
out << "scope: " << scopeName << "\n";
}
- for(size_t i = 0; i < n->blocks().size(); ++i) {
+ for (size_t i = 0; i < n->blocks().size(); ++i) {
auto b = n->blocks()[i];
- indent(out, level + 1) << "block" << i << "(" << const_value_list_with_types(b->inputs(), false) << ") {\n";
- for(auto n : b->nodes()) {
+ indent(out, level + 1) << "block" << i << "("
+ << const_value_list_with_types(b->inputs(), false)
+ << ") {\n";
+ for (auto n : b->nodes()) {
printNode(out, level + 2, n, groups);
}
indent(out, level + 2) << "-> (" << b->outputs() << ")\n";
return out;
}
-std::ostream& operator<<(std::ostream & out, const Node & n) {
+std::ostream& operator<<(std::ostream& out, const Node& n) {
return printNode(out, 0, &n, nullptr);
}
-std::ostream& operator<<(std::ostream & out, const Graph & g) {
+std::ostream& operator<<(std::ostream& out, const Graph& g) {
out << "graph(" << const_value_list_with_types(g.inputs(), true) << ") {\n";
std::vector<const Node*> groups;
- for(auto n : g.nodes()) {
+ for (auto n : g.nodes()) {
printNode(out, 1, n, &groups);
}
out << " return (" << g.outputs() << ");\n}\n";
size_t i = 0;
- for(auto fg : groups) {
- out << "with " << fg->kind().toQualString() << "_" <<i++ << " = " << *fg->g(attr::Subgraph);
+ for (auto fg : groups) {
+ out << "with " << fg->kind().toQualString() << "_" << i++ << " = "
+ << *fg->g(attr::Subgraph);
}
/*
// Uncomment this to debug all_nodes issues
return out;
}
-std::ostream& Graph::prettyPrint(std::ostream & out) {
+std::ostream& Graph::prettyPrint(std::ostream& out) {
std::vector<at::Tensor> tensor_table;
PythonPrint(out, *this, tensor_table);
return out;
bool has_device = false;
c10::optional<at::Device> device = c10::nullopt;
auto checkValue = [&](const Value* v) {
- if(CompleteTensorTypePtr type = v->type()->cast<CompleteTensorType>()) {
- if(!has_device) {
+ if (CompleteTensorTypePtr type = v->type()->cast<CompleteTensorType>()) {
+ if (!has_device) {
has_device = true;
device = type->device();
} else {
}
}
};
- for(auto input : node->inputs()) {
+ for (auto input : node->inputs()) {
checkValue(input);
}
- for(auto output : node->outputs()) {
+ for (auto output : node->outputs()) {
checkValue(output);
}
}
for (auto input : inputs_) {
// WARNING: O(n^2)
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
- JIT_ASSERT(std::find(ALL_OF(input->uses_), Use(const_cast<Node*>(this), i)) != input->uses_.end());
+ JIT_ASSERT(
+ std::find(ALL_OF(input->uses_), Use(const_cast<Node*>(this), i)) !=
+ input->uses_.end());
JIT_ASSERT(graph_->all_nodes.count(this) == 1);
i++;
}
}
- for(auto o : outputs()) {
+ for (auto o : outputs()) {
size_t i = 0;
for (auto use : o->uses()) {
// Use invariants
}
// Node subclass invariants
- IR_IF(this,Constant)
- JIT_ASSERT(inputs_.size() == 0);
+ IR_IF(this, Constant)
+ JIT_ASSERT(inputs_.size() == 0);
IR_ELSEIF(Return)
- // Return uses is zero
- JIT_ASSERT(outputs().size() == 0);
+ // Return uses is zero
+ JIT_ASSERT(outputs().size() == 0);
IR_ELSEIF(Param)
- // Param inputs is zero
- JIT_ASSERT(inputs_.size() == 0);
+ // Param inputs is zero
+ JIT_ASSERT(inputs_.size() == 0);
IR_ELSEIFM_CONST(PythonOp)
- // Python operator cconv is correct
- size_t n_scalars = 0, n_tensors = 0;
- for (auto c : value->cconv) {
- if (c == 'c') {
- n_scalars++;
- } else if (c == 'd') {
- n_tensors++;
- } else {
- JIT_ASSERT(0);
- }
- JIT_ASSERT(static_cast<bool>(value->pyobj));
+ // Python operator cconv is correct
+ size_t n_scalars = 0, n_tensors = 0;
+ for (auto c : value->cconv) {
+ if (c == 'c') {
+ n_scalars++;
+ } else if (c == 'd') {
+ n_tensors++;
+ } else {
+ JIT_ASSERT(0);
}
- JIT_ASSERT(n_scalars == value->scalar_args.size());
- JIT_ASSERT(n_tensors == inputs_.size());
+ JIT_ASSERT(static_cast<bool>(value->pyobj));
+ }
+ JIT_ASSERT(n_scalars == value->scalar_args.size());
+ JIT_ASSERT(n_tensors == inputs_.size());
IR_ELSEIF(Eval)
- // TODO: add invariants
+ // TODO: add invariants
// TODO: It's not good for these ops to be top-level, it makes cases longer.
IR_ELSEIF(FusionGroup)
- checkSameDevice(value);
- // TODO: Typecheck the parameters
- value->g(attr::Subgraph)->lint();
+ checkSameDevice(value);
+ // TODO: Typecheck the parameters
+ value->g(attr::Subgraph)->lint();
IR_END()
-
}
// TODO: When lint fails, give better indication about which
struct LintScope {
LintScope() = default;
- LintScope(std::unique_ptr<LintScope> parent)
- : parent(std::move(parent)) {}
- bool contains(const Value * v) {
+ LintScope(std::unique_ptr<LintScope> parent) : parent(std::move(parent)) {}
+ bool contains(const Value* v) {
return values.count(v) > 0 || (parent && parent->contains(v));
}
- bool contains(const Node * n) {
+ bool contains(const Node* n) {
return nodes.count(n) > 0 || (parent && parent->contains(n));
}
- void insert(const Value * v) {
+ void insert(const Value* v) {
JIT_ASSERT(!contains(v));
values.insert(v);
}
- void insert(const Node * n) {
+ void insert(const Node* n) {
JIT_ASSERT(!contains(n));
nodes.insert(n);
}
std::unique_ptr<LintScope> parent;
- private:
+
+ private:
std::unordered_set<const Value*> values;
std::unordered_set<const Node*> nodes;
};
// Struct enables mutual recursion in linting methods.
// Putting it inside Graph::lint enables access to private Graph members
struct LintImpl {
- LintImpl(const Graph & g)
- : g(g)
- , scope(new LintScope())
- , all_nodes_set(ALL_OF(g.all_nodes)) {} // NB: all_nodes is *unordered*
- const Graph & g;
+ LintImpl(const Graph& g)
+ : g(g),
+ scope(new LintScope()),
+ all_nodes_set(ALL_OF(g.all_nodes)) {} // NB: all_nodes is *unordered*
+ const Graph& g;
std::unique_ptr<LintScope> scope;
std::unordered_set<size_t> seen_uniques;
std::unordered_map<const Node*, int64_t> anticipated_uses;
void check_value(const Value* v) {
scope->insert(v);
auto b2 = seen_uniques.insert(v->unique());
- JIT_ASSERT(b2.second); // insertion took place
+ JIT_ASSERT(b2.second); // insertion took place
JIT_ASSERT(v->unique() < g.next_unique_);
for (auto use : v->uses()) {
JIT_ASSERT(!scope->contains(use.user));
JIT_ASSERT(g.all_nodes.count(use.user) == 1);
- anticipated_uses[use.user]++; // int default constructs to 0
+ anticipated_uses[use.user]++; // int default constructs to 0
}
}
void check_node(const Node* n) {
JIT_ASSERTM(0, input->unique(), " not in scope");
}
}
- JIT_ASSERT(anticipated_uses[n] == static_cast<int64_t>(n->inputs_.size()));
- anticipated_uses[n] = -1; // we saw the anticipated user!
+ JIT_ASSERT(
+ anticipated_uses[n] == static_cast<int64_t>(n->inputs_.size()));
+ anticipated_uses[n] = -1; // we saw the anticipated user!
scope->insert(n);
- for(auto block : n->blocks()) {
+ for (auto block : n->blocks()) {
std::unique_ptr<LintScope> new_scope(new LintScope(std::move(scope)));
scope = std::move(new_scope);
check_block(block);
scope = std::move(scope->parent);
}
size_t i = 0;
- for(auto o : n->outputs()) {
+ for (auto o : n->outputs()) {
JIT_ASSERT(o->node() == n);
JIT_ASSERT(i++ == o->offset_);
check_value(o);
}
n->lint();
}
- void check_block(const Block *b) {
+ void check_block(const Block* b) {
// Check topological ordering
JIT_ASSERT(b->param_node()->isBefore(*b->nodes().begin()));
auto curNode = *b->nodes().begin();
// - only one return node???
node_set nodes_set(ALL_OF(b->nodes()));
- node_set inputs_set {b->input_};
- node_set output_set {b->output_};
- // TODO: Make a more type safe std::includes wrapper which disallows use on
- // non-ordered containers
+ node_set inputs_set{b->input_};
+ node_set output_set{b->output_};
+ // TODO: Make a more type safe std::includes wrapper which disallows use
+ // on non-ordered containers
JIT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(nodes_set)));
JIT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(inputs_set)));
JIT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(output_set)));
sum_set.insert(ALL_OF(output_set));
}
void check_graph() {
- node_set all_nodes_set(ALL_OF(g.all_nodes)); // NB: all_nodes is *unordered*
+ node_set all_nodes_set(
+ ALL_OF(g.all_nodes)); // NB: all_nodes is *unordered*
check_block(g.block_);
for (auto kv : anticipated_uses) {
}
}
-void Block::cloneFrom(Block * src, std::function<Value*(Value*)> value_map) {
+void Block::cloneFrom(Block* src, std::function<Value*(Value*)> value_map) {
std::unordered_map<Value*, Value*> local_map;
- auto env = [&](Value * v) {
+ auto env = [&](Value* v) {
auto it = local_map.find(v);
- if(it != local_map.end())
+ if (it != local_map.end())
return it->second;
return value_map(v);
};
auto graph = owningGraph();
- for(auto input : src->inputs()) {
+ for (auto input : src->inputs()) {
local_map[input] = this->addInput()->copyMetadata(input);
}
- for(auto node : src->nodes()) {
+ for (auto node : src->nodes()) {
auto new_node = this->appendNode(graph->createClone(node, env));
- for(size_t i = 0; i < node->outputs().size(); ++i) {
+ for (size_t i = 0; i < node->outputs().size(); ++i) {
auto oo = node->outputs()[i];
auto no = new_node->outputs()[i];
local_map[oo] = no;
no->copyMetadata(oo);
}
}
- for(auto output : src->outputs()) {
+ for (auto output : src->outputs()) {
this->registerOutput(env(output));
}
}
// we cannot destroy the output because it is used as the sentinel
// for the nodes() list and has to remain valid for the loop
output_->removeAllInputs();
- for(auto it = this->nodes().reverse().begin(),
- end = this->nodes().reverse().end();
- it != end; ++it) {
+ for (auto it = this->nodes().reverse().begin(),
+ end = this->nodes().reverse().end();
+ it != end;
+ ++it) {
it.destroyCurrent();
}
output_->destroy();
std::string name_base = name;
auto last_dot_pos = name.find_last_of('.');
if (last_dot_pos != std::string::npos && last_dot_pos + 1 != name.size()) {
- if (name.find_first_not_of("0123456789", last_dot_pos + 1) == std::string::npos) {
+ if (name.find_first_not_of("0123456789", last_dot_pos + 1) ==
+ std::string::npos) {
name_base = name.substr(0, last_dot_pos);
}
}
return name_base;
}
-Value* Value::setUniqueName(const std::string & name) {
- if (name.size() > 0 && name.find_first_not_of("0123456789") == std::string::npos) {
+Value* Value::setUniqueName(const std::string& name) {
+ if (name.size() > 0 &&
+ name.find_first_not_of("0123456789") == std::string::npos) {
throw std::runtime_error("names may not be integers: " + name);
}
- auto & names = node()->owningGraph()->unique_names_;
+ auto& names = node()->owningGraph()->unique_names_;
// clear any old name from the map
- if(hasUniqueName()) {
+ if (hasUniqueName()) {
names.erase(unique_name_);
unique_name_ = "";
}
// allow "" to clear the uniquename
- if(name == "")
+ if (name == "")
return this;
// if someone else has this name, then rename the other value
auto old_owner_of_name = names.find(name);
- if(old_owner_of_name != names.end()) {
+ if (old_owner_of_name != names.end()) {
size_t suffix = 1;
std::string name_base = name;
auto last_dot_pos = name.find_last_of('.');
if (last_dot_pos != std::string::npos && last_dot_pos + 1 != name.size()) {
- if (name.find_first_not_of("0123456789", last_dot_pos + 1) == std::string::npos) {
+ if (name.find_first_not_of("0123456789", last_dot_pos + 1) ==
+ std::string::npos) {
suffix = std::stoll(name.substr(last_dot_pos + 1));
name_base = name.substr(0, last_dot_pos);
}
std::stringstream ss;
ss << name_base << "." << suffix++;
replacement_name = ss.str();
- } while(names.count(replacement_name) > 0);
+ } while (names.count(replacement_name) > 0);
old_owner_of_name->second->setUniqueName(replacement_name);
}
return this;
}
-Value* Value::copyMetadata(Value * from) {
+Value* Value::copyMetadata(Value* from) {
setType(from->type());
if (from->hasUniqueName())
setUniqueName(from->uniqueName());
return this;
}
-void Value::replaceFirstUseWith(Value * newValue) {
+void Value::replaceFirstUseWith(Value* newValue) {
JIT_ASSERT(owningGraph() == newValue->owningGraph());
auto u = uses()[0];
u.user->inputs_[u.offset] = newValue;
uses_.erase(uses_.begin());
}
-void Value::replaceAllUsesWith(Value * newValue) {
+void Value::replaceAllUsesWith(Value* newValue) {
while (!uses().empty()) {
replaceFirstUseWith(newValue);
}
return i;
}
}
- throw std::runtime_error(std::string("Couldn't find an argument called ") + name.toQualString());
+ throw std::runtime_error(
+ std::string("Couldn't find an argument called ") + name.toQualString());
}
c10::optional<IValue> Node::get(Symbol name) const {
return input(findArgument(schema(), name));
}
-bool Node::matches(const char *signature_literal, at::ArrayRef<Symbol> const_inputs) const {
- if (!sig(signature_literal).matches(this)) return false;
+bool Node::matches(
+ const char* signature_literal,
+ at::ArrayRef<Symbol> const_inputs) const {
+ if (!sig(signature_literal).matches(this))
+ return false;
for (Symbol s : const_inputs) {
- if (!is_constant(s)) return false;
+ if (!is_constant(s))
+ return false;
}
return true;
}
}
const FunctionSchema* Node::maybeSchema() const {
- if(!schema_) {
- if(auto op = findOperatorFor(this)) {
+ if (!schema_) {
+ if (auto op = findOperatorFor(this)) {
schema_ = &op->schema();
}
}
bool Node::isNondeterministic() const {
static const OperatorSet nondeterministic_ops = {
- "aten::dropout(Tensor input, float p, bool train) -> Tensor",
- "aten::_fused_dropout(Tensor self, float p, Generator? generator) -> (Tensor, Tensor)",
- "aten::_standard_gamma(Tensor self, Generator? generator) -> Tensor",
- "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
- "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
- "aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator? generator) -> Tensor",
- "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",
- "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
- "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
- "aten::poisson(Tensor self, Generator? generator) -> Tensor",
- "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
- "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
- "aten::rand(int[] size, *, int dtype, int layout, Device device) -> Tensor",
- "aten::rand_like(Tensor self) -> Tensor",
- "aten::rand_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
- "aten::randint(int high, int[] size, *, int dtype, int layout, Device device) -> Tensor",
- "aten::randint(int low, int high, int[] size, *, int dtype, int layout, Device device) -> Tensor",
- "aten::randint_like(Tensor self, int high) -> Tensor",
- "aten::randint_like(Tensor self, int low, int high) -> Tensor",
- "aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device) -> Tensor",
- "aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device) -> Tensor",
- "aten::randn(int[] size, *, int dtype, int layout, Device device) -> Tensor",
- "aten::randn_like(Tensor self) -> Tensor",
- "aten::randn_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
- "aten::randperm(int n, *, int dtype, int layout, Device device) -> Tensor"
- };
+ "aten::dropout(Tensor input, float p, bool train) -> Tensor",
+ "aten::_fused_dropout(Tensor self, float p, Generator? generator) -> (Tensor, Tensor)",
+ "aten::_standard_gamma(Tensor self, Generator? generator) -> Tensor",
+ "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
+ "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
+ "aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator? generator) -> Tensor",
+ "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",
+ "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
+ "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
+ "aten::poisson(Tensor self, Generator? generator) -> Tensor",
+ "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
+ "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
+ "aten::rand(int[] size, *, int dtype, int layout, Device device) -> Tensor",
+ "aten::rand_like(Tensor self) -> Tensor",
+ "aten::rand_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
+ "aten::randint(int high, int[] size, *, int dtype, int layout, Device device) -> Tensor",
+ "aten::randint(int low, int high, int[] size, *, int dtype, int layout, Device device) -> Tensor",
+ "aten::randint_like(Tensor self, int high) -> Tensor",
+ "aten::randint_like(Tensor self, int low, int high) -> Tensor",
+ "aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device) -> Tensor",
+ "aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device) -> Tensor",
+ "aten::randn(int[] size, *, int dtype, int layout, Device device) -> Tensor",
+ "aten::randn_like(Tensor self) -> Tensor",
+ "aten::randn_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
+ "aten::randperm(int n, *, int dtype, int layout, Device device) -> Tensor"};
if (nondeterministic_ops.find(this) == nullptr) {
return false;
}
// Dropout with train = False is deterministic
- if (matches("aten::dropout(Tensor input, float p, bool train) -> Tensor") && is_constant(attr::train) && !get<bool>(attr::train).value()) {
+ if (matches("aten::dropout(Tensor input, float p, bool train) -> Tensor") &&
+ is_constant(attr::train) && !get<bool>(attr::train).value()) {
return false;
}
return true;
}
}
-Node::Node(Graph * graph_, NodeKind kind_) :
- kind_(kind_),
- graph_(graph_),
- owning_block_(nullptr),
- scope_(graph_->current_scope_),
- schema_(nullptr),
- topo_position_(0) {
+Node::Node(Graph* graph_, NodeKind kind_)
+ : kind_(kind_),
+ graph_(graph_),
+ owning_block_(nullptr),
+ scope_(graph_->current_scope_),
+ schema_(nullptr),
+ topo_position_(0) {
graph_->all_nodes.emplace(this);
}
JIT_ASSERT(i < outputs_.size());
JIT_ASSERT(outputs_[i]->uses().empty());
schema_ = nullptr;
- Value * n = outputs_[i];
+ Value* n = outputs_[i];
outputs_.erase(outputs_.begin() + i);
owningGraph()->freeValue(n);
- for(size_t j = i; j < outputs_.size(); j++) {
+ for (size_t j = i; j < outputs_.size(); j++) {
outputs_[j]->offset_--;
}
}
-Block * Node::addBlock() {
+Block* Node::addBlock() {
schema_ = nullptr;
blocks_.push_back(new Block(owningGraph(), this));
return blocks_.back();
void Node::eraseBlock(size_t i) {
JIT_ASSERT(i < blocks_.size());
schema_ = nullptr;
- Block * n = blocks_[i];
+ Block* n = blocks_[i];
blocks_.erase(blocks_.begin() + i);
n->destroy();
}
void Node::destroy() {
- while(!outputs().empty())
+ while (!outputs().empty())
eraseOutput(outputs().size() - 1);
- while(!blocks().empty())
+ while (!blocks().empty())
eraseBlock(blocks().size() - 1);
removeAllInputs();
- if(inBlockList())
+ if (inBlockList())
removeFromList();
graph_->freeNode(this);
}
-void Node::cloneFrom(Node * s) {
+void Node::cloneFrom(Node* s) {
setSourceLocation(s->getSourceLocation());
- if(s->scope_ && !s->scope_->isBlank())
+ if (s->scope_ && !s->scope_->isBlank())
scope_ = s->scope_;
copyAttributes(*s);
}
-void Node::replaceAllUsesWith(Node * n) {
+void Node::replaceAllUsesWith(Node* n) {
JIT_ASSERT(outputs().size() == n->outputs().size());
size_t nOutputs = outputs().size();
- for(size_t i = 0; i < nOutputs; i++) {
+ for (size_t i = 0; i < nOutputs; i++) {
outputs()[i]->replaceAllUsesWith(n->outputs()[i]);
}
}
return value;
}
-Value* Node::addInput(Value * value) {
+Value* Node::addInput(Value* value) {
JIT_ASSERT(graph_ == value->owningGraph());
schema_ = nullptr;
value->uses_.emplace_back(this, inputs_.size());
return value;
}
-Value* Node::replaceInput(size_t i, Value * newValue) {
+Value* Node::replaceInput(size_t i, Value* newValue) {
JIT_ASSERT(newValue->owningGraph() == graph_);
schema_ = nullptr;
- Value * old = dropInput(i);
+ Value* old = dropInput(i);
inputs_[i] = newValue;
newValue->uses_.emplace_back(this, i);
return old;
}
-void Node::replaceInputWith(Value * from, Value * to) {
+void Node::replaceInputWith(Value* from, Value* to) {
JIT_ASSERT(from->owningGraph() == graph_);
JIT_ASSERT(to->owningGraph() == graph_);
schema_ = nullptr;
size_t i = 0;
- for(auto input : inputs()) {
- if(input == from)
+ for (auto input : inputs()) {
+ if (input == from)
replaceInput(i, to);
i++;
}
}
// should never reach here, since both nodes are ultimately in the same graph
JIT_ASSERT(false);
-
}
-bool Node::isBefore(const Node * n) const {
+bool Node::isBefore(const Node* n) const {
return isBeforeOrAfter(n, MoveSide::BEFORE);
}
-bool Node::isAfter(const Node * n) const {
+bool Node::isAfter(const Node* n) const {
return isBeforeOrAfter(n, MoveSide::AFTER);
}
-Node* Node::insertBefore(Node * n) {
+Node* Node::insertBefore(Node* n) {
JIT_ASSERT(n->inBlockList());
insertAfter(n->prev());
return this;
}
-Node* Node::insertAfter(Node * n) {
+Node* Node::insertAfter(Node* n) {
JIT_ASSERT(!inBlockList() && n->inBlockList());
JIT_ASSERT(n->owningBlock());
this->owning_block_ = n->owningBlock();
- Node * next = n->next();
+ Node* next = n->next();
n->next() = this;
this->prev() = n;
this->next() = next;
namespace {
struct WorkingSet {
public:
- explicit WorkingSet(Node* mover, const AliasDb& aliasDb)
- : aliasDb_(aliasDb) {
+ explicit WorkingSet(Node* mover, const AliasDb& aliasDb) : aliasDb_(aliasDb) {
add(mover);
}
// node at a time. When we can't move past a node (because it depends on the
// working set), then add it to the working set and keep moving until we hit
// `moveAfter`.
-bool Node::tryMove(Node* movePoint, MoveSide moveSide, const AliasDb& aliasDb, bool dryRun) {
+bool Node::tryMove(
+ Node* movePoint,
+ MoveSide moveSide,
+ const AliasDb& aliasDb,
+ bool dryRun) {
JIT_ASSERT(this->inBlockList() && movePoint->inBlockList());
JIT_ASSERT(this->owningBlock() == movePoint->owningBlock());
if (this == movePoint) {
}
}
-void Node::moveAfter(Node * n) {
+void Node::moveAfter(Node* n) {
removeFromList();
insertAfter(n);
}
-void Node::moveBefore(Node * n) {
+void Node::moveBefore(Node* n) {
removeFromList();
insertBefore(n);
}
dropInput(i);
// everything after this input shifts left,
// so we need to update their use offsets to match
- for(size_t j = i+1; j < inputs_.size(); j++) {
+ for (size_t j = i + 1; j < inputs_.size(); j++) {
auto it = findUseForInput(j);
it->offset--;
}
void Node::removeAllInputs() {
schema_ = nullptr;
- for(size_t i = 0; i < inputs().size(); ++i)
+ for (size_t i = 0; i < inputs().size(); ++i)
dropInput(i);
inputs_.clear();
}
use_list::iterator Node::findUseForInput(size_t i) {
- auto & input_uses = inputs_[i]->uses_;
+ auto& input_uses = inputs_[i]->uses_;
// O(N) on the use list, but unless we get nodes with +100 uses
// vector traversal still is probably faster than linked list
auto use_it = std::find(input_uses.begin(), input_uses.end(), Use(this, i));
void Node::removeFromList() {
JIT_ASSERT(inBlockList());
this->owning_block_ = nullptr;
- Node * next = this->next();
- Node * prev = this->prev();
+ Node* next = this->next();
+ Node* prev = this->prev();
prev->next() = next;
next->prev() = prev;
this->next() = nullptr;
}
inline const SourceRange& fakeRange() {
- static SourceRange range(std::make_shared<std::string>("<internally-created-node>"), 0, 1);
+ static SourceRange range(
+ std::make_shared<std::string>("<internally-created-node>"), 0, 1);
return range;
}
Node* Graph::create(NodeKind kind, size_t num_outputs) {
// NB: Node constructor adds node to all_nodes
auto n = new Node(this, kind);
- for(size_t i = 0; i < num_outputs; i++)
+ for (size_t i = 0; i < num_outputs; i++)
n->addOutput();
return n;
}
-Node* Graph::create(NodeKind kind, ArrayRef<Value*> inputs, size_t num_outputs) {
+Node* Graph::create(
+ NodeKind kind,
+ ArrayRef<Value*> inputs,
+ size_t num_outputs) {
auto n = create(kind, num_outputs);
- for(auto i : inputs)
+ for (auto i : inputs)
n->addInput(i);
return n;
}
}
Node* Graph::createNone(TypePtr typ) {
- Node * n = create(prim::None);
+ Node* n = create(prim::None);
n->output()->setType(OptionalType::create(std::move(typ)));
return n;
}
-Node * Graph::createFusionGroup() {
+Node* Graph::createFusionGroup() {
auto n = create(prim::FusionGroup, 0);
- n->g_(attr::Subgraph,std::make_shared<Graph>(current_scope()));
+ n->g_(attr::Subgraph, std::make_shared<Graph>(current_scope()));
return n;
}
return n;
}
-Node* Graph::createTupleUnpack(Value * v) {
+Node* Graph::createTupleUnpack(Value* v) {
TupleTypePtr tt = v->type()->expect<TupleType>();
auto n = create(prim::TupleUnpack, {v}, 0);
- for(auto & element : tt->elements()) {
+ for (auto& element : tt->elements()) {
n->addOutput()->setType(element);
}
return n;
}
-Node* Graph::createTupleIndex(Value * tup, int64_t index) {
+Node* Graph::createTupleIndex(Value* tup, int64_t index) {
auto n = create(prim::TupleIndex, {tup});
n->i_(attr::index, index);
auto tuple_type = tup->type()->expect<TupleType>();
return n;
}
-Node* Graph::createTupleSlice(Value * tup, int64_t beg, int64_t end) {
+Node* Graph::createTupleSlice(Value* tup, int64_t beg, int64_t end) {
auto n = create(prim::TupleSlice, {tup});
auto tuple_type = tup->type()->expect<TupleType>();
n->i_(attr::beg, beg);
Node* Graph::createList(const TypePtr& elem_type, at::ArrayRef<Value*> values) {
auto n = create(prim::ListConstruct, values);
- for(const auto & v : values) {
+ for (const auto& v : values) {
JIT_ASSERT(v->type()->isSubtypeOf(elem_type));
}
n->output()->setType(ListType::create(elem_type));
return n;
}
-Node* Graph::createListUnpack(Value *v, size_t size) {
+Node* Graph::createListUnpack(Value* v, size_t size) {
ListTypePtr list_type = v->type()->expect<ListType>();
TypePtr elem_type = list_type->getElementType();
auto n = create(prim::ListUnpack, {v}, 0);
Node* Graph::createNumToTensor(Value* value) {
auto typ = value->type();
- Node * result = create(prim::NumToTensor, {value});
+ Node* result = create(prim::NumToTensor, {value});
result->output()->setType(CompleteTensorType::fromNumberType(std::move(typ)));
return result;
}
return result;
}
-Node* Graph::createClone(Node * n, const std::function<Value*(Value*)>& value_map, bool copy_blocks) {
- //n can be from a different graph
- Node * r = n->allocNewInstance(this);
- for(auto o : n->outputs()) {
+Node* Graph::createClone(
+ Node* n,
+ const std::function<Value*(Value*)>& value_map,
+ bool copy_blocks) {
+ // n can be from a different graph
+ Node* r = n->allocNewInstance(this);
+ for (auto o : n->outputs()) {
r->addOutput()->copyMetadata(o);
}
r->cloneFrom(n);
- for(auto i : n->inputs()) {
+ for (auto i : n->inputs()) {
r->addInput(value_map(i));
}
- if(copy_blocks) {
- for(auto b : n->blocks()) {
+ if (copy_blocks) {
+ for (auto b : n->blocks()) {
r->addBlock()->cloneFrom(b, value_map);
}
}
IValue val,
c10::optional<SourceRange> loc,
c10::optional<ScopePtr> scope) {
- return jit::insertConstant(*this, std::move(val), std::move(loc), std::move(scope));
+ return jit::insertConstant(
+ *this, std::move(val), std::move(loc), std::move(scope));
}
std::string Graph::toString() const {
}
Graph::~Graph() {
- for (const Node * n : all_nodes)
+ for (const Node* n : all_nodes)
delete n;
- for (const Value * v : all_values)
+ for (const Value* v : all_values)
delete v;
- for (const Block * b : all_blocks)
+ for (const Block* b : all_blocks)
delete b;
}
-void Graph::freeNode(Node * n) {
+void Graph::freeNode(Node* n) {
auto it = all_nodes.find(n);
JIT_ASSERT(it != all_nodes.end());
delete *it;
all_nodes.erase(it);
}
-void Graph::freeValue(Value * v) {
+void Graph::freeValue(Value* v) {
v->setUniqueName("");
auto it = all_values.find(v);
JIT_ASSERT(it != all_values.end());
delete *it;
all_values.erase(it);
}
-void Graph::freeBlock(Block * b) {
+void Graph::freeBlock(Block* b) {
auto it = all_blocks.find(b);
JIT_ASSERT(it != all_blocks.end());
delete *it;
at::ArrayRef<Value*> createTupleUnpack(Value* v) {
// small peephole optimization to ensure IntList attributes can still turn
// into constants e.g. in x.expand([3, 4])
- if(v->node()->kind() == prim::TupleConstruct)
+ if (v->node()->kind() == prim::TupleConstruct)
return v->node()->inputs();
- auto & g = *v->owningGraph();
+ auto& g = *v->owningGraph();
return g.insertNode(g.createTupleUnpack(v))->outputs();
}
-std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs, bool unpack_outputs) {
+std::vector<Value*> inlineCallTo(
+ Graph& g,
+ Graph& callee,
+ ArrayRef<Value*> inputs,
+ bool unpack_outputs) {
std::unordered_map<Value*, Value*> value_map;
auto value_map_func = [&](Value* v) { return value_map.at(v); };
JIT_ASSERT(callee.inputs().size() == inputs.size());
value_map[callee.inputs()[i]] = inputs[i];
}
for (auto* node : callee.nodes()) {
- auto* new_node =
- g.insertNode(g.createClone(node, value_map_func));
+ auto* new_node = g.insertNode(g.createClone(node, value_map_func));
for (size_t i = 0; i < node->outputs().size(); ++i) {
value_map[node->outputs()[i]] = new_node->outputs()[i];
}
if (unpack_outputs && outputs.size() == 1 &&
callee.outputs().at(0)->type()->kind() == TupleType::Kind) {
- auto tup = outputs[0];
- outputs.clear();
- for(Value* v : createTupleUnpack(tup)) {
- outputs.emplace_back(v);
- }
- // if this was a peephole tuple unpack we can just get rid of
- // the tuple construct here and prevent needing DCE
- if (tup->node()->kind() == prim::TupleConstruct && !tup->node()->hasUses()) {
- tup->node()->destroy();
- }
+ auto tup = outputs[0];
+ outputs.clear();
+ for (Value* v : createTupleUnpack(tup)) {
+ outputs.emplace_back(v);
+ }
+ // if this was a peephole tuple unpack we can just get rid of
+ // the tuple construct here and prevent needing DCE
+ if (tup->node()->kind() == prim::TupleConstruct &&
+ !tup->node()->hasUses()) {
+ tup->node()->destroy();
+ }
}
return outputs;
}
-PythonOp* defaultAllocPythonOp(Graph*g) {
- throw std::runtime_error("Trying to allocate a Python object without python bindings loaded");
+PythonOp* defaultAllocPythonOp(Graph* g) {
+ throw std::runtime_error(
+ "Trying to allocate a Python object without python bindings loaded");
}
std::atomic<decltype(&defaultAllocPythonOp)> alloc_python_op;
alloc_python_op.store(v);
}
-
-}} // namespace torch::jit
+} // namespace jit
+} // namespace torch
#pragma once
-#include <torch/csrc/jit/attributes.h>
#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/attributes.h>
+#include <torch/csrc/jit/constants.h>
+#include <torch/csrc/jit/function_schema.h>
#include <torch/csrc/jit/generic_if.h>
#include <torch/csrc/jit/graph_node_list.h>
#include <torch/csrc/jit/interned_strings.h>
+#include <torch/csrc/jit/ivalue.h>
+#include <torch/csrc/jit/named_value.h>
#include <torch/csrc/jit/resource_guard.h>
#include <torch/csrc/jit/scope.h>
#include <torch/csrc/jit/source_location.h>
#include <torch/csrc/jit/source_range.h>
-#include <torch/csrc/jit/constants.h>
-#include <torch/csrc/jit/function_schema.h>
-#include <torch/csrc/jit/ivalue.h>
#include <torch/csrc/jit/type.h>
-#include <torch/csrc/jit/named_value.h>
+#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/utils/disallow_copy.h>
#include <torch/csrc/utils/functional.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/python_stub.h>
-#include <torch/csrc/WindowsTorchApiMacro.h>
#include <ATen/ATen.h>
#include <c10/util/ArrayRef.h>
#include <unordered_set>
#include <vector>
-namespace torch { namespace autograd {
+namespace torch {
+namespace autograd {
struct Function;
-}} // namespace torch::autograd
+}
+} // namespace torch
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
// Graph represents one "function" of computation.
-// It uses a simple ownership model where the graph owns all the nodes inside it.
-// All references inside the graph are raw pointers.
-// Destroying the Graph will invalidate any pointers to nodes in the graph.
+// It uses a simple ownership model where the graph owns all the nodes inside
+// it. All references inside the graph are raw pointers. Destroying the Graph
+// will invalidate any pointers to nodes in the graph.
struct Graph;
// Node is the base class of the IR graph. It represents one computation
// Tensor or an opaque Handle object, as determined by type().
struct Value;
-TORCH_API std::ostream& operator<<(std::ostream & out, const Graph & g);
-TORCH_API std::ostream& operator<<(std::ostream & out, const Node & n);
+TORCH_API std::ostream& operator<<(std::ostream& out, const Graph& g);
+TORCH_API std::ostream& operator<<(std::ostream& out, const Node& n);
// A list of nodes, with inputs and outputs
struct Block;
// 'user' is the consumer of the value, offset is the index into
// 'user's input this where the produces will be found.
struct Use {
- Use(Node * user, size_t offset)
- : user(user), offset(offset) {}
- Node * user;
+ Use(Node* user, size_t offset) : user(user), offset(offset) {}
+ Node* user;
size_t offset;
- bool operator==(const Use & b) {
+ bool operator==(const Use& b) {
return user == b.user && offset == b.offset;
}
};
using value_list = std::vector<Value*>;
using use_list = std::vector<Use>;
using pyobj_list = std::vector<THPObjectPtr>;
-template<typename T>
+template <typename T>
using ArrayRef = at::ArrayRef<T>;
using NodeKind = Symbol;
using topo_position_t = int64_t;
struct Value {
TH_DISALLOW_COPY_AND_ASSIGN(Value);
- Value(Node * node_, size_t offset_);
-private:
+ Value(Node* node_, size_t offset_);
+
+ private:
friend struct Node;
friend struct Graph;
- Node * node_;
+ Node* node_;
size_t offset_;
- size_t unique_ = 0; // unique id
+ size_t unique_ = 0; // unique id
use_list uses_;
std::string unique_name_;
TypePtr type_;
-public:
+
+ public:
Value* setType(TypePtr type);
void inferTypeFrom(const at::Tensor& output) {
setType(CompleteTensorType::create(output));
}
- const TypePtr & type() const {
+ const TypePtr& type() const {
JIT_ASSERT(type_ != nullptr);
return type_;
}
bool hasUniqueName() const {
return !unique_name_.empty();
}
- TORCH_API Value* setUniqueName(const std::string & name);
+ TORCH_API Value* setUniqueName(const std::string& name);
std::string uniqueName() const {
if (hasUniqueName())
return unique_name_;
void setOffset(size_t offset) {
offset_ = offset;
}
- const Node * node() const {
+ const Node* node() const {
return node_;
}
- Graph * owningGraph();
- const Graph * owningGraph() const;
+ Graph* owningGraph();
+ const Graph* owningGraph() const;
// TODO: make this more const correct
- const use_list & uses() const {
+ const use_list& uses() const {
return uses_;
}
return !uses().empty();
}
- TORCH_API void replaceFirstUseWith(Value * newValue);
+ TORCH_API void replaceFirstUseWith(Value* newValue);
// Replaces all uses of this value with 'newValue'.
//
// Result: %3 = f(%1, %2)
// %4 = g(%6)
// %5 = h(%6, %6)
- TORCH_API void replaceAllUsesWith(Value * newValue);
+ TORCH_API void replaceAllUsesWith(Value* newValue);
- TORCH_API Value* copyMetadata(Value * from);
+ TORCH_API Value* copyMetadata(Value* from);
};
-
struct Node : public Attributes<Node> {
TH_DISALLOW_COPY_AND_ASSIGN(Node);
friend struct Graph;
friend const_graph_node_list;
friend graph_node_list_iterator;
friend const_graph_node_list_iterator;
-private:
+
+ private:
// each node but Return/Param
// is associated with exactly one place in the node list...
// of the graph_
- // this circular is a doubly-linked list, the Return node is used as the sentinel for the beginning and end of the list
- // such that the list never has null pointers
- // next_in_graph[0] is next pointer
- // next_in_graph[1] is prev pointer
- // using an array to allow the same iterator class for forward and reverse node lists
+ // this circular is a doubly-linked list. The Return node is used as the
+ // sentinel for the beginning and end of the list such that the list never has
+ // null pointers.
+ // - next_in_graph[0] is next pointer
+ // - next_in_graph[1] is prev pointer
+ //
+ // Using an array to allow the same iterator class for forward and
+ // reverse node lists
+ //
// This list represents a topological sort
- Node* next_in_graph[2] = { nullptr, nullptr };
+ Node* next_in_graph[2] = {nullptr, nullptr};
const NodeKind kind_;
std::vector<Value*> inputs_;
ScopePtr scope_;
// Assumes FunctionSchemas are persistent, so we don't manage their lifetime.
// This field is effective a cache that's populated on attribute lookups and
- // invalidated every time we perform an operation that could potentially change
- // the schema.
- // note: mutable because schema_ is effectively a cache
+ // invalidated every time we perform an operation that could potentially
+ // change the schema. note: mutable because schema_ is effectively a cache
mutable const FunctionSchema* schema_;
topo_position_t topo_position_ = 0;
-protected:
- TORCH_API Node(Graph * graph_, NodeKind kind_); //defined after graph
-public:
- Node* & next() { return next_in_graph[kNextDirection]; }
- Node* & prev() { return next_in_graph[kPrevDirection]; }
- Node* const & next() const { return next_in_graph[kNextDirection]; }
- Node* const & prev() const { return next_in_graph[kPrevDirection]; }
+
+ protected:
+ TORCH_API Node(Graph* graph_, NodeKind kind_); // defined after graph
+ public:
+ Node*& next() {
+ return next_in_graph[kNextDirection];
+ }
+ Node*& prev() {
+ return next_in_graph[kPrevDirection];
+ }
+ Node* const& next() const {
+ return next_in_graph[kNextDirection];
+ }
+ Node* const& prev() const {
+ return next_in_graph[kPrevDirection];
+ }
NodeKind kind() const {
return kind_;
std::shared_ptr<SourceLocation> getSourceLocation() const {
return source_location_;
}
- Graph * owningGraph() {
+ Graph* owningGraph() {
return graph_;
}
- const Graph * owningGraph() const {
+ const Graph* owningGraph() const {
return graph_;
}
- Block * owningBlock() {
+ Block* owningBlock() {
return owning_block_;
}
- const Block * owningBlock() const {
+ const Block* owningBlock() const {
return owning_block_;
}
ScopePtr scope() {
// raw pointers are.
return {outputs_.data(), outputs_.size()};
}
- Value * output(size_t i) const {
+ Value* output(size_t i) const {
return outputs_.at(i);
}
bool hasUses() const {
- for(auto o : outputs()) {
- if(!o->uses().empty())
+ for (auto o : outputs()) {
+ if (!o->uses().empty())
return true;
}
return false;
}
- TORCH_API void replaceAllUsesWith(Node * n);
+ TORCH_API void replaceAllUsesWith(Node* n);
- // lots of things like chunk have a single input or single output, so we have a
- // helper to make accessing it easier
- Value * input() {
+ // lots of things like chunk have a single input or single output, so we have
+ // a helper to make accessing it easier
+ Value* input() {
JIT_ASSERT(inputs_.size() == 1);
return inputs_.at(0);
}
- Value * output() {
+ Value* output() {
JIT_ASSERT(outputs_.size() == 1);
return outputs_.at(0);
}
JIT_ASSERT(outputs_.size() == 1);
return outputs_.at(0);
}
- const Value * input() const {
+ const Value* input() const {
JIT_ASSERT(inputs_.size() == 1);
return inputs_.at(0);
}
// Access a particular input. This is a checked index.
- Value * input(size_t i) const {
+ Value* input(size_t i) const {
return inputs_.at(i);
}
template <typename T>
c10::optional<T> get(Symbol name) const {
- if(auto v = get(name))
+ if (auto v = get(name))
return v->template to<T>();
return c10::nullopt;
}
}
TORCH_API bool isNondeterministic() const;
- TORCH_API bool hasSideEffects () const;
+ TORCH_API bool hasSideEffects() const;
// Graphs
// Given: %3 = f(%1, %2)
// Execute: %3.addInput(%4)
// Result: %3 = f(%1, %2, %4)
- TORCH_API Value* addInput(Value * value);
+ TORCH_API Value* addInput(Value* value);
// Add 'value' as an input to 'this' at the specified position in the
// arguments. Returns the added value for ease of chaining.
// Given: %3 = f(%1, %2)
// Execute: %3.replaceInput(1, %4)
// Result: %3 = f(%1, %4)
- TORCH_API Value * replaceInput(size_t i, Value * newValue);
+ TORCH_API Value* replaceInput(size_t i, Value* newValue);
// Replace all occurrences of 'from' in the inputs of this
// node with 'to'. Corresponds to llvm's replaceUsesOfWith.
// Given: %3 = f(%1, %2, %1)
// Execute: %3.replaceInputWith(%1, %4)
// Result: %3 = f(%4, %2, %4)
- TORCH_API void replaceInputWith(Value * from, Value * to);
+ TORCH_API void replaceInputWith(Value* from, Value* to);
TORCH_API Value* addOutput();
TORCH_API void eraseOutput(size_t i);
- TORCH_API Block * addBlock();
+ TORCH_API Block* addBlock();
TORCH_API void eraseBlock(size_t i);
// Each Node can have a list of subblocks. These are used to define structured
// nested control flow operators such as If and Loop.
// The meaning of a block is specific to the kind of node it is in, but
// all blocks share these semantics:
- // * Nested lexical scoping: If a node 'Parent' has a subblock which contains a
- // node 'Child', Child can use any value that was in scope for the Parent
+ // * Nested lexical scoping: If a node 'Parent' has a subblock which contains
+ // a node 'Child', Child can use any value that was in scope for the Parent
// node in addition to any values defined before 'Child' in the subblock.
- // * The list of inputs to the block are in scope for the duration of the block
+ // * The list of inputs to the block are in scope for the duration of the
+ // block
// * the outputs of the Parent node are not in scope for the subblocks
// Typically the inputs to a block that represents control flow act as
// as the equivalents phi-nodes in standard SSA form,
}
// Is 'this' before 'n' in the topological order?
- TORCH_API bool isBefore(const Node * n) const;
+ TORCH_API bool isBefore(const Node* n) const;
// Is 'this' after 'n' in the topological order?
- TORCH_API bool isAfter(const Node * n) const;
+ TORCH_API bool isAfter(const Node* n) const;
// Insert unattached 'this' node before 'n' in the topological order.
// Returns this (for chaining).
// Result: %3 = f(%1, %2)
// %5 = h(%1)
// %4 = g(%3)
- TORCH_API Node* insertBefore(Node * n);
+ TORCH_API Node* insertBefore(Node* n);
// Insert unattached 'this' node after 'n' in the topological order.
// Returns this (for chaining).
// Result: %3 = f(%1, %2)
// %4 = g(%3)
// %5 = h(%1)
- TORCH_API Node* insertAfter(Node * n);
+ TORCH_API Node* insertAfter(Node* n);
// Move 'this' (already in the graph) after 'n' in the topological order.
//
// Result: %3 = g(%1)
// %2 = f(%1)
//
- TORCH_API void moveAfter(Node * n);
+ TORCH_API void moveAfter(Node* n);
// Move 'this' (already in the graph) after 'n' in the topological order.
//
// Execute: %3.moveBefore(%2)
// Result: %3 = g(%1)
// %2 = f(%1)
- TORCH_API void moveBefore(Node * n);
+ TORCH_API void moveBefore(Node* n);
// Move 'this' (already in the graph) before 'n' in the topological order.
//
// Example usage: if(auto s = n.cast<Select>()) { ... }
//
// TODO: Make this const correct
- template<typename T>
+ template <typename T>
T* cast() {
- if(T::Kind == kind())
+ if (T::Kind == kind())
return static_cast<T*>(this);
return nullptr;
}
- template<typename T>
+ template <typename T>
T* expect() {
JIT_ASSERTM(
T::Kind == kind(),
- "expected a ", T::Kind.toDisplayString(),
- " but found a ", kind().toDisplayString());
+ "expected a ",
+ T::Kind.toDisplayString(),
+ " but found a ",
+ kind().toDisplayString());
return static_cast<T*>(this);
}
// XXX: this function is meant to be used with string literals only!
- TORCH_API bool matches(const char *signature_literal, at::ArrayRef<Symbol> const_inputs={}) const;
+ TORCH_API bool matches(
+ const char* signature_literal,
+ at::ArrayRef<Symbol> const_inputs = {}) const;
const FunctionSchema& schema() const {
if (!schema_)
private:
enum class MoveSide { BEFORE, AFTER };
- bool tryMove(Node* movePoint, MoveSide moveSide, const AliasDb& aliasDb, bool dryRun);
+ bool tryMove(
+ Node* movePoint,
+ MoveSide moveSide,
+ const AliasDb& aliasDb,
+ bool dryRun);
void move(Node* movePoint, MoveSide moveSide);
bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const;
std::pair<Value*, const Argument&> findInput(Symbol name);
void findSchema() const;
- // Lookup iterator in use list of _input i_ that corresponds to its use of _this_
+ // Lookup iterator in use list of _input i_ that corresponds to its use of
+ // _this_
TORCH_API use_list::iterator findUseForInput(size_t i);
// remove the use of input i, this sets input i to nullptr, but
TORCH_API Value* dropInput(size_t i);
bool inBlockList() const {
- if(next() == nullptr) {
+ if (next() == nullptr) {
JIT_ASSERT(prev() == nullptr);
}
return next() != nullptr;
void assignTopoPosition();
-protected:
+ protected:
// subclasses must override
// this function is used by createClone to initialize a new version
// of a node in another graph. It should allocate a new instance of the same
// concrete type as 'this', but in graph 'g' which might be different
// than graph_
- virtual Node * allocNewInstance(Graph * g) {
+ virtual Node* allocNewInstance(Graph* g) {
return new Node(g, kind());
}
// create a copy of all properties of Node s into this.
// 'this' will be allocated with s->allocNewInstance(g) so it should have
// the same concrete type as 's'
//
- TORCH_API virtual void cloneFrom(Node * s);
+ TORCH_API virtual void cloneFrom(Node* s);
};
struct Block {
friend struct Node;
friend struct Graph;
TH_DISALLOW_COPY_AND_ASSIGN(Block);
- TORCH_API Block(Graph * graph_, Node * node_);
+ TORCH_API Block(Graph* graph_, Node* node_);
at::ArrayRef<Value*> inputs() {
return input_->outputs();
}
at::ArrayRef<const Value*> inputs() const {
- const auto & inputs = input_->outputs();
+ const auto& inputs = input_->outputs();
return {inputs.data(), inputs.size()};
}
at::ArrayRef<Value*> outputs() {
const_graph_node_list nodes() const {
return {output_, kNextDirection};
}
- Node * return_node() {
+ Node* return_node() {
return output_;
}
- const Node * return_node() const {
+ const Node* return_node() const {
return output_;
}
- Node * param_node() {
+ Node* param_node() {
return input_;
}
- const Node * param_node() const {
+ const Node* param_node() const {
return input_;
}
- Value * addInput(std::string name="") {
- Value * v = input_->addOutput();
+ Value* addInput(std::string name = "") {
+ Value* v = input_->addOutput();
v->setUniqueName(std::move(name));
return v;
}
void eraseInput(size_t i) {
input_->eraseOutput(i);
}
- size_t registerOutput(Value * v) {
+ size_t registerOutput(Value* v) {
output_->addInput(v);
return outputs().size() - 1;
}
void eraseOutput(size_t i) {
output_->removeInput(i);
}
- Node * appendNode(Node * n) {
+ Node* appendNode(Node* n) {
JIT_ASSERT(n->graph_ == graph_ && !n->inBlockList());
n->insertBefore(output_);
return n;
}
- Node * prependNode(Node * n) {
+ Node* prependNode(Node* n) {
JIT_ASSERT(n->graph_ == graph_ && !n->inBlockList());
n->insertAfter(output_);
return n;
}
- Graph * owningGraph() {
+ Graph* owningGraph() {
return graph_;
}
- const Graph * owningGraph() const {
+ const Graph* owningGraph() const {
return graph_;
}
- Node * owningNode() {
+ Node* owningNode() {
return owning_node_;
}
- const Node * owningNode() const {
+ const Node* owningNode() const {
return owning_node_;
}
// clone all inputs, nodes, and outputs from src and append them
// to the inputs, nodes, and outputs of this block
// value_map is used whenever a node in src references a free variable
// in src to look up its corresponding value
- TORCH_API void cloneFrom(Block * src, std::function<Value*(Value*)> value_map);
-private:
+ TORCH_API void cloneFrom(Block* src, std::function<Value*(Value*)> value_map);
+
+ private:
void reIndexTopology();
// should only be called in the constructor
// do not have to be removed before you can destroy the block
void destroy();
- Graph * const graph_;
+ Graph* const graph_;
// holds outputs in a way that can be reflected
// as a Use object
// also used as the beginning/end of the circular node list to avoid
// having corner cases where the list is empty.
- Node * const output_;
- Node * const input_;
- Node * const owning_node_; // either the node that has this block or nullptr for root
+ Node* const output_;
+ Node* const input_;
+ Node* const
+ owning_node_; // either the node that has this block or nullptr for root
};
struct Graph {
-TH_DISALLOW_COPY_AND_ASSIGN(Graph);
-friend struct Node;
-friend struct Value;
-friend struct Block;
-private:
+ TH_DISALLOW_COPY_AND_ASSIGN(Graph);
+ friend struct Node;
+ friend struct Value;
+ friend struct Block;
+ private:
// only used to keep track of allocated nodes
// actual representation of Graph is done with
// inputs, outputs, nodes
// by default this is set to append to the top level block
Node* insert_before_;
-public:
-
+ public:
Graph(ScopePtr scope_root)
- : next_unique_(0)
- , current_scope_(std::move(scope_root))
- , block_(new Block(this, nullptr))
- , insert_before_(return_node()) {}
+ : next_unique_(0),
+ current_scope_(std::move(scope_root)),
+ block_(new Block(this, nullptr)),
+ insert_before_(return_node()) {}
Graph() : Graph(c10::make_intrusive<Scope>()) {}
return block_->inputs();
}
at::ArrayRef<const Value*> inputs() const {
- const auto & block = *block_;
+ const auto& block = *block_;
return block.inputs();
}
at::ArrayRef<Value*> outputs() {
return block_->outputs();
}
at::ArrayRef<const Value*> outputs() const {
- const auto & block = *block_;
+ const auto& block = *block_;
return block.outputs();
}
graph_node_list nodes() {
return block_->nodes();
}
const_graph_node_list nodes() const {
- const auto & block = *block_;
+ const auto& block = *block_;
return block.nodes();
}
- Node * param_node() {
+ Node* param_node() {
return block_->param_node();
}
- const Node * param_node() const {
+ const Node* param_node() const {
return block_->param_node();
}
- Node * return_node() {
+ Node* return_node() {
return block_->return_node();
}
- const Node * return_node() const {
+ const Node* return_node() const {
return block_->return_node();
}
void push_scope(const std::string& scope_name) {
void set_current_scope(ScopePtr scope) {
current_scope_ = std::move(scope);
}
- Value * addInput(std::string name="") {
+ Value* addInput(std::string name = "") {
return block_->addInput(std::move(name));
}
Value* insertInput(size_t i, std::string name = "") {
return unique_names_;
}
- size_t registerOutput(Value * n) {
+ size_t registerOutput(Value* n) {
return block_->registerOutput(n);
}
- TORCH_API Node * create(NodeKind kind, size_t num_outputs=1);
- TORCH_API Node * create(NodeKind kind, ArrayRef<Value*> inputs, size_t num_outputs=1);
+ TORCH_API Node* create(NodeKind kind, size_t num_outputs = 1);
+ TORCH_API Node* create(
+ NodeKind kind,
+ ArrayRef<Value*> inputs,
+ size_t num_outputs = 1);
-
- TORCH_API Node* createNone(TypePtr typ); // value of None with type Optional[typ]
+ TORCH_API Node* createNone(
+ TypePtr typ); // value of None with type Optional[typ]
TORCH_API Node* createUndefined();
TORCH_API Node* createFusionGroup();
TORCH_API Node* createDifferentiableSubgraph();
TORCH_API Node* createTuple(at::ArrayRef<Value*> values);
- TORCH_API Node* createTupleUnpack(Value * v);
- TORCH_API Node* createTupleIndex(Value * tup, int64_t index);
- TORCH_API Node* createTupleSlice(Value * tup, int64_t beg, int64_t end);
- TORCH_API Node* createList(const TypePtr& elem_type, at::ArrayRef<Value*> values);
- TORCH_API Node* createListUnpack(Value *v, size_t size);
+ TORCH_API Node* createTupleUnpack(Value* v);
+ TORCH_API Node* createTupleIndex(Value* tup, int64_t index);
+ TORCH_API Node* createTupleSlice(Value* tup, int64_t beg, int64_t end);
+ TORCH_API Node* createList(
+ const TypePtr& elem_type,
+ at::ArrayRef<Value*> values);
+ TORCH_API Node* createListUnpack(Value* v, size_t size);
TORCH_API Node* createNumToTensor(Value* value);
TORCH_API Node* createImplicitTensorToNum(const TypePtr& type, Value* value);
Node* createPythonOp(
// use node_map to translate inputs of n to inputs of the cloned node
// if copy_blocks is false, it will not recursively clone the nested blocks
// this node contains.
- TORCH_API Node * createClone(Node * n, const std::function<Value*(Value*)>& value_map, bool copy_blocks=true);
+ TORCH_API Node* createClone(
+ Node* n,
+ const std::function<Value*(Value*)>& value_map,
+ bool copy_blocks = true);
TORCH_API Value* insertConstant(
IValue val,
c10::optional<SourceRange> loc = c10::nullopt,
c10::optional<ScopePtr> scope = c10::nullopt);
-
- // schema-driven insert
- // this inserts a node into the graph with inputs determined from args and kwargs using Python
- // argument matching rules, and checks that the op matches a known schema
- // if this node successfully completes, it guarentees the node is a correctly-formed invocation
- // of opname
+ // Schema-driven insert:
+ // This inserts a node into the graph with inputs determined from args and
+ // kwargs using Python argument matching rules, and checks that the op matches
+ // a known schema.
+ //
+ // If this node successfully completes, it guarentees the node
+ // is a correctly-formed invocation of opname
TORCH_API Value* insert(
Symbol opname,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs = {},
const c10::optional<SourceRange>& range = {});
- Node * appendNode(Node * n) {
+ Node* appendNode(Node* n) {
return block_->appendNode(n);
}
- Node * prependNode(Node * n) {
+ Node* prependNode(Node* n) {
return block_->prependNode(n);
}
// insert before insert_before_ node
// initialized to insert at the end of the top level block
// can be changed with setInsertPoint()
- Node * insertNode(Node * n) {
- JIT_ASSERT(insert_before_->inBlockList() && "insert point node is no longer in a block list");
+ Node* insertNode(Node* n) {
+ JIT_ASSERT(
+ insert_before_->inBlockList() &&
+ "insert point node is no longer in a block list");
return n->insertBefore(insert_before_);
}
// set where nodes are inserted to append to the end of this block
- void setInsertPoint(Block * b) {
+ void setInsertPoint(Block* b) {
JIT_ASSERT(b->owningGraph() == this);
insert_before_ = b->return_node();
}
// set where nodes are inserted to insert _before_ this node
- // for implementation simplicity we only support inserting before a node for now
- void setInsertPoint(Node * n) {
+ // for implementation simplicity we only support inserting before a node for
+ // now
+ void setInsertPoint(Node* n) {
JIT_ASSERT(n->owningGraph() == this && n->inBlockList());
insert_before_ = n;
}
- Node * insertPoint() {
+ Node* insertPoint() {
return insert_before_;
}
// the top level block
- Block * block() {
+ Block* block() {
return block_;
}
- const Block * block() const {
+ const Block* block() const {
return block_;
}
TORCH_API std::string toString() const;
- friend TORCH_API std::ostream& operator<<(std::ostream & out, const Graph & g);
+ friend TORCH_API std::ostream& operator<<(std::ostream& out, const Graph& g);
- TORCH_API std::ostream& prettyPrint(std::ostream & out);
+ TORCH_API std::ostream& prettyPrint(std::ostream& out);
TORCH_API void dumpPretty();
TORCH_API std::shared_ptr<Graph> copy();
-private:
-
- TORCH_API void freeNode(Node * n);
- TORCH_API void freeValue(Value * v);
- TORCH_API void freeBlock(Block * b);
+ private:
+ TORCH_API void freeNode(Node* n);
+ TORCH_API void freeValue(Value* v);
+ TORCH_API void freeBlock(Block* b);
};
struct WithInsertPoint : public ResourceGuard {
- WithInsertPoint(Node * n)
- : ResourceGuard([this] {
- prev->owningGraph()->setInsertPoint(prev);
- })
- , prev(n->owningGraph()->insertPoint()) {
+ WithInsertPoint(Node* n)
+ : ResourceGuard([this] { prev->owningGraph()->setInsertPoint(prev); }),
+ prev(n->owningGraph()->insertPoint()) {
n->owningGraph()->setInsertPoint(n);
}
- WithInsertPoint(Block * b)
- : WithInsertPoint(b->return_node()) {}
-private:
- Node * prev;
+ WithInsertPoint(Block* b) : WithInsertPoint(b->return_node()) {}
+
+ private:
+ Node* prev;
};
struct WithCurrentScope : public ResourceGuard {
- WithCurrentScope(Graph & g, ScopePtr scope)
- : ResourceGuard([&g, this]() {
- g.set_current_scope(prev_scope);
- })
- , prev_scope(g.current_scope()) {
+ WithCurrentScope(Graph& g, ScopePtr scope)
+ : ResourceGuard([&g, this]() { g.set_current_scope(prev_scope); }),
+ prev_scope(g.current_scope()) {
g.set_current_scope(std::move(scope));
}
-private:
+
+ private:
ScopePtr prev_scope;
};
-inline Value::Value(Node * node_, size_t offset_)
-: node_(node_),
- offset_(offset_),
- unique_(node_->graph_->next_unique_++),
- type_(DynamicType::get()) {
+inline Value::Value(Node* node_, size_t offset_)
+ : node_(node_),
+ offset_(offset_),
+ unique_(node_->graph_->next_unique_++),
+ type_(DynamicType::get()) {
node_->graph_->all_values.emplace(this);
}
inline Value* Value::setType(TypePtr type) {
JIT_ASSERT(type);
type_ = std::move(type);
- for (Use & use : uses_) {
+ for (Use& use : uses_) {
use.user->schema_ = nullptr;
}
return this;
}
-inline Graph * Value::owningGraph() {
+inline Graph* Value::owningGraph() {
return node()->owningGraph();
}
-inline const Graph * Value::owningGraph() const {
+inline const Graph* Value::owningGraph() const {
return node()->owningGraph();
}
// Mutable case
// The IFM/ELSEIFM indicate that subclass *refinement* occurs.
// This is only valid for node types for which we have subclasses.
-#define IR_IFM(x,Kind) GENERIC_IF(,prim::Kind,x,Kind)
-#define IR_ELSEIFM(Kind) GENERIC_ELSEIF(,prim::Kind,Kind)
+#define IR_IFM(x, Kind) GENERIC_IF(, prim::Kind, x, Kind)
+#define IR_ELSEIFM(Kind) GENERIC_ELSEIF(, prim::Kind, Kind)
-#define IR_IFM_CONST(x,Kind) GENERIC_IF(const,prim::Kind,x,Kind)
-#define IR_ELSEIFM_CONST(Kind) GENERIC_ELSEIF(const,prim::Kind,Kind)
+#define IR_IFM_CONST(x, Kind) GENERIC_IF(const, prim::Kind, x, Kind)
+#define IR_ELSEIFM_CONST(Kind) GENERIC_ELSEIF(const, prim::Kind, Kind)
#define IR_IF(x, Kind) \
auto&& __match_key = x; \
/************* All nodes not required to be defined before Graph **************/
- // execute a Python function, used for Ops we can't optimize but that we want to optimize around
+// execute a Python function, used for Ops we can't optimize but that we want to
+// optimize around
struct PythonOp : public Node {
static constexpr Symbol Kind = prim::PythonOp;
- PythonOp(Graph * graph)
- : Node(graph,prim::PythonOp) {}
+ PythonOp(Graph* graph) : Node(graph, prim::PythonOp) {}
PythonOp* init(
THPObjectPtr&& pyobj,
const std::string& cconv,
std::vector<THPObjectPtr> scalar_args;
virtual std::string name() const = 0;
virtual void writeScalars(std::ostream& out) const = 0;
- void cloneFrom(Node * other_) override = 0;
- Node * allocNewInstance(Graph * g) override = 0;
+ void cloneFrom(Node* other_) override = 0;
+ Node* allocNewInstance(Graph* g) override = 0;
// recover the autograd.Function instance, if this PythonOp's function
// was originally SomeFunction.apply
// used in ONNX for discovering symbolics
const std::string& cconv,
pyobj_list&& scalar_args) {
auto op = allocPythonOp(this);
- return op->init(
- std::move(pyobj),
- cconv,
- std::move(scalar_args));
+ return op->init(std::move(pyobj), cconv, std::move(scalar_args));
}
TORCH_API void LintGraph(std::shared_ptr<Graph>& graph);
TORCH_API at::ArrayRef<Value*> createTupleUnpack(Value* v);
-// unpack_outputs - if true, and the callee returns a single tuple value, then insert a tuple unpack node
+// unpack_outputs - if true, and the callee returns a single tuple value, then
+// insert a tuple unpack node
// and return the resulting values
-TORCH_API std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs, bool unpack_outputs=false);
-
-
-}} // namespace torch::jit
+TORCH_API std::vector<Value*> inlineCallTo(
+ Graph& g,
+ Graph& callee,
+ ArrayRef<Value*> inputs,
+ bool unpack_outputs = false);
+
+} // namespace jit
+} // namespace torch
using ::c10::ivalue::Shared;
using ::c10::IValue;
-using ::c10::ivalue::Tuple;
using ::c10::ivalue::Future;
+using ::c10::ivalue::Tuple;
using ::c10::ivalue::BoolList;
using ::c10::ivalue::DoubleList;
#pragma once
#include <ATen/ATen.h>
+#include <torch/csrc/jit/constants.h>
+#include <torch/csrc/jit/ivalue.h>
#include <torch/csrc/jit/source_range.h>
#include <torch/csrc/utils/variadic.h>
-#include <torch/csrc/jit/ivalue.h>
-#include <torch/csrc/jit/constants.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
struct Value;
struct NamedValue {
NamedValue(const SourceRange& loc, const std::string& name, Value* value)
- : loc_(loc), name_(name), value_(value) {}
- NamedValue(const SourceRange& loc, Value* value)
- : loc_(loc), value_(value) {}
+ : loc_(loc), name_(name), value_(value) {}
+ NamedValue(const SourceRange& loc, Value* value) : loc_(loc), value_(value) {}
- /* implicit */ NamedValue(Value* value)
- : value_(value) {}
+ /* implicit */ NamedValue(Value* value) : value_(value) {}
NamedValue(const std::string& name, Value* value)
- : name_(name), value_(value) {}
+ : name_(name), value_(value) {}
/* implicit */ NamedValue(IValue value)
- : value_(nullptr), ivalue_(std::move(value)) {}
+ : value_(nullptr), ivalue_(std::move(value)) {}
NamedValue(const std::string& name, IValue value)
- : name_(name), ivalue_(std::move(value)) {}
+ : name_(name), ivalue_(std::move(value)) {}
template <
typename T,
typename = enable_if_t<
(!std::is_same<decay_t<T>, NamedValue>::value &&
- !std::is_same<decay_t<T>, Value*>::value && !std::is_same<decay_t<T>, IValue>::value)>>
+ !std::is_same<decay_t<T>, Value*>::value &&
+ !std::is_same<decay_t<T>, IValue>::value)>>
NamedValue(T&& t) : NamedValue(IValue(std::forward<T>(t))) {}
template <
(!std::is_same<decay_t<T>, Value*>::value &&
!std::is_same<decay_t<T>, IValue>::value)>>
NamedValue(const std::string& name, T&& t)
- : NamedValue(name, IValue(std::forward<T>(t))) {}
+ : NamedValue(name, IValue(std::forward<T>(t))) {}
SourceRange locOr(const SourceRange& backup_location) const {
- if(!loc_)
+ if (!loc_)
return backup_location;
return loc();
}
// note: this will insert a constant node into the graph at the current
// insert point if this NamedValue is actually a constant
Value* value(Graph& g) const {
- if(!value_)
- return insertConstant(g, ivalue_); // use insertConstant to remove need to include ir.h here
+ if (!value_)
+ return insertConstant(
+ g, ivalue_); // use insertConstant to remove need to include ir.h here
return value_;
}
return *loc_;
}
-private:
- c10::optional<SourceRange> loc_;
- c10::optional<std::string> name_;
- Value* value_{nullptr};
- // only valid if value_ == nullptr;
- IValue ivalue_;
+ private:
+ c10::optional<SourceRange> loc_;
+ c10::optional<std::string> name_;
+ Value* value_{nullptr};
+ // only valid if value_ == nullptr;
+ IValue ivalue_;
};
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/assertions.h>
#include <torch/csrc/jit/interned_strings.h>
-#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/node_hashing.h>
+#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/utils/functional.h>
#include <torch/csrc/utils/hash.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
namespace {
return &lhs.type() == &rhs.type() && lhs.equal(rhs);
}
-bool tensorListEqual(const std::vector<at::Tensor>& lhs, const std::vector<at::Tensor>& rhs) {
- if (lhs.size() != rhs.size()) return false;
+bool tensorListEqual(
+ const std::vector<at::Tensor>& lhs,
+ const std::vector<at::Tensor>& rhs) {
+ if (lhs.size() != rhs.size())
+ return false;
return std::equal(lhs.begin(), lhs.end(), rhs.begin(), tensorEqual);
}
-
// Check whether two nodes have the same attributes in CSE.
// This function may be too conservative for general use.
// Do NOT support g/gs attributes.
JIT_ASSERT(lhs != nullptr);
JIT_ASSERT(rhs != nullptr);
// One has attributes, the other does not.
- if (lhs->hasAttributes() != rhs->hasAttributes()) return false;
+ if (lhs->hasAttributes() != rhs->hasAttributes())
+ return false;
// Neither has attributes.
- if (!lhs->hasAttributes() && !rhs->hasAttributes()) return true;
+ if (!lhs->hasAttributes() && !rhs->hasAttributes())
+ return true;
auto lnames = lhs->attributeNames();
auto rnames = rhs->attributeNames();
std::sort(lnames.begin(), lnames.end());
std::sort(rnames.begin(), rnames.end());
- if (lnames != rnames) return false;
+ if (lnames != rnames)
+ return false;
for (auto name : lnames) {
- if (lhs->kindOf(name) != rhs->kindOf(name)) return false;
+ if (lhs->kindOf(name) != rhs->kindOf(name))
+ return false;
- #define COMPARE_ATTRIBUTEVALUE(type) \
- case AttributeKind::type: \
- { if (lhs->type(name) != rhs->type(name)) return false; } break;
+#define COMPARE_ATTRIBUTEVALUE(type) \
+ case AttributeKind::type: { \
+ if (lhs->type(name) != rhs->type(name)) \
+ return false; \
+ } break;
- switch(lhs->kindOf(name)) {
+ switch (lhs->kindOf(name)) {
COMPARE_ATTRIBUTEVALUE(f)
COMPARE_ATTRIBUTEVALUE(fs)
COMPARE_ATTRIBUTEVALUE(i)
COMPARE_ATTRIBUTEVALUE(s)
COMPARE_ATTRIBUTEVALUE(ss)
case AttributeKind::t: {
- if (!tensorEqual(lhs->t(name), rhs->t(name))) return false;
+ if (!tensorEqual(lhs->t(name), rhs->t(name)))
+ return false;
break;
}
case AttributeKind::ts: {
- if (!tensorListEqual(lhs->ts(name), rhs->ts(name))) return false;
+ if (!tensorListEqual(lhs->ts(name), rhs->ts(name)))
+ return false;
break;
}
case AttributeKind::g:
return false;
}
- #undef COMPARE_ATTRIBUTEVALUE
+#undef COMPARE_ATTRIBUTEVALUE
}
return true;
} // anonymous namespace
-
size_t HashNode::operator()(const Node* k) const {
JIT_ASSERT(k != nullptr);
- return get_hash(k->kind(),
- fmap(k->outputs(), [](const Value *v) { return v->type()->kind(); }),
- fmap(k->inputs(), [](const Value *v) { return v->unique(); }));
+ return get_hash(
+ k->kind(),
+ fmap(k->outputs(), [](const Value* v) { return v->type()->kind(); }),
+ fmap(k->inputs(), [](const Value* v) { return v->unique(); }));
};
bool EqualNode::operator()(const Node* lhs, const Node* rhs) const {
- if (lhs == nullptr && rhs == nullptr) return true;
- if (lhs == nullptr || rhs == nullptr) return false;
+ if (lhs == nullptr && rhs == nullptr)
+ return true;
+ if (lhs == nullptr || rhs == nullptr)
+ return false;
- if (lhs->kind() != rhs->kind()) return false;
+ if (lhs->kind() != rhs->kind())
+ return false;
// Check whether the output types are the same.
auto lhs_outputs = lhs->outputs();
auto rhs_outputs = rhs->outputs();
- if (lhs_outputs.size() != rhs_outputs.size()) return false;
+ if (lhs_outputs.size() != rhs_outputs.size())
+ return false;
for (size_t i = 0; i < lhs_outputs.size(); ++i) {
if (*lhs_outputs[i]->type() != *rhs_outputs[i]->type())
return false;
// Check whether the inputs are the same.
auto lhs_inputs = lhs->inputs();
auto rhs_inputs = rhs->inputs();
- if (lhs_inputs.size() != rhs_inputs.size()) return false;
- if (!std::equal(lhs_inputs.begin(), lhs_inputs.end(), rhs_inputs.begin())) return false;
+ if (lhs_inputs.size() != rhs_inputs.size())
+ return false;
+ if (!std::equal(lhs_inputs.begin(), lhs_inputs.end(), rhs_inputs.begin()))
+ return false;
- if (!attributesEqualCSE(lhs, rhs)) return false;
+ if (!attributesEqualCSE(lhs, rhs))
+ return false;
return true;
};
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/ir.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
struct HashNode {
size_t operator()(const Node* k) const;
bool operator()(const Node* lhs, const Node* rhs) const;
};
-}}
+} // namespace jit
+} // namespace torch
#include <ATen/ATen.h>
#include <torch/csrc/jit/alias_info.h>
-#include <torch/csrc/jit/script/lexer.h>
-#include <torch/csrc/jit/script/parse_string_literal.h>
-#include <torch/csrc/jit/script/tree.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/passes/python_print.h>
#include <torch/csrc/jit/script/error_report.h>
+#include <torch/csrc/jit/script/lexer.h>
+#include <torch/csrc/jit/script/parse_string_literal.h>
+#include <torch/csrc/jit/script/tree.h>
#include <functional>
#include <memory>
#include <utility>
#include <vector>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
namespace script {
struct SchemaParser {
- SchemaParser(const std::string& str)
- : L(str) {}
+ SchemaParser(const std::string& str) : L(str) {}
FunctionSchema parseDeclaration() {
auto name = L.expect(TK_IDENT).text();
- if(L.nextIf(':')) {
+ if (L.nextIf(':')) {
L.expect(':');
name = name + "::" + L.expect(TK_IDENT).text();
}
bool is_vararg = false;
size_t idx = 0;
parseList('(', ',', ')', [&] {
- if(is_vararg)
- throw ErrorReport(L.cur()) << "... must be the last element of the argument list";
+ if (is_vararg)
+ throw ErrorReport(L.cur())
+ << "... must be the last element of the argument list";
if (L.nextIf('*')) {
kwarg_only = true;
- } else if(L.nextIf(TK_DOTS)) {
+ } else if (L.nextIf(TK_DOTS)) {
is_vararg = true;
} else {
arguments.push_back(parseArgument(
std::vector<FunctionSchema> results;
do {
results.push_back(parseDeclaration());
- } while(L.nextIf(TK_NEWLINE));
+ } while (L.nextIf(TK_NEWLINE));
L.expect(TK_EOF);
return results;
}
}
TypePtr parseBaseType() {
static std::unordered_map<std::string, TypePtr> type_map = {
- {"Generator", GeneratorType::get() },
- {"ScalarType", IntType::get() },
- {"Layout", IntType::get() },
- {"Device", DeviceObjType::get() },
- {"Scalar", NumberType::get() },
- {"str", StringType::get() },
- {"float", FloatType::get() },
- {"int", IntType::get() },
- {"bool", BoolType::get() },
+ {"Generator", GeneratorType::get()},
+ {"ScalarType", IntType::get()},
+ {"Layout", IntType::get()},
+ {"Device", DeviceObjType::get()},
+ {"Scalar", NumberType::get()},
+ {"str", StringType::get()},
+ {"float", FloatType::get()},
+ {"int", IntType::get()},
+ {"bool", BoolType::get()},
};
auto tok = L.expect(TK_IDENT);
auto text = tok.text();
auto it = type_map.find(text);
- if(it == type_map.end()) {
- if(text.size() > 0 && islower(text[0])) {
+ if (it == type_map.end()) {
+ if (text.size() > 0 && islower(text[0])) {
// lower case identifiers that are not otherwise valid types
// are treated as type variables
return VarType::create(text);
alias_info.addSet(
Symbol::fromQualString("alias::$" + std::to_string(next_id++)));
alias_info.setIsWrite(true);
- } else{
+ } else {
return c10::nullopt;
}
} else {
value = parseBaseType();
}
- while(true) {
- if(L.cur().kind == '[' && L.lookahead().kind == ']') {
+ while (true) {
+ if (L.cur().kind == '[' && L.lookahead().kind == ']') {
L.next(); // [
L.next(); // ]
value = ListType::create(value);
container->addContainedType(std::move(*alias_info));
}
alias_info = std::move(container);
- } else if(L.nextIf('?')) {
+ } else if (L.nextIf('?')) {
value = OptionalType::create(value);
} else {
break;
c10::optional<IValue> default_value;
c10::optional<std::string> alias_set;
std::string name;
- if(L.nextIf('[')) {
+ if (L.nextIf('[')) {
// note: an array with a size hint can only occur at the Argument level
type = ListType::create(type);
N = std::stoll(L.expect(TK_NUMBER).text());
}
alias_info = std::move(container);
}
- if(is_return) {
+ if (is_return) {
// optionally named return values
- if(L.cur().kind == TK_IDENT) {
+ if (L.cur().kind == TK_IDENT) {
name = L.next().text();
} else {
name = "ret" + std::to_string(idx);
}
} else {
name = L.expect(TK_IDENT).text();
- if(L.nextIf('=')) {
+ if (L.nextIf('=')) {
default_value = parseDefaultValue(type, N);
}
}
std::move(alias_info));
}
IValue parseSingleConstant(TypeKind kind) {
- switch(L.cur().kind) {
+ switch (L.cur().kind) {
case TK_TRUE:
L.next();
return true;
case TK_IDENT: {
auto tok = L.next();
auto text = tok.text();
- if("float" == text) {
+ if ("float" == text) {
return static_cast<int64_t>(at::kFloat);
- } else if("strided" == text) {
+ } else if ("strided" == text) {
return static_cast<int64_t>(at::kStrided);
- } else if("Mean" == text) {
+ } else if ("Mean" == text) {
return static_cast<int64_t>(Reduction::Mean);
} else {
throw ErrorReport(L.cur().range) << "invalid numeric default value";
}
default:
std::string n;
- if(L.nextIf('-'))
+ if (L.nextIf('-'))
n = "-" + L.expect(TK_NUMBER).text();
else
n = L.expect(TK_NUMBER).text();
- if(kind == TypeKind::FloatType || n.find('.') != std::string::npos || n.find('e') != std::string::npos) {
+ if (kind == TypeKind::FloatType || n.find('.') != std::string::npos ||
+ n.find('e') != std::string::npos) {
return std::stod(n);
} else {
int64_t v = std::stoll(n);
}
}
}
- IValue convertToList(TypeKind kind, const SourceRange& range, std::vector<IValue> vs) {
- switch(kind) {
- case TypeKind::FloatType:
- return fmap(vs, [](IValue v) {
- return v.toDouble();
- });
- case TypeKind::IntType:
- return fmap(vs, [](IValue v) {
- return v.toInt();
- });
- case TypeKind::BoolType:
- return fmap(vs, [](IValue v) {
- return v.toBool();
- });
- default:
- throw ErrorReport(range) << "lists are only supported for float or int types.";
- }
+ IValue convertToList(
+ TypeKind kind,
+ const SourceRange& range,
+ std::vector<IValue> vs) {
+ switch (kind) {
+ case TypeKind::FloatType:
+ return fmap(vs, [](IValue v) { return v.toDouble(); });
+ case TypeKind::IntType:
+ return fmap(vs, [](IValue v) { return v.toInt(); });
+ case TypeKind::BoolType:
+ return fmap(vs, [](IValue v) { return v.toBool(); });
+ default:
+ throw ErrorReport(range)
+ << "lists are only supported for float or int types.";
+ }
}
IValue parseConstantList(TypeKind kind) {
auto tok = L.expect('[');
std::vector<IValue> vs;
- if(L.cur().kind != ']') {
+ if (L.cur().kind != ']') {
do {
vs.push_back(parseSingleConstant(kind));
- } while(L.nextIf(','));
+ } while (L.nextIf(','));
}
L.expect(']');
return convertToList(kind, tok.range, std::move(vs));
L.expect(TK_NONE);
return IValue();
}
- IValue parseDefaultValue(const TypePtr& arg_type, c10::optional<int32_t> arg_N) {
+ IValue parseDefaultValue(
+ const TypePtr& arg_type,
+ c10::optional<int32_t> arg_N) {
auto range = L.cur().range;
- switch(arg_type->kind()) {
+ switch (arg_type->kind()) {
case TypeKind::DynamicType:
case TypeKind::GeneratorType: {
return parseTensorDefault(range);
- } break;
+ } break;
case TypeKind::StringType:
case TypeKind::OptionalType:
case TypeKind::NumberType:
return parseSingleConstant(arg_type->kind());
break;
case TypeKind::DeviceObjType: {
- auto device_text = parseStringLiteral(range, L.expect(TK_STRINGLITERAL).text());
+ auto device_text =
+ parseStringLiteral(range, L.expect(TK_STRINGLITERAL).text());
return c10::Device(device_text);
break;
}
case TypeKind::ListType: {
auto elem_kind = arg_type->cast<ListType>()->getElementType();
- if(L.cur().kind == TK_IDENT) {
+ if (L.cur().kind == TK_IDENT) {
return parseTensorDefault(range);
- } else if(arg_N && L.cur().kind != '[') {
+ } else if (arg_N && L.cur().kind != '[') {
IValue v = parseSingleConstant(elem_kind->kind());
std::vector<IValue> repeated(*arg_N, v);
return convertToList(elem_kind->kind(), range, repeated);
return IValue(); // silence warnings
}
- void parseList(int begin, int sep, int end, const std::function<void()>& callback) {
+ void parseList(
+ int begin,
+ int sep,
+ int end,
+ const std::function<void()>& callback) {
auto r = L.cur().range;
if (begin != TK_NOTHING)
L.expect(begin);
} // namespace script
namespace {
-using OperatorMap = std::unordered_map<Symbol, std::vector<std::shared_ptr<Operator>>>;
-struct OperatorRegistry {
-private:
+using OperatorMap =
+ std::unordered_map<Symbol, std::vector<std::shared_ptr<Operator>>>;
+struct OperatorRegistry {
+ private:
std::mutex lock;
OperatorMap operators;
// list of operators whose schema have not yet been parsed, and must
// be registered before any call to lookup an opeator
std::vector<std::shared_ptr<Operator>> to_register;
- // Those two maps are used to implement lookupByLiteral, which is needed for the n->match(...) calls.
- // Basically, every function schema is assigned a unique string you can use to match it. However,
- // parsing those strings or comparing and hashing them character by character would be very slow, so
- // we use a trick here! Every string literal in your program is guaranteed to have static storage
- // duration and so its address won't change at runtime. This allows us to memoize answers for every
- // pointer, which is done by the operators_by_sig_literal map. Still, this map is initially
- // empty, and so we still need to do the complete string matching at the first time, which is implemented
- // by performing a lookup in the operators_by_sig map.
+ // Those two maps are used to implement lookupByLiteral, which is needed for
+ // the n->match(...) calls. Basically, every function schema is assigned a
+ // unique string you can use to match it. However, parsing those strings or
+ // comparing and hashing them character by character would be very slow, so we
+ // use a trick here! Every string literal in your program is guaranteed to
+ // have static storage duration and so its address won't change at runtime.
+ // This allows us to memoize answers for every pointer, which is done by the
+ // operators_by_sig_literal map. Still, this map is initially empty, and so we
+ // still need to do the complete string matching at the first time, which is
+ // implemented by performing a lookup in the operators_by_sig map.
std::unordered_map<std::string, std::shared_ptr<Operator>> operators_by_sig;
- std::unordered_map<const char *, std::shared_ptr<Operator>> operators_by_sig_literal;
+ std::unordered_map<const char*, std::shared_ptr<Operator>>
+ operators_by_sig_literal;
// XXX - caller must be holding lock
void registerPendingOperators() {
- for(const auto& op : to_register) {
+ for (const auto& op : to_register) {
Symbol sym = Symbol::fromQualString(op->schema().name());
operators[sym].push_back(op);
operators_by_sig[canonicalSchemaString(op->schema())] = op;
to_register.clear();
}
-public:
+ public:
void registerOperator(Operator&& op) {
std::lock_guard<std::mutex> guard(lock);
to_register.push_back(std::make_shared<Operator>(std::move(op)));
}
- const std::shared_ptr<Operator>& lookupByLiteral(const char * name) {
+ const std::shared_ptr<Operator>& lookupByLiteral(const char* name) {
std::lock_guard<std::mutex> guard(lock);
registerPendingOperators();
auto it = operators_by_sig_literal.find(name);
if (it == operators_by_sig_literal.end()) {
- auto op_ptr_it = operators_by_sig.find(canonicalSchemaString(parseSchema(name)));
+ auto op_ptr_it =
+ operators_by_sig.find(canonicalSchemaString(parseSchema(name)));
// Handy debugging code that dumps all operators we know about on mismatch
#if 0
if (op_ptr_it == operators_by_sig.end()) {
}
}
#endif
- JIT_ASSERTM(op_ptr_it != operators_by_sig.end(), "Couldn't find an operator for ", name);
+ JIT_ASSERTM(
+ op_ptr_it != operators_by_sig.end(),
+ "Couldn't find an operator for ",
+ name);
it = operators_by_sig_literal.emplace_hint(it, name, op_ptr_it->second);
}
return it->second;
}
-
const std::vector<std::shared_ptr<Operator>>& getOperators(Symbol name) {
std::lock_guard<std::mutex> guard(lock);
registerPendingOperators();
static std::vector<std::shared_ptr<Operator>> empty;
auto it = operators.find(name);
- if(it != operators.end())
+ if (it != operators.end())
return it->second;
return empty;
}
} // anonymous namespace
void registerOperator(Operator&& op) {
- if(op.schema().is_varret()) {
+ if (op.schema().is_varret()) {
Symbol s = Symbol::fromQualString(op.schema().name());
if (!printerHasSpecialCaseFor(s)) {
std::cout << c10::str(
return getRegistry().getOperators(name);
}
-Operator& sig(const char *signature) {
+Operator& sig(const char* signature) {
return *getRegistry().lookupByLiteral(signature);
}
out << "(";
bool seen_kwarg_only = false;
- for(size_t i = 0; i < schema.arguments().size(); ++i) {
- if (i > 0) out << ", ";
+ for (size_t i = 0; i < schema.arguments().size(); ++i) {
+ if (i > 0)
+ out << ", ";
if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) {
out << "*, ";
seen_kwarg_only = true;
}
- const auto & arg = schema.arguments()[i];
+ const auto& arg = schema.arguments()[i];
out << arg.type()->str() << " " << arg.name();
}
} else if (schema.returns().size() > 1) {
out << "(";
for (size_t i = 0; i < schema.returns().size(); ++i) {
- if (i > 0) out << ", ";
+ if (i > 0)
+ out << ", ";
out << schema.returns()[i].type()->str();
}
out << ")";
const auto& formals = schema().arguments();
// not enough inputs
- if(actuals.size() < formals.size())
+ if (actuals.size() < formals.size())
return false;
-
TypeEnv type_env;
- for(size_t i = 0; i < formals.size(); ++i) {
+ for (size_t i = 0; i < formals.size(); ++i) {
const MatchTypeReturn matched_type =
matchTypeVariables(formals[i].type(), actuals[i]->type(), type_env);
if (!matched_type.type) {
}
// too many inputs
- if(!schema().is_vararg() && actuals.size() != formals.size()) {
- // std::cout << "not all inputs used\n" << input_i << " " << inputs_size << "\n";
+ if (!schema().is_vararg() && actuals.size() != formals.size()) {
+ // std::cout << "not all inputs used\n" << input_i << " " << inputs_size <<
+ // "\n";
return false;
}
std::shared_ptr<Operator> findOperatorFor(const Node* node) {
const auto& candidates = getAllOperatorsFor(node->kind());
- for(const auto& candidate : candidates) {
- if(candidate->matches(node)) {
+ for (const auto& candidate : candidates) {
+ if (candidate->matches(node)) {
return candidate;
}
}
const Operator& getOperatorFor(const Node* node) {
auto op = findOperatorFor(node);
- if(op)
+ if (op)
return *op;
auto er = script::ErrorReport(node->getSourceLocation());
er << "Schema not found for node. File a bug report.\n";
er << "Node: " << *node << "\n";
er << "Input types:";
- for(size_t i = 0; i < node->inputs().size(); ++i) {
- if(i > 0)
+ for (size_t i = 0; i < node->inputs().size(); ++i) {
+ if (i > 0)
er << ", ";
er << *node->inputs()[i]->type();
}
er << "\ncandidates were:\n";
const auto& candidates = getAllOperatorsFor(node->kind());
- for(auto & candidate : candidates) {
+ for (auto& candidate : candidates) {
er << " " << candidate->schema() << "\n";
}
er << *node->owningGraph() << "\n";
throw er;
}
-
-OperatorSet::OperatorSet(std::initializer_list<const char *> sig_literals) {
- auto & registry = getRegistry();
- for (const char * sig : sig_literals) {
+OperatorSet::OperatorSet(std::initializer_list<const char*> sig_literals) {
+ auto& registry = getRegistry();
+ for (const char* sig : sig_literals) {
auto op = registry.lookupByLiteral(sig);
ops[Symbol::fromQualString(op->schema().name())].push_back(op);
}
}
-Operator* OperatorSet::find(const Node *n) const {
+Operator* OperatorSet::find(const Node* n) const {
auto it = ops.find(n->kind());
if (it == ops.end()) {
return nullptr;
}
- for (auto & op : it->second) {
+ for (auto& op : it->second) {
if (op->matches(n)) {
return op.get();
}
return nullptr;
}
-}}
+} // namespace jit
+} // namespace torch
#pragma once
#include <torch/csrc/jit/assertions.h>
-#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/function_schema.h>
+#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/stack.h>
#include <ATen/ATen.h>
#include <utility>
#include <vector>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
TORCH_API FunctionSchema parseSchema(const std::string& schema);
// arguments. This is used for things like prim::While or prim::If that can
// take a number of different valid input types and lengths.
Operator(Symbol name, OperationCreator op_creator)
- : Operator(FunctionSchema(name, {}, {}, /*is_vararg*/true, /*is_varret*/true), std::move(op_creator)) {}
+ : Operator(
+ FunctionSchema(
+ name,
+ {},
+ {},
+ /*is_vararg*/ true,
+ /*is_varret*/ true),
+ std::move(op_creator)) {}
Operator(FunctionSchema schema, Operation op)
: schema_(std::make_shared<FunctionSchema>(std::move(schema))),
return op_creator_(node);
}
- const FunctionSchema & schema() const {
+ const FunctionSchema& schema() const {
// we lazily parse schema initialized from strings so that
// we do less work during static operator registration
- if(!schema_) {
- schema_ = std::make_shared<FunctionSchema>(parseSchema(schema_string_.value()));
+ if (!schema_) {
+ schema_ =
+ std::make_shared<FunctionSchema>(parseSchema(schema_string_.value()));
schema_string_ = c10::nullopt;
}
return *schema_;
}
-private:
- mutable c10::optional<std::string> schema_string_;
- // cannot use c10::optional because windows has issues that require an
- // assignment operator to be generated cannot use std::unique_ptr because
- // initializer lists of Operators end up copying the Operator
- mutable std::shared_ptr<FunctionSchema> schema_;
-
- // Essentially a variant<Operation, OperationCreator>.
- // NB: std::function has a default state (where it == nullptr).
- std::shared_ptr<Operation> op_;
- OperationCreator op_creator_;
+
+ private:
+ mutable c10::optional<std::string> schema_string_;
+ // cannot use c10::optional because windows has issues that require an
+ // assignment operator to be generated cannot use std::unique_ptr because
+ // initializer lists of Operators end up copying the Operator
+ mutable std::shared_ptr<FunctionSchema> schema_;
+
+ // Essentially a variant<Operation, OperationCreator>.
+ // NB: std::function has a default state (where it == nullptr).
+ std::shared_ptr<Operation> op_;
+ OperationCreator op_creator_;
};
TORCH_API std::string canonicalSchemaString(const FunctionSchema& schema);
-TORCH_API const std::vector<std::shared_ptr<Operator>>& getAllOperatorsFor(Symbol name);
+TORCH_API const std::vector<std::shared_ptr<Operator>>& getAllOperatorsFor(
+ Symbol name);
std::shared_ptr<Operator> findOperatorFor(const Node* node);
const Operator& getOperatorFor(const Node* node);
inline Operation getOperation(const Node* node) {
- // note: getOperatorFor ensures that getOperatorFor(node).matches(node) == true
- // so the call to selectVariant is always valid.
+ // note: getOperatorFor ensures that getOperatorFor(node).matches(node) ==
+ // true so the call to selectVariant is always valid.
return getOperatorFor(node).getOperation(node);
}
TORCH_API void registerOperator(Operator&& op);
// XXX: this function is meant to be used with string literals only!
-Operator& sig(const char *signature_literal);
+Operator& sig(const char* signature_literal);
struct OperatorSet {
- OperatorSet(std::initializer_list<const char *> sig_literals);
+ OperatorSet(std::initializer_list<const char*> sig_literals);
// XXX: Returns a nullptr if no Operator in the set matches n
- Operator* find(const Node *n) const;
-private:
+ Operator* find(const Node* n) const;
+
+ private:
std::unordered_map<Symbol, std::vector<std::shared_ptr<Operator>>> ops;
};
-
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/passes/batch_mm.h>
+#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/constants.h>
+#include <torch/csrc/jit/custom_operator.h>
+#include <torch/csrc/jit/interned_strings.h>
+#include <torch/csrc/jit/passes/alias_analysis.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/peephole.h>
-#include <torch/csrc/jit/passes/alias_analysis.h>
-#include <torch/csrc/jit/interned_strings.h>
-#include <torch/csrc/jit/constants.h>
#include <torch/csrc/jit/symbolic_variable.h>
-#include <torch/csrc/jit/assertions.h>
-#include <torch/csrc/jit/custom_operator.h>
#include <torch/csrc/utils/functional.h>
#include <ATen/ATen.h>
#include <algorithm>
#include <unordered_map>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
-// This pass looks for trees in the graph, where leaves are mm ops, and the inner
-// vertices are add nodes. Once we have such a tree they can be reduced to two
-// concats and a single mm (basically into a single multiply of a wide matrix, with
-// a tall matrix).
-// Such patterns show up mostly in backward of RNNs, since the derivative of many
-// uses of matrix multiplies with same weights forms exactly such a tree
-// (note that it's usually also highly imbalanced i.e. has O(n) depth).
+// This pass looks for trees in the graph, where leaves are mm ops, and the
+// inner vertices are add nodes. Once we have such a tree they can be reduced to
+// two concats and a single mm (basically into a single multiply of a wide
+// matrix, with a tall matrix). Such patterns show up mostly in backward of
+// RNNs, since the derivative of many uses of matrix multiplies with same
+// weights forms exactly such a tree (note that it's usually also highly
+// imbalanced i.e. has O(n) depth).
//
// This (or any tree of adds of MMs):
//
// +------+------+ +------+
// Note [Further optimizations]
-// It would be straightforward to extend the TreeToken class to also detect if all
-// MMs had the same lhs/rhs. In such case it's more efficient to expand the lhs
-// and use bmm + sum instead of repeating it in memory via concat.
+// It would be straightforward to extend the TreeToken class to also detect if
+// all MMs had the same lhs/rhs. In such case it's more efficient to expand the
+// lhs and use bmm + sum instead of repeating it in memory via concat.
// Note [Overlapping trees]
// Additionally it wouldn't be too hard to add support for partially overlapping
// trees. Right now the it's forbidden in the algorithm (only a single tree will
-// be allowed), so theoretically we might miss some optimization options, especially
-// that the rejected tree could be much larger. I didn't implement that because it's
-// not necessary for the simple RNN cases I saw, so I decided to keep stuff simple.
-// If we ever get around implementing this, the right solution is probably to fuse
-// MMs for the common part, and assume it's an input leaf for the outer two parts
-// (I don't think it's beneficial to recompute, unless the subtree is super small,
-// but let's not get into such details).
+// be allowed), so theoretically we might miss some optimization options,
+// especially that the rejected tree could be much larger. I didn't implement
+// that because it's not necessary for the simple RNN cases I saw, so I decided
+// to keep stuff simple. If we ever get around implementing this, the right
+// solution is probably to fuse MMs for the common part, and assume it's an
+// input leaf for the outer two parts (I don't think it's beneficial to
+// recompute, unless the subtree is super small, but let's not get into such
+// details).
// The algorithm we're using is simple. We're iterating through the graph in the
-// topological order and labeling nodes with TreeTokens. Then, we look for roots of
-// the trees we formed and fuse them.
+// topological order and labeling nodes with TreeTokens. Then, we look for roots
+// of the trees we formed and fuse them.
// Tunable parameter. Set to something larger if it turns out to be better.
static constexpr size_t min_fusion_size = 4;
bool have_same_shape(at::TensorList inputs) {
auto expected_sizes = inputs[0].sizes();
- return std::all_of(inputs.begin(), inputs.end(),
- [expected_sizes](const at::Tensor& t) {
- return t.sizes() == expected_sizes;
- });
+ return std::all_of(
+ inputs.begin(), inputs.end(), [expected_sizes](const at::Tensor& t) {
+ return t.sizes() == expected_sizes;
+ });
}
bool shape_is_fast_for_reduce(const at::Tensor& lhs, const at::Tensor& rhs) {
return m < 512 || ((l < 256 && r < 256) || (l > 256 && r > 256));
}
-RegisterOperators mm_tree_reduction_reg({
- Operator(
- prim::MMTreeReduce,
- [](const Node* node) {
+RegisterOperators mm_tree_reduction_reg(
+ {Operator(prim::MMTreeReduce, [](const Node* node) {
size_t num_inputs = node->inputs().size();
return [num_inputs](Stack& stack) {
std::vector<at::Tensor> inputs;
size_t side_num_elems = inputs.size() / 2;
auto lhs_inputs = at::TensorList(inputs).slice(0, side_num_elems);
auto rhs_inputs = at::TensorList(inputs).slice(side_num_elems);
- // TODO: checking this is not free, so we should stop if this keeps failing
- if (have_same_shape(lhs_inputs) && have_same_shape(rhs_inputs) && shape_is_fast_for_reduce(lhs_inputs[0], rhs_inputs[0])) {
+ // TODO: checking this is not free, so we should stop if this keeps
+ // failing
+ if (have_same_shape(lhs_inputs) && have_same_shape(rhs_inputs) &&
+ shape_is_fast_for_reduce(lhs_inputs[0], rhs_inputs[0])) {
auto lhs = at::cat(lhs_inputs, /*dim=*/1);
auto rhs = at::cat(rhs_inputs, /*dim=*/0);
push(stack, at::mm(lhs, rhs));
}
return 0;
};
- })
-});
+ })});
// TreeTokens will be used to label nodes of the graph, if the nodes will fit
// our mm/add tree pattern. Basically we do dynamic programming on DAGs, where
// and build a larger tree.
struct TreeToken {
uint64_t tree_size = 0; // NOTE: measured in number of leaves i.e. mm ops
- Node *node = nullptr;
+ Node* node = nullptr;
bool is_root = false;
- static TreeToken mm(Node *mm) {
+ static TreeToken mm(Node* mm) {
TreeToken token;
token.tree_size = 1;
token.node = mm;
return token;
}
- // NB: the returned token might be invalid, so make sure to check its boolean value!
- static TreeToken transpose(Node *t, TreeToken& inp_token) {
+ // NB: the returned token might be invalid, so make sure to check its boolean
+ // value!
+ static TreeToken transpose(Node* t, TreeToken& inp_token) {
TreeToken token;
- if (!inp_token.node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
+ if (!inp_token.node->matches(
+ "aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
return token;
}
token.tree_size = 1;
return token;
}
- // NB: the returned token might be invalid, so make sure to check its boolean value!
- static TreeToken add(Node *add, TreeToken& l, TreeToken& r) {
+ // NB: the returned token might be invalid, so make sure to check its boolean
+ // value!
+ static TreeToken add(Node* add, TreeToken& l, TreeToken& r) {
TreeToken token;
// See Note [Overlapping trees]
if (&l == &r || !l.is_root || !r.is_root)
token.tree_size = l.tree_size + r.tree_size;
token.node = add;
token.is_root = true;
- l.is_root = r.is_root = false; // Reserve the subtrees, so they can't be used again.
+ l.is_root = r.is_root =
+ false; // Reserve the subtrees, so they can't be used again.
return token;
}
std::vector<Node*> removeTransposesAndGatherMatmuls() {
std::vector<Node*> matmuls;
- std::vector<Node*> queue {node};
+ std::vector<Node*> queue{node};
Graph* graph = node->owningGraph();
while (!queue.empty()) {
- auto n = queue.back(); queue.pop_back();
+ auto n = queue.back();
+ queue.pop_back();
if (n->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
matmuls.push_back(n);
} else if (n->matches("aten::t(Tensor self) -> Tensor")) {
- Node * input_node = n->input()->node();
- JIT_ASSERT(input_node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor"));
+ Node* input_node = n->input()->node();
+ JIT_ASSERT(input_node->matches(
+ "aten::mm(Tensor self, Tensor mat2) -> Tensor"));
// (AB)^T == B^TA^T
- WithInsertPoint insert_guard { input_node };
- Value * A = input_node->inputs()[0];
- Value * B = input_node->inputs()[1];
- Value * AT = graph->insert(aten::t, {A});
- Value * BT = graph->insert(aten::t, {B});
- Value * BTAT = graph->insert(aten::mm, {BT, AT});
+ WithInsertPoint insert_guard{input_node};
+ Value* A = input_node->inputs()[0];
+ Value* B = input_node->inputs()[1];
+ Value* AT = graph->insert(aten::t, {A});
+ Value* BT = graph->insert(aten::t, {B});
+ Value* BTAT = graph->insert(aten::mm, {BT, AT});
n->output()->replaceAllUsesWith(BTAT);
matmuls.push_back(BTAT->node());
- } else if (n->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
+ } else if (
+ n->matches(
+ "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
queue.push_back(n->inputs()[0]->node());
queue.push_back(n->inputs()[1]->node());
} else {
if (input_it != tokens.end()) {
tokens[node] = TreeToken::transpose(node, input_it->second);
}
- } else if (node->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
- Node *lhs = node->inputs()[0]->node();
- Node *rhs = node->inputs()[1]->node();
+ } else if (
+ node->matches(
+ "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
+ Node* lhs = node->inputs()[0]->node();
+ Node* rhs = node->inputs()[1]->node();
auto lhs_it = tokens.find(lhs);
auto rhs_it = tokens.find(rhs);
// See Note [Overlapping trees] (regarding the uses().size() == 1 check)
// XXX: uses().size() == 1 is also something that guarantees that this
// transform is valid, because we know for sure that the none of these
// operands depend on the result of the other. If we were to remove this,
- // we need to compute a transitive closure and actually check the dependencies.
+ // we need to compute a transitive closure and actually check the
+ // dependencies.
if (lhs_it != tokens.end() && rhs_it != tokens.end() &&
- lhs->output()->uses().size() == 1 && rhs->output()->uses().size() == 1) {
+ lhs->output()->uses().size() == 1 &&
+ rhs->output()->uses().size() == 1) {
if (auto token = TreeToken::add(node, lhs_it->second, rhs_it->second)) {
tokens[node] = token;
}
}
// Merge trees we've found
- for (auto & item : tokens) {
- auto & root = item.second;
+ for (auto& item : tokens) {
+ auto& root = item.second;
if (!root || root.tree_size < min_fusion_size)
continue;
auto matmuls = root.removeTransposesAndGatherMatmuls();
- WithInsertPoint insert_guard {root.node};
- Node * tree_reduce = graph->insertNode(graph->create(Symbol::prim("MMTreeReduce")));
- for (Node * matmul : matmuls) {
+ WithInsertPoint insert_guard{root.node};
+ Node* tree_reduce =
+ graph->insertNode(graph->create(Symbol::prim("MMTreeReduce")));
+ for (Node* matmul : matmuls) {
tree_reduce->addInput(matmul->inputs().at(0));
}
- for (Node * matmul : matmuls) {
+ for (Node* matmul : matmuls) {
tree_reduce->addInput(matmul->inputs().at(1));
}
root.node->output()->replaceAllUsesWith(tree_reduce->output());
return other_side_input.numel() <= 1024 * 2048;
}
-RegisterOperators mm_batch_side_reg({
- Operator(
- prim::MMBatchSide,
- [](const Node* node) {
+RegisterOperators mm_batch_side_reg(
+ {Operator(prim::MMBatchSide, [](const Node* node) {
size_t num_other_side_inputs = node->inputs().size() - 1;
Side single_side = static_cast<Side>(node->i(Symbol::attr("side")));
return [num_other_side_inputs, single_side](Stack& stack) {
at::Tensor side_input;
std::vector<at::Tensor> other_side_inputs;
other_side_inputs.reserve(num_other_side_inputs);
- for (auto it = stack.end() - num_other_side_inputs; it != stack.end(); ++it) {
+ for (auto it = stack.end() - num_other_side_inputs; it != stack.end();
+ ++it) {
other_side_inputs.push_back(std::move(*it).toTensor());
}
drop(stack, num_other_side_inputs);
pop(stack, side_input);
auto any_other_input = other_side_inputs[0];
- if (have_same_shape(other_side_inputs) && shape_is_fast_for_side(other_side_inputs[0])) {
- auto other_side_input = at::cat(other_side_inputs, single_side == Side::LHS ? 1 : 0);
- auto mm_out = single_side == Side::LHS ? side_input.mm(other_side_input) : other_side_input.mm(side_input);
- auto outputs = at::chunk(mm_out, num_other_side_inputs, /*dim=*/single_side == Side::LHS ? 1 : 0);
- stack.insert(stack.end(), std::make_move_iterator(outputs.begin()),
- std::make_move_iterator(outputs.end()));
+ if (have_same_shape(other_side_inputs) &&
+ shape_is_fast_for_side(other_side_inputs[0])) {
+ auto other_side_input =
+ at::cat(other_side_inputs, single_side == Side::LHS ? 1 : 0);
+ auto mm_out = single_side == Side::LHS
+ ? side_input.mm(other_side_input)
+ : other_side_input.mm(side_input);
+ auto outputs = at::chunk(
+ mm_out,
+ num_other_side_inputs,
+ /*dim=*/single_side == Side::LHS ? 1 : 0);
+ stack.insert(
+ stack.end(),
+ std::make_move_iterator(outputs.begin()),
+ std::make_move_iterator(outputs.end()));
} else {
if (single_side == Side::LHS) {
for (at::Tensor& other : other_side_inputs) {
return 0;
};
- })
-});
+ })});
-std::pair<std::vector<Node*>, std::vector<Node*>>
-gatherIndependentMMUses(Value *value, const AliasDb& alias_db) {
+std::pair<std::vector<Node*>, std::vector<Node*>> gatherIndependentMMUses(
+ Value* value,
+ const AliasDb& alias_db) {
const auto postprocess = [&](std::vector<Node*> mms) {
if (mms.size() == 0) {
return mms;
}
- std::sort(mms.begin(), mms.end(), [](Node* n, Node* m) { return n->isBefore(m); });
- // Filter out dependent MMs. This algorithm might do very badly if e.g. you have
- // a lot of independent MMs, that depend on the first one, but I doubt this will
- // be a common scenario.
+ std::sort(mms.begin(), mms.end(), [](Node* n, Node* m) {
+ return n->isBefore(m);
+ });
+ // Filter out dependent MMs. This algorithm might do very badly if e.g. you
+ // have a lot of independent MMs, that depend on the first one, but I doubt
+ // this will be a common scenario.
for (size_t i = 0; i < mms.size(); ++i) {
- if (mms[i] == nullptr) continue;
+ if (mms[i] == nullptr)
+ continue;
for (size_t j = i + 1; j < mms.size(); ++j) {
- if (mms[j] == nullptr) continue;
+ if (mms[j] == nullptr)
+ continue;
if (!mms[j]->couldMoveBeforeTopologically(mms[i], alias_db)) {
mms[j] = nullptr;
}
}
}
- return filter(mms, [](Node *n) { return n != nullptr; });
+ return filter(mms, [](Node* n) { return n != nullptr; });
};
- Block * block = value->node()->owningBlock();
+ Block* block = value->node()->owningBlock();
std::vector<Node*> lhses; // Will contain nodes where value is used as an lhs
std::vector<Node*> rhses; // Like above, but rhs
for (Use u : value->uses()) {
return std::make_pair(postprocess(lhses), postprocess(rhses));
}
-void BatchMMSide(Block * block, const AliasDb& alias_db) {
+void BatchMMSide(Block* block, const AliasDb& alias_db) {
// NB: 8 is the current loop unrolling factor
static constexpr size_t how_many_is_many = 8;
const auto batch_side = [&](std::vector<Node*>& mms, Side side) {
bool move_ok = mms[i]->moveBeforeTopologicallyValid(mms[i + 1], alias_db);
JIT_ASSERT(move_ok);
}
- WithInsertPoint insert_guard { mms[0] };
+ WithInsertPoint insert_guard{mms[0]};
Graph* graph = mms[0]->owningGraph();
- Node* batch_mm = graph->create(prim::MMBatchSide,
- /*inputs=*/{}, /*num_outputs=*/mms.size());
+ Node* batch_mm = graph->create(
+ prim::MMBatchSide,
+ /*inputs=*/{},
+ /*num_outputs=*/mms.size());
graph->insertNode(batch_mm);
batch_mm->i_(Symbol::attr("side"), static_cast<int>(side));
Value* const_side = mms[0]->inputs().at(side == Side::LHS ? 0 : 1);
};
std::unordered_set<Value*> considered_values;
- for (Node * node : block->nodes()) {
+ for (Node* node : block->nodes()) {
if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
- for (Value * input : node->inputs()) {
- if (/*bool not_inserted = */!considered_values.emplace(input).second) {
+ for (Value* input : node->inputs()) {
+ if (/*bool not_inserted = */ !considered_values.emplace(input).second) {
continue;
}
auto uses_with_many = gatherIndependentMMUses(input, alias_db);
}
}
}
-
}
bool hasMutableOperators(Block* block) {
BatchMMTreeReduce(graph->block());
BatchMMSide(graph->block(), alias_db);
EliminateDeadCode(graph);
- // It's possible that transpose rearrangements have created sequences of consecutive
- // transposes that didn't exist before.
+ // It's possible that transpose rearrangements have created sequences of
+ // consecutive transposes that didn't exist before.
PeepholeOptimize(graph);
}
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/ir.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
TORCH_API void BatchMM(std::shared_ptr<Graph>& graph);
-}}
+}
+} // namespace torch
#include <torch/csrc/jit/passes/canonicalize.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
// Canonicalize a graph, renumbering it so that all structurally equivalent
// graphs have same numbers.
// and replacing them with normal value names.
// Otherwise, ignores values with unique names.
std::shared_ptr<Graph> Canonicalize(
- const std::shared_ptr<Graph>& graph, bool keep_unique_names) {
+ const std::shared_ptr<Graph>& graph,
+ bool keep_unique_names) {
auto r = std::make_shared<Graph>(graph->current_scope());
std::unordered_map<Value*, Value*> rn_env;
auto rn_fn = [&](Value* v) { return rn_env.at(v); };
for (auto* input : graph->inputs()) {
auto* r_input = r->addInput();
r_input->copyMetadata(input);
- if (!keep_unique_names) r_input->setUniqueName("");
+ if (!keep_unique_names)
+ r_input->setUniqueName("");
rn_env[input] = r_input;
}
for (auto* node : graph->nodes()) {
rn_env[outputs.at(i)] = r_outputs.at(i);
}
if (node->hasAttribute(attr::Subgraph)) {
- r_node->g_(attr::Subgraph, Canonicalize(node->g(attr::Subgraph), keep_unique_names));
+ r_node->g_(
+ attr::Subgraph,
+ Canonicalize(node->g(attr::Subgraph), keep_unique_names));
}
}
for (auto* output : graph->outputs()) {
}
return r;
-
}
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/ir.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
TORCH_API std::shared_ptr<Graph> Canonicalize(
- const std::shared_ptr<Graph>& graph, bool keep_unique_names=true);
+ const std::shared_ptr<Graph>& graph,
+ bool keep_unique_names = true);
-}}
+}
+} // namespace torch
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/symbolic_variable.h>
-
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
struct ChunkOutput {
- ChunkOutput(Value * v, size_t o)
- : val(v), offset(o) {};
- Value * val;
+ ChunkOutput(Value* v, size_t o) : val(v), offset(o){};
+ Value* val;
size_t offset;
};
static c10::optional<std::vector<ChunkOutput>> getChunkOutputs(Node* chunk) {
std::vector<ChunkOutput> outputs;
for (auto list_use : chunk->output()->uses()) {
- if (list_use.user->matches("aten::select(Tensor[] list, int idx) -> Tensor", attr::b)) {
- outputs.emplace_back(list_use.user->output(),
- list_use.user->get<int64_t>(attr::b).value());
+ if (list_use.user->matches(
+ "aten::select(Tensor[] list, int idx) -> Tensor", attr::b)) {
+ outputs.emplace_back(
+ list_use.user->output(),
+ list_use.user->get<int64_t>(attr::b).value());
} else if (list_use.user->kind() == prim::ListUnpack) {
- // This sometimes happens if the sizes can't be evenly divided by the number of chunks
- if (static_cast<int64_t>(list_use.user->outputs().size()) != chunk->get<int64_t>(attr::chunks).value()) {
+ // This sometimes happens if the sizes can't be evenly divided by the
+ // number of chunks
+ if (static_cast<int64_t>(list_use.user->outputs().size()) !=
+ chunk->get<int64_t>(attr::chunks).value()) {
return c10::nullopt;
}
auto unpack_outputs = list_use.user->outputs();
// followed by an add so that it can go through the existing optimization,
// shape analysis and differentiation passes for those two individual ops.
// Later, we will fuse together those two ops into a single addmm.
- if (it->matches("aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor",
- /*const_inputs=*/{attr::beta, attr::alpha})) {
+ if (it->matches(
+ "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor",
+ /*const_inputs=*/{attr::beta, attr::alpha})) {
if (it->get<at::Scalar>(attr::alpha)->toDouble() != 1.0 ||
it->get<at::Scalar>(attr::beta)->toDouble() != 1.0) {
continue;
SymbolicVariable mat2(it->inputs()[2]);
auto mm_result = mat1.mm(mat2);
- // Set this intermediate aten::mm node to have the same output type as the original aten::addmm
- // otherwise the canonicalized graph will have DynamicType as the output of this node which is incorrect
+ // Set this intermediate aten::mm node to have the same output type as the
+ // original aten::addmm otherwise the canonicalized graph will have
+ // DynamicType as the output of this node which is incorrect
(static_cast<Value*>(mm_result))->setType(it->output()->type());
auto result = mat + mm_result;
(static_cast<Value*>(result))->setType(it->output()->type());
it->output()->replaceAllUsesWith(result);
it.destroyCurrent();
- } else if (it->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
- it->matches("aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
- it->matches("aten::mul(Tensor self, Tensor other) -> Tensor") ||
- it->matches("aten::div(Tensor self, Tensor other) -> Tensor")) {
+ } else if (
+ it->matches(
+ "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
+ it->matches(
+ "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
+ it->matches("aten::mul(Tensor self, Tensor other) -> Tensor") ||
+ it->matches("aten::div(Tensor self, Tensor other) -> Tensor")) {
if (auto other = it->get<at::Tensor>(attr::other)) {
if (other->dim() == 0) {
- WithInsertPoint insert_guard {*it};
+ WithInsertPoint insert_guard{*it};
auto graph = it->owningGraph();
auto new_other = graph->insertConstant(other->item());
std::vector<Value*> inputs = it->inputs().vec();
inputs.at(1) = new_other;
- Value * new_output = graph->insertNode(graph->create(it->kind(), inputs))->output();
+ Value* new_output =
+ graph->insertNode(graph->create(it->kind(), inputs))->output();
it->output()->replaceAllUsesWith(new_output);
}
}
- } else if (it->matches("aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]",
- /*const_inputs=*/{attr::chunks, attr::dim})) {
+ } else if (it->matches(
+ "aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]",
+ /*const_inputs=*/{attr::chunks, attr::dim})) {
if (auto orig_outputs = getChunkOutputs(*it)) {
WithInsertPoint guard(*it);
- SymbolicVariable self {it->namedInput(attr::self)};
- auto outputs = self.chunk(it->get<int64_t>(attr::chunks).value(),
- it->get<int64_t>(attr::dim).value());
+ SymbolicVariable self{it->namedInput(attr::self)};
+ auto outputs = self.chunk(
+ it->get<int64_t>(attr::chunks).value(),
+ it->get<int64_t>(attr::dim).value());
for (ChunkOutput orig_out : *orig_outputs) {
orig_out.val->replaceAllUsesWith(outputs.at(orig_out.offset));
outputs[orig_out.offset].value()->setType(orig_out.val->type());
EliminateDeadCode(graph);
}
-
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/ir.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
TORCH_API void CanonicalizeOps(const std::shared_ptr<Graph>& graph);
-}}
+}
+} // namespace torch
#include <torch/csrc/jit/assertions.h>
#include <torch/csrc/jit/interned_strings.h>
+#include <torch/csrc/jit/node_hashing.h>
#include <torch/csrc/jit/passes/alias_analysis.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
-#include <torch/csrc/jit/node_hashing.h>
#include <torch/csrc/utils/functional.h>
#include <torch/csrc/utils/hash.h>
const AliasDb& aliasDb,
std::function<Node*(Node*)> parent_lookup_fn) {
std::unordered_set<Node*, HashNode, EqualNode> subexprs;
- for (auto it = block->nodes().begin(); it != block->nodes().end(); ++ it) {
+ for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
auto node = *it;
if (node->hasSideEffects() || node->isNondeterministic() ||
aliasDb.hasWriters(node) || aliasDb.hasWildcard(node)) {
}
}
}
-}
+} // namespace
void EliminateCommonSubexpression(std::shared_ptr<Graph>& graph) {
const auto aliasDb = AliasAnalysis(graph);
#include <torch/csrc/jit/ir.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
TORCH_API void EliminateCommonSubexpression(std::shared_ptr<Graph>& graph);
-}}
+}
+} // namespace torch
-#include <torch/csrc/jit/ir.h>
-#include <unordered_set>
#include <torch/csrc/jit/interned_strings.h>
-#include <torch/csrc/jit/passes/constant_pooling.h>
+#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/node_hashing.h>
+#include <torch/csrc/jit/passes/constant_pooling.h>
+#include <unordered_set>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
namespace {
-//Very similar to the common subexpression elimination pass
-//Move all constants to the beginning of the graph, and deduplicate
-void ConstantPooling(Block * block, std::unordered_set<Node*, HashNode, EqualNode>& constants) {
+// Very similar to the common subexpression elimination pass
+// Move all constants to the beginning of the graph, and deduplicate
+void ConstantPooling(
+ Block* block,
+ std::unordered_set<Node*, HashNode, EqualNode>& constants) {
for (auto it = block->nodes().begin(); it != block->nodes().end();) {
auto node = *it;
// node may be moved to a different block so advance iterator now
} // anonymous namespace
-
void ConstantPooling(const std::shared_ptr<Graph>& graph) {
std::unordered_set<Node*, HashNode, EqualNode> constants;
ConstantPooling(graph->block(), constants);
}
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/ir.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
TORCH_API void ConstantPooling(const std::shared_ptr<Graph>& graph);
-}}
+}
+} // namespace torch
-#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/jit/constants.h>
#include <torch/csrc/jit/interpreter.h>
#include <torch/csrc/jit/ivalue.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/passes/alias_analysis.h>
+#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/utils/functional.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
namespace {
std::unordered_set<Symbol> skip_list = {
- prim::If,
- prim::Loop, //TODO: handle Loop
- prim::Constant,
- prim::Undefined,
- prim::None, // it is already a constant and propagating it will lose
- // important type information about which Optional type it is
- // TODO (zach): we should consider skipping tensor factories in the cases
- // where the constant tensor would be large but cheap to create.
- };
+ prim::If,
+ prim::Loop, // TODO: handle Loop
+ prim::Constant,
+ prim::Undefined,
+ prim::None, // it is already a constant and propagating it will lose
+ // important type information about which Optional type it is
+ // TODO (zach): we should consider skipping tensor factories in the cases
+ // where the constant tensor would be large but cheap to create.
+};
std::vector<IValue> runNode(Node* n) {
auto op = getOperation(n);
auto var_outputs = fmap(stack, [&](IValue v) -> IValue {
if (v.isTensor()) {
auto t = std::move(v).toTensor();
- if(t.defined()) {
+ if (t.defined()) {
return IValue(autograd::as_variable_ref(t).data());
} else {
return t;
try {
auto new_output = graph->insertConstant(outputs[i]);
n->outputs()[i]->replaceAllUsesWith(new_output);
- } catch(constant_not_supported_error& err) {
+ } catch (constant_not_supported_error& err) {
// we cannot actually represent the IValue as a constant node,
// so we give up replacing it
}
}
}
-void inlineIf(Block *body, Node * n) {
- for(auto it = body->nodes().begin(); it != body->nodes().end();) {
- Node *body_node = *it;
- //advance iterator because after body_node is moved its next pointer will be
- //to n
+void inlineIf(Block* body, Node* n) {
+ for (auto it = body->nodes().begin(); it != body->nodes().end();) {
+ Node* body_node = *it;
+ // advance iterator because after body_node is moved its next pointer will
+ // be to n
it++;
body_node->moveBefore(n);
}
for (size_t i = 0; i < n->outputs().size(); ++i) {
n->outputs().at(i)->replaceAllUsesWith(body->outputs().at(i));
}
- // NB: destroy the node here, because it might contain side effects, like print
+ // NB: destroy the node here, because it might contain side effects, like
+ // print
n->destroy();
}
-bool isTrueConstant(Value *val) {
+bool isTrueConstant(Value* val) {
c10::optional<bool> maybe_value = constant_as<bool>(val);
JIT_ASSERT(maybe_value);
return *maybe_value;
}
-void inlineIf(Node *n) {
+void inlineIf(Node* n) {
if (isTrueConstant(n->input())) {
inlineIf(n->blocks()[0], n);
} else {
}
}
-//remove extra outputs from the node
-bool removeExtraNodeOutputs(Node *n) {
+// remove extra outputs from the node
+bool removeExtraNodeOutputs(Node* n) {
JIT_ASSERTM(n->kind() == prim::If, "Only supported for If nodes");
auto true_block = n->blocks()[0];
auto false_block = n->blocks()[1];
auto initial_outputs = true_block->outputs().size();
- for (size_t i = 0; i < true_block->outputs().size(); ) {
- //neither block changes the output value
+ for (size_t i = 0; i < true_block->outputs().size();) {
+ // neither block changes the output value
if (true_block->outputs()[i] == false_block->outputs()[i]) {
n->outputs().at(i)->replaceAllUsesWith(true_block->outputs()[i]);
n->eraseOutput(i);
true_block->eraseOutput(i);
false_block->eraseOutput(i);
} else {
- i++; //increment bc we didn't remove current index
+ i++; // increment bc we didn't remove current index
}
}
- //an output was removed
+ // an output was removed
return initial_outputs != true_block->outputs().size();
}
return v->node()->kind() == prim::Constant;
});
bool supported_node = !n->kind().is_onnx() &&
- skip_list.count(n->kind()) == 0 && !n->isNondeterministic() && !n->hasSideEffects() &&
- !aliasDb.hasWriters(n) && !aliasDb.hasWildcard(n);
+ skip_list.count(n->kind()) == 0 && !n->isNondeterministic() &&
+ !n->hasSideEffects() && !aliasDb.hasWriters(n) && !aliasDb.hasWildcard(n);
auto run_blocks = [&]() {
if (recurse) {
- for (Block * block : n->blocks()) {
+ for (Block* block : n->blocks()) {
ConstantPropagation(block, aliasDb, recurse);
}
}
};
if (n->kind() == prim::If) {
run_blocks();
- //inline node if we can, otherwise check for simplified outputs
+ // inline node if we can, otherwise check for simplified outputs
if (constant_inputs) {
inlineIf(n);
} else {
removeExtraNodeOutputs(n);
}
- //don't rerun run_blocks
+ // don't rerun run_blocks
return;
} else if (constant_inputs && supported_node) {
propagateNode(n);
}
- //TODO handle loop nodes. Even if a loop node contains an if that is
- //inlined its mutated variables currently don't get updated
+ // TODO handle loop nodes. Even if a loop node contains an if that is
+ // inlined its mutated variables currently don't get updated
run_blocks();
}
void ConstantPropagation(Block* block, const AliasDb& aliasDb, bool recurse) {
- for(auto it = block->nodes().begin(); it != block->nodes().end();) {
- Node *n = *it;
- it++; //advance iterator bc the current node may be destroyed
+ for (auto it = block->nodes().begin(); it != block->nodes().end();) {
+ Node* n = *it;
+ it++; // advance iterator bc the current node may be destroyed
ConstantPropagation(n, aliasDb, recurse);
}
}
} // anonymous namespace
-
void ConstantPropagation(std::shared_ptr<Graph>& graph) {
const auto aliasDb = AliasAnalysis(graph);
ConstantPropagation(graph->block(), aliasDb, true);
EliminateDeadCode(graph);
}
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/ir.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
TORCH_API void ConstantPropagation(std::shared_ptr<Graph>& graph);
-}}
+}
+} // namespace torch
#pragma once
-#include <torch/csrc/jit/ir.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/ir.h>
#include <cstddef>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
// insert GraphExecutor nodes that group together
// subgraphs that are differentiable by the jit's autodiff passes
TORCH_API std::vector<Node*> CreateAutodiffSubgraphs(
const std::shared_ptr<Graph>& graph,
size_t threshold = 2);
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/passes/dead_code_elimination.h>
-#include <torch/csrc/jit/passes/alias_analysis.h>
#include <torch/csrc/jit/ir_views.h>
#include <torch/csrc/jit/passes/alias_analysis.h>
#include <torch/csrc/jit/ir.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
// If given a top-level graph, DCE will construct do alias analysis that allows
// for "smarter" dead code elimination (we will eliminate mutable ops if we can
//
// So, prefer to use the graph version if you can.
TORCH_API void EliminateDeadCode(const std::shared_ptr<Graph>& graph);
-TORCH_API void EliminateDeadCode(Block *block, bool recurse=true);
+TORCH_API void EliminateDeadCode(Block* block, bool recurse = true);
// Invoke the user-provided callback on all live values before deleting anything
TORCH_API void EliminateDeadCode(
-#include <torch/csrc/jit/passes/erase_number_types.h>
#include <torch/csrc/jit/constants.h>
+#include <torch/csrc/jit/passes/erase_number_types.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
static void EraseNumberTypesOnBlock(Block* block) {
for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
// Let DCE cleanup
} break;
default: {
- for(auto o : it->outputs()) {
+ for (auto o : it->outputs()) {
if (o->type()->isSubtypeOf(NumberType::get())) {
o->setType(CompleteTensorType::fromNumberType(o->type()));
} else if (o->type()->isSubtypeOf(BoolType::get())) {
EraseNumberTypesOnBlock(graph->block());
}
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/ir.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
// Erase NumberType information. This is necessary for and only used in
// exporting to ONNX. This pass ensures that no remaining Values have
// - prim::Constant nodes which are numbers get changed into 0-dim tensors of
// the corresponding type
// - prim::TensorToNum, prim::ImplicitTensorToNum and prim::NumToTensor nodes
-// are erased.
+// are erased.
//
// The pass assumes that DCE will be called sometime after.
TORCH_API void EraseNumberTypes(const std::shared_ptr<Graph>& graph);
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/passes/graph_fuser.h>
+#include <ATen/ExpandUtils.h>
+#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/autodiff.h>
+#include <torch/csrc/jit/fuser/interface.h>
+#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/passes/alias_analysis.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/symbolic_variable.h>
-#include <torch/csrc/jit/fuser/interface.h>
-#include <torch/csrc/jit/operator.h>
-#include <torch/csrc/jit/autodiff.h>
-#include <torch/csrc/jit/assertions.h>
-#include <ATen/ExpandUtils.h>
#include <unordered_map>
#ifdef USE_CUDA
- #include <cuda.h> // for CUDA_VERSION
+#include <cuda.h> // for CUDA_VERSION
#endif
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
namespace {
// - Produces contiguous outputs
// Some of these restrictions may be relaxable, but you should
// carefully read the code first, as we rely on these assumptions.
-bool isSimpleMap(Node *node) {
- static OperatorSet simple_mappable {{
- "aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor",
-
- "aten::abs(Tensor self) -> Tensor",
- "aten::acos(Tensor self) -> Tensor",
- "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
- "aten::asin(Tensor self) -> Tensor",
- "aten::atan(Tensor self) -> Tensor",
- "aten::atan2(Tensor self, Tensor other) -> Tensor",
- "aten::ceil(Tensor self) -> Tensor",
- "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
- "aten::cos(Tensor self) -> Tensor",
- "aten::cosh(Tensor self) -> Tensor",
- "aten::div(Tensor self, Tensor other) -> Tensor",
- "aten::exp(Tensor self) -> Tensor",
- "aten::expm1(Tensor self) -> Tensor",
- "aten::erf(Tensor self) -> Tensor",
- "aten::erfc(Tensor self) -> Tensor",
- "aten::floor(Tensor self) -> Tensor",
- "aten::fmod(Tensor self, Tensor other) -> Tensor",
- "aten::frac(Tensor self) -> Tensor",
- "aten::lgamma(Tensor self) -> Tensor",
- "aten::log(Tensor self) -> Tensor",
- "aten::log10(Tensor self) -> Tensor",
- "aten::log1p(Tensor self) -> Tensor",
- "aten::log2(Tensor self) -> Tensor",
- "aten::max(Tensor self, Tensor other) -> Tensor",
- "aten::min(Tensor self, Tensor other) -> Tensor",
- "aten::mul(Tensor self, Tensor other) -> Tensor",
- "aten::neg(Tensor self) -> Tensor",
- "aten::pow(Tensor self, Tensor exponent) -> Tensor",
- "aten::pow(Tensor self, Scalar exponent) -> Tensor",
- // See https://github.com/pytorch/pytorch/issues/14674 and make sure you
- // won't make the same mistake before you reenable this.
- //"aten::rand_like(Tensor self) -> Tensor",
- "aten::reciprocal(Tensor self) -> Tensor",
- "aten::relu(Tensor self) -> Tensor",
- "aten::remainder(Tensor self, Tensor other) -> Tensor",
- "aten::round(Tensor self) -> Tensor",
- "aten::rsqrt(Tensor self) -> Tensor",
- "aten::sigmoid(Tensor self) -> Tensor",
- "aten::sin(Tensor self) -> Tensor",
- "aten::sinh(Tensor self) -> Tensor",
- "aten::sqrt(Tensor self) -> Tensor",
- "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
- "aten::tan(Tensor self) -> Tensor",
- "aten::tanh(Tensor self) -> Tensor",
- "aten::trunc(Tensor self) -> Tensor",
- "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
- "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
- "aten::mul(Tensor self, Scalar other) -> Tensor",
- "aten::div(Tensor self, Scalar other) -> Tensor",
-
- "aten::eq(Tensor self, Tensor other) -> Tensor",
- "aten::eq(Tensor self, Scalar other) -> Tensor",
- "aten::ne(Tensor self, Tensor other) -> Tensor",
- "aten::ne(Tensor self, Scalar other) -> Tensor",
- "aten::ge(Tensor self, Tensor other) -> Tensor",
- "aten::ge(Tensor self, Scalar other) -> Tensor",
- "aten::gt(Tensor self, Tensor other) -> Tensor",
- "aten::gt(Tensor self, Scalar other) -> Tensor",
- "aten::le(Tensor self, Tensor other) -> Tensor",
- "aten::le(Tensor self, Scalar other) -> Tensor",
- "aten::lt(Tensor self, Tensor other) -> Tensor",
- "aten::lt(Tensor self, Scalar other) -> Tensor",
-
- "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
-
- "aten::type_as(Tensor self, Tensor other) -> Tensor",
+bool isSimpleMap(Node* node) {
+ static OperatorSet simple_mappable{{
+ "aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor",
+
+ "aten::abs(Tensor self) -> Tensor",
+ "aten::acos(Tensor self) -> Tensor",
+ "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
+ "aten::asin(Tensor self) -> Tensor",
+ "aten::atan(Tensor self) -> Tensor",
+ "aten::atan2(Tensor self, Tensor other) -> Tensor",
+ "aten::ceil(Tensor self) -> Tensor",
+ "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
+ "aten::cos(Tensor self) -> Tensor",
+ "aten::cosh(Tensor self) -> Tensor",
+ "aten::div(Tensor self, Tensor other) -> Tensor",
+ "aten::exp(Tensor self) -> Tensor",
+ "aten::expm1(Tensor self) -> Tensor",
+ "aten::erf(Tensor self) -> Tensor",
+ "aten::erfc(Tensor self) -> Tensor",
+ "aten::floor(Tensor self) -> Tensor",
+ "aten::fmod(Tensor self, Tensor other) -> Tensor",
+ "aten::frac(Tensor self) -> Tensor",
+ "aten::lgamma(Tensor self) -> Tensor",
+ "aten::log(Tensor self) -> Tensor",
+ "aten::log10(Tensor self) -> Tensor",
+ "aten::log1p(Tensor self) -> Tensor",
+ "aten::log2(Tensor self) -> Tensor",
+ "aten::max(Tensor self, Tensor other) -> Tensor",
+ "aten::min(Tensor self, Tensor other) -> Tensor",
+ "aten::mul(Tensor self, Tensor other) -> Tensor",
+ "aten::neg(Tensor self) -> Tensor",
+ "aten::pow(Tensor self, Tensor exponent) -> Tensor",
+ "aten::pow(Tensor self, Scalar exponent) -> Tensor",
+ // See https://github.com/pytorch/pytorch/issues/14674 and make sure you
+ // won't make the same mistake before you reenable this.
+ //"aten::rand_like(Tensor self) -> Tensor",
+ "aten::reciprocal(Tensor self) -> Tensor",
+ "aten::relu(Tensor self) -> Tensor",
+ "aten::remainder(Tensor self, Tensor other) -> Tensor",
+ "aten::round(Tensor self) -> Tensor",
+ "aten::rsqrt(Tensor self) -> Tensor",
+ "aten::sigmoid(Tensor self) -> Tensor",
+ "aten::sin(Tensor self) -> Tensor",
+ "aten::sinh(Tensor self) -> Tensor",
+ "aten::sqrt(Tensor self) -> Tensor",
+ "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
+ "aten::tan(Tensor self) -> Tensor",
+ "aten::tanh(Tensor self) -> Tensor",
+ "aten::trunc(Tensor self) -> Tensor",
+ "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
+ "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
+ "aten::mul(Tensor self, Scalar other) -> Tensor",
+ "aten::div(Tensor self, Scalar other) -> Tensor",
+
+ "aten::eq(Tensor self, Tensor other) -> Tensor",
+ "aten::eq(Tensor self, Scalar other) -> Tensor",
+ "aten::ne(Tensor self, Tensor other) -> Tensor",
+ "aten::ne(Tensor self, Scalar other) -> Tensor",
+ "aten::ge(Tensor self, Tensor other) -> Tensor",
+ "aten::ge(Tensor self, Scalar other) -> Tensor",
+ "aten::gt(Tensor self, Tensor other) -> Tensor",
+ "aten::gt(Tensor self, Scalar other) -> Tensor",
+ "aten::le(Tensor self, Tensor other) -> Tensor",
+ "aten::le(Tensor self, Scalar other) -> Tensor",
+ "aten::lt(Tensor self, Tensor other) -> Tensor",
+ "aten::lt(Tensor self, Scalar other) -> Tensor",
+
+ "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
+
+ "aten::type_as(Tensor self, Tensor other) -> Tensor",
}};
if (!simple_mappable.find(node)) {
return false;
}
// Check that all non-tensor inputs are constant
- for (Value * input : node->inputs()) {
+ for (Value* input : node->inputs()) {
if (input->type()->isSubtypeOf(DynamicType::get())) {
continue;
}
return true;
}
-Value * broadcastSizes(at::ArrayRef<Value*> sizes) {
+Value* broadcastSizes(at::ArrayRef<Value*> sizes) {
JIT_ASSERT(!sizes.empty());
- Graph * graph = sizes[0]->owningGraph();
- Node * broadcast_n = graph->insertNode(graph->create(prim::BroadcastSizes, sizes));
+ Graph* graph = sizes[0]->owningGraph();
+ Node* broadcast_n =
+ graph->insertNode(graph->create(prim::BroadcastSizes, sizes));
broadcast_n->output()->setType(ListType::ofInts());
return broadcast_n->output();
}
struct GraphFuser {
- Block * block_;
+ Block* block_;
std::shared_ptr<Graph> graph_;
GraphFuser(Block* block, std::shared_ptr<Graph> graph)
: block_(block), graph_(std::move(graph)) {}
- value_list tensorInputs(Node * node) {
- return filter(node->inputs(), [](Value * v) {
+ value_list tensorInputs(Node* node) {
+ return filter(node->inputs(), [](Value* v) {
return v->type()->isSubtypeOf(DynamicType::get());
});
}
- bool isFusable(Node * node) {
+ bool isFusable(Node* node) {
// We don't want to bother with cross-block node movements, as they
// are not necessarily correct.
- if (node->owningBlock() != block_) return false;
+ if (node->owningBlock() != block_)
+ return false;
return node->kind() == prim::FusionGroup || isSimpleMap(node);
}
- bool isFusableCatNode(Node * node) {
+ bool isFusableCatNode(Node* node) {
if (node->kind() != aten::cat)
return false;
if (!node->is_constant(attr::dim))
return false;
auto tensors_node = node->namedInput(attr::tensors)->node();
- if (tensors_node->kind() != prim::ListConstruct) return false;
- // NB: Note that technically other uses of the list aren't a big problem for us.
- // It would be enough to place the prim::FusedConcat before the prim::ListConstruct, and
- // allUsersAreThisConsumerOrOccurAfterIt would still be satisfied. However, I don't expect this
- // to be necessary any time soon, and so we're simply assuming that we don't have to deal with it.
- if (tensors_node->output()->uses().size() > 1) return false;
+ if (tensors_node->kind() != prim::ListConstruct)
+ return false;
+ // NB: Note that technically other uses of the list aren't a big problem for
+ // us. It would be enough to place the prim::FusedConcat before the
+ // prim::ListConstruct, and allUsersAreThisConsumerOrOccurAfterIt would
+ // still be satisfied. However, I don't expect this to be necessary any time
+ // soon, and so we're simply assuming that we don't have to deal with it.
+ if (tensors_node->output()->uses().size() > 1)
+ return false;
return true;
}
// Can this node produce an _output_ of a fusion group?
- // all Fusable nodes can do this, but additionally Concat, which normally cannot be fused
- // because it is not a simple map, can be put in a fusion group
- // as long as no items in the group read the output of concat
- bool isFusableAsExitNode(Node * node) {
+ // all Fusable nodes can do this, but additionally Concat, which normally
+ // cannot be fused because it is not a simple map, can be put in a fusion
+ // group as long as no items in the group read the output of concat
+ bool isFusableAsExitNode(Node* node) {
return isFusable(node) || isFusableOnlyAsExitNode(node);
}
- bool isFusableOnlyAsExitNode(Node * node) {
+ bool isFusableOnlyAsExitNode(Node* node) {
return isFusableCatNode(node) || node->kind() == prim::FusedConcat;
}
- bool calculatesSize(Node * node) {
+ bool calculatesSize(Node* node) {
return node->matches("aten::size(Tensor self) -> int[]");
}
- bool allUsersAreThisConsumerOrCalcSizes(Node * consumer, Value * producer) {
+ bool allUsersAreThisConsumerOrCalcSizes(Node* consumer, Value* producer) {
auto defining_node = producer->node();
- for(auto o : defining_node->outputs()) {
- for(auto u : o->uses()) {
- if(u.user != consumer && !calculatesSize(u.user))
+ for (auto o : defining_node->outputs()) {
+ for (auto u : o->uses()) {
+ if (u.user != consumer && !calculatesSize(u.user))
return false;
}
}
return true;
}
- bool mustRemainAsFusionGroupOutput(Value * producer) {
+ bool mustRemainAsFusionGroupOutput(Value* producer) {
if (producer->node()->kind() != prim::FusionGroup) {
return false;
}
auto subgraph = producer->node()->g(attr::Subgraph);
- auto * node = subgraph->outputs().at(producer->offset())->node();
+ auto* node = subgraph->outputs().at(producer->offset())->node();
return isFusableOnlyAsExitNode(node);
}
- Graph & getSubgraph(Node * n) {
+ Graph& getSubgraph(Node* n) {
JIT_ASSERT(n->kind() == prim::FusionGroup);
return *n->g(attr::Subgraph);
}
- void mergeFusionGroups(Node *consumer_group, Node *producer_group) {
+ void mergeFusionGroups(Node* consumer_group, Node* producer_group) {
// Now we have two fusion groups!
- // Revert the fusion - place all inner nodes of producer back in the outer graph.
+ // Revert the fusion - place all inner nodes of producer back in the outer
+ // graph.
std::vector<Node*> temporary_nodes;
auto producer_subgraph = &getSubgraph(producer_group);
// Clone all nodes
for (auto inner : producer_subgraph->nodes()) {
- Node * outer = block_->owningGraph()->createClone(inner, [&](Value * k) -> Value* {
- return inner_to_outer.at(k);
- });
+ Node* outer = block_->owningGraph()->createClone(
+ inner, [&](Value* k) -> Value* { return inner_to_outer.at(k); });
outer->insertBefore(producer_group);
temporary_nodes.emplace_back(outer);
auto inner_outputs = inner->outputs();
producer_group->outputs()[i]->replaceAllUsesWith(outer_output);
}
producer_group->destroy();
- producer_group = nullptr; // Just to get a clear error in case someone uses it
+ producer_group =
+ nullptr; // Just to get a clear error in case someone uses it
// Inline the temporary nodes into the first group
auto consumer_subgraph = &getSubgraph(consumer_group);
- for (auto it = temporary_nodes.rbegin(); it != temporary_nodes.rend(); ++it) {
- Node *node = *it;
- Node *merged = mergeNodeIntoGroup(consumer_group, node);
+ for (auto it = temporary_nodes.rbegin(); it != temporary_nodes.rend();
+ ++it) {
+ Node* node = *it;
+ Node* merged = mergeNodeIntoGroup(consumer_group, node);
// If any of the outputs are still used then we need to add them
auto outputs = node->outputs();
for (size_t i = 0; i < outputs.size(); ++i) {
auto output = outputs[i];
- if (output->uses().size() == 0) continue;
+ if (output->uses().size() == 0)
+ continue;
consumer_subgraph->registerOutput(merged->outputs()[i]);
auto new_output = consumer_group->addOutput();
output->replaceAllUsesWith(new_output);
// insert a producer node into a consuming fusion group.
// DOES NOT WORK if n is a consumer of an output of the fusion group
// returns the node _inside_ the group that represents the node
- Node * mergeNodeIntoGroup(Node* group, Node * n) {
+ Node* mergeNodeIntoGroup(Node* group, Node* n) {
JIT_ASSERT(n->kind() != prim::FusionGroup);
- auto & subgraph = getSubgraph(group);
+ auto& subgraph = getSubgraph(group);
// map from nodes in the surrounding graph to parameters in the fusion
// group's subgraph that correspond to them
- std::unordered_map<Value*,Value*> inputs_map;
+ std::unordered_map<Value*, Value*> inputs_map;
size_t i = 0;
JIT_ASSERT(group->inputs().size() == subgraph.inputs().size());
- for(auto input : group->inputs()) {
+ for (auto input : group->inputs()) {
inputs_map[input] = subgraph.inputs()[i++];
}
- // add n's inputs to the fusion group's input list if we don't already have them
+ // add n's inputs to the fusion group's input list if we don't already have
+ // them
WithInsertPoint guard(*subgraph.nodes().begin());
for (auto input : n->inputs()) {
if (inputs_map.count(input) == 0) {
inputs_map[input] = in_group;
group->addInput(input);
} else {
- // We don't support passing in scalars as arguments to fused kernels, so we generally
- // don't allow fusing tensor-scalar operations unless the scalar is constant. In those
- // cases we inline the constants directly in the body of the fused group.
+ // We don't support passing in scalars as arguments to fused kernels,
+ // so we generally don't allow fusing tensor-scalar operations unless
+ // the scalar is constant. In those cases we inline the constants
+ // directly in the body of the fused group.
JIT_ASSERT(input->node()->kind() == prim::Constant);
- Node * in_const = subgraph.createClone(input->node(), [](Value*) -> Value* { throw std::runtime_error("unexpected input"); });
+ Node* in_const =
+ subgraph.createClone(input->node(), [](Value*) -> Value* {
+ throw std::runtime_error("unexpected input");
+ });
subgraph.insertNode(in_const);
inputs_map[input] = in_const->output();
}
}
}
// copy n into the graph, remapping its inputs to internal nodes
- Node * in_graph = subgraph.createClone(n,[&](Value * k)-> Value* {
- return inputs_map[k];
- });
+ Node* in_graph = subgraph.createClone(
+ n, [&](Value* k) -> Value* { return inputs_map[k]; });
// if n's outputs are already inputs to the fusion group,
// we need to remove them because n is now inside the fusion group.
//
auto inputs = group->inputs();
for (size_t i = 0; i < n->outputs().size(); ++i) {
auto it = std::find(inputs.begin(), inputs.end(), n->outputs()[i]);
- if(it != inputs.end()) {
+ if (it != inputs.end()) {
size_t p = it - inputs.begin();
group->removeInput(p);
subgraph.inputs()[p]->replaceAllUsesWith(in_graph->outputs()[i]);
// turn consumer node n into a fusion group with just n inside
// to prepare for fusion and replace uses of n with the new group
- Node * createSingletonFusionGroup(Node * n) {
+ Node* createSingletonFusionGroup(Node* n) {
auto group = block_->owningGraph()->createFusionGroup();
// propogate position information for the new node so we can always
// have a valid mapping
group->insertBefore(n);
- Node * mergedNode = mergeNodeIntoGroup(group,n);
+ Node* mergedNode = mergeNodeIntoGroup(group, n);
getSubgraph(group).registerOutput(mergedNode->output());
auto sel = group->addOutput();
sel->copyMetadata(n->output());
}
// TODO: remove this and use WithInsertPoint instead
- void insertAt(Node ** insertion_point, Node * n) {
+ void insertAt(Node** insertion_point, Node* n) {
n->insertAfter(*insertion_point);
*insertion_point = n;
}
Node* consumer,
Value* producer,
const AliasDb& aliasDb) {
- // this handles cases where producer can be moved _into_ the fusion group of consumer.
+ // this handles cases where producer can be moved _into_ the fusion group of
+ // consumer.
// TODO: extend to fusion of consumer into _producer's_ fusion blob
// if the consumer allInputsAreThisProducer(consumer,producer)
// we can move the consumer up into the producer.
- // but this requires better handling of merging fusion groups so it is not done now
+ // but this requires better handling of merging fusion groups so it is not
+ // done now
Node* real_consumer = consumer->kind() == aten::cat
? consumer->namedInput(attr::tensors)->node()
: consumer;
auto group = consumer;
if (consumer->kind() == aten::cat) {
- Graph * graph = consumer->owningGraph();
- Node * list_construct = consumer->namedInput(attr::tensors)->node();
+ Graph* graph = consumer->owningGraph();
+ Node* list_construct = consumer->namedInput(attr::tensors)->node();
int64_t dim = consumer->get<int64_t>(attr::dim).value();
- Node * fused_cat = graph->create(prim::FusedConcat, list_construct->inputs())->i_(attr::dim, dim);
+ Node* fused_cat =
+ graph->create(prim::FusedConcat, list_construct->inputs())
+ ->i_(attr::dim, dim);
fused_cat->insertBefore(list_construct);
fused_cat->output()->copyMetadata(consumer->output());
consumer->output()->replaceAllUsesWith(fused_cat->output());
return group;
}
JIT_ASSERT(producer->node()->outputs().size() == 1);
- Node * merged = mergeNodeIntoGroup(group, producer->node());
+ Node* merged = mergeNodeIntoGroup(group, producer->node());
// remaining uses of this producer can occur because we allow
// fusion in cases where uses remain after the consumer
// if these exist, re-route them to the version of producer
// created in FusionGroup
- if(producer->uses().size() != 0) {
+ if (producer->uses().size() != 0) {
getSubgraph(group).registerOutput(merged->output());
- Value * new_producer = group->addOutput();
+ Value* new_producer = group->addOutput();
new_producer->copyMetadata(producer);
producer->replaceAllUsesWith(new_producer);
}
return false;
}
// Does the chunk have constant chunks/dim?
- auto * chunk = producer->node();
+ auto* chunk = producer->node();
if (chunk->kind() != prim::ConstantChunk)
- return false;
+ return false;
// And all uses of the chunk are in this consumer
for (auto s : chunk->outputs()) {
for (auto u : s->uses()) {
return c10::nullopt;
}
size_t input_index = it - group->inputs().begin();
- auto & subgraph = getSubgraph(group);
- auto * subgraph_input = subgraph.inputs().at(input_index);
+ auto& subgraph = getSubgraph(group);
+ auto* subgraph_input = subgraph.inputs().at(input_index);
// If subgraph_input is an input to prim::ConstantChunk, it will have 1 use
- auto * node = subgraph_input->uses().at(0).user;
+ auto* node = subgraph_input->uses().at(0).user;
if (node->kind() == prim::ConstantChunk) {
JIT_ASSERT(subgraph_input->uses().size() == 1);
return node;
}
void fuseChunkByReusingExistingFusedChunk(
- Node * group, Node * chunk, Node * existingFusedChunk) {
+ Node* group,
+ Node* chunk,
+ Node* existingFusedChunk) {
if (chunk->outputs().size() != existingFusedChunk->outputs().size()) {
return;
}
- auto & subgraph = getSubgraph(group);
+ auto& subgraph = getSubgraph(group);
for (size_t i = 0; i < chunk->outputs().size(); ++i) {
// Find the input to the FusionGroup (group)
- auto * replacement_val = existingFusedChunk->outputs().at(i);
- auto * val = chunk->outputs().at(i);
+ auto* replacement_val = existingFusedChunk->outputs().at(i);
+ auto* val = chunk->outputs().at(i);
auto it = std::find(group->inputs().begin(), group->inputs().end(), val);
auto input_index = it - group->inputs().begin();
}
// There are two invariants for prim::ConstantChunk:
- // (1) the tensor input to prim::ConstantChunk must be an input to the fusion group
- // (2) no two ConstantChunks in the same FusionGroup can share a tensor input.
- graph_node_list::iterator fuseChunk(Node * consumer, Value * producer) {
- auto * chunk = producer->node();
+ // (1) the tensor input to prim::ConstantChunk must be an input to the fusion
+ // group (2) no two ConstantChunks in the same FusionGroup can share a tensor
+ // input.
+ graph_node_list::iterator fuseChunk(Node* consumer, Value* producer) {
+ auto* chunk = producer->node();
JIT_ASSERT(consumer->kind() == prim::FusionGroup);
JIT_ASSERT(chunk->kind() == prim::ConstantChunk);
// if producer's input is already an input to a prim::ConstantChunk node,
// we cannot add a new prim::ConstantChunk node because of invariant (2).
- auto * chunked_tensor = producer->node()->input();
+ auto* chunked_tensor = producer->node()->input();
if (auto existingFusedChunk = findFusedChunk(consumer, chunked_tensor)) {
- fuseChunkByReusingExistingFusedChunk(consumer, chunk, *existingFusedChunk);
+ fuseChunkByReusingExistingFusedChunk(
+ consumer, chunk, *existingFusedChunk);
return consumer->reverseIterator();
}
}
}
// Sort in reverse topological order
- std::sort(result.begin(), result.end(), [&](Value * a, Value * b) {
+ std::sort(result.begin(), result.end(), [&](Value* a, Value* b) {
return a->node()->isAfter(b->node());
});
return result;
}
- graph_node_list::iterator scanNodeForChunks(Node * consumer) {
+ graph_node_list::iterator scanNodeForChunks(Node* consumer) {
if (consumer->kind() == prim::FusionGroup) {
auto inputs = sortReverseTopological(consumer->inputs());
- for(auto producer : inputs) {
+ for (auto producer : inputs) {
if (!canFuseChunk(consumer, producer)) {
continue;
}
return ++consumer->reverseIterator();
}
- void insertExplicitBroadcast(Node *node) {
- WithInsertPoint insert_guard { node };
+ void insertExplicitBroadcast(Node* node) {
+ WithInsertPoint insert_guard{node};
auto tensors = tensorInputs(node);
- auto new_tensors = SymbolicVariable::broadcast_tensors(fmap<SymbolicVariable>(tensors));
+ auto new_tensors =
+ SymbolicVariable::broadcast_tensors(fmap<SymbolicVariable>(tensors));
// Replace tensors inputs with broadcasted values
auto new_tensors_it = new_tensors.begin();
}
}
- Node * promoteChunkToBroadcastingChunk(Node * chunk) {
+ Node* promoteChunkToBroadcastingChunk(Node* chunk) {
JIT_ASSERT(chunk->kind() == prim::ConstantChunk);
size_t nchunks = chunk->i(attr::chunks);
- Node * bchunk = chunk->owningGraph()->create(prim::BroadcastingChunk, nchunks);
+ Node* bchunk =
+ chunk->owningGraph()->create(prim::BroadcastingChunk, nchunks);
bchunk->addInput(chunk->input());
for (size_t i = 0; i < nchunks; ++i) {
- auto * old_output = chunk->outputs().at(i);
- auto * new_output = bchunk->outputs().at(i);
+ auto* old_output = chunk->outputs().at(i);
+ auto* new_output = bchunk->outputs().at(i);
new_output->copyMetadata(old_output);
old_output->replaceAllUsesWith(new_output);
}
// we exit the fusion group.
//
// NB: The intermediate BroadcastingChunk is important for moving chunks past
- // more than one operation: the graph fuser is not able to easily move operations
- // around broadcast_tensors + chunk nodes. Let f, g, h be fusible ops
+ // more than one operation: the graph fuser is not able to easily move
+ // operations around broadcast_tensors + chunk nodes. Let f, g, h be fusible
+ // ops
// x = f(v, w)
// z = g(x, y)
// a, b = chunk(z)
// b = g(bx, by)
// c = h(a, b)
- bool tryToMoveChunk(Node * consumer, Value * producer) {
+ bool tryToMoveChunk(Node* consumer, Value* producer) {
// is the output from a chunk/bchunk node?
- auto * chunk = producer->node();
- if (chunk->kind() != prim::ConstantChunk && chunk->kind() != prim::BroadcastingChunk)
+ auto* chunk = producer->node();
+ if (chunk->kind() != prim::ConstantChunk &&
+ chunk->kind() != prim::BroadcastingChunk)
return false;
- // try to find a producer to move after the chunk/bchunk. The producer must be
- // fusible into the consumer.
+ // try to find a producer to move after the chunk/bchunk. The producer must
+ // be fusible into the consumer.
auto it = std::find_if(
chunk->inputs().begin(),
chunk->inputs().end(),
- [&](Value * producer_for_chunk) {
+ [&](Value* producer_for_chunk) {
return isFusable(producer_for_chunk->node()) &&
allUsersAreThisConsumerOrCalcSizes(chunk, producer_for_chunk);
});
if (it == chunk->inputs().end()) {
return false;
}
- Value * producer_for_chunk = *it;
+ Value* producer_for_chunk = *it;
size_t producer_index = it - chunk->inputs().begin();
// all uses of the chunk must be in in this consumer
}
}
// multiple return operators
- Node * producer_for_chunk_node = producer_for_chunk->node();
+ Node* producer_for_chunk_node = producer_for_chunk->node();
JIT_ASSERT(producer_for_chunk_node->outputs().size() == 1);
// Convert chunk to bchunk, if it isn't one already. The bchunk represents a
// broadcast and one or more chunk operations.
- auto * bchunk = chunk;
+ auto* bchunk = chunk;
if (chunk->kind() == prim::ConstantChunk) {
bchunk = promoteChunkToBroadcastingChunk(chunk);
}
std::vector<Value*> producer_chunk_outputs;
for (size_t i = 0; i < nchunks; i++) {
- producer_chunk_outputs.push_back(bchunk->output(nchunks * producer_index + i));
+ producer_chunk_outputs.push_back(
+ bchunk->output(nchunks * producer_index + i));
}
// Add each of op's operands to the bchunk node.
std::vector<std::vector<Value*>> chunked_inputs;
for (auto input : producer_for_chunk_node->inputs()) {
- // XXX: we only work with pointwise ops in here, so we know it is valid to push
- // the concat only through tensor arguments (and all other args can be safely ignored).
+ // XXX: we only work with pointwise ops in here, so we know it is valid to
+ // push the concat only through tensor arguments (and all other args can
+ // be safely ignored).
if (!input->type()->isSubtypeOf(DynamicType::get()))
continue;
bchunk->addInput(input);
chunked_inputs.emplace_back(); // alas, to not be C++17
for (auto chunk_sel : producer_chunk_outputs) {
- Value * input_chunk_sel = bchunk->addOutput();
+ Value* input_chunk_sel = bchunk->addOutput();
input_chunk_sel->setType(chunk_sel->type());
chunked_inputs.back().push_back(input_chunk_sel);
}
// and then rewrite the graph to use them!
for (auto chunk_sel : producer_chunk_outputs) {
auto original_inputs = producer_for_chunk_node->inputs();
- Node * chunked_op = block_->owningGraph()->create(producer_for_chunk_node->kind());
+ Node* chunked_op =
+ block_->owningGraph()->create(producer_for_chunk_node->kind());
chunked_op->copyAttributes(*producer_for_chunk_node);
chunked_op->output()->setType(chunk_sel->type());
auto chunked_inputs_it = chunked_inputs.begin();
for (Value* original_input : original_inputs) {
if (original_input->type()->isSubtypeOf(DynamicType::get())) {
JIT_ASSERT(chunked_inputs_it != chunked_inputs.end());
- chunked_op->addInput(chunked_inputs_it->at(chunk_sel->offset() % nchunks));
+ chunked_op->addInput(
+ chunked_inputs_it->at(chunk_sel->offset() % nchunks));
++chunked_inputs_it;
} else {
chunked_op->addInput(original_input);
bchunk->eraseOutput(nchunks * producer_index);
}
- // The output of producer_for_chunk_node could have been used in some aten::size
- // operators, so we need to clean those up as well (we simply broadcast all its tensor inputs).
+ // The output of producer_for_chunk_node could have been used in some
+ // aten::size operators, so we need to clean those up as well (we simply
+ // broadcast all its tensor inputs).
auto size_calc_uses = producer_for_chunk_node->output()->uses();
if (!size_calc_uses.empty()) {
- auto tensor_inputs = filter(producer_for_chunk_node->inputs(),
- [](Value * v) { return v->type()->isSubtypeOf(DynamicType::get()); });
- auto tensor_sizes = fmap(tensor_inputs,
- [](Value * v) { return v->owningGraph()->insert(aten::size, {v}); });
+ auto tensor_inputs = filter(
+ producer_for_chunk_node->inputs(),
+ [](Value* v) { return v->type()->isSubtypeOf(DynamicType::get()); });
+ auto tensor_sizes = fmap(tensor_inputs, [](Value* v) {
+ return v->owningGraph()->insert(aten::size, {v});
+ });
JIT_ASSERT(!tensor_sizes.empty());
- Value * output_size = tensor_sizes.size() == 1 ? tensor_sizes[0] : broadcastSizes(tensor_sizes);
+ Value* output_size = tensor_sizes.size() == 1
+ ? tensor_sizes[0]
+ : broadcastSizes(tensor_sizes);
for (Use u : size_calc_uses) {
u.user->output()->replaceAllUsesWith(output_size);
u.user->destroy();
std::pair<graph_node_list::iterator, bool> scanNode(
Node* consumer,
const AliasDb& aliasDb) {
- if(isFusableAsExitNode(consumer)) {
- auto consumer_inputs = consumer->kind() == aten::cat ?
- consumer->namedInput(attr::tensors)->node()->inputs() :
- consumer->inputs();
+ if (isFusableAsExitNode(consumer)) {
+ auto consumer_inputs = consumer->kind() == aten::cat
+ ? consumer->namedInput(attr::tensors)->node()->inputs()
+ : consumer->inputs();
// handle inputs in reverse topological order as well...
// otherwise in f(a,a+b) it will appear a is used twice if we consider
// the f-a fusion before the f-(a+b) fusion first.
auto inputs = sortReverseTopological(consumer_inputs);
- for(auto producer : inputs) {
+ for (auto producer : inputs) {
// Don't fuse if producer must come from a FusionGroup exit node
- if (mustRemainAsFusionGroupOutput(producer)) continue;
- if(tryToMoveChunk(consumer,producer)) {
+ if (mustRemainAsFusionGroupOutput(producer))
+ continue;
+ if (tryToMoveChunk(consumer, producer)) {
// the chunk before this consumer was re-arranged to allow fusion,
// we scan this consumer again to perform the fusion
return std::make_pair(consumer->reverseIterator(), true);
void replaceIntermediateBroadcastingChunks() {
for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
- auto * node = *it;
- ++it; // We might delete node, so increment the iterator now.
+ auto* node = *it;
+ ++it; // We might delete node, so increment the iterator now.
if (node->kind() != prim::BroadcastingChunk) {
continue;
}
- auto * bchunk = node;
+ auto* bchunk = node;
insertExplicitBroadcast(bchunk);
- auto * graph = block_->owningGraph();
+ auto* graph = block_->owningGraph();
size_t nchunks = bchunk->i(attr::chunks);
WithInsertPoint guard(bchunk->next());
// Split the bchunk into bchunks.inputs().size() number of chunk nodes.
- for (size_t input_offset = 0; input_offset < bchunk->inputs().size(); input_offset++) {
+ for (size_t input_offset = 0; input_offset < bchunk->inputs().size();
+ input_offset++) {
auto* input = bchunk->inputs().at(input_offset);
- Node * new_chunk = graph->insertNode(graph->create(prim::ConstantChunk, input, 0));
+ Node* new_chunk =
+ graph->insertNode(graph->create(prim::ConstantChunk, input, 0));
new_chunk->copyAttributes(*bchunk);
- for (size_t output_offset = 0; output_offset < nchunks; output_offset++) {
+ for (size_t output_offset = 0; output_offset < nchunks;
+ output_offset++) {
auto new_output = new_chunk->addOutput();
- auto old_output = bchunk->outputs().at(input_offset * nchunks + output_offset);
+ auto old_output =
+ bchunk->outputs().at(input_offset * nchunks + output_offset);
new_output->copyMetadata(old_output);
old_output->replaceAllUsesWith(new_output);
}
}
}
- bool usedOnlyInSize(Value * v) {
- const auto & uses = v->uses();
- return std::all_of(uses.begin(), uses.end(),
- [](const Use& u) { return u.user->matches("aten::size(Tensor self) -> int[]"); });
+ bool usedOnlyInSize(Value* v) {
+ const auto& uses = v->uses();
+ return std::all_of(uses.begin(), uses.end(), [](const Use& u) {
+ return u.user->matches("aten::size(Tensor self) -> int[]");
+ });
}
- // Builds up expressions that compute shapes of all intermediates (and outputs)
- // of the fusion group, based on the sizes of inputs. You should run DCE to remove
- // those that you end up not using.
- std::unordered_map<Value*, Value*> buildShapeExpressions(Node * fusion_group) {
- WithInsertPoint insert_guard { fusion_group->next() };
+ // Builds up expressions that compute shapes of all intermediates (and
+ // outputs) of the fusion group, based on the sizes of inputs. You should run
+ // DCE to remove those that you end up not using.
+ std::unordered_map<Value*, Value*> buildShapeExpressions(Node* fusion_group) {
+ WithInsertPoint insert_guard{fusion_group->next()};
std::unordered_map<Value*, Value*> shape_of;
- Graph * graph = fusion_group->owningGraph();
+ Graph* graph = fusion_group->owningGraph();
auto subgraph = fusion_group->g(attr::Subgraph);
auto inputs = fusion_group->inputs();
// When we have a guarantee that an output won't be removed, because it's
// used in expressions that don't involve size checks, we can use its size
- // instead of computing a long chain of broadcasts, starting from the beginning
- // of the kernel.
+ // instead of computing a long chain of broadcasts, starting from the
+ // beginning of the kernel.
auto outputs = fusion_group->outputs();
auto soutputs = subgraph->outputs();
JIT_ASSERT(outputs.size() == soutputs.size());
for (size_t i = 0; i < outputs.size(); ++i) {
- if (usedOnlyInSize(outputs[i])) continue;
+ if (usedOnlyInSize(outputs[i]))
+ continue;
shape_of[soutputs[i]] = graph->insert(aten::size, {outputs[i]});
}
- for (Node * n : subgraph->nodes()) {
- // XXX: Use of shape_of.emplace is crucial to the output shape optimization!
+ for (Node* n : subgraph->nodes()) {
+ // XXX: Use of shape_of.emplace is crucial to the output shape
+ // optimization!
if (n->kind() == prim::FusedConcat) {
// This is a bit more involved, because we have to account for the case
// when inputs have different shapes, but fortunately those tensors are
continue;
}
if (n->kind() == prim::ConstantChunk) {
- Node * sizes_node = graph->insertNode(graph->create(prim::ChunkSizes, shape_of.at(n->input()), 2));
+ Node* sizes_node = graph->insertNode(
+ graph->create(prim::ChunkSizes, shape_of.at(n->input()), 2));
sizes_node->i_(attr::dim, n->i(attr::dim));
sizes_node->i_(attr::chunks, n->i(attr::chunks));
- Value * regular_size = sizes_node->outputs().at(0);
- Value * last_size = sizes_node->outputs().at(1);
+ Value* regular_size = sizes_node->outputs().at(0);
+ Value* last_size = sizes_node->outputs().at(1);
regular_size->setType(ListType::ofInts());
last_size->setType(ListType::ofInts());
auto outputs = n->outputs();
- for (Value * o : outputs.slice(0, outputs.size() - 1)) {
+ for (Value* o : outputs.slice(0, outputs.size() - 1)) {
shape_of.emplace(o, regular_size);
}
shape_of.emplace(outputs.at(outputs.size() - 1), last_size);
continue;
}
- auto tensor_inputs = filter(n->inputs(),
- [](Value * v) { return v->type()->isSubtypeOf(DynamicType::get()); });
- auto shapes = fmap(tensor_inputs, [&](Value * v) { return shape_of.at(v); });
+ auto tensor_inputs = filter(n->inputs(), [](Value* v) {
+ return v->type()->isSubtypeOf(DynamicType::get());
+ });
+ auto shapes =
+ fmap(tensor_inputs, [&](Value* v) { return shape_of.at(v); });
JIT_ASSERT(!shapes.empty());
- shape_of.emplace(n->output(), shapes.size() == 1 ? shapes[0] : broadcastSizes(shapes));
+ shape_of.emplace(
+ n->output(), shapes.size() == 1 ? shapes[0] : broadcastSizes(shapes));
}
return shape_of;
}
- void removeOutputsUsedOnlyInSize(Node * fusion_group) {
- if (fusion_group->kind() != prim::FusionGroup) return;
+ void removeOutputsUsedOnlyInSize(Node* fusion_group) {
+ if (fusion_group->kind() != prim::FusionGroup)
+ return;
auto subgraph = fusion_group->g(attr::Subgraph);
auto shape_of = buildShapeExpressions(fusion_group);
// where f, g, h, l are simple map ops.
// The first iteration will fuse %4 and %3, and see that %1 is an input, but
// can't be fused, because it has a different use before the fusion group
- // in our topological ordering. Then, %2 will be considered, and fused with %1.
- // If we do another iteration, the algorithm will consider the fusion of these
- // two groups and fix the situation.
+ // in our topological ordering. Then, %2 will be considered, and fused with
+ // %1. If we do another iteration, the algorithm will consider the fusion of
+ // these two groups and fix the situation.
bool any_changed = true;
while (any_changed) {
any_changed = false;
}
// Remove outputs that have been added only because we need their size
- for (Node * n : block_->nodes()) {
+ for (Node* n : block_->nodes()) {
removeOutputsUsedOnlyInSize(n);
}
- for (Node * node : block_->nodes()) {
- for (Block * sub_block : node->blocks()) {
+ for (Node* node : block_->nodes()) {
+ for (Block* sub_block : node->blocks()) {
GraphFuser(sub_block, graph_).run();
}
}
}
};
-void PeepholeOptimizeShapeExpressions(Block * block) {
+void PeepholeOptimizeShapeExpressions(Block* block) {
auto nodes = block->nodes();
for (auto it = nodes.begin(); it != nodes.end(); ++it) {
- Node * node = *it;
- for (Block * subblock : node->blocks()) {
+ Node* node = *it;
+ for (Block* subblock : node->blocks()) {
PeepholeOptimizeShapeExpressions(subblock);
}
if (node->kind() == prim::BroadcastSizes) {
// Deduplicate inputs, but use their unique() values to ensure
// this process only depends on the graph.
std::map<size_t, Value*> unique_to_value;
- for (Value * input : node->inputs()) {
+ for (Value* input : node->inputs()) {
unique_to_value.emplace(input->unique(), input);
}
if (unique_to_value.size() != node->inputs().size()) {
std::vector<Value*> inputs;
inputs.reserve(unique_to_value.size());
- for (auto & entry : unique_to_value) {
+ for (auto& entry : unique_to_value) {
inputs.push_back(entry.second);
}
if (inputs.size() == 1) {
node->output()->replaceAllUsesWith(inputs[0]);
} else {
- WithInsertPoint insert_guard { node };
+ WithInsertPoint insert_guard{node};
node->output()->replaceAllUsesWith(broadcastSizes(inputs));
}
it.destroyCurrent();
continue;
}
// Remove compose simple chains of broadcasts into a single node.
- const auto & uses = node->output()->uses();
+ const auto& uses = node->output()->uses();
if (uses.size() == 1 && uses[0].user->kind() == prim::BroadcastSizes) {
- Node * user = uses[0].user;
+ Node* user = uses[0].user;
user->removeInput(uses[0].offset);
- // NB: we don't care about deduplication in here, as we will visit user later.
- for (Value * i : node->inputs()) {
+ // NB: we don't care about deduplication in here, as we will visit user
+ // later.
+ for (Value* i : node->inputs()) {
user->addInput(i);
}
it.destroyCurrent();
} // anonymous namespace
void FuseGraph(std::shared_ptr<Graph>& graph) {
- // NYI on Windows
- #ifndef _WIN32
+// NYI on Windows
+#ifndef _WIN32
GraphFuser(graph->block(), graph).run();
// After FuseGraph some common subexpressions may come back
// Improve the quality of shape propagation code that was left
PeepholeOptimizeShapeExpressions(graph->block());
- #endif
+#endif
}
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/ir.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
// NB: Be sure to run DCE before fusion, because dead instructions
// can prevent fusion opportunities from being exploited.
// On Windows will noop, NYI
TORCH_API void FuseGraph(std::shared_ptr<Graph>& graph);
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/ir.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
-TORCH_API void InlineAutodiffSubgraphs(std::shared_ptr<Graph>& graph, size_t threshold = 5);
+TORCH_API void InlineAutodiffSubgraphs(
+ std::shared_ptr<Graph>& graph,
+ size_t threshold = 5);
-}} // namespace torch::jit
+}
+} // namespace torch
#include <torch/csrc/jit/passes/inplace_check.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
-void CheckInplace(Block * block) {
+void CheckInplace(Block* block) {
for (auto node : block->nodes()) {
if (node->kind() == prim::PythonOp && node->hasAttribute(attr::inplace)) {
if (node->i(attr::inplace)) {
- throw std::runtime_error(std::string("inplace ") +
- static_cast<PythonOp*>(node)->name() +
- " not supported in the JIT");
+ throw std::runtime_error(
+ std::string("inplace ") + static_cast<PythonOp*>(node)->name() +
+ " not supported in the JIT");
}
}
}
CheckInplace(graph->block());
}
-}} // namespace torch::jit
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/ir.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
TORCH_API void CheckInplace(std::shared_ptr<Graph>& graph);
-}}
+}
+} // namespace torch
#include <torch/csrc/jit/passes/loop_unrolling.h>
-#include <torch/csrc/jit/interned_strings.h>
#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/interned_strings.h>
#include <torch/csrc/jit/symbolic_variable.h>
-#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/constants.h>
+#include <torch/csrc/jit/passes/dead_code_elimination.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
namespace {
static constexpr int64_t kMaxBodySize = 32;
static constexpr int64_t kMaxBodyRepeats = 64;
-bool isTrueConstant(Value *val) {
+bool isTrueConstant(Value* val) {
c10::optional<bool> maybe_value = constant_as<bool>(val);
return maybe_value && *maybe_value;
}
bool isForLoop(Node* node) {
if (node->kind() != prim::Loop)
return false;
- Value *start_cond = node->inputs().at(1);
- Value *continue_cond = node->blocks().at(0)->outputs().at(0);
+ Value* start_cond = node->inputs().at(1);
+ Value* continue_cond = node->blocks().at(0)->outputs().at(0);
return isTrueConstant(start_cond) && isTrueConstant(continue_cond);
}
-// Counts the size of this block, stopping and returning once reaches limit instructions.
-int64_t limitedBlockSize(Block *body, int64_t limit) {
+// Counts the size of this block, stopping and returning once reaches limit
+// instructions.
+int64_t limitedBlockSize(Block* body, int64_t limit) {
auto it = body->nodes().begin();
auto end = body->nodes().end();
for (int64_t i = 0; i < limit; ++i, ++it) {
- for (Block *subblock : it->blocks()) {
+ for (Block* subblock : it->blocks()) {
i += limitedBlockSize(subblock, limit - i);
}
if (it == end) {
return limit;
}
-bool isSmallBlock(Block *body) {
+bool isSmallBlock(Block* body) {
return limitedBlockSize(body, kMaxBodySize + 1) <= kMaxBodySize;
}
-// XXX: This function can only be called with a loop that is guaranteed to execute EXACTLY ONCE.
-void inlineBody(Node *loop) {
+// XXX: This function can only be called with a loop that is guaranteed to
+// execute EXACTLY ONCE.
+void inlineBody(Node* loop) {
auto graph = loop->owningGraph();
auto body = loop->blocks().at(0);
- WithInsertPoint insert_point_guard { loop };
+ WithInsertPoint insert_point_guard{loop};
std::unordered_map<Value*, Value*> value_map;
- auto get_value = [&](Value *v) {
+ auto get_value = [&](Value* v) {
auto it = value_map.find(v);
if (it != value_map.end())
return it->second;
value_map[body->inputs()[i - 1]] = loop->inputs()[i];
}
- for (Node *orig : body->nodes()) {
- Node *clone = graph->insertNode(graph->createClone(orig, get_value));
+ for (Node* orig : body->nodes()) {
+ Node* clone = graph->insertNode(graph->createClone(orig, get_value));
for (size_t i = 0; i < orig->outputs().size(); ++i) {
value_map[orig->outputs()[i]] = clone->outputs()[i];
}
}
for (size_t i = 0; i < loop->outputs().size(); ++i) {
- loop->outputs().at(i)->replaceAllUsesWith(get_value(body->outputs().at(i + 1)));
+ loop->outputs().at(i)->replaceAllUsesWith(
+ get_value(body->outputs().at(i + 1)));
}
- // XXX: it is extremely important to destroy the loop in here. DCE might not be able
- // to conclude that it's safe, because the loop might contain side effects.
+ // XXX: it is extremely important to destroy the loop in here. DCE might not
+ // be able to conclude that it's safe, because the loop might contain side
+ // effects.
loop->destroy();
}
-void repeatBody(Block *body, int64_t times) {
+void repeatBody(Block* body, int64_t times) {
// We will be adding nodes to the body, so cache the initial start and end.
// XXX: they are both inclusive, because the exclusive body_end would point to
- // return_node, which would move further away if we were to add nodes, and we
- // would enter an infinite loop.
+ // return_node, which would move further away if we were to add nodes,
+ // and we would enter an infinite loop.
auto body_start = body->nodes().begin();
auto body_end = std::prev(body->nodes().end());
auto graph = body->owningGraph();
- WithInsertPoint insert_point_guard { body };
+ WithInsertPoint insert_point_guard{body};
std::unordered_map<Value*, Value*> value_map;
- auto get_value = [&](Value *v) {
+ auto get_value = [&](Value* v) {
auto it = value_map.find(v);
if (it != value_map.end())
return it->second;
for (int64_t i = 1; i < times; ++i) {
// Update loop-carried values
- // NB: note that we don't need to worry about the loop counter, because we've
- // replaced it with a loop-carried variable
+ // NB: note that we don't need to worry about the loop counter, because
+ // we've replaced it with a loop-carried variable
JIT_ASSERT(body->inputs().size() == body->outputs().size());
for (size_t i = 1; i < body->inputs().size(); ++i) {
value_map[body->inputs()[i]] = get_value(body->outputs()[i]);
// Clone the nodes
for (auto it = body_start; it != std::next(body_end); ++it) {
- Node *orig = *it;
- Node *clone = graph->insertNode(graph->createClone(orig, get_value));
+ Node* orig = *it;
+ Node* clone = graph->insertNode(graph->createClone(orig, get_value));
for (size_t i = 0; i < orig->outputs().size(); ++i) {
value_map[orig->outputs()[i]] = clone->outputs()[i];
}
for (int64_t i = new_outputs.size() - 1; i >= 0; --i) {
body->eraseOutput(i);
}
- for (Value *output : new_outputs) {
+ for (Value* output : new_outputs) {
body->registerOutput(output);
}
- // It's likely that we have some dead nodes now - for example the "true" constant
- // that prevents the loop from breaking. We shouldn't wait too long before removing
- // them because they might artificially increase the loop size and prevent outer loop
- // unrolling.
+ // It's likely that we have some dead nodes now - for example the "true"
+ // constant that prevents the loop from breaking. We shouldn't wait too long
+ // before removing them because they might artificially increase the loop size
+ // and prevent outer loop unrolling.
EliminateDeadCode(body, false);
}
-// Replaces the builtin loop counter with a "mutable" variable outside of the loop.
-void replaceLoopCounter(Node *loop) {
- Graph *graph = loop->owningGraph();
- Block *body = loop->blocks().at(0);
+// Replaces the builtin loop counter with a "mutable" variable outside of the
+// loop.
+void replaceLoopCounter(Node* loop) {
+ Graph* graph = loop->owningGraph();
+ Block* body = loop->blocks().at(0);
WithInsertPoint guard(loop);
Value* init_counter = graph->insertConstant(0);
loop->insertInput(2, init_counter);
loop->insertOutput(0)->setType(IntType::get());
- Value * internal_counter = body->insertInput(1)->setType(init_counter->type());
+ Value* internal_counter = body->insertInput(1)->setType(init_counter->type());
body->inputs()[0]->replaceAllUsesWith(internal_counter);
- WithInsertPoint insertPointGuard{ body->return_node() };
+ WithInsertPoint insertPointGuard{body->return_node()};
Value* result = graph->insert(aten::add, {internal_counter, 1});
body->insertOutput(1, result);
}
-void unroll(Node *loop) {
- Graph *graph = loop->owningGraph();
- Block *body = loop->blocks().at(0);
+void unroll(Node* loop) {
+ Graph* graph = loop->owningGraph();
+ Block* body = loop->blocks().at(0);
if (!isSmallBlock(body))
return;
- // We will be using a "mutable" counter outside of the loop instead of the default
- // one, because this will allow us to share it between the unrolled loop and its epilogue.
- // This is necessary only if the loop counter is actually used in the body.
+ // We will be using a "mutable" counter outside of the loop instead of the
+ // default one, because this will allow us to share it between the unrolled
+ // loop and its epilogue. This is necessary only if the loop counter is
+ // actually used in the body.
if (body->inputs()[0]->uses().size() > 0)
replaceLoopCounter(loop);
- // Some optimization for constant-length loops. If we know they won't run too many
- // times, then we can unroll them entirely.
- Value *trip_count = loop->inputs().at(0);
+ // Some optimization for constant-length loops. If we know they won't run too
+ // many times, then we can unroll them entirely.
+ Value* trip_count = loop->inputs().at(0);
int64_t const_len = constant_as<int64_t>(trip_count).value_or(-1);
if (const_len != -1 && const_len < kMaxBodyRepeats) {
repeatBody(body, const_len);
return;
}
- WithInsertPoint insert_point_guard { loop };
+ WithInsertPoint insert_point_guard{loop};
// Clone the loop before we unroll it. The clone will become the epilogue.
- Node *loop_epilogue = graph->createClone(loop, [](Value *v) { return v; })
- ->insertAfter(loop);
+ Node* loop_epilogue =
+ graph->createClone(loop, [](Value* v) { return v; })->insertAfter(loop);
for (size_t i = 0; i < loop->outputs().size(); ++i) {
loop->outputs()[i]->replaceAllUsesWith(loop_epilogue->outputs()[i]);
loop_epilogue->replaceInput(i + 2, loop->outputs()[i]);
// Change the iteration counts of both loops
Value* iter_count = loop->inputs().at(0);
- Value* unrolled_iter_count = graph->insert(aten::__round_to_zero_floordiv, {iter_count, kUnrollFactor});
+ Value* unrolled_iter_count = graph->insert(
+ aten::__round_to_zero_floordiv, {iter_count, kUnrollFactor});
loop->replaceInput(0, unrolled_iter_count);
- loop_epilogue->replaceInput(0, graph->insert(aten::sub, {iter_count, graph->insert(aten::mul,{unrolled_iter_count , kUnrollFactor})}));
+ loop_epilogue->replaceInput(
+ 0,
+ graph->insert(
+ aten::sub,
+ {iter_count,
+ graph->insert(aten::mul, {unrolled_iter_count, kUnrollFactor})}));
}
-void UnrollLoops(Block *block) {
+void UnrollLoops(Block* block) {
for (auto it = block->nodes().begin(); it != block->nodes().end();) {
- // XXX: unroll might destroy the current node, so we need to pre-increment the iterator
- Node *node = *it; ++it;
- for (Block *subblock : node->blocks()) {
+ // XXX: unroll might destroy the current node, so we need to pre-increment
+ // the iterator
+ Node* node = *it;
+ ++it;
+ for (Block* subblock : node->blocks()) {
UnrollLoops(subblock);
}
if (isForLoop(node)) {
EliminateDeadCode(graph);
}
-}} // namespace torch::jit
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/ir.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
TORCH_API void UnrollLoops(std::shared_ptr<Graph>& graph);
-}} // namespace torch::jit
+}
+} // namespace torch
#include <torch/csrc/jit/passes/lower_grad_of.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
void LowerGradOf(Graph& g) {
- for(auto it = g.nodes().begin(); it != g.nodes().end(); ++it) {
- if(it->kind() == prim::GradOf) {
+ for (auto it = g.nodes().begin(); it != g.nodes().end(); ++it) {
+ if (it->kind() == prim::GradOf) {
// if any_defined(inputs):
// outputs = <original_computation>
// else:
auto cond = g.insertNode(g.create(prim::AnyDefined, it->inputs()))
->output()
->setType(IntType::get());
- auto if_stat = g.insertNode(
- g.create(prim::If, {cond}, it->outputs().size()));
+ auto if_stat =
+ g.insertNode(g.create(prim::If, {cond}, it->outputs().size()));
if_stat->addBlock()->cloneFrom(
it->blocks().at(0), [](Value* v) { return v; });
auto else_block = if_stat->addBlock();
}
}
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/ir.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
// This pass removes 'grad_of' nodes, replacing them with conditionals of
// the form:
// outputs = undefineds
TORCH_API void LowerGradOf(Graph& g);
-}}
+} // namespace jit
+} // namespace torch
-#include <torch/csrc/jit/passes/lower_tuples.h>
+#include <torch/csrc/jit/assertions.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
+#include <torch/csrc/jit/passes/lower_tuples.h>
#include <torch/csrc/utils/functional.h>
-#include <torch/csrc/jit/assertions.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
namespace {
// this is to assert we are only doing modifications when we know
// we can flatten tuples
std::unordered_set<Symbol> white_list = {
- prim::If,
- prim::Loop,
- prim::TupleUnpack,
- prim::TupleConstruct,
- prim::TupleIndex,
- prim::TupleSlice,
- prim::Param,
- prim::Return,
+ prim::If,
+ prim::Loop,
+ prim::TupleUnpack,
+ prim::TupleConstruct,
+ prim::TupleIndex,
+ prim::TupleSlice,
+ prim::Param,
+ prim::Return,
};
-void removeTupleNodes(Node *n, bool must_remove_tuples) {
- if (n->kind() != prim::TupleUnpack && n->kind() != prim::TupleIndex
- && n->kind() != prim::TupleSlice) {
+void removeTupleNodes(Node* n, bool must_remove_tuples) {
+ if (n->kind() != prim::TupleUnpack && n->kind() != prim::TupleIndex &&
+ n->kind() != prim::TupleSlice) {
return;
}
auto construct = n->input()->node();
return;
}
if (n->kind() == prim::TupleUnpack) {
- for(size_t i = 0; i < n->outputs().size(); ++i) {
+ for (size_t i = 0; i < n->outputs().size(); ++i) {
n->outputs()[i]->replaceAllUsesWith(construct->inputs().at(i));
}
} else if (n->kind() == prim::TupleIndex) {
}
}
-} //anonymous namespace
+} // anonymous namespace
static void LowerAllTuples(Block* block);
static void VisitNode(Node* n, Node* insert_point) {
- auto & graph = *n->owningGraph();
+ auto& graph = *n->owningGraph();
// tuple construction operators will become dead when the unpacks are replaced
- if(n->kind() == prim::TupleConstruct) {
+ if (n->kind() == prim::TupleConstruct) {
return;
}
- // note: changing the second argument to false changes this pass from a complete lowering
- // pass to one that removes tuples when possible. When tuples are first-class
- // in the interpreter, we should still run this pass to remove extraneous uses
+ // note: changing the second argument to false changes this pass from a
+ // complete lowering pass to one that removes tuples when possible. When
+ // tuples are first-class in the interpreter, we should still run this pass to
+ // remove extraneous uses
- if(n->kind() == prim::TupleUnpack || n->kind() == prim::TupleIndex ||
+ if (n->kind() == prim::TupleUnpack || n->kind() == prim::TupleIndex ||
n->kind() == prim::TupleSlice) {
- removeTupleNodes(n, /*must_remove_tuples*/true);
- return;
+ removeTupleNodes(n, /*must_remove_tuples*/ true);
+ return;
}
// flatten the input list op(a, tup, b) --> op(a, t0, t1, b)
- for(size_t i = 0; i < n->inputs().size();) {
+ for (size_t i = 0; i < n->inputs().size();) {
auto input = n->inputs()[i];
- if(TupleTypePtr tt = input->type()->cast<TupleType>()) {
- JIT_ASSERTM(white_list.count(n->kind()) > 0, "tuple appears in op that does not forward tuples");
- JIT_ASSERTM(input->node()->kind() == prim::TupleConstruct, "tuple use not matched to tuple construct");
- for(size_t j = 0; j < tt->elements().size(); ++j) {
+ if (TupleTypePtr tt = input->type()->cast<TupleType>()) {
+ JIT_ASSERTM(
+ white_list.count(n->kind()) > 0,
+ "tuple appears in op that does not forward tuples");
+ JIT_ASSERTM(
+ input->node()->kind() == prim::TupleConstruct,
+ "tuple use not matched to tuple construct");
+ for (size_t j = 0; j < tt->elements().size(); ++j) {
n->insertInput(i + 1 + j, input->node()->inputs().at(j));
}
n->removeInput(i);
++i;
}
}
- for(auto b : n->blocks()) {
+ for (auto b : n->blocks()) {
LowerAllTuples(b);
}
// flatten the outputs list
- for(size_t i = 0; i < n->outputs().size();) {
- Value * output = n->outputs()[i];
+ for (size_t i = 0; i < n->outputs().size();) {
+ Value* output = n->outputs()[i];
// (a, b, tup, c) -> (a, b, t0, t1, c)
// and:
// tup = (t0, t1)
// is placed at the current insertion point
- if(TupleTypePtr tt = output->type()->cast<TupleType>()) {
- JIT_ASSERTM(white_list.count(n->kind()) > 0, "tuple appears in op that does not forward tuples");
- for(size_t j = 0; j < tt->elements().size(); j++) {
+ if (TupleTypePtr tt = output->type()->cast<TupleType>()) {
+ JIT_ASSERTM(
+ white_list.count(n->kind()) > 0,
+ "tuple appears in op that does not forward tuples");
+ for (size_t j = 0; j < tt->elements().size(); j++) {
n->insertOutput(i + 1 + j)->setType(tt->elements()[j]);
}
- auto new_tup = graph.createTuple(n->outputs().slice(i + 1, tt->elements().size()));
+ auto new_tup =
+ graph.createTuple(n->outputs().slice(i + 1, tt->elements().size()));
new_tup->insertBefore(insert_point);
insert_point = new_tup;
output->replaceAllUsesWith(new_tup->output());
// _outputs_ of normal instructions, since the param_node represents the
// parameters as outputs, we can handle it by simply visiting the node
VisitNode(block->param_node(), *block->nodes().begin());
- for(auto it = block->nodes().begin(), end = block->nodes().end(); it != end;) {
+ for (auto it = block->nodes().begin(), end = block->nodes().end();
+ it != end;) {
auto n = *it++;
VisitNode(n, *it);
}
VisitNode(block->return_node(), nullptr);
}
-
static void EnsureNoTuples(ArrayRef<Value*> values) {
- for (Value * v : values) {
- JIT_ASSERTM(v->type()->kind() != TypeKind::TupleType,
- "Couldn't lower all tuples.");
+ for (Value* v : values) {
+ JIT_ASSERTM(
+ v->type()->kind() != TypeKind::TupleType, "Couldn't lower all tuples.");
}
}
}
void LowerSimpleTuples(Block* block) {
- for(auto n : block->nodes()) {
- removeTupleNodes(n, /*must_remove_tuples*/false);
- for(auto b : n->blocks()) {
+ for (auto n : block->nodes()) {
+ removeTupleNodes(n, /*must_remove_tuples*/ false);
+ for (auto b : n->blocks()) {
LowerSimpleTuples(b);
}
}
EliminateDeadCode(graph);
}
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/ir.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
// removes tuples where TupleConstruct and TupleUnpack are matched
// but leaves tuples in place across if statements, loops, and as inputs/outputs
TORCH_API void LowerSimpleTuples(Block* block);
-}}
+} // namespace jit
+} // namespace torch
-#include <torch/csrc/utils/pybind.h>
-#include <torch/csrc/jit/passes/onnx.h>
-#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/symbolic.h>
#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/passes/dead_code_elimination.h>
+#include <torch/csrc/jit/passes/onnx.h>
#include <torch/csrc/utils/functional.h>
-#include <unordered_map>
+#include <torch/csrc/utils/pybind.h>
#include <sstream>
+#include <unordered_map>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
// Transform PythonOps into Nodes that match ONNX semantics.
-std::shared_ptr<Graph> ToONNX(std::shared_ptr<Graph>& graph, ::torch::onnx::OperatorExportTypes operator_export_type) {
+std::shared_ptr<Graph> ToONNX(
+ std::shared_ptr<Graph>& graph,
+ ::torch::onnx::OperatorExportTypes operator_export_type) {
auto new_graph = std::make_shared<Graph>(graph->current_scope());
std::unordered_map<Value*, Value*> env;
BlockToONNX(graph->block(), new_graph->block(), operator_export_type, env);
return new_graph;
}
-void BlockToONNX(Block* old_block, Block* new_block, ::torch::onnx::OperatorExportTypes operator_export_type, std::unordered_map<Value*, Value*> env) {
+void BlockToONNX(
+ Block* old_block,
+ Block* new_block,
+ ::torch::onnx::OperatorExportTypes operator_export_type,
+ std::unordered_map<Value*, Value*> env) {
torch::autograd::SymbolicContext ctx{};
ctx.block = new_block;
py::object onnx_symbolic = py::module::import("torch.onnx.symbolic");
// Returns a node that n maps to in the new graph
- auto envFn = [&env](Value * n) -> Value* {
+ auto envFn = [&env](Value* n) -> Value* {
auto it = env.find(n);
JIT_ASSERTM(it != env.end(), "Dangling node reference");
JIT_ASSERTM(it->second, "Unused node was subsequently used");
// Put the new outputs in our environment map, and copy the type from the
// input graph if they were not set by the symbolic. This is called only
// with results of symbolic call (not for nodes that are just cloned).
- auto setOutputs = [&](const std::string& op_name, Node * node, const value_list & outputs) {
+ auto setOutputs = [&](const std::string& op_name,
+ Node* node,
+ const value_list& outputs) {
auto old_outputs = node->outputs();
// Count all outputs, excluding Handles
auto num_old_outputs = old_outputs.size();
if (outputs.size() != num_old_outputs) {
std::ostringstream ss;
- ss << "symbolic for " << op_name << " produced an incorrect number of outputs (expected ";
+ ss << "symbolic for " << op_name
+ << " produced an incorrect number of outputs (expected ";
ss << num_old_outputs << ", but got " << outputs.size() << ")";
throw std::runtime_error(ss.str());
}
env[old] = nullptr;
if (!old->uses().empty()) {
std::ostringstream ss;
- ss << "symbolic for " << op_name << " returned None for the output " << i;
+ ss << "symbolic for " << op_name << " returned None for the output "
+ << i;
ss << " (indicating conversion for that particular output is not supported), ";
ss << "but the network uses this output later";
// TODO: Say what actually used it
};
// Clone the node and add it to the new graph
- auto cloneNode = [&](Node * node) {
- auto n_ = ctx.block->appendNode(ctx.block->owningGraph()->createClone(node, envFn));
- for(size_t i = 0; i < node->outputs().size(); i++) {
+ auto cloneNode = [&](Node* node) {
+ auto n_ = ctx.block->appendNode(
+ ctx.block->owningGraph()->createClone(node, envFn));
+ for (size_t i = 0; i < node->outputs().size(); i++) {
// n_->outputs()[i]->setType(node->outputs()[i]->type());
env[node->outputs()[i]] = n_->outputs()[i];
}
};
// Cast output of symbolic() python implementation
- auto processSymbolicOutput = [&](const std::string& op_name, Node* n, const py::object& raw_output) {
+ auto processSymbolicOutput = [&](const std::string& op_name,
+ Node* n,
+ const py::object& raw_output) {
if (raw_output.ptr() == Py_None) {
cloneNode(n);
return;
py::tuple py_inputs(n->inputs().size());
Py_ssize_t input_nr = 0;
for (auto* input : n->inputs()) {
- py_inputs[input_nr++] = py::cast(envFn(input));
+ py_inputs[input_nr++] = py::cast(envFn(input));
}
WithInsertPoint insert_point_guard(ctx.block);
WithCurrentScope scope_guard(*ctx.block->owningGraph(), n->scope());
- py::object raw_output = onnx.attr("_run_symbolic_function")(ctx.block->owningGraph(), n, py_inputs, env, operator_export_type);
+ py::object raw_output = onnx.attr("_run_symbolic_function")(
+ ctx.block->owningGraph(), n, py_inputs, env, operator_export_type);
// TODO: Assert it's an ATen identifier???
// (Sometimes it's not...)
};
auto callPySymbolicMethod = [&](PythonOp* op) {
-
// Test if there is a symbolic function; bail if there is not
auto pyobj = py::handle(op->pyobj.get());
auto func = op->autogradFunction();
- if(func) {
+ if (func) {
pyobj = func->get();
}
- if(!py::hasattr(pyobj, "symbolic")) {
+ if (!py::hasattr(pyobj, "symbolic")) {
cloneNode(op);
return;
}
for (auto arg_type : op->cconv) {
py::object obj;
if (arg_type == 'c') {
- JIT_ASSERTM(scalar_it != op->scalar_args.end(), "expected too many scalar args");
- obj = py::reinterpret_borrow<py::object>(py::handle((scalar_it++)->get()));
+ JIT_ASSERTM(
+ scalar_it != op->scalar_args.end(),
+ "expected too many scalar args");
+ obj = py::reinterpret_borrow<py::object>(
+ py::handle((scalar_it++)->get()));
} else if (arg_type == 'd') {
JIT_ASSERTM(node_it != inputs.end(), "expected too many inputs");
obj = py::cast(envFn(*node_it++));
// Call the symbolic function
// Use a little trampoline function so we can give good error messages
// upon argument mismatch
- py::object raw_output = onnx.attr("_run_symbolic_method")(op->name(), pyobj.attr("symbolic"), py_symbolic_args);
+ py::object raw_output = onnx.attr("_run_symbolic_method")(
+ op->name(), pyobj.attr("symbolic"), py_symbolic_args);
processSymbolicOutput(op->name(), op, raw_output);
};
// Finally, visit all nodes in the graph
for (auto node : old_block->nodes()) {
IR_IFM(node, PythonOp)
- callPySymbolicMethod(value);
+ callPySymbolicMethod(value);
IR_ELSE()
- callPySymbolicFunction(node);
+ callPySymbolicFunction(node);
IR_END()
}
for (auto output : old_block->outputs()) {
EliminateDeadCode(ctx.block);
}
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/onnx/onnx.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
-TORCH_API std::shared_ptr<Graph> ToONNX(std::shared_ptr<Graph>& state, ::torch::onnx::OperatorExportTypes operator_export_type);
-TORCH_API void BlockToONNX(Block* old_block, Block* new_block, ::torch::onnx::OperatorExportTypes operator_export_type, std::unordered_map<Value*, Value*> env);
+TORCH_API std::shared_ptr<Graph> ToONNX(
+ std::shared_ptr<Graph>& state,
+ ::torch::onnx::OperatorExportTypes operator_export_type);
+TORCH_API void BlockToONNX(
+ Block* old_block,
+ Block* new_block,
+ ::torch::onnx::OperatorExportTypes operator_export_type,
+ std::unordered_map<Value*, Value*> env);
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/passes/onnx/fixup_onnx_loop.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
-void FixupONNXLoops(Block *block) {
- for (auto *node : block->nodes()) {
+void FixupONNXLoops(Block* block) {
+ for (auto* node : block->nodes()) {
if (node->kind() == torch::jit::onnx::Loop) {
JIT_ASSERT(node->blocks().size() == 1);
- auto *sub_block = node->blocks()[0];
+ auto* sub_block = node->blocks()[0];
sub_block->insertInput(1, "cond");
}
- for (Block * block : node->blocks()) {
+ for (Block* block : node->blocks()) {
FixupONNXLoops(block);
}
}
FixupONNXLoops(graph->block());
}
-}} // namespace torch::jit
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/ir.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
void FixupONNXLoops(std::shared_ptr<Graph>& graph);
-}}
+}
+} // namespace torch
-#include <torch/csrc/jit/passes/onnx/peephole.h>
#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/passes/onnx/peephole.h>
#include <c10/util/Optional.h>
typedef SSIZE_T ssize_t;
#endif
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
-bool isRNN(const Node *node) {
+bool isRNN(const Node* node) {
auto k = node->kind();
return k == onnx::RNN || k == onnx::LSTM || k == onnx::GRU;
}
-bool isNopTranspose(const std::vector<int64_t> & perm) {
+bool isNopTranspose(const std::vector<int64_t>& perm) {
for (int64_t i = 0, perm_size = perm.size(); i < perm_size; i++)
if (perm[i] != i)
return false;
// iteration would have folded all the transposes up to that point. Thus,
// `ret[i] = t1[t2[i]]` says "the output of t2 at position i takes the value of
// the input tensor index contained in t1 at position `t2[i]``".
-std::vector<int64_t> composeTransposes(const std::vector<int64_t> & t1,
- const std::vector<int64_t> & t2) {
+std::vector<int64_t> composeTransposes(
+ const std::vector<int64_t>& t1,
+ const std::vector<int64_t>& t2) {
JIT_ASSERT(t1.size() == t2.size());
std::vector<int64_t> ret;
ret.reserve(t1.size());
return to.size() - from.size();
}
-void fuseBroadcast(Block *b) {
- for(auto n : b->nodes()) {
- for (auto *child_block : n->blocks()) {
+void fuseBroadcast(Block* b) {
+ for (auto n : b->nodes()) {
+ for (auto* child_block : n->blocks()) {
fuseBroadcast(child_block);
}
}
}
-void fuseConsecutiveTransposes(Block *b) {
- for(auto n : b->nodes()) {
- for (auto *child_block : n->blocks()) {
+void fuseConsecutiveTransposes(Block* b) {
+ for (auto n : b->nodes()) {
+ for (auto* child_block : n->blocks()) {
fuseConsecutiveTransposes(child_block);
}
- if (n->kind() == onnx::Transpose && n->input()->node()->kind() == onnx::Transpose) {
+ if (n->kind() == onnx::Transpose &&
+ n->input()->node()->kind() == onnx::Transpose) {
auto origInput = n->input();
- n->is_(attr::perm, composeTransposes(origInput->node()->is(attr::perm), n->is(attr::perm)));
+ n->is_(
+ attr::perm,
+ composeTransposes(
+ origInput->node()->is(attr::perm), n->is(attr::perm)));
n->replaceInput(0, origInput->node()->input());
if (origInput->uses().size() == 0) {
origInput->node()->destroy();
}
}
-void eliminateNopTranspose(Block *b) {
- for(auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
+void eliminateNopTranspose(Block* b) {
+ for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
auto n = *it;
- for (auto *child_block : n->blocks()) {
+ for (auto* child_block : n->blocks()) {
eliminateNopTranspose(child_block);
}
if (n->kind() == onnx::Transpose) {
}
}
-void fuseTransposeIntoGemm(Block *b) {
- static const std::vector<int64_t> simpleTransPerm({1,0});
+void fuseTransposeIntoGemm(Block* b) {
+ static const std::vector<int64_t> simpleTransPerm({1, 0});
- for(auto n : b->nodes()) {
- for (auto *child_block : n->blocks()) {
+ for (auto n : b->nodes()) {
+ for (auto* child_block : n->blocks()) {
fuseTransposeIntoGemm(child_block);
}
if (n->kind() == onnx::Gemm) {
- for (size_t i : {0,1}) {
+ for (size_t i : {0, 1}) {
auto inp = n->inputs()[i];
auto trans = i == 0 ? attr::transA : attr::transB;
- if (inp->node()->kind() == onnx::Transpose && inp->node()->is(attr::perm) == simpleTransPerm) {
+ if (inp->node()->kind() == onnx::Transpose &&
+ inp->node()->is(attr::perm) == simpleTransPerm) {
n->replaceInput(i, inp->node()->input());
n->i_(trans, n->hasAttribute(trans) ? !n->i(trans) : 1);
if (inp->uses().size() == 0) {
// entirely by pairing them with their inverse PadPacked. If the
// input graph does not pair the operations, export will fail.
-void pushPackingPastRnn(Block *b) {
+void pushPackingPastRnn(Block* b) {
for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
auto* n = *it;
- for (auto *child_block : n->blocks()) {
+ for (auto* child_block : n->blocks()) {
pushPackingPastRnn(child_block);
}
// For now, only handle the case where there is one consumer.
continue;
}
- Node * rnn = n->outputs()[0]->uses()[0].user;
+ Node* rnn = n->outputs()[0]->uses()[0].user;
if (!isRNN(rnn)) {
continue;
}
- if(rnn->owningBlock() != n->owningBlock())
+ if (rnn->owningBlock() != n->owningBlock())
continue;
- // Packing only has an effect on a network when its outputs are actually used,
- // so we can remove it here.
- if (rnn->outputs().at(0)->uses().empty() && n->outputs().at(1)->uses().size() == 1) {
+ // Packing only has an effect on a network when its outputs are actually
+ // used, so we can remove it here.
+ if (rnn->outputs().at(0)->uses().empty() &&
+ n->outputs().at(1)->uses().size() == 1) {
n->outputs().at(0)->replaceAllUsesWith(n->inputs().at(0));
n->outputs().at(1)->replaceFirstUseWith(n->inputs().at(1));
it.destroyCurrent();
// The rnn is followed by a transpose and a reshape (if
// bidirectional), or by a squeeze (if unidirectional).
- Node * next = rnn->outputs().at(0)->uses().at(0).user;
+ Node* next = rnn->outputs().at(0)->uses().at(0).user;
if (next->kind() == onnx::Transpose) {
next = next->outputs().at(0)->uses().at(0).user;
if (next->kind() != onnx::Reshape) {
n->outputs().at(1)->replaceFirstUseWith(n->inputs().at(1));
// and insert new PackPadded after the RNN
- Node * newPackPadded = b->owningGraph()->create(prim::PackPadded, 2);
+ Node* newPackPadded = b->owningGraph()->create(prim::PackPadded, 2);
newPackPadded->insertAfter(next);
// make things consume from the new PackPadded
// unhygenic way, Pytorch ends up propagating an incorrect type.
// Until a long-term cleanup comes around, we can fix this by
// resetting the size to the correct value.
- CompleteTensorTypePtr oldType = rnn->inputs().at(0)->type()->cast<CompleteTensorType>();
+ CompleteTensorTypePtr oldType =
+ rnn->inputs().at(0)->type()->cast<CompleteTensorType>();
if (oldType) {
std::vector<int64_t> new_sizes;
new_sizes.push_back(oldType->sizes().at(0));
void removeNopPacking(Block* graph) {
for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) {
auto* n = *it;
- for (auto *child_block : n->blocks()) {
+ for (auto* child_block : n->blocks()) {
removeNopPacking(child_block);
}
// of its input.
for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) {
auto* n = *it;
- for (auto *child_block : n->blocks()) {
+ for (auto* child_block : n->blocks()) {
removeNopPacking(child_block);
}
}
}
-void fixDefaultRNNState(Graph* graph, Node * n, int input_index) {
+void fixDefaultRNNState(Graph* graph, Node* n, int input_index) {
auto initial_state = n->inputs()[input_index];
// The RNN code in pytorch accepts an optional hidden state. When it
// with something that doesn't fix the batch size. Note that for
// multi-layer RNNs there will be a Slice operation between the
// Constant and the RNN.
- bool needsFixing =
- initial_state->node()->kind() == onnx::Constant ||
- (initial_state->node()->kind() == onnx::Slice &&
- initial_state->node()->inputs()[0]->node()->kind() == onnx::Constant);
+ bool needsFixing = initial_state->node()->kind() == onnx::Constant ||
+ (initial_state->node()->kind() == onnx::Slice &&
+ initial_state->node()->inputs()[0]->node()->kind() == onnx::Constant);
if (!needsFixing) {
return;
}
- Node * shape_of_input = graph->create(onnx::Shape, 1);
+ Node* shape_of_input = graph->create(onnx::Shape, 1);
shape_of_input->insertBefore(n);
shape_of_input->addInput(n->inputs()[0]);
- Node * gather_indices = graph->create(onnx::Constant, 1);
+ Node* gather_indices = graph->create(onnx::Constant, 1);
gather_indices->insertBefore(n);
gather_indices->t_(attr::value, at::scalar_to_tensor(at::Scalar(1)));
- Node * batch_size = graph->create(onnx::Gather, 1);
+ Node* batch_size = graph->create(onnx::Gather, 1);
batch_size->insertBefore(n);
batch_size->addInput(shape_of_input->outputs()[0]);
batch_size->addInput(gather_indices->outputs()[0]);
- Node * unsqueezed_batch_size = graph->create(onnx::Unsqueeze, 1);
+ Node* unsqueezed_batch_size = graph->create(onnx::Unsqueeze, 1);
unsqueezed_batch_size->insertBefore(n);
unsqueezed_batch_size->addInput(batch_size->outputs()[0]);
unsqueezed_batch_size->is_(attr::axes, {0});
- Node * hidden_size = graph->create(onnx::Constant, 1);
+ Node* hidden_size = graph->create(onnx::Constant, 1);
hidden_size->insertBefore(n);
- hidden_size->t_(attr::value, at::full({1}, n->i(attr::hidden_size), at::kLong)); // at::Scalar(n->i(attr::hidden_size)).toTensor());
-
- Node * num_directions = graph->create(onnx::Constant, 1);
+ hidden_size->t_(
+ attr::value,
+ at::full(
+ {1},
+ n->i(attr::hidden_size),
+ at::kLong)); // at::Scalar(n->i(attr::hidden_size)).toTensor());
+
+ Node* num_directions = graph->create(onnx::Constant, 1);
num_directions->insertBefore(n);
- num_directions->t_(attr::value, scalar_to_tensor(at::Scalar(n->hasAttribute(attr::direction) && n->s(attr::direction) == "bidirectional" ? 2 : 1)));
-
- Node * unsqueezed_num_directions = graph->create(onnx::Unsqueeze, 1);
+ num_directions->t_(
+ attr::value,
+ scalar_to_tensor(at::Scalar(
+ n->hasAttribute(attr::direction) &&
+ n->s(attr::direction) == "bidirectional"
+ ? 2
+ : 1)));
+
+ Node* unsqueezed_num_directions = graph->create(onnx::Unsqueeze, 1);
unsqueezed_num_directions->insertBefore(n);
unsqueezed_num_directions->addInput(num_directions->outputs()[0]);
unsqueezed_num_directions->is_(attr::axes, {0});
- Node * concated_dims = graph->create(onnx::Concat, 1);
+ Node* concated_dims = graph->create(onnx::Concat, 1);
concated_dims->insertBefore(n);
concated_dims->i_(attr::axis, 0);
concated_dims->addInput(unsqueezed_num_directions->outputs()[0]);
concated_dims->addInput(unsqueezed_batch_size->outputs()[0]);
concated_dims->addInput(hidden_size->outputs()[0]);
- Node * constant_fill = graph->create(onnx::ConstantFill, 1);
+ Node* constant_fill = graph->create(onnx::ConstantFill, 1);
constant_fill->insertBefore(n);
constant_fill->i_(attr::input_as_shape, 1);
constant_fill->addInput(concated_dims->outputs()[0]);
void fixDefaultRnnHiddenState(Block* b) {
for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
auto* n = *it;
- for (auto *child_block : n->blocks()) {
+ for (auto* child_block : n->blocks()) {
fixDefaultRnnHiddenState(child_block);
}
}
}
-void fixDefaultLstmCellState(Block *b) {
+void fixDefaultLstmCellState(Block* b) {
for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
auto* n = *it;
- for (auto *child_block : n->blocks()) {
+ for (auto* child_block : n->blocks()) {
fixDefaultLstmCellState(child_block);
}
}
static void speculateOps(Block* block) {
- for(auto it = block->nodes().begin(), end = block->nodes().end();
- it != end;) {
- Node * n = *it;
- ++it; //note: increment first so that it is safe to move the node if needed
+ for (auto it = block->nodes().begin(), end = block->nodes().end();
+ it != end;) {
+ Node* n = *it;
+ ++it; // note: increment first so that it is safe to move the node if needed
- for(auto b : n->blocks()) {
+ for (auto b : n->blocks()) {
speculateOps(b);
}
- if(!isSafeToSpeculate(n))
+ if (!isSafeToSpeculate(n))
continue;
// XXX - only works for nodes with a single input
// move node n outside of the control flow it is nested in
auto node_input = n->input()->node();
- if(node_input->owningBlock() == n->owningBlock())
+ if (node_input->owningBlock() == n->owningBlock())
continue;
// find the control flow node in the same block as node_input that contains
// Node n
auto control_flow_node = n->owningBlock()->owningNode();
- while(control_flow_node->owningBlock() != node_input->owningBlock())
+ while (control_flow_node->owningBlock() != node_input->owningBlock())
control_flow_node = control_flow_node->owningBlock()->owningNode();
// put the node right before this flow node
n->moveBefore(control_flow_node);
}
}
-static void replaceInputWithList(Node *node, size_t i, ArrayRef<Value*> to) {
+static void replaceInputWithList(Node* node, size_t i, ArrayRef<Value*> to) {
node->removeInput(i);
for (auto* to_val : to) {
JIT_ASSERT(to_val->owningGraph() == node->owningGraph());
for (auto* input : n->inputs()) {
if (input->node()->kind() == prim::ListConstruct) {
auto* lc_node = input->node();
- TypePtr elem = lc_node->output()->type()->cast<ListType>()->getElementType();
+ TypePtr elem =
+ lc_node->output()->type()->cast<ListType>()->getElementType();
if (elem->cast<IntType>()) {
- // ListConstruct Int[] output case, we need to transfrom to ONNX Concat to ensure
- // the output is a single tensor(dynamic) type in order to be consumed as inputs
+ // ListConstruct Int[] output case, we need to transfrom to ONNX
+ // Concat to ensure the output is a single tensor(dynamic) type in
+ // order to be consumed as inputs
std::vector<Value*> unsqueezed;
- Graph *g = block->owningGraph();
- for (auto* input: lc_node->inputs()) {
+ Graph* g = block->owningGraph();
+ for (auto* input : lc_node->inputs()) {
Node* unsqueezed_node = g->create(onnx::Unsqueeze, 1);
unsqueezed_node->insertBefore(lc_node);
unsqueezed_node->addInput(input);
}
Node* concat_node = g->create(onnx::Concat, 1);
concat_node->i_(attr::axis, 0);
- for(auto v: unsqueezed) {
+ for (auto v : unsqueezed) {
concat_node->addInput(v);
}
concat_node->insertBefore(lc_node);
- // make concat node output as new input, then ListConstruct should become dead
- replacements.emplace_back(i, std::vector<Value*>({concat_node->output()}));
+ // make concat node output as new input, then ListConstruct should
+ // become dead
+ replacements.emplace_back(
+ i, std::vector<Value*>({concat_node->output()}));
} else {
- // Tensor lists are used mostly for inputs to cat/stack. They are already handled
- // in those symbolics, and should become dead afterwards.
+ // Tensor lists are used mostly for inputs to cat/stack. They are
+ // already handled in those symbolics, and should become dead
+ // afterwards.
replacements.emplace_back(
i,
std::vector<Value*>(
lc_node->inputs().begin(), lc_node->inputs().end()));
}
-
}
i++;
}
//
// At the moment, here are the optimizations it does:
// - This optimization fuses expand calls into ONNX operators, because it is
-// easier for non-strided backends to more efficiently do broadcasts if this is
-// local information. This optimization is not useful for PyTorch as 'expand'
-// is free.
+// easier for non-strided backends to more efficiently do broadcasts if this
+// is local information. This optimization is not useful for PyTorch as
+// 'expand' is free.
// - Fusing of consecutive transposes
// - Elimination of NOP transposes
// - Fusing of transposes into Gemm
eraseListConstruct(graph->block());
}
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/ir.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph);
-}}
+}
+} // namespace torch
-#include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h>
#include <torch/csrc/jit/constants.h>
+#include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
static void PrepareDivisionForONNXOnBlock(Block* block) {
for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
if (it->matches("aten::div(int a, int b) -> float")) {
// Cast to Float before dividing
- std::vector<Value*> floattensor_inputs = fmap(it->inputs(), [&](Value* input) {
- auto* longtensor = subgraph->insertNode(subgraph->createNumToTensor(input))->output();
- auto* nonblocking = subgraph->insertConstant(0);
- auto* cast = subgraph->create(aten::_cast_Float, {longtensor, nonblocking});
- return subgraph->insertNode(cast)->output();
- });
+ std::vector<Value*> floattensor_inputs =
+ fmap(it->inputs(), [&](Value* input) {
+ auto* longtensor =
+ subgraph->insertNode(subgraph->createNumToTensor(input))
+ ->output();
+ auto* nonblocking = subgraph->insertConstant(0);
+ auto* cast =
+ subgraph->create(aten::_cast_Float, {longtensor, nonblocking});
+ return subgraph->insertNode(cast)->output();
+ });
it->replaceInput(0, floattensor_inputs[0]);
it->replaceInput(1, floattensor_inputs[1]);
- it->output()->setType(CompleteTensorType::fromNumberType(FloatType::get()));
+ it->output()->setType(
+ CompleteTensorType::fromNumberType(FloatType::get()));
}
}
}
PrepareDivisionForONNXOnBlock(graph->block());
}
-}}
-
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/ir.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
// Prepare division ops for ONNX export. This is necessary for and only used
// by ONNX export.
//
TORCH_API void PrepareDivisionForONNX(const std::shared_ptr<Graph>& graph);
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/passes/dead_code_elimination.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
// The intent for this optimization pass is to catch all of the small, easy to
// catch peephole optimizations you might be interested in doing.
//
// The parameter `addmm_fusion_enabled` exists because, as it is today, fusing
// add + mm has no benefit within PyTorch running ATen ops. However, we rely on
-// seeing the fused version of addmm for ONNX export, since after ONNX translation
-// we would see redundant Gemm ops with sub-optimal inputs. This flag is exposed
-// so that ONNX export can pass `true` to get the fused behavior, but normal
-// JIT peephole optimization is left alone.
-void PeepholeOptimizeImpl(Block * block, bool addmm_fusion_enabled) {
+// seeing the fused version of addmm for ONNX export, since after ONNX
+// translation we would see redundant Gemm ops with sub-optimal inputs. This
+// flag is exposed so that ONNX export can pass `true` to get the fused
+// behavior, but normal JIT peephole optimization is left alone.
+void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) {
for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
auto* node = *it;
- for (Block * sub_block : node->blocks()) {
+ for (Block* sub_block : node->blocks()) {
PeepholeOptimizeImpl(sub_block, addmm_fusion_enabled);
}
- // XXX: remember that if you want to simplify an expression by combining multiple nodes
- // into a different one, then you need to check that they all belong to the given block
- if (node->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor",
- /*const_inputs=*/attr::size)) {
+ // XXX: remember that if you want to simplify an expression by combining
+ // multiple nodes into a different one, then you need to check that they all
+ // belong to the given block
+ if (node->matches(
+ "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor",
+ /*const_inputs=*/attr::size)) {
// x.expand(x.size()) == x
- if (auto input_type = node->namedInput(attr::self)->type()->cast<CompleteTensorType>()) {
+ if (auto input_type = node->namedInput(attr::self)
+ ->type()
+ ->cast<CompleteTensorType>()) {
auto expanded_sizes = node->get<std::vector<int64_t>>(attr::size);
if (expanded_sizes == input_type->sizes()) {
node->output()->replaceAllUsesWith(node->namedInput(attr::self));
}
} else if (node->matches("aten::t(Tensor self) -> Tensor")) {
// x.t().t() == x
- Node *input_node = node->input()->node();
+ Node* input_node = node->input()->node();
if (input_node->matches("aten::t(Tensor self) -> Tensor")) {
node->output()->replaceAllUsesWith(input_node->input());
}
- } else if (node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) {
+ } else if (node->matches(
+ "aten::type_as(Tensor self, Tensor other) -> Tensor")) {
// x.type_as(y) == x iff x.type() == y.type()
auto self_type = node->input(0)->type()->cast<TensorType>();
auto other_type = node->input(1)->type()->cast<TensorType>();
self_type->device() == other_type->device()) {
node->output()->replaceAllUsesWith(node->input(0));
}
- } else if (node->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
- /*const_inputs=*/attr::alpha)) {
+ } else if (
+ node->matches(
+ "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
+ /*const_inputs=*/attr::alpha)) {
// z + x.mm(y) == z.addmm(x, y) == x.mm(y) + z
- // This optimization has been disabled at the moment, because it's not helpful at all
- // until we will be able to represent torch.addmm(a, b, c, out=a). That's because addmm
- // dispatches internally to gemm, which computes:
+ // This optimization has been disabled at the moment, because it's not
+ // helpful at all until we will be able to represent torch.addmm(a, b, c,
+ // out=a). That's because addmm dispatches internally to gemm, which
+ // computes:
// C = beta * C + alpha * A @ B
// but aten::addmm(a, b, c, 1, 1) is really:
// D = beta * C + alpha * A @ B
- // and because it works out of place on C, we're only trading off an explicit add for
- // a copy inside the addmm function. Note that it doesn't even result in fewer reads,
- // because mm won't even load C (because beta == 0 for it).
- if (addmm_fusion_enabled && node->get<at::Scalar>(attr::alpha).value().toDouble() == 1.) {
+ // and because it works out of place on C, we're only trading off an
+ // explicit add for a copy inside the addmm function. Note that it doesn't
+ // even result in fewer reads, because mm won't even load C (because beta
+ // == 0 for it).
+ if (addmm_fusion_enabled &&
+ node->get<at::Scalar>(attr::alpha).value().toDouble() == 1.) {
// Look for mm from both sides of the add
for (size_t mm_side = 0; mm_side < 2; mm_side++) {
+ // Add will accept tensors of mismatched scalar types, as long as one
+ // of them is a scalar. Addmm will throw in that case, so we can only
+ // perform this fusion if we're sure that it is correct, and for that
+ // we need the add_mat_type. An alternative would be to insert a
+ // type_as conditional on the tensor shape being a scalar, but that
+ // might add overhead, and make analysis harder.
+ auto add_mat_type =
+ node->input(1 - mm_side)->type()->cast<TensorType>();
+ if (!add_mat_type)
+ continue;
- // Add will accept tensors of mismatched scalar types, as long as one of them is a scalar.
- // Addmm will throw in that case, so we can only perform this fusion if we're sure
- // that it is correct, and for that we need the add_mat_type.
- // An alternative would be to insert a type_as conditional on the tensor shape being a
- // scalar, but that might add overhead, and make analysis harder.
- auto add_mat_type = node->input(1 - mm_side)->type()->cast<TensorType>();
- if (!add_mat_type) continue;
-
- if (node->input(mm_side)->node()->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
+ if (node->input(mm_side)->node()->matches(
+ "aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
WithInsertPoint guard(node);
auto mm_node = node->input(mm_side)->node();
if (!mat_type) {
mat_type = mat2.value()->type()->cast<TensorType>();
}
- // We insert the type_as if we're sure that the added element is a scalar, and we
- // either don't know what is the type of the multiplied matrices, or know the type,
- // and know that it's mismatched.
- if (add_mat_type->dim() == 0 && (!mat_type || add_mat_type->scalarType() != mat_type->scalarType())) {
+ // We insert the type_as if we're sure that the added element is a
+ // scalar, and we either don't know what is the type of the
+ // multiplied matrices, or know the type, and know that it's
+ // mismatched.
+ if (add_mat_type->dim() == 0 &&
+ (!mat_type ||
+ add_mat_type->scalarType() != mat_type->scalarType())) {
add_mat = add_mat.type_as(mat1);
}
}
}
}
- // TODO: this doesn't work with Scalar-Tensor ops! We should canonicalize those
- } else if (node->matches("aten::mul(Tensor self, Scalar other) -> Tensor", /*const_inputs=*/attr::other) ||
- node->matches("aten::div(Tensor self, Scalar other) -> Tensor", /*const_inputs=*/attr::other)) {
+ // TODO: this doesn't work with Scalar-Tensor ops! We should canonicalize
+ // those
+ } else if (
+ node->matches(
+ "aten::mul(Tensor self, Scalar other) -> Tensor",
+ /*const_inputs=*/attr::other) ||
+ node->matches(
+ "aten::div(Tensor self, Scalar other) -> Tensor",
+ /*const_inputs=*/attr::other)) {
// x * 1 == x / 1 == x
if (node->get<at::Scalar>(attr::other)->toDouble() == 1) {
node->output()->replaceAllUsesWith(node->input(0));
}
- } else if (node->matches("aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor", /*const_inputs=*/{attr::alpha, attr::other}) ||
- node->matches("aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor", /*const_inputs=*/{attr::alpha, attr::other})) {
+ } else if (
+ node->matches(
+ "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
+ /*const_inputs=*/{attr::alpha, attr::other}) ||
+ node->matches(
+ "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
+ /*const_inputs=*/{attr::alpha, attr::other})) {
// x + 0 == x - 0 == x
if (node->get<at::Scalar>(attr::alpha)->toDouble() == 1 &&
node->get<at::Scalar>(attr::other)->toDouble() == 0) {
node->output()->replaceAllUsesWith(node->input(0));
}
- } else if (node->kind() == prim::Float || node->kind() == prim::Int || node->kind() == prim::ImplicitTensorToNum) {
+ } else if (
+ node->kind() == prim::Float || node->kind() == prim::Int ||
+ node->kind() == prim::ImplicitTensorToNum) {
Node* input_node = node->input()->node();
if (input_node->kind() == prim::NumToTensor) {
node->output()->replaceAllUsesWith(input_node->input());
}
- } else if (node->matches("prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)")) {
+ } else if (
+ node->matches(
+ "prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)")) {
auto uses = node->output()->uses();
for (Use u : uses) {
- if (u.user->matches("prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)")) {
+ if (u.user->matches(
+ "prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)")) {
u.user->replaceInput(0, node->inputs().at(0));
}
}
EliminateDeadCode(block);
}
-void PeepholeOptimize(const std::shared_ptr<Graph>& graph, bool addmm_fusion_enabled) {
+void PeepholeOptimize(
+ const std::shared_ptr<Graph>& graph,
+ bool addmm_fusion_enabled) {
PeepholeOptimize(graph->block(), addmm_fusion_enabled);
}
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/ir.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
-TORCH_API void PeepholeOptimize(const std::shared_ptr<Graph>& graph, bool addmm_fusion_enabled=false);
-TORCH_API void PeepholeOptimize(Block* block, bool addmm_fusion_enabled=false);
+TORCH_API void PeepholeOptimize(
+ const std::shared_ptr<Graph>& graph,
+ bool addmm_fusion_enabled = false);
+TORCH_API void PeepholeOptimize(
+ Block* block,
+ bool addmm_fusion_enabled = false);
-}}
+} // namespace jit
+} // namespace torch
-#include <torch/csrc/jit/passes/python_print.h>
#include <torch/csrc/jit/attributes.h>
+#include <torch/csrc/jit/export.h>
#include <torch/csrc/jit/generic_if.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/ir_views.h>
-#include <torch/csrc/jit/export.h>
+#include <torch/csrc/jit/passes/python_print.h>
#include <torch/csrc/jit/resource_guard.h>
#include <torch/csrc/jit/script/error_report.h>
#include <torch/csrc/jit/script/module.h>
-
namespace torch {
namespace jit {
void printQuotedString(std::ostream& stmt, const std::string& str) {
stmt << "\"";
- for(auto s : str) {
+ for (auto s : str) {
switch (s) {
case '\\':
stmt << "\\\\";
// C++ io has stateful formatting settings. Messing with
// them is probably worse than doing this manually.
char buf[4] = "000";
- buf[2] += s % 8; s /= 8;
- buf[1] += s % 8; s /= 8;
+ buf[2] += s % 8;
+ s /= 8;
+ buf[1] += s % 8;
+ s /= 8;
buf[0] += s;
stmt << "\\" << buf;
}
}
static bool isValidIdentifierChar(char c, size_t pos) {
- return islower(c) || isupper(c) || c == '_' || (pos > 0 && isdigit(c));
+ return islower(c) || isupper(c) || c == '_' || (pos > 0 && isdigit(c));
}
-static bool isValidIdentifier(const std::string & name) {
+static bool isValidIdentifier(const std::string& name) {
if (name.size() == 0)
return false;
for (size_t i = 0; i < name.size(); ++i) {
using QualifiedNamePtr = c10::intrusive_ptr<QualifiedName>;
struct QualifiedName : c10::intrusive_ptr_target {
QualifiedName(QualifiedNamePtr prefix, std::string name)
- : prefix_(std::move(prefix)), name_(std::move(name)) {}
+ : prefix_(std::move(prefix)), name_(std::move(name)) {}
QualifiedNamePtr prefix_;
std::string name_;
static QualifiedNamePtr create(QualifiedNamePtr prefix, std::string name) {
- return c10::make_intrusive<QualifiedName>(std::move(prefix), std::move(name));
+ return c10::make_intrusive<QualifiedName>(
+ std::move(prefix), std::move(name));
}
static QualifiedNamePtr create(std::string name) {
- return c10::make_intrusive<QualifiedName>(QualifiedNamePtr(), std::move(name));
+ return c10::make_intrusive<QualifiedName>(
+ QualifiedNamePtr(), std::move(name));
}
std::string str() const {
std::stringstream ss;
emit(ss);
return ss.str();
}
-private:
+
+ private:
void emit(std::ostream& out) const {
if (isValidIdentifier(name_)) {
if (prefix_) {
const script::Module& module,
const QualifiedNamePtr& prefix,
std::unordered_map<at::Tensor*, QualifiedNamePtr>& result) {
-
for (const auto& elem : module.get_parameters()) {
const script::NamedParameter& param = elem.value();
result[param.slot()] = QualifiedName::create(prefix, param.name);
}
}
- // some names are valid identifiers but off limits because
- // they are keywords or namespaces used in the output
- const static std::unordered_set<std::string> reserved_names = {
+// some names are valid identifiers but off limits because
+// they are keywords or namespaces used in the output
+const static std::unordered_set<std::string> reserved_names = {
// identifiers in the environment while parsing
"_", // avoid the confusing unnamed _
"aten",
"while",
"with",
"yield",
- };
+};
struct PythonPrintPass {
std::ostream& out;
// we only do this if
// (1) it is a constant, or
// (2) the temporary is unnamed, is single output, is used once,
- // and would appear in the same order when the expression tree is reparsed.
+ // and would appear in the same order when the expression tree is
+ // reparsed.
// The last case can be checked
// becuase when we emit a expresion tree in the parser,
- // we do a left-to-right postorder traversal of the expression tree (emit children, then emit op).
- // The reverse of this is a right-to-left preorder traversal of the tree.
- // By doing a right-to-left preorder traversal of the inputs of a node,
- // while also scanning the list of emitted nodes backward, we can see if
- // they line up with what would happen when parsed the node as an expression. While they line
- // up we collapse them into an inline expression.
+ // we do a left-to-right postorder traversal of the expression tree (emit
+ // children, then emit op). The reverse of this is a right-to-left preorder
+ // traversal of the tree. By doing a right-to-left preorder traversal of the
+ // inputs of a node, while also scanning the list of emitted nodes backward,
+ // we can see if they line up with what would happen when parsed the node as
+ // an expression. While they line up we collapse them into an inline
+ // expression.
- // The inductive step is that the right-most input should be produced by the node
- // immediatly before the current node if it is in tree order.
+ // The inductive step is that the right-most input should be produced by the
+ // node immediatly before the current node if it is in tree order.
bool isConstantLike(Node* n) {
- switch(n->kind()) {
+ switch (n->kind()) {
case prim::Constant:
case prim::Undefined:
case prim::None:
bool canInline(Value* v) {
Node* n = v->node();
- // there must be only 1 values, otherwise we need an assignment to handle the multiple outout values
+ // there must be only 1 values, otherwise we need an assignment to handle
+ // the multiple outout values
if (n->outputs().size() != 1)
return false;
// if it is used more than once, then we need a variable
if (n->blocks().size() != 0)
return false;
// if it is a loop-carried input, we need a variable
- // otherwise the condition or trip count may be emitted in the wrong order w.r.t. to it
+ // otherwise the condition or trip count may be emitted in the wrong order
+ // w.r.t. to it
if (use.user->kind() == prim::Loop && use.offset >= 2)
return false;
return true;
}
- // block_point is the current node in the reverse linear scan of the emitted nodes
- // v is the current value in the tree traversal that may match with block_point's output.
+ // block_point is the current node in the reverse linear scan of the emitted
+ // nodes v is the current value in the tree traversal that may match with
+ // block_point's output.
Node* scanValue(Node* block_point, Value* v) {
Node* n = v->node();
JIT_ASSERT(isConstantLike(n) || output_inline_.count(n) == 0);
- if (n == block_point && canInline(v)) { // the node must be at the expected point of the typical tree traversal
+ if (n == block_point &&
+ canInline(v)) { // the node must be at the expected point of the typical
+ // tree traversal
// recursively see if we can inline the inputs to this input
block_point = scanNode(block_point);
output_inline_.insert(n);
Node* previousNonConstant(Node* n) {
do {
n = n->prev();
- } while(isConstantLike(n));
+ } while (isConstantLike(n));
return n;
}
Node* scanNode(Node* n) {
// don't bother to scan nodes we have already determined to be inline
- if(output_inline_.count(n)) {
+ if (output_inline_.count(n)) {
return n;
}
- for(auto b : n->blocks()) {
+ for (auto b : n->blocks()) {
scanBlock(b);
}
Node* block_point = previousNonConstant(n);
- for(auto it = n->inputs().rbegin(),
- end = n->inputs().rend(); it != end; ++it) {
+ for (auto it = n->inputs().rbegin(), end = n->inputs().rend(); it != end;
+ ++it) {
block_point = scanValue(block_point, *it);
}
return block_point;
void scanBlock(Block* b) {
scanNode(b->return_node());
- for(auto node : b->nodes().reverse()) {
+ for (auto node : b->nodes().reverse()) {
scanNode(node);
}
}
// ConstantPool, which is also N^2 in the size of the constants,
// because it doesn't hash any information about the tensors.
// We will probably need to optimize this at some point using hashing.
- for(size_t i = 0; i < tensor_table_.size(); ++i) {
+ for (size_t i = 0; i < tensor_table_.size(); ++i) {
if (t.type() == tensor_table_[i].type() && t.equal(tensor_table_[i])) {
return i;
}
std::unordered_set<Node*> seen_constants;
void buildConstantList(Node* n, std::vector<Node*>& constants) {
- for(auto input : n->inputs()) {
- if (isConstantLike(input->node()) && seen_constants.count(input->node()) == 0) {
+ for (auto input : n->inputs()) {
+ if (isConstantLike(input->node()) &&
+ seen_constants.count(input->node()) == 0) {
constants.push_back(input->node());
seen_constants.insert(input->node());
}
}
- for(auto b : n->blocks()) {
+ for (auto b : n->blocks()) {
buildConstantList(b, constants);
}
}
void buildConstantList(Block* b, std::vector<Node*>& constants) {
- for(auto n : b->nodes())
+ for (auto n : b->nodes())
buildConstantList(n, constants);
buildConstantList(b->return_node(), constants);
}
// anything we have used.
size_t next_id = 0;
- std::string genNameImpl(const std::string& candidate, std::unordered_set<std::string>& used) {
+ std::string genNameImpl(
+ const std::string& candidate,
+ std::unordered_set<std::string>& used) {
std::string name = candidate;
- while(used.count(name) || reserved_names.count(name)) {
+ while (used.count(name) || reserved_names.count(name)) {
name = candidate + std::to_string(next_id++);
}
used.insert(name);
std::stringstream ss;
if (candidate.size() == 0 || isdigit(candidate[0]))
ss << "_";
- for(char c : candidate) {
+ for (char c : candidate) {
if (isupper(c) || islower(c) || isdigit(c) || c == '_')
ss << c;
else
assignValue(v, useOf(w));
}
void assignValuesToTheirUniqueNames(at::ArrayRef<Value*> values) {
- for(auto v : values) {
+ for (auto v : values) {
assignValue(v, genUniqueNameFor(v));
}
}
ResourceGuard WithIndented() {
level++;
- return ResourceGuard([this]{
- level--;
- });
+ return ResourceGuard([this] { level--; });
}
template <class T0, class T1, class F>
- void zipWith(
- at::ArrayRef<T0> list_a,
- at::ArrayRef<T1> list_b,
- F action) const {
+ void zipWith(at::ArrayRef<T0> list_a, at::ArrayRef<T1> list_b, F action)
+ const {
auto it_a = list_a.begin();
auto it_b = list_b.begin();
}
}
- void printValueList(std::ostream& stmt, at::ArrayRef<Value*> list, const char* begin = "", const char* end = "") {
+ void printValueList(
+ std::ostream& stmt,
+ at::ArrayRef<Value*> list,
+ const char* begin = "",
+ const char* end = "") {
stmt << begin;
auto delimiter = "";
for (auto* value : list) {
stmt << end;
}
- void printAssignment(
- at::ArrayRef<Value*> lhs,
- at::ArrayRef<Value*> rhs) {
- if(lhs.size() > 0) {
+ void printAssignment(at::ArrayRef<Value*> lhs, at::ArrayRef<Value*> rhs) {
+ if (lhs.size() > 0) {
indent();
printValueList(out, lhs);
out << " = ";
}
}
- // our way of encoding loops makes them difficult to turn back into python syntax.
- // we have to check properties of the condition and trip count inputs to
- // figure out which one it initially was
+ // our way of encoding loops makes them difficult to turn back into python
+ // syntax. we have to check properties of the condition and trip count inputs
+ // to figure out which one it initially was
static bool shouldEmitAsForLoop(LoopView stmt) {
- auto trip_count = toIValue(stmt.maxTripCount());
- auto cond_input = toIValue(stmt.inputCond());
- auto cond_next = toIValue(stmt.nextCond());
-
- bool condition_is_always_true = cond_input && cond_input->toBool() && cond_next &&
- cond_next->toBool();
- bool trip_count_is_specified = !trip_count || // trip is not a constant
- trip_count->toInt() != std::numeric_limits<int64_t>::max() || // it is a constant but not the default one
- stmt.currentTripCount()->uses().size() > 0; // it is actually being used in the body.
-
- if (condition_is_always_true) {
- // if the trip count was not specified this was a user-written while True:
- return trip_count_is_specified;
- } else {
- // this must be a while loop, but check that there isn't _also_ a trip count
- if (trip_count_is_specified) {
- throw script::ErrorReport(stmt.node()->getSourceLocation())
- << "loop cannot be printed as python because it has gone through an optimization "
- << "that combined while and for loops. File a bug.";
- }
- return false;
+ auto trip_count = toIValue(stmt.maxTripCount());
+ auto cond_input = toIValue(stmt.inputCond());
+ auto cond_next = toIValue(stmt.nextCond());
+
+ bool condition_is_always_true =
+ cond_input && cond_input->toBool() && cond_next && cond_next->toBool();
+ bool trip_count_is_specified = !trip_count || // trip is not a constant
+ trip_count->toInt() !=
+ std::numeric_limits<int64_t>::max() || // it is a constant but not
+ // the default one
+ stmt.currentTripCount()->uses().size() >
+ 0; // it is actually being used in the body.
+
+ if (condition_is_always_true) {
+ // if the trip count was not specified this was a user-written while True:
+ return trip_count_is_specified;
+ } else {
+ // this must be a while loop, but check that there isn't _also_ a trip
+ // count
+ if (trip_count_is_specified) {
+ throw script::ErrorReport(stmt.node()->getSourceLocation())
+ << "loop cannot be printed as python because it has gone through an optimization "
+ << "that combined while and for loops. File a bug.";
}
+ return false;
+ }
}
void printLoop(LoopView stmt) {
-
// Loop carried dependencies are handled by assigning their initial
// values to the node->outputs() before the loop,
// and assign node->outputs() to the new values at the end of each trip.
-
bool emit_as_for_loop = shouldEmitAsForLoop(stmt);
assignValuesToTheirUniqueNames(stmt.carriedOutputs());
// the condition is always True
size_t offset = emit_as_for_loop ? 1 : 0;
auto body_block = stmt.bodyBlock();
- ArrayRef<Value*> loop_carried_block_inputs = body_block->inputs().slice(offset);
+ ArrayRef<Value*> loop_carried_block_inputs =
+ body_block->inputs().slice(offset);
printBlock(body_block, loop_carried_block_inputs.size() > 0);
- printAssignment(loop_carried_block_inputs, body_block->outputs().slice(offset));
+ printAssignment(
+ loop_carried_block_inputs, body_block->outputs().slice(offset));
}
}
// this node is safe to inline, so assign the output value
// to that expression directly
// guard against really long lines
- if (output_inline_.count(node) > 0 && ss.str().size() + level * 2 < 40) {
+ if (output_inline_.count(node) > 0 &&
+ ss.str().size() + level * 2 < 40) {
assignValue(node->output(), ss.str());
return;
}
const char* the_type,
size_t list_size,
const IValue& the_list) {
- if(list_size == 0) {
+ if (list_size == 0) {
stmt << "annotate(List[" << the_type << "], [])";
} else {
stmt << the_list;
}
void printConstant(std::ostream& stmt, const IValue& v) {
- if(v.isTensor()) {
+ if (v.isTensor()) {
stmt << "CONSTANTS.c" << getOrAddTensorConstant(v.toTensor());
- } else if(v.isString()) {
+ } else if (v.isString()) {
printQuotedString(stmt, v.toStringRef());
- } else if(v.isDevice()) {
+ } else if (v.isDevice()) {
std::stringstream ss;
ss << v.toDevice();
stmt << "torch.device(";
printQuotedString(stmt, ss.str());
stmt << ")";
- } else if(v.isTensorList()) {
+ } else if (v.isTensorList()) {
stmt << "[";
const char* delim = "";
- for(const auto& t : v.toTensorListRef()) {
+ for (const auto& t : v.toTensorListRef()) {
stmt << delim << "CONSTANTS.c" << getOrAddTensorConstant(t);
delim = ", ";
}
stmt << "]";
- } else if(v.isBoolList()) {
- printMaybeAnnotatedConstantList(stmt, "bool", v.toBoolListRef().size(), v);
- } else if(v.isIntList()) {
+ } else if (v.isBoolList()) {
+ printMaybeAnnotatedConstantList(
+ stmt, "bool", v.toBoolListRef().size(), v);
+ } else if (v.isIntList()) {
printMaybeAnnotatedConstantList(stmt, "int", v.toIntListRef().size(), v);
- } else if(v.isDoubleList()) {
- printMaybeAnnotatedConstantList(stmt, "float", v.toDoubleListRef().size(), v);
+ } else if (v.isDoubleList()) {
+ printMaybeAnnotatedConstantList(
+ stmt, "float", v.toDoubleListRef().size(), v);
} else {
stmt << v;
}
// Prints the RHS value of a Node, e.g. `aten.add(x, y)`
void printRHS(std::ostream& stmt, Node* node) {
- switch(node->kind()) {
+ switch (node->kind()) {
case PythonOp::Kind: {
auto value = static_cast<const PythonOp*>(node);
if (enforce_importable_) {
// XXX - when None has an Optional[T] type, we must ensure that type
// can be recovered on parsing. It cannot be recovered if it will be
- // matched to schema with free variables. If it is used only in places where
- // there is schema and the scheme has no free variables, then we can
- // recover it without annotation. Otherwise, we annotate None with the right
- // optional type
+ // matched to schema with free variables. If it is used only in places
+ // where there is schema and the scheme has no free variables, then we
+ // can recover it without annotation. Otherwise, we annotate None with
+ // the right optional type
const auto& uses = node->output()->uses();
bool all_usable_schema =
std::all_of(uses.begin(), uses.end(), [](const Use& u) {
return false;
}
return !schema->arguments()
- .at(u.offset)
- .type()
- ->hasFreeVariables();
+ .at(u.offset)
+ .type()
+ ->hasFreeVariables();
}
return false;
});
if (all_usable_schema) {
stmt << "None";
} else {
- stmt << "annotate(" << node->output()->type()->python_str() << ", None)";
+ stmt << "annotate(" << node->output()->type()->python_str()
+ << ", None)";
}
} break;
case prim::ImplicitTensorToNum: {
printValueList(stmt, node->inputs(), "bool(", ")");
} break;
case prim::Print: {
- printValueList(stmt, node->inputs(), "print(",")");
+ printValueList(stmt, node->inputs(), "print(", ")");
} break;
case prim::TupleConstruct: {
printValueList(
stmt, node->inputs(), "(", node->inputs().size() == 1 ? ",)" : ")");
} break;
case prim::TupleIndex: {
- stmt << "(" << useOf(node->input()) << ")[" << node->i(attr::index) << "]";
+ stmt << "(" << useOf(node->input()) << ")[" << node->i(attr::index)
+ << "]";
} break;
case prim::TupleSlice: {
stmt << "(" << useOf(node->input()) << ")[" << node->i(attr::beg) << ":"
// to infer the type on import
if (node->inputs().size() == 0 &&
!node->output()->type()->isSubtypeOf(DynamicType::get())) {
- stmt << "annotate(" << node->output()->type()->python_str() << ", [])";
+ stmt << "annotate(" << node->output()->type()->python_str()
+ << ", [])";
} else {
printValueList(stmt, node->inputs(), "[", "]");
}
// the subgraph gets emitted as another function
auto name = genMethodName("__forked_function");
std::shared_ptr<Graph> graph = node->g(attr::Subgraph);
- worklist.emplace_back([graph, name, this] {
- printFunctionDefinition(*graph, name);
- });
+ worklist.emplace_back(
+ [graph, name, this] { printFunctionDefinition(*graph, name); });
// and we put a call to fork which invokes that function.
stmt << "fork(self." << name;
- for(Value* v : node->inputs()) {
+ for (Value* v : node->inputs()) {
stmt << ", " << useOf(v);
}
stmt << ")";
} break;
case prim::Function: {
if (enforce_importable_) {
- throw script::ErrorReport(node->getSourceLocation()) << "closures are not exportable";
+ throw script::ErrorReport(node->getSourceLocation())
+ << "closures are not exportable";
}
auto name = genMethodName("__lambda");
std::shared_ptr<Graph> graph = node->g(attr::Subgraph);
- worklist.emplace_back([graph, name, this] {
- printFunctionDefinition(*graph, name);
- });
+ worklist.emplace_back(
+ [graph, name, this] { printFunctionDefinition(*graph, name); });
stmt << "self." << name;
} break;
default: {
// doing it here ensures we do not have fix up archives later
stmt << "torch." << kind.toUnqualString() << "(";
} else {
- stmt << "ops." << kind.ns().toUnqualString() << "." << kind.toUnqualString() << "(";
+ stmt << "ops." << kind.ns().toUnqualString() << "."
+ << kind.toUnqualString() << "(";
}
const FunctionSchema& schema = node->schema();
for (size_t i = 0; i < node->inputs().size(); ++i) {
- if (i > 0) {
- stmt << ", ";
+ if (i > 0) {
+ stmt << ", ";
+ }
+ auto v = useOf(node->inputs().at(i));
+ // print the kwarg name if it is a kwarg only argument.
+ if (i < schema.arguments().size()) {
+ auto arg = schema.arguments().at(i);
+ if (arg.kwarg_only()) {
+ stmt << arg.name() << "=";
}
- auto v = useOf(node->inputs().at(i));
- // print the kwarg name if it is a kwarg only argument.
- if (i < schema.arguments().size()) {
- auto arg = schema.arguments().at(i);
- if (arg.kwarg_only()) {
- stmt << arg.name() << "=";
- }
- } else {
- // vararg functions like format can have extra arguments
- JIT_ASSERT(schema.is_vararg());
- }
- stmt << v;
+ } else {
+ // vararg functions like format can have extra arguments
+ JIT_ASSERT(schema.is_vararg());
+ }
+ stmt << v;
}
stmt << ")";
} break;
}
std::ostream& printBlock(Block* root, bool block_has_other_statements) {
- // pythons weird 'pass' syntax creates a bunch of places where we have to check
- // if this block would be empty. But not everything in a block is a node.
- // Sometimes if, loop, and return statements will follow this block
+ // pythons weird 'pass' syntax creates a bunch of places where we have to
+ // check if this block would be empty. But not everything in a block is a
+ // node. Sometimes if, loop, and return statements will follow this block
// and block_has_other_statements == true.
if (!block_has_other_statements &&
root->nodes().begin() == root->nodes().end()) {
return out;
}
- void printDefaultValue(const TypePtr& typ, std::ostream& stmt, const IValue& value) {
- // xxx - many weak script modules store default values for broadcasting lists
- // that are not actually the same type as the argument. We can only serialize
- // default values that will implicitly convert to their declared return type
- // since we do not need to serialize these built-in modules with their defaults,
- // we just drop them for now.
+ void printDefaultValue(
+ const TypePtr& typ,
+ std::ostream& stmt,
+ const IValue& value) {
+ // xxx - many weak script modules store default values for broadcasting
+ // lists that are not actually the same type as the argument. We can only
+ // serialize default values that will implicitly convert to their declared
+ // return type since we do not need to serialize these built-in modules with
+ // their defaults, we just drop them for now.
if (typ->kind() == ListType::Kind &&
(value.isInt() || value.isDouble() || value.isBool())) {
return;
}
stmt << "=";
if (value.isTensor() && !value.toTensor().defined()) {
- // XXX - because undefined tensors are not stored as None, we need special handling.
- // otherwise they get printed as CONSTANTS.c0 and then cannot be recreated because
- // constant nodes cannot have an undefined value in them.
- // The right solution is to make None of type Tensor actually be an IValue None.
+ // XXX - because undefined tensors are not stored as None, we need special
+ // handling. otherwise they get printed as CONSTANTS.c0 and then cannot be
+ // recreated because constant nodes cannot have an undefined value in
+ // them. The right solution is to make None of type Tensor actually be an
+ // IValue None.
stmt << "None";
return;
}
const std::string& name,
const std::vector<c10::optional<IValue>>& defaults = {},
const std::vector<std::string>& param_names = {}) {
-
used_names_.clear(); // each graph can reuse local names
// we always print constants at the top of the function, in the order
// last param_names.size() arguments to the graph are parameters and not
// actual inputs, we will print these as, e.g. self.foo.bar
// while we print the true_inputs out as parameters
- auto true_inputs = graph.inputs().slice(0, graph.inputs().size() - param_names.size());
+ auto true_inputs =
+ graph.inputs().slice(0, graph.inputs().size() - param_names.size());
auto param_names_it = param_names.begin();
- for(auto param : graph.inputs().slice(true_inputs.size())) {
+ for (auto param : graph.inputs().slice(true_inputs.size())) {
assignValue(param, *param_names_it++);
}
assignValuesToTheirUniqueNames(true_inputs);
std::ostream& out_,
std::vector<at::Tensor>& tensor_table,
bool enforce_importable)
- : out(out_), tensor_table_(tensor_table), enforce_importable_(enforce_importable) {}
+ : out(out_),
+ tensor_table_(tensor_table),
+ enforce_importable_(enforce_importable) {}
// TODO: we should consider forcing functions to return a single value
// instead of handling this tuple logic both in the compiler and the printer
const std::vector<c10::optional<IValue>>& defaults = {},
const std::vector<std::string>& param_names = {}) {
printFunctionDefinition(graph, name, defaults, param_names);
- while(!worklist.empty()) {
+ while (!worklist.empty()) {
out << "\n\n";
auto work = worklist.back();
worklist.pop_back();
}
}
void printMethod(script::Method& method) {
- std::unordered_map<at::Tensor*, QualifiedNamePtr> parameter_names;;
- createTensorToParameterNameMap(method.owner(), QualifiedName::create("self"), parameter_names);
+ std::unordered_map<at::Tensor*, QualifiedNamePtr> parameter_names;
+ ;
+ createTensorToParameterNameMap(
+ method.owner(), QualifiedName::create("self"), parameter_names);
printMethod(method, parameter_names);
}
void printMethod(
[&](at::Tensor* slot) { return parameter_names.at(slot)->str(); });
const std::string& name = method.name();
Graph& graph = *method.graph();
- auto defaults = fmap(method.getSchema().arguments(), [](const Argument& arg) {
- return arg.default_value();
- });
+ auto defaults = fmap(
+ method.getSchema().arguments(),
+ [](const Argument& arg) { return arg.default_value(); });
printFunction(graph, name, defaults, param_names);
}
void printModule(script::Module& module) {
- std::unordered_map<at::Tensor*, QualifiedNamePtr> parameter_names;;
- createTensorToParameterNameMap(module, QualifiedName::create("self"), parameter_names);
- for(auto& method : module.get_methods()) {
+ std::unordered_map<at::Tensor*, QualifiedNamePtr> parameter_names;
+ ;
+ createTensorToParameterNameMap(
+ module, QualifiedName::create("self"), parameter_names);
+ for (auto& method : module.get_methods()) {
const std::string& name = method.value()->name();
// we skip __forked_functions because they actually get inlined into their
- // callers, exporting them again will lead to more code generated on each export
+ // callers, exporting them again will lead to more code generated on each
+ // export
if (name.find("__forked_function") == 0) {
continue;
}
}
};
-TORCH_API void PythonPrint(std::ostream& out, const Graph& graph, std::vector<at::Tensor>& tensor_table, bool enforce_importable) {
+TORCH_API void PythonPrint(
+ std::ostream& out,
+ const Graph& graph,
+ std::vector<at::Tensor>& tensor_table,
+ bool enforce_importable) {
PythonPrintPass pp(out, tensor_table, enforce_importable);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
pp.printFunction(const_cast<Graph&>(graph), "graph");
}
-TORCH_API void PythonPrint(std::ostream& out, const script::Method& method, std::vector<at::Tensor>& tensor_table, bool enforce_importable) {
+TORCH_API void PythonPrint(
+ std::ostream& out,
+ const script::Method& method,
+ std::vector<at::Tensor>& tensor_table,
+ bool enforce_importable) {
PythonPrintPass pp(out, tensor_table, enforce_importable);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
pp.printMethod(const_cast<script::Method&>(method));
}
-TORCH_API void PythonPrint(std::ostream& out, const script::Module& module, std::vector<at::Tensor>& tensor_table, bool enforce_importable) {
+TORCH_API void PythonPrint(
+ std::ostream& out,
+ const script::Module& module,
+ std::vector<at::Tensor>& tensor_table,
+ bool enforce_importable) {
PythonPrintPass pp(out, tensor_table, enforce_importable);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
pp.printModule(const_cast<script::Module&>(module));
// schema to editing this list here. These cases should only be things
// that require special handling because they do not fit normal schema
const static std::unordered_set<Symbol> handled = {
- prim::Constant,
- prim::fork,
- prim::ListConstruct,
- prim::ListUnpack,
- prim::None,
- prim::Print,
- prim::PythonOp,
- prim::TupleConstruct,
- prim::TupleIndex,
- prim::TupleSlice,
- prim::TupleUnpack,
- prim::Undefined,
+ prim::Constant,
+ prim::fork,
+ prim::ListConstruct,
+ prim::ListUnpack,
+ prim::None,
+ prim::Print,
+ prim::PythonOp,
+ prim::TupleConstruct,
+ prim::TupleIndex,
+ prim::TupleSlice,
+ prim::TupleUnpack,
+ prim::Undefined,
};
// WARNING: by adding a value to this set, you are asserting that your
// to be correctly printed for export (a process that happens before
// optimization passes run)
const static std::unordered_set<Symbol> unneeded = {
- onnx::Reshape, // only used in onnx
- onnx::Shape, // only used in onnx
- prim::AnyDefined, // temporarily inserted by autograd
- prim::AutogradAdd, // temporarily inserted by autograd
- prim::ConstantChunk, // optimization pass adds it
- prim::DifferentiableGraph, // optimization pass adds it
- prim::BroadcastSizes, // optimization pass (fuser) adds it
- prim::ChunkSizes, // optimization pass (fuser) adds it
- prim::Drop, // used in interpreter only
- prim::FusedConcat, // optimization pass adds it
- prim::FusionGroup, // optimization pass adds it
- prim::Load, // used in interpreter only
- prim::MMTreeReduce, // used as an optimization
- prim::MMBatchSide, // used as an optimization
- prim::Store, // used in interpreter only
+ onnx::Reshape, // only used in onnx
+ onnx::Shape, // only used in onnx
+ prim::AnyDefined, // temporarily inserted by autograd
+ prim::AutogradAdd, // temporarily inserted by autograd
+ prim::ConstantChunk, // optimization pass adds it
+ prim::DifferentiableGraph, // optimization pass adds it
+ prim::BroadcastSizes, // optimization pass (fuser) adds it
+ prim::ChunkSizes, // optimization pass (fuser) adds it
+ prim::Drop, // used in interpreter only
+ prim::FusedConcat, // optimization pass adds it
+ prim::FusionGroup, // optimization pass adds it
+ prim::Load, // used in interpreter only
+ prim::MMTreeReduce, // used as an optimization
+ prim::MMBatchSide, // used as an optimization
+ prim::Store, // used in interpreter only
};
#include <iostream>
#include <vector>
-
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
namespace script {
- struct Method;
- struct Module;
-}
+struct Method;
+struct Module;
+} // namespace script
-TORCH_API void PythonPrint(std::ostream& out, const Graph& graph, std::vector<at::Tensor>& tensor_table, bool enforce_importable=false);
-TORCH_API void PythonPrint(std::ostream& out, const script::Method& graph, std::vector<at::Tensor>& tensor_table, bool enforce_importable=false);
-TORCH_API void PythonPrint(std::ostream& out, const script::Module& module, std::vector<at::Tensor>& tensor_table, bool enforce_importable=false);
+TORCH_API void PythonPrint(
+ std::ostream& out,
+ const Graph& graph,
+ std::vector<at::Tensor>& tensor_table,
+ bool enforce_importable = false);
+TORCH_API void PythonPrint(
+ std::ostream& out,
+ const script::Method& graph,
+ std::vector<at::Tensor>& tensor_table,
+ bool enforce_importable = false);
+TORCH_API void PythonPrint(
+ std::ostream& out,
+ const script::Module& module,
+ std::vector<at::Tensor>& tensor_table,
+ bool enforce_importable = false);
TORCH_API bool printerHasSpecialCaseFor(c10::Symbol sym);
-}}
+} // namespace jit
+} // namespace torch
-#include <torch/csrc/jit/passes/remove_expands.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
+#include <torch/csrc/jit/passes/remove_expands.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
static void RemoveExpands(Block* block) {
for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
RemoveExpands(graph->block());
}
-
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/ir.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
TORCH_API void RemoveExpands(const std::shared_ptr<Graph>& graph);
-}}
+}
+} // namespace torch
-#include <torch/csrc/jit/passes/remove_inplace_ops.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
+#include <torch/csrc/jit/passes/remove_inplace_ops.h>
namespace torch {
namespace jit {
}
}
}
-}
+} // namespace
void RemoveInplaceOps(const std::shared_ptr<Graph>& graph) {
RemoveInplaceOps(graph->block());
-#include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/type.h>
#include <torch/csrc/jit/argument_spec.h>
+#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/operator.h>
+#include <torch/csrc/jit/type.h>
#include <vector>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
namespace {
-bool getRequiresGrad(Value * value) {
+bool getRequiresGrad(Value* value) {
return value->requires_grad();
}
-void setRequiresGrad(Value * value, bool req_value) {
+void setRequiresGrad(Value* value, bool req_value) {
if (auto type = value->type()->cast<TensorType>()) {
value->setType(type->withRequiresGrad(req_value));
}
}
-void setRequiresGrad(at::ArrayRef<Value*> outputs, const std::vector<bool>& values) {
+void setRequiresGrad(
+ at::ArrayRef<Value*> outputs,
+ const std::vector<bool>& values) {
JIT_ASSERT(outputs.size() == values.size());
for (size_t i = 0; i < values.size(); ++i) {
setRequiresGrad(outputs[i], values[i]);
}
}
-void setRequiresGrad(Node * node, const std::vector<bool>& values) {
+void setRequiresGrad(Node* node, const std::vector<bool>& values) {
setRequiresGrad(node->outputs(), values);
}
return a;
}
-
void PropagateRequiresGradSimpleNode(Node* node) {
static const OperatorSet comparison_ops = {
- "aten::lt(Tensor self, Tensor other) -> Tensor",
- "aten::le(Tensor self, Tensor other) -> Tensor",
- "aten::gt(Tensor self, Tensor other) -> Tensor",
- "aten::ge(Tensor self, Tensor other) -> Tensor",
- "aten::eq(Tensor self, Tensor other) -> Tensor",
- "aten::ne(Tensor self, Tensor other) -> Tensor",
- "aten::lt(Tensor self, Scalar other) -> Tensor",
- "aten::le(Tensor self, Scalar other) -> Tensor",
- "aten::gt(Tensor self, Scalar other) -> Tensor",
- "aten::ge(Tensor self, Scalar other) -> Tensor",
- "aten::eq(Tensor self, Scalar other) -> Tensor",
- "aten::ne(Tensor self, Scalar other) -> Tensor",
+ "aten::lt(Tensor self, Tensor other) -> Tensor",
+ "aten::le(Tensor self, Tensor other) -> Tensor",
+ "aten::gt(Tensor self, Tensor other) -> Tensor",
+ "aten::ge(Tensor self, Tensor other) -> Tensor",
+ "aten::eq(Tensor self, Tensor other) -> Tensor",
+ "aten::ne(Tensor self, Tensor other) -> Tensor",
+ "aten::lt(Tensor self, Scalar other) -> Tensor",
+ "aten::le(Tensor self, Scalar other) -> Tensor",
+ "aten::gt(Tensor self, Scalar other) -> Tensor",
+ "aten::ge(Tensor self, Scalar other) -> Tensor",
+ "aten::eq(Tensor self, Scalar other) -> Tensor",
+ "aten::ne(Tensor self, Scalar other) -> Tensor",
};
if (comparison_ops.find(node)) {
return setRequiresGrad(node->output(), false);
- } else if (node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) {
+ } else if (node->matches(
+ "aten::type_as(Tensor self, Tensor other) -> Tensor")) {
return setRequiresGrad(node->output(), node->input(0)->requires_grad());
} else if (node->matches("aten::detach(Tensor self) -> Tensor")) {
return setRequiresGrad(node->output(), false);
auto inputs = node->inputs();
auto outputs = node->outputs();
- bool should_require = std::any_of(inputs.begin(), inputs.end(), getRequiresGrad);
+ bool should_require =
+ std::any_of(inputs.begin(), inputs.end(), getRequiresGrad);
for (Value* output : outputs) {
if (auto type = output->type()->cast<TensorType>()) {
- setRequiresGrad(output, should_require && at::isFloatingType(type->scalarType()));
+ setRequiresGrad(
+ output, should_require && at::isFloatingType(type->scalarType()));
}
}
}
-void PropagateRequiresGrad(Block * block);
+void PropagateRequiresGrad(Block* block);
-void PropagateRequiresGrad(Node * node) {
+void PropagateRequiresGrad(Node* node) {
if (node->kind() == prim::If) {
auto blocks = node->blocks();
auto true_block = blocks.at(0);
PropagateRequiresGrad(true_block);
PropagateRequiresGrad(false_block);
- auto outputs_require =
- bitwiseOr(fmap(true_block->outputs(), getRequiresGrad),
- fmap(false_block->outputs(), getRequiresGrad));
+ auto outputs_require = bitwiseOr(
+ fmap(true_block->outputs(), getRequiresGrad),
+ fmap(false_block->outputs(), getRequiresGrad));
setRequiresGrad(node, outputs_require);
} else if (node->kind() == prim::Loop) {
auto body = node->blocks().at(0);
- std::vector<bool> body_inputs_require = fmap(node->inputs().slice(2), getRequiresGrad);
- std::vector<bool> body_outputs_require (node->outputs().size(), false);
+ std::vector<bool> body_inputs_require =
+ fmap(node->inputs().slice(2), getRequiresGrad);
+ std::vector<bool> body_outputs_require(node->outputs().size(), false);
while (body_inputs_require != body_outputs_require) {
- body_inputs_require = bitwiseOr(body_inputs_require, body_outputs_require);
- setRequiresGrad(body->param_node()->outputs().slice(1), body_inputs_require);
+ body_inputs_require =
+ bitwiseOr(body_inputs_require, body_outputs_require);
+ setRequiresGrad(
+ body->param_node()->outputs().slice(1), body_inputs_require);
PropagateRequiresGrad(body);
- body_outputs_require = fmap(body->return_node()->inputs().slice(1), getRequiresGrad);
+ body_outputs_require =
+ fmap(body->return_node()->inputs().slice(1), getRequiresGrad);
}
setRequiresGrad(node, body_outputs_require);
}
}
-void PropagateRequiresGrad(Block * block) {
- for (Node * node : block->nodes()) {
+void PropagateRequiresGrad(Block* block) {
+ for (Node* node : block->nodes()) {
PropagateRequiresGrad(node);
}
}
PropagateRequiresGrad(graph->block());
}
-}} // namespace torch::jit
+} // namespace jit
+} // namespace torch
#include <memory>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
struct Graph;
struct ArgumentSpec;
TORCH_API void PropagateRequiresGrad(std::shared_ptr<Graph>& graph);
-}}
-
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/passes/shape_analysis.h>
-#include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/constants.h>
#include <torch/csrc/jit/argument_spec.h>
-#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/constants.h>
+#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/passes/alias_analysis.h>
#include <torch/csrc/autograd/variable.h>
#include <utility>
#include <vector>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
struct propagation_error : std::exception {};
-#define SHAPE_ASSERT(cond) if (!(cond)) throw propagation_error()
+#define SHAPE_ASSERT(cond) \
+ if (!(cond)) \
+ throw propagation_error()
namespace {
}
}
}
+
private:
const AliasDb aliasDb_;
"aten::inverse(Tensor self) -> Tensor",
};
-
// Check if this node depends on a value that has been mutated previously. If
// it has, then it's not safe to run this node in isolation, since we don't
// know whether the dependency has been executed.
}
return false;
};
- auto list_node = ((cat_node->kind() == prim::FusedConcat)
- ? cat_node
- : cat_node->namedInput(attr::tensors)->node());
- if (list_node->kind() == prim::ListConstruct
- || cat_node->kind() == prim::FusedConcat) {
+ auto list_node =
+ ((cat_node->kind() == prim::FusedConcat)
+ ? cat_node
+ : cat_node->namedInput(attr::tensors)->node());
+ if (list_node->kind() == prim::ListConstruct ||
+ cat_node->kind() == prim::FusedConcat) {
auto tensors = list_node->inputs();
if (!tensors.empty()) {
if (propagate_complete(cat_node, tensors)) {
return; // correct num type is already set
case prim::NumToTensor: {
TypePtr typ = node->input()->type();
- if (typ->isSubtypeOf(IntType::get()) || typ->isSubtypeOf(BoolType::get())) {
+ if (typ->isSubtypeOf(IntType::get()) ||
+ typ->isSubtypeOf(BoolType::get())) {
node->output()->setType(TensorType::create(at::kLong, at::kCPU, 0));
} else if (node->input()->type()->isSubtypeOf(FloatType::get())) {
node->output()->setType(TensorType::create(at::kDouble, at::kCPU, 0));
return;
}
- if (node->matches("aten::cat(Tensor[] tensors, int dim) -> Tensor")
- || node->kind() == prim::FusedConcat) {
+ if (node->matches("aten::cat(Tensor[] tensors, int dim) -> Tensor") ||
+ node->kind() == prim::FusedConcat) {
return PropagateCatShape(node);
}
// primitive/tensor outputs.
bool PropagateTensorShapeOnNode(Node* node, bool insert_expands) {
- static const auto broadcast =
- [](std::vector<TensorTypePtr>& tensor_types, size_t arg_for_type) -> TensorTypePtr {
+ static const auto broadcast = [](std::vector<TensorTypePtr>& tensor_types,
+ size_t arg_for_type) -> TensorTypePtr {
if (tensor_types.size() == 1) {
return tensor_types[0];
}
return {};
}};
- // aten::where is special in that its return type is the second argument's (self)
- // type rather than the that of condition
+ // aten::where is special in that its return type is the second argument's
+ // (self) type rather than the that of condition
static const register_formula_for where_op{
{
"aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
- },
+ },
[this](Node* node) -> type_vec_t {
if (auto maybe_tensor_types = gatherTensorTypes<TensorType>(node)) {
return {broadcast(*maybe_tensor_types, 1)};
// Requirements:
// dims : preserved from the first argument
- // scalar type : preserved from the first argument (doesn't have to match other arguments)
- // device : always matching and preserved
+ // scalar type : preserved from the first argument (doesn't have to
+ // match other arguments) device : always matching and preserved
// tensor inputs : *
// tensor outputs : 1
// NB: those ops (with slight adjustments) are good candidates for restarts.
node, /*num_reduce_dim=*/1, /*integer_upcast=*/true);
}};
-
// Requirements:
- // dims : preserved if keepdim == false, dim->size() smaller otherwise
- // scalar type : preserved
- // device : preserved
- // tensor inputs : 1
- // tensor outputs : 1
+ // dims : preserved if keepdim == false, dim->size() smaller
+ // otherwise scalar type : preserved device : preserved tensor
+ // inputs : 1 tensor outputs : 1
// Additionally:
// - First input should be the only tensor input
// - has a bool keepdim argument
- static const register_formula_for multidim_reduce_ops {
+ static const register_formula_for multidim_reduce_ops{
{
"aten::mean(Tensor self, int[] dim, bool keepdim) -> Tensor",
"aten::std(Tensor self, int[] dim, bool unbiased, bool keepdim) -> Tensor",
},
- [](Node * node) -> type_vec_t {
+ [](Node* node) -> type_vec_t {
if (auto dim = node->get<std::vector<int64_t>>(attr::dim)) {
- return multidim_reduce_with_postprocess(node, /*num_reduce_dim=*/dim->size(), /*integer_upcast=*/false);
+ return multidim_reduce_with_postprocess(
+ node, /*num_reduce_dim=*/dim->size(), /*integer_upcast=*/false);
}
return {};
}};
setUnshapedType(node);
return false;
}
-
};
} // anonymous namespace
namespace {
void EraseShapeInformation(at::ArrayRef<Value*> vals) {
- for (Value * v : vals) {
+ for (Value* v : vals) {
v->setType(unshapedType(v->type()));
}
}
-void EraseShapeInformation(Block * b) {
+void EraseShapeInformation(Block* b) {
EraseShapeInformation(b->inputs());
EraseShapeInformation(b->outputs());
- for (Node * n : b->nodes()) {
+ for (Node* n : b->nodes()) {
EraseShapeInformation(n->outputs());
- for (Block *sb : n->blocks()) {
+ for (Block* sb : n->blocks()) {
EraseShapeInformation(sb);
}
}
EraseShapeInformation(graph->block());
}
-}}
+} // namespace jit
+} // namespace torch
#pragma once
-#include <memory>
#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <memory>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
struct Graph;
TORCH_API void EraseShapeInformation(const std::shared_ptr<Graph>& graph);
TORCH_API void PropagateInputShapes(const std::shared_ptr<Graph>& graph);
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/passes/specialize_undef.h>
#include <torch/csrc/jit/symbolic_variable.h>
-namespace torch { namespace jit {
-
+namespace torch {
+namespace jit {
// propagate undefined information through a gradient graph and
// remove grad_of blocks if present.
// operations generated by the symbolic autodiff code and cleans up
// AutogradAdds when possible. Outputs of other nodes are conservatively
// marked Unknown and not optimized.
-void specializeUndef(Graph & g) {
+void specializeUndef(Graph& g) {
enum class State { Defined, Undefined, Unknown };
std::unordered_map<Value*, State> state;
}
}
- for(auto it = g.nodes().begin(); it != g.nodes().end(); ++it) {
+ for (auto it = g.nodes().begin(); it != g.nodes().end(); ++it) {
auto n = *it;
- switch(n->kind()) {
+ switch (n->kind()) {
case prim::GradOf: {
auto all_undefined =
std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) {
return state[v] == State::Undefined;
});
// Property 1: if all the gradInputs to the GradOf are undefined
- // then the gradOutputs are also zero and will be represented as undefined nodes
- if(all_undefined) {
+ // then the gradOutputs are also zero and will be represented as
+ // undefined nodes
+ if (all_undefined) {
auto undef = g.createUndefined()->insertAfter(n)->output();
- for(auto o : n->outputs()) {
+ for (auto o : n->outputs()) {
o->replaceAllUsesWith(undef);
}
} else {
- // Property 2: GradOfs are required to correctly handle combinations
- // of defined and undefined inputs. They are expected to produce defined
- // output tensors in this case.
+ // Property 2: GradOfs are required to correctly handle combinations
+ // of defined and undefined inputs. They are expected to produce
+ // defined output tensors in this case.
- // Remove the GradOf, splicing its body back into the surrounding block
+ // Remove the GradOf, splicing its body back into the surrounding
+ // block
auto body = n->blocks().at(0);
- for(auto input : n->inputs()){
+ for (auto input : n->inputs()) {
// we should never get into a situation when specializing a GradOf
// where we do not know if a value is defined since at the top level
// a gradient graph is composed of Linear nodes and AutogradAdds
JIT_ASSERT(state[input] != State::Unknown);
}
// hoist the nodes in the GradOf body to be before the linear block
- for(auto it = body->nodes().begin(); it != body->nodes().end();) {
+ for (auto it = body->nodes().begin(); it != body->nodes().end();) {
auto block_node = *it++;
block_node->moveBefore(n);
}
- for(size_t i = 0; i < n->outputs().size(); ++i)
+ for (size_t i = 0; i < n->outputs().size(); ++i)
n->outputs().at(i)->replaceAllUsesWith(body->outputs().at(i));
}
it.destroyCurrent();
auto a = n->input(0);
auto b = n->input(1);
// if one is undefined, we can just drop the add
- if(state[a] == State::Undefined) {
+ if (state[a] == State::Undefined) {
// Undef + b == b
n->output()->replaceAllUsesWith(b);
it.destroyCurrent();
- } else if(state[b] == State::Undefined) {
+ } else if (state[b] == State::Undefined) {
// a + Undef == a
n->output()->replaceAllUsesWith(a);
it.destroyCurrent();
- } else if(state[a] == State::Defined && state[b] == State::Defined) {
- // when both are defined, we can use a normal, optimizable add instruction
+ } else if (state[a] == State::Defined && state[b] == State::Defined) {
+ // when both are defined, we can use a normal, optimizable add
+ // instruction
WithInsertPoint guard(n);
Value* new_add = toVar(a) + toVar(b);
state[new_add] = State::Defined;
state[n->output()] = State::Undefined;
} break;
default:
- for(auto o : n->outputs()) {
+ for (auto o : n->outputs()) {
state[o] = State::Unknown;
}
break;
}
}
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/ir.h>
-namespace torch { namespace jit {
-
+namespace torch {
+namespace jit {
// propagate undefined information through a gradient graph and
// remove grad_of blocks if present.
// operations generated by the symbolic autodiff code and cleans up
// AutogradAdds when possible. Outputs of other nodes are conservatively
// marked Unknown and not optimized.
-TORCH_API void specializeUndef(Graph & g);
+TORCH_API void specializeUndef(Graph& g);
-}}
+} // namespace jit
+} // namespace torch
+#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/to_batch.h>
#include <torch/csrc/jit/script/compiler.h>
-#include <torch/csrc/jit/passes/dead_code_elimination.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
-std::unordered_map<std::string, std::vector<std::shared_ptr<Graph>>> ToBatch::batch_operator_table;
+std::unordered_map<std::string, std::vector<std::shared_ptr<Graph>>>
+ ToBatch::batch_operator_table;
-std::shared_ptr<Graph> ToBatch::getBatchOperator(const std::string& name, int64_t num_inputs){
- if(batch_operator_table.find(name) == batch_operator_table.end()){
- throw std::runtime_error("function " + name + " is not supported in batched tensor yet");
+std::shared_ptr<Graph> ToBatch::getBatchOperator(
+ const std::string& name,
+ int64_t num_inputs) {
+ if (batch_operator_table.find(name) == batch_operator_table.end()) {
+ throw std::runtime_error(
+ "function " + name + " is not supported in batched tensor yet");
}
auto ops = batch_operator_table.at(name);
- if(num_inputs == -1) // default function
+ if (num_inputs == -1) // default function
return ops[0];
- for(auto op : ops){
- if(size_t(num_inputs) == op->inputs().size())
+ for (auto op : ops) {
+ if (size_t(num_inputs) == op->inputs().size())
return op;
}
- throw std::runtime_error("function " + name + " with " + std::to_string(num_inputs) + " inputs is not supported in batched tensor yet");
+ throw std::runtime_error(
+ "function " + name + " with " + std::to_string(num_inputs) +
+ " inputs is not supported in batched tensor yet");
}
-std::vector<Value*> inlineUnpackedCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs) {
+std::vector<Value*> inlineUnpackedCallTo(
+ Graph& g,
+ Graph& callee,
+ ArrayRef<Value*> inputs) {
return inlineCallTo(g, callee, inputs, /*unpack_outputs=*/true);
}
// replace aten operator node with BatchTensor operator graph
-void ToBatch::visitAten(Node* n, Block* block, Block* res_block){
+void ToBatch::visitAten(Node* n, Block* block, Block* res_block) {
auto res_graph = res_block->owningGraph();
auto func_name = std::string(n->kind().toUnqualString());
std::vector<Value*> new_inputs;
- for(Value *input : n->inputs()){
- if(rn_env.find(input) == rn_env.end()){ // non-tensor input
+ for (Value* input : n->inputs()) {
+ if (rn_env.find(input) == rn_env.end()) { // non-tensor input
auto new_input = batch_map.at(input);
new_inputs.insert(new_inputs.end(), new_input.begin(), new_input.end());
- }
- else{ // batched tensor input
+ } else { // batched tensor input
new_inputs.push_back(rn_env.at(input));
}
}
// transform scalar to tensor before pass to batch operator script
- for (auto& input : new_inputs) {
- if(input->type() == IntType::get() || input->type() == FloatType::get() || input->type() == BoolType::get()){
+ for (auto& input : new_inputs) {
+ if (input->type() == IntType::get() || input->type() == FloatType::get() ||
+ input->type() == BoolType::get()) {
auto to_tensor_node = res_graph->createNumToTensor(input);
res_graph->insertNode(to_tensor_node);
input = to_tensor_node->output();
}
auto batch_graph = getBatchOperator(func_name, new_inputs.size());
- auto outputs = inlineUnpackedCallTo(*res_block->owningGraph(), *batch_graph, new_inputs);
+ auto outputs =
+ inlineUnpackedCallTo(*res_block->owningGraph(), *batch_graph, new_inputs);
- // Assume all outputs from inlined operator implementation are in the triple form batched tensor or just a single non-tensor.
+ // Assume all outputs from inlined operator implementation are in the triple
+ // form batched tensor or just a single non-tensor.
if (outputs.size() == 1) {
- // if previous output is scalar, transform new output back to scalar from dynamic
+ // if previous output is scalar, transform new output back to scalar from
+ // dynamic
TypePtr orig_type = n->outputs()[0]->type();
if (!orig_type->isSubtypeOf(outputs[0]->type())) {
Symbol op;
} else if (orig_type == BoolType::get()) {
op = prim::Bool;
} else {
- throw std::runtime_error("NYI: scalar types other than int, float, and bool are not supported yet");
+ throw std::runtime_error(
+ "NYI: scalar types other than int, float, and bool are not supported yet");
}
- rn_env[n->outputs()[0]] = res_graph->insert(op, { outputs[0] });
+ rn_env[n->outputs()[0]] = res_graph->insert(op, {outputs[0]});
} else {
rn_env[n->outputs()[0]] = outputs[0];
}
} else {
- for(size_t i = 0; i < n->outputs().size(); i++){
+ for (size_t i = 0; i < n->outputs().size(); i++) {
auto output = n->outputs()[i];
- batch_map[output] = std::vector<Value*>(outputs.begin() + i * EXP_BTENSOR_SIZE, outputs.begin() + i * EXP_BTENSOR_SIZE + EXP_BTENSOR_SIZE);
+ batch_map[output] = std::vector<Value*>(
+ outputs.begin() + i * EXP_BTENSOR_SIZE,
+ outputs.begin() + i * EXP_BTENSOR_SIZE + EXP_BTENSOR_SIZE);
}
}
}
// clone prim::Constant to new graph
// batching transformation is applied to the output of prim::NumToTensor.
-// If there is a prim::NumToTensor following prim::Constant, it will be finally transformed to BatchTensor.
-void ToBatch::visitConstant(Node* n, Block* block, Block* res_block){
+// If there is a prim::NumToTensor following prim::Constant, it will be finally
+// transformed to BatchTensor.
+void ToBatch::visitConstant(Node* n, Block* block, Block* res_block) {
auto res_graph = res_block->owningGraph();
auto* r_node = res_graph->createClone(n, rn_fn);
res_block->appendNode(r_node);
}
// change return tensor to expanded batched tensor, eg: {data, mask, dims}
-void ToBatch::visitNumToTensor(Node* n, Block* block, Block* res_block){
+void ToBatch::visitNumToTensor(Node* n, Block* block, Block* res_block) {
auto res_graph = res_block->owningGraph();
auto* r_node = res_graph->createClone(n, rn_fn);
res_block->appendNode(r_node);
- auto outputs = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("batch_from_scalar_tensor"), r_node->outputs());
+ auto outputs = inlineUnpackedCallTo(
+ *res_block->owningGraph(),
+ *getBatchOperator("batch_from_scalar_tensor"),
+ r_node->outputs());
batch_map[n->output()] = outputs;
}
// clone prim::TensorToNum to new graph
-void ToBatch::visitTensorToNum(Node* n, Block* block, Block* res_block){
+void ToBatch::visitTensorToNum(Node* n, Block* block, Block* res_block) {
auto res_graph = res_block->owningGraph();
- if(rn_env.find(n->input()) == rn_env.end()){
+ if (rn_env.find(n->input()) == rn_env.end()) {
rn_env[n->input()] = batch_map.at(n->input())[0];
}
auto* r_node = res_graph->createClone(n, rn_fn);
}
// clone prim::ListConstruct to new graph
-void ToBatch::visitListConstruct(Node* n, Block* block, Block* res_block){
+void ToBatch::visitListConstruct(Node* n, Block* block, Block* res_block) {
auto res_graph = res_block->owningGraph();
- if(n->inputs()[0]->type() == DynamicType::get()){ // TensorList: expand directly
+ if (n->inputs()[0]->type() ==
+ DynamicType::get()) { // TensorList: expand directly
std::vector<Value*> inputs;
- for(Value* input: n->inputs()) {
+ for (Value* input : n->inputs()) {
auto res = batch_map.at(input);
inputs.insert(inputs.end(), res.begin(), res.end());
}
batch_map[n->output()] = inputs;
- }
- else { // ScalarList: transform to tensor, then transform back
- for(Value* input : n->inputs()) {
- if(rn_env.find(input) == rn_env.end()){
+ } else { // ScalarList: transform to tensor, then transform back
+ for (Value* input : n->inputs()) {
+ if (rn_env.find(input) == rn_env.end()) {
rn_env[input] = batch_map.at(input)[0];
}
}
auto* r_node = res_graph->createClone(n, rn_fn);
res_block->appendNode(r_node);
// transform int[] to tensor
- auto to_tensor_node = res_graph->create(Symbol::fromQualString("aten::_list_to_tensor"));
+ auto to_tensor_node =
+ res_graph->create(Symbol::fromQualString("aten::_list_to_tensor"));
to_tensor_node->addInput(r_node->output());
res_block->appendNode(to_tensor_node);
rn_env[n->output()] = to_tensor_node->output();
}
}
+// clang-format off
// prim::If transformation:
// elif is not supported
//
// %res_dims : Dynamic = aten::__or__(%dims.1, %dims)
// return (%res_data, %res_mask, %res_dims);
// }
-void ToBatch::visitIf(Node* n, Block* block, Block* res_block){
+// clang-format on
+void ToBatch::visitIf(Node* n, Block* block, Block* res_block) {
toBatch(n->blocks()[0], res_block);
toBatch(n->blocks()[1], res_block);
// combine results from two if paths
- for(size_t i = 0; i < n->outputs().size(); i++){
+ for (size_t i = 0; i < n->outputs().size(); i++) {
std::vector<Value*> inputs;
- if(batch_map.find(n->input()) == batch_map.end()){ // cond is scalar
+ if (batch_map.find(n->input()) == batch_map.end()) { // cond is scalar
inputs.push_back(rn_env.at(n->input()));
- }
- else{ // cond is tensor
+ } else { // cond is tensor
auto cond = batch_map.at(n->input());
inputs.insert(inputs.end(), cond.begin(), cond.end());
}
inputs.insert(inputs.end(), if_output.begin(), if_output.end());
auto else_output = batch_map.at(n->blocks()[1]->outputs()[i]);
inputs.insert(inputs.end(), else_output.begin(), else_output.end());
- auto outputs = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("where", inputs.size()), inputs);
+ auto outputs = inlineUnpackedCallTo(
+ *res_block->owningGraph(),
+ *getBatchOperator("where", inputs.size()),
+ inputs);
batch_map[n->outputs()[i]] = outputs;
}
}
+// clang-format off
// prim::Loop transformation:
//
// transformation example:
// }
// return (%a, %60, %61);
// }
-void ToBatch::visitLoop(Node* n, Block* block, Block* res_block){
+// clang-format on
+void ToBatch::visitLoop(Node* n, Block* block, Block* res_block) {
auto res_graph = res_block->owningGraph();
// bool cond_is_tensor indicates whether cond is tensor
// cond_is_tensor = false, eg: for loop, n->inputs()[1] = byte()
// cond_is_tensor = true, eg: in some while loop, cond is a batched tensor,
- // we need to add expanded cond to the inputs of loop node and block,
- // and compute cond_any as cond for while loop
+ // we need to add expanded cond to the inputs of
+ // loop node and block, and compute cond_any as
+ // cond for while loop
bool cond_is_tensor = (batch_map.find(n->inputs()[1]) != batch_map.end());
// create prim::Loop node for res_block
// type of cond in loop should be int type
- if(rn_env.at(n->inputs()[0])->type() != IntType::get()){
- rn_env[n->inputs()[0]] = res_graph->insert(prim::Int, {rn_env.at(n->inputs()[0])});
+ if (rn_env.at(n->inputs()[0])->type() != IntType::get()) {
+ rn_env[n->inputs()[0]] =
+ res_graph->insert(prim::Int, {rn_env.at(n->inputs()[0])});
}
- if(cond_is_tensor){
+ if (cond_is_tensor) {
auto cond = batch_map.at(n->inputs()[1]);
- auto cond_any = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("any"), cond);
- rn_env[n->inputs()[1]] =res_graph->insert(prim::Bool, {cond_any[0]});
+ auto cond_any = inlineUnpackedCallTo(
+ *res_block->owningGraph(), *getBatchOperator("any"), cond);
+ rn_env[n->inputs()[1]] = res_graph->insert(prim::Bool, {cond_any[0]});
}
- for(size_t i = 2; i < n->inputs().size(); i++){
+ for (size_t i = 2; i < n->inputs().size(); i++) {
auto input = n->inputs()[i];
rn_env[input] = batch_map.at(input)[0];
}
auto* r_node = res_graph->createClone(n, rn_fn, /*copy_blocks=*/false);
// change inputs of prim::Loop
- if(cond_is_tensor){
- for(size_t i = 0; i < EXP_BTENSOR_SIZE; i++){
+ if (cond_is_tensor) {
+ for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
auto cond = batch_map.at(n->inputs()[1]);
r_node->insertInput(i + 2, cond[i]);
}
}
- for(size_t i = 2; i < n->inputs().size(); i++){
- for(size_t j = 1; j < EXP_BTENSOR_SIZE; j++){
- r_node->insertInput((i - 2) * EXP_BTENSOR_SIZE + EXP_BTENSOR_SIZE * cond_is_tensor + 2 + j, batch_map.at(n->inputs()[i])[j]);
+ for (size_t i = 2; i < n->inputs().size(); i++) {
+ for (size_t j = 1; j < EXP_BTENSOR_SIZE; j++) {
+ r_node->insertInput(
+ (i - 2) * EXP_BTENSOR_SIZE + EXP_BTENSOR_SIZE * cond_is_tensor + 2 +
+ j,
+ batch_map.at(n->inputs()[i])[j]);
}
}
res_block->appendNode(r_node);
// create block for Loop node in res_block
- // if cond is tensor: first 4 inputs of block: cond_any, cond_data, cond_mask, cond_dims
+ // if cond is tensor: first 4 inputs of block: cond_any, cond_data,
+ // cond_mask, cond_dims
// if cond is not tensor: first 1 input of block: cond
auto loop_block = r_node->addBlock();
loop_block->addInput("loop_num");
loop_block->inputs()[0]->setType(IntType::get());
rn_env[n->blocks()[0]->inputs()[0]] = loop_block->inputs()[0];
- if(cond_is_tensor){
- for(size_t i = 0; i < EXP_BTENSOR_SIZE; i++){
+ if (cond_is_tensor) {
+ for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
loop_block->addInput("cond_" + EXP_BTENSOR_NAME[i]);
}
}
- for(size_t i = 1; i < n->blocks()[0]->inputs().size(); i++){
+ for (size_t i = 1; i < n->blocks()[0]->inputs().size(); i++) {
auto input = n->blocks()[0]->inputs()[i];
auto name = input->uniqueName();
- for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){
+ for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
loop_block->addInput(name + "_" + EXP_BTENSOR_NAME[j]);
}
- batch_map[input] = std::vector<Value*>(loop_block->inputs().slice((i - 1) * EXP_BTENSOR_SIZE + 1 + EXP_BTENSOR_SIZE * cond_is_tensor, EXP_BTENSOR_SIZE).vec());
+ batch_map[input] =
+ std::vector<Value*>(loop_block->inputs()
+ .slice(
+ (i - 1) * EXP_BTENSOR_SIZE + 1 +
+ EXP_BTENSOR_SIZE * cond_is_tensor,
+ EXP_BTENSOR_SIZE)
+ .vec());
}
toBatch(n->blocks()[0], loop_block);
WithInsertPoint guard(loop_block);
// use where operator to update variables and add to outputs
- for(size_t i = 0; i < n->outputs().size(); i++){
+ for (size_t i = 0; i < n->outputs().size(); i++) {
std::vector<Value*> inputs, outputs;
- if(cond_is_tensor){
- for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){
+ if (cond_is_tensor) {
+ for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
inputs.push_back(loop_block->inputs()[j + 1]);
}
auto data = batch_map.at(n->blocks()[0]->outputs()[i + 1]);
inputs.insert(inputs.end(), data.begin(), data.end());
- for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){
- inputs.push_back(loop_block->inputs()[i * EXP_BTENSOR_SIZE + j + EXP_BTENSOR_SIZE + 1]);
+ for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
+ inputs.push_back(
+ loop_block
+ ->inputs()[i * EXP_BTENSOR_SIZE + j + EXP_BTENSOR_SIZE + 1]);
}
- outputs = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("where"), inputs);
- }
- else{
- for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){
+ outputs = inlineUnpackedCallTo(
+ *res_block->owningGraph(), *getBatchOperator("where"), inputs);
+ } else {
+ for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
inputs.push_back(loop_block->inputs()[i * EXP_BTENSOR_SIZE + j + 1]);
}
auto data = batch_map.at(n->blocks()[0]->outputs()[i + 1]);
inputs.insert(inputs.end(), data.begin(), data.end());
- outputs = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("update"), inputs);
+ outputs = inlineUnpackedCallTo(
+ *res_block->owningGraph(), *getBatchOperator("update"), inputs);
}
batch_map[n->outputs()[i]] = outputs;
- for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){
+ for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
loop_block->registerOutput(outputs[j]);
}
}
// update loop conditions
- if(cond_is_tensor){
+ if (cond_is_tensor) {
auto cond = batch_map.at(n->blocks()[0]->outputs()[0]);
- auto cond_any = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("any"), cond);
+ auto cond_any = inlineUnpackedCallTo(
+ *res_block->owningGraph(), *getBatchOperator("any"), cond);
auto to_bool_output = res_graph->insert(prim::Bool, {cond_any[0]});
- loop_block->insertOutput(0, to_bool_output);
- for(size_t i = 0; i < EXP_BTENSOR_SIZE; i++){
+ loop_block->insertOutput(0, to_bool_output);
+ for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
loop_block->insertOutput(i + 1, cond[i]);
}
- }
- else{
+ } else {
auto cond = rn_env.at(n->blocks()[0]->outputs()[0]);
loop_block->insertOutput(0, cond);
}
// change outputs of prim::Loop
auto size = r_node->outputs().size();
- for(size_t i = 0; i < size; i++){
- for(size_t j = 1; j < EXP_BTENSOR_SIZE; j++){
+ for (size_t i = 0; i < size; i++) {
+ for (size_t j = 1; j < EXP_BTENSOR_SIZE; j++) {
r_node->insertOutput(i * EXP_BTENSOR_SIZE + j);
}
- batch_map[n->outputs()[i]] = r_node->outputs().slice(i * EXP_BTENSOR_SIZE, EXP_BTENSOR_SIZE).vec();
+ batch_map[n->outputs()[i]] =
+ r_node->outputs().slice(i * EXP_BTENSOR_SIZE, EXP_BTENSOR_SIZE).vec();
}
// add cond to outputs of loop node
- if(cond_is_tensor){
- for(size_t i = 0; i < EXP_BTENSOR_SIZE; i++){
+ if (cond_is_tensor) {
+ for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
r_node->insertOutput(i);
}
}
void ToBatch::toBatch(Block* block, Block* res_block) {
WithInsertPoint guard(res_block);
- // change inputs of block - expand tensor to batchtensor eg: (data, mask, dims)
- // eg: a -> a_data, a_mask, a_dims
- // for block in prim::Loop, register inputs separately to deal with cond
- if(!block->owningNode() || block->owningNode()->kind() != prim::Loop){
+ // change inputs of block-expand tensor to batchtensor eg: (data, mask, dims)
+ // eg: a -> a_data, a_mask, a_dims for block in prim::Loop, register inputs
+ // separately to deal with cond
+ if (!block->owningNode() || block->owningNode()->kind() != prim::Loop) {
auto size = block->inputs().size();
- for(size_t i = 0; i < size; i++){
+ for (size_t i = 0; i < size; i++) {
auto input = block->inputs()[i];
auto name = input->uniqueName();
- for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){
+ for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
res_block->addInput(name + "_" + EXP_BTENSOR_NAME[j]);
}
- batch_map[input] = std::vector<Value*>(res_block->inputs().slice(i * EXP_BTENSOR_SIZE, EXP_BTENSOR_SIZE).vec());
+ batch_map[input] =
+ std::vector<Value*>(res_block->inputs()
+ .slice(i * EXP_BTENSOR_SIZE, EXP_BTENSOR_SIZE)
+ .vec());
}
}
for (auto it = block->nodes().begin(); it != block->nodes().end(); it++) {
auto n = *it;
- if(n->kind().is_aten()){
+ if (n->kind().is_aten()) {
visitAten(n, block, res_block);
- }
- else if(n->kind().is_prim()){
- switch(n->kind()){
+ } else if (n->kind().is_prim()) {
+ switch (n->kind()) {
case prim::Constant:
case prim::None:
visitConstant(n, block, res_block);
visitLoop(n, block, res_block);
break;
default:
- throw std::runtime_error("NYI: node of prim kind other than [Constant, NumToTensor, TensorToNum, If, Loop] is not supported yet");
+ throw std::runtime_error(
+ "NYI: node of prim kind other than [Constant, NumToTensor, TensorToNum, If, Loop] is not supported yet");
}
- }
- else{
- throw std::runtime_error("NYI: node that is not aten or prim kind is not supported yet");
+ } else {
+ throw std::runtime_error(
+ "NYI: node that is not aten or prim kind is not supported yet");
}
}
// change outputs of block - expand tensor to batchtensor(data, mask, dims)
- // for block in prim::Loop, register outputs separately to deal with cond and cond_any
- // for block in prim::If, register outputs separately by combining outputs from two paths and return
- if(!block->owningNode() || (block->owningNode()->kind() != prim::Loop && block->owningNode()->kind() != prim::If)) {
- for(Value* output : block->outputs()){
+ // for block in prim::Loop, register outputs separately to deal with cond and
+ // cond_any
+ //
+ // for block in prim::If, register outputs separately by combining
+ // outputs from two paths and return
+ if (!block->owningNode() ||
+ (block->owningNode()->kind() != prim::Loop &&
+ block->owningNode()->kind() != prim::If)) {
+ for (Value* output : block->outputs()) {
auto r_output = batch_map.at(output);
- for(size_t i = 0; i < EXP_BTENSOR_SIZE; i++){
+ for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
res_block->registerOutput(r_output[i]);
}
}
graph = graph->copy();
auto outs = createTupleUnpack(graph->outputs().at(0));
graph->eraseOutput(0);
- for(auto o : outs)
+ for (auto o : outs)
graph->registerOutput(o);
EliminateDeadCode(graph->block());
}
ToBatch to_batch;
to_batch.toBatch(graph->block(), res_graph->block());
- // methods should only have a single output, so we pack everything into a tuple
- auto tup = res_graph->insertNode(res_graph->createTuple(res_graph->outputs()));
+ // methods should only have a single output, so we pack everything into a
+ // tuple
+ auto tup =
+ res_graph->insertNode(res_graph->createTuple(res_graph->outputs()));
while (res_graph->outputs().size() > 0)
res_graph->eraseOutput(res_graph->outputs().size() - 1);
res_graph->registerOutput(tup->output());
void initRegisterBatchOpsBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
m.def("to_batch_graph", to_batch_graph);
- m.def("register_batch_operator", [](std::string name, std::shared_ptr<Graph> graph){
- ToBatch::batch_operator_table[name].push_back(graph);
- });
+ m.def(
+ "register_batch_operator",
+ [](std::string name, std::shared_ptr<Graph> graph) {
+ ToBatch::batch_operator_table[name].push_back(graph);
+ });
}
-}} // namespace torch.jit
+} // namespace jit
+} // namespace torch
#pragma once
-#include <torch/csrc/jit/pybind.h>
#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/pybind.h>
#include <ATen/ATen.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
class ToBatch {
-private:
- // number of tensors to represent a expanded BatchTensor. {data, mask, dims} for now.
+ private:
+ // number of tensors to represent a expanded BatchTensor. {data, mask, dims}
+ // for now.
const size_t EXP_BTENSOR_SIZE = 3;
const std::vector<std::string> EXP_BTENSOR_NAME = {"data", "mask", "dims"};
// mapping from tensor in original graph to {data, mask, dims} in new graph
std::unordered_map<Value*, std::vector<Value*>> batch_map;
- // mapping from input in original graph to new input in new graph - used in createClone
+ // mapping from input in original graph to new input in new graph - used in
+ // createClone
std::unordered_map<Value*, Value*> rn_env;
- std::function<Value*(Value*)> rn_fn = [this](Value* v) { return rn_env.at(v); };
+ std::function<Value*(Value*)> rn_fn = [this](Value* v) {
+ return rn_env.at(v);
+ };
-private:
- std::shared_ptr<Graph> getBatchOperator(const std::string& name, int64_t input_num = -1);
+ private:
+ std::shared_ptr<Graph> getBatchOperator(
+ const std::string& name,
+ int64_t input_num = -1);
void visitAten(Node* n, Block* block, Block* res_block);
void visitConstant(Node* n, Block* block, Block* res_block);
void visitNumToTensor(Node* n, Block* block, Block* res_block);
void visitIf(Node* n, Block* block, Block* res_block);
void visitLoop(Node* n, Block* block, Block* res_block);
-public:
- static std::unordered_map<std::string, std::vector<std::shared_ptr<Graph>>> batch_operator_table;
+ public:
+ static std::unordered_map<std::string, std::vector<std::shared_ptr<Graph>>>
+ batch_operator_table;
TORCH_API void toBatch(Block* block, Block* res_block);
};
TORCH_API std::shared_ptr<Graph> to_batch_graph(std::shared_ptr<Graph> graph);
TORCH_API void initRegisterBatchOpsBindings(PyObject* module);
-}}
+} // namespace jit
+} // namespace torch
}
struct AliasAndIValue {
- AliasAndIValue(
- c10::optional<at::AliasInfo> aliasInfo,
- IValue iValue)
+ AliasAndIValue(c10::optional<at::AliasInfo> aliasInfo, IValue iValue)
: aliasInfo(std::move(aliasInfo)), iValue(std::move(iValue)) {}
const c10::optional<at::AliasInfo> aliasInfo;
#include <torch/csrc/python_headers.h>
-#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/DynamicTypes.h>
#include <torch/csrc/THP.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/jit/interned_strings.h>
-#include <torch/csrc/jit/tracer.h>
#include <torch/csrc/jit/ivalue.h>
#include <torch/csrc/jit/pybind_utils.h>
+#include <torch/csrc/jit/tracer.h>
+#include <torch/csrc/utils/pybind.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
namespace py = pybind11;
-namespace pybind11 { namespace detail {
+namespace pybind11 {
+namespace detail {
-template <> struct type_caster<torch::jit::IValue> {
-public:
+template <>
+struct type_caster<torch::jit::IValue> {
+ public:
PYBIND11_TYPE_CASTER(torch::jit::IValue, _("IValue"));
bool load(handle src, bool) {
}
}
- static handle cast(torch::jit::IValue src, return_value_policy /* policy */, handle /* parent */) {
+ static handle cast(
+ torch::jit::IValue src,
+ return_value_policy /* policy */,
+ handle /* parent */) {
return torch::jit::toPyObject(std::move(src)).release();
}
};
-template <> struct type_caster<torch::jit::Symbol> {
-public:
+template <>
+struct type_caster<torch::jit::Symbol> {
+ public:
PYBIND11_TYPE_CASTER(torch::jit::Symbol, _("Symbol"));
bool load(handle src, bool) {
return true;
}
- static handle cast(torch::jit::Symbol src, return_value_policy /* policy */, handle /* parent */) {
- return py::cast(std::string(src.toQualString()), return_value_policy::copy).release();
+ static handle cast(
+ torch::jit::Symbol src,
+ return_value_policy /* policy */,
+ handle /* parent */) {
+ return py::cast(std::string(src.toQualString()), return_value_policy::copy)
+ .release();
}
};
-template <> struct type_caster<torch::jit::AttributeKind> {
-public:
+template <>
+struct type_caster<torch::jit::AttributeKind> {
+ public:
PYBIND11_TYPE_CASTER(torch::jit::AttributeKind, _("AttributeKind"));
bool load(handle src, bool) {
return false;
}
- static handle cast(torch::jit::AttributeKind src, return_value_policy /* policy */, handle /* parent */) {
- return py::cast(std::string(torch::jit::toString(src)), return_value_policy::copy).release();
+ static handle cast(
+ torch::jit::AttributeKind src,
+ return_value_policy /* policy */,
+ handle /* parent */) {
+ return py::cast(
+ std::string(torch::jit::toString(src)),
+ return_value_policy::copy)
+ .release();
}
};
// See https://github.com/pybind/pybind11/issues/637
-using ListCasterBase = pybind11::detail::list_caster<std::vector<torch::jit::Node *>, torch::jit::Node *>;
-template<> struct type_caster<std::vector<torch::jit::Node *>> : ListCasterBase {
- static handle cast(const std::vector<torch::jit::Node *> &src, return_value_policy, handle parent) {
- return ListCasterBase::cast(src, return_value_policy::reference, parent);
- }
- static handle cast(const std::vector<torch::jit::Node *> *src, return_value_policy pol, handle parent) {
- return cast(*src, pol, parent);
- }
+using ListCasterBase = pybind11::detail::
+ list_caster<std::vector<torch::jit::Node*>, torch::jit::Node*>;
+template <>
+struct type_caster<std::vector<torch::jit::Node*>> : ListCasterBase {
+ static handle cast(
+ const std::vector<torch::jit::Node*>& src,
+ return_value_policy,
+ handle parent) {
+ return ListCasterBase::cast(src, return_value_policy::reference, parent);
+ }
+ static handle cast(
+ const std::vector<torch::jit::Node*>* src,
+ return_value_policy pol,
+ handle parent) {
+ return cast(*src, pol, parent);
+ }
};
-}} // namespace pybind11::detail
+} // namespace detail
+} // namespace pybind11
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
-static inline py::tuple tuple_tail(const py::tuple & tup) {
+static inline py::tuple tuple_tail(const py::tuple& tup) {
py::tuple r(tup.size() - 1);
- for(size_t i = 1; i < tup.size(); i++) {
- r[i-1] = tup[i];
+ for (size_t i = 1; i < tup.size(); i++) {
+ r[i - 1] = tup[i];
}
return r;
}
-}}
+} // namespace jit
+} // namespace torch
#pragma once
+#include <torch/csrc/Device.h>
#include <torch/csrc/jit/function_schema.h>
#include <torch/csrc/jit/ivalue.h>
-#include <torch/csrc/jit/stack.h>
+#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/script/module.h>
+#include <torch/csrc/jit/stack.h>
#include <torch/csrc/jit/type.h>
-#include <torch/csrc/jit/operator.h>
-#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/auto_gil.h>
-#include <torch/csrc/Device.h>
+#include <torch/csrc/utils/pybind.h>
#include <c10/util/Exception.h>
#define VISIBILITY_HIDDEN __attribute__((visibility("hidden")))
#endif
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
namespace detail {
// error reporting: when reporting user-caused errors, these functions should
// that is confusing to display to the end user since it always reports
// locations in libtorch code rather than user code.
-inline void findErrorInKwargs(
- const FunctionSchema& schema,
- py::kwargs kwargs) {
+inline void findErrorInKwargs(const FunctionSchema& schema, py::kwargs kwargs) {
const auto& arguments = schema.arguments();
// First check if any of the kwargs are unknown, i.e. don't match the name of
// any argument in the schema.
for (const auto& kwarg : kwargs) {
const auto key = py::cast<std::string>(kwarg.first);
- if(!std::count_if(
+ if (!std::count_if(
arguments.begin(),
arguments.end(),
- [&key](const Argument& argument) { return argument.name() == key; })) {
+ [&key](const Argument& argument) {
+ return argument.name() == key;
+ })) {
throw std::runtime_error(c10::str(
"Unknown keyword argument '",
key,
}
return Tuple::create(s);
} else {
- AT_ERROR("Only tensors and (possibly nested) tuples of tensors are supported "
- "as inputs or outputs of traced functions");
+ AT_ERROR(
+ "Only tensors and (possibly nested) tuples of tensors are supported "
+ "as inputs or outputs of traced functions");
}
}
return toIValue(inputs).toTuple()->elements();
}
-inline IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N = c10::nullopt);
+inline IValue toIValue(
+ py::handle obj,
+ const TypePtr& type,
+ c10::optional<int32_t> N = c10::nullopt);
inline IValue createGenericList(py::handle obj, const TypePtr& elem_type) {
std::vector<IValue> elems;
- for(auto elem : obj) {
+ for (auto elem : obj) {
elems.push_back(toIValue(elem, elem_type));
}
return List<IValue>::create(std::move(elems));
}
-inline IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
- switch (type->kind()) {
- case TypeKind::DynamicType:
- case TypeKind::TensorType:
- case TypeKind::UndefinedTensorType:
- case TypeKind::CompleteTensorType: {
- auto var = py::cast<autograd::Variable>(obj);
- if (var.is_sparse()) {
- AT_ERROR("sparse tensors not supported");
- }
- return var;
- }
- case TypeKind::FloatType:
- return py::cast<double>(obj);
- case TypeKind::IntType:
- return py::cast<int64_t>(obj);
- case TypeKind::NoneType:
- if(obj != Py_None)
- throw py::cast_error();
-
- return {};
- case TypeKind::BoolType:
- return py::cast<bool>(obj);
- case TypeKind::TupleType: {
- if(!PyTuple_Check(obj.ptr()))
- throw py::cast_error(); // note: the py::cast does not throw cast_error
- // because it attempts to iterate a non-tuple
- py::tuple tuple = py::cast<py::tuple>(obj);
- size_t tuple_size = tuple.size();
- const auto & elem_types = type->cast<TupleType>()->elements();
- if (elem_types.size() != tuple_size) {
- throw py::cast_error();
- }
- std::vector<IValue> values;
- values.reserve(tuple_size);
- for (size_t i = 0; i < tuple_size; ++i) {
- values.push_back(toIValue(tuple[i], elem_types[i]));
- }
- return Tuple::create(std::move(values));
+inline IValue toIValue(
+ py::handle obj,
+ const TypePtr& type,
+ c10::optional<int32_t> N) {
+ switch (type->kind()) {
+ case TypeKind::DynamicType:
+ case TypeKind::TensorType:
+ case TypeKind::UndefinedTensorType:
+ case TypeKind::CompleteTensorType: {
+ auto var = py::cast<autograd::Variable>(obj);
+ if (var.is_sparse()) {
+ AT_ERROR("sparse tensors not supported");
}
- case TypeKind::StringType:
- return ConstantString::create(py::cast<std::string>(obj));
- case TypeKind::DeviceObjType: {
- auto device = reinterpret_cast<THPDevice*>(obj.ptr());
- return device->device;
+ return var;
+ }
+ case TypeKind::FloatType:
+ return py::cast<double>(obj);
+ case TypeKind::IntType:
+ return py::cast<int64_t>(obj);
+ case TypeKind::NoneType:
+ if (obj != Py_None)
+ throw py::cast_error();
+
+ return {};
+ case TypeKind::BoolType:
+ return py::cast<bool>(obj);
+ case TypeKind::TupleType: {
+ if (!PyTuple_Check(obj.ptr()))
+ throw py::cast_error(); // note: the py::cast does not throw cast_error
+ // because it attempts to iterate a non-tuple
+ py::tuple tuple = py::cast<py::tuple>(obj);
+ size_t tuple_size = tuple.size();
+ const auto& elem_types = type->cast<TupleType>()->elements();
+ if (elem_types.size() != tuple_size) {
+ throw py::cast_error();
}
- case TypeKind::ListType: {
- const auto& elem_type = type->expect<ListType>()->getElementType();
- switch(elem_type->kind()) {
- //allows single int/float to be broadcasted to a fixed size list
- case TypeKind::IntType:
- if (!N || !py::isinstance<py::int_>(obj)) {
- return py::cast<std::vector<int64_t>>(obj);
- } else {
- double value = py::cast<int64_t>(obj);
- std::vector<double> repeated(*N, value);
- return repeated;
- }
- case TypeKind::FloatType:
- if (!N || !py::isinstance<py::float_>(obj)) {
- return py::cast<std::vector<double>>(obj);
- } else {
- double value = py::cast<double>(obj);
- std::vector<double> repeated(*N, value);
- return repeated;
- }
- case TypeKind::TensorType:
- case TypeKind::DynamicType:
- return py::cast<std::vector<at::Tensor>>(obj);
- default:
- return createGenericList(obj, elem_type);
- }
+ std::vector<IValue> values;
+ values.reserve(tuple_size);
+ for (size_t i = 0; i < tuple_size; ++i) {
+ values.push_back(toIValue(tuple[i], elem_types[i]));
}
- case TypeKind::OptionalType: {
- const auto& elem_type = type->expect<OptionalType>()->getElementType();
- // check if it's a none obj since optional accepts NoneType
- if (obj == Py_None) {
- if(elem_type->isSubtypeOf(DynamicType::get())) {
- // return undefined tensor for Optional[Tensor]
- return at::Tensor();
+ return Tuple::create(std::move(values));
+ }
+ case TypeKind::StringType:
+ return ConstantString::create(py::cast<std::string>(obj));
+ case TypeKind::DeviceObjType: {
+ auto device = reinterpret_cast<THPDevice*>(obj.ptr());
+ return device->device;
+ }
+ case TypeKind::ListType: {
+ const auto& elem_type = type->expect<ListType>()->getElementType();
+ switch (elem_type->kind()) {
+ // allows single int/float to be broadcasted to a fixed size list
+ case TypeKind::IntType:
+ if (!N || !py::isinstance<py::int_>(obj)) {
+ return py::cast<std::vector<int64_t>>(obj);
+ } else {
+ double value = py::cast<int64_t>(obj);
+ std::vector<double> repeated(*N, value);
+ return repeated;
}
- else {
- // for other optional types, return an IValue() to denote a None
- return {};
+ case TypeKind::FloatType:
+ if (!N || !py::isinstance<py::float_>(obj)) {
+ return py::cast<std::vector<double>>(obj);
+ } else {
+ double value = py::cast<double>(obj);
+ std::vector<double> repeated(*N, value);
+ return repeated;
}
+ case TypeKind::TensorType:
+ case TypeKind::DynamicType:
+ return py::cast<std::vector<at::Tensor>>(obj);
+ default:
+ return createGenericList(obj, elem_type);
+ }
+ }
+ case TypeKind::OptionalType: {
+ const auto& elem_type = type->expect<OptionalType>()->getElementType();
+ // check if it's a none obj since optional accepts NoneType
+ if (obj == Py_None) {
+ if (elem_type->isSubtypeOf(DynamicType::get())) {
+ // return undefined tensor for Optional[Tensor]
+ return at::Tensor();
+ } else {
+ // for other optional types, return an IValue() to denote a None
+ return {};
}
- return toIValue(obj, type->expect<OptionalType>()->getElementType());
}
- case TypeKind::NumberType:
- case TypeKind::GeneratorType:
- case TypeKind::VarType:
- case TypeKind::FutureType:
- break;
+ return toIValue(obj, type->expect<OptionalType>()->getElementType());
}
- AT_ERROR("Missing cases in toIValue for type: ", type->str(), "! File a bug report.");
+ case TypeKind::NumberType:
+ case TypeKind::GeneratorType:
+ case TypeKind::VarType:
+ case TypeKind::FutureType:
+ break;
+ }
+ AT_ERROR(
+ "Missing cases in toIValue for type: ",
+ type->str(),
+ "! File a bug report.");
}
inline IValue argumentToIValue(
}
}
-inline IValue returnToIValue(
- const TypePtr& type,
- py::handle object) {
+inline IValue returnToIValue(const TypePtr& type, py::handle object) {
try {
return toIValue(object, type);
} catch (const py::cast_error& error) {
return py::cast(ivalue.toTensorListRef());
} else if (ivalue.isGenericList()) {
auto list = ivalue.toGenericList();
- const auto & elements = list->elements();
- py::list t { elements.size() };
+ const auto& elements = list->elements();
+ py::list t{elements.size()};
for (size_t i = 0; i < elements.size(); ++i) {
t[i] = toPyObject(IValue{elements[i]});
}
return t;
} else if (ivalue.isTuple()) {
auto tuple = ivalue.toTuple();
- const auto & elements = tuple->elements();
- py::tuple t { elements.size() };
+ const auto& elements = tuple->elements();
+ py::tuple t{elements.size()};
for (size_t i = 0; i < elements.size(); ++i) {
t[i] = toPyObject(IValue{elements[i]});
}
struct VISIBILITY_HIDDEN tuple_slice {
/*implicit*/ tuple_slice(py::tuple tup_)
- : tup(std::move(tup_)), b(0), e(tup.size()) {}
+ : tup(std::move(tup_)), b(0), e(tup.size()) {}
tuple_slice(py::tuple tup_, int64_t b_)
- : tup(std::move(tup_)), b(b_), e(tup.size()) {}
+ : tup(std::move(tup_)), b(b_), e(tup.size()) {}
tuple_slice(py::tuple tup_, int64_t b_, int64_t e_)
- : tup(std::move(tup_)), b(b_), e(e_) {}
+ : tup(std::move(tup_)), b(b_), e(e_) {}
py::detail::tuple_iterator begin() const {
return {tup, b};
}
py::detail::tuple_accessor operator[](size_t index) const {
return {tup, b + index};
}
-private:
+
+ private:
py::tuple tup;
int64_t b;
int64_t e;
const FunctionSchema& schema,
const tuple_slice& args,
const py::kwargs& kwargs = py::kwargs()) {
- if(args.size() + kwargs.size() > schema.arguments().size()) {
+ if (args.size() + kwargs.size() > schema.arguments().size()) {
throw std::runtime_error(c10::str(
- schema.name(), "() expected at most ", schema.arguments().size(),
+ schema.name(),
+ "() expected at most ",
+ schema.arguments().size(),
" argument(s) but received ",
- args.size() + kwargs.size(), " argument(s). Declaration: ", schema));
+ args.size() + kwargs.size(),
+ " argument(s). Declaration: ",
+ schema));
}
Stack stack;
stack.reserve(schema.arguments().size());
}
// TODO: Remove once we clean up the GraphExecutor usage.
-inline Stack evilDeprecatedBadCreateStackDoNotUse(const py::tuple& tuple, at::ArrayRef<Value*> inputs, size_t reserve_extra_space = 0) {
+inline Stack evilDeprecatedBadCreateStackDoNotUse(
+ const py::tuple& tuple,
+ at::ArrayRef<Value*> inputs,
+ size_t reserve_extra_space = 0) {
if (tuple.size() != inputs.size()) {
- AT_ERROR("expected " + std::to_string(inputs.size()) +
- " inputs, but got " + std::to_string(tuple.size()));
+ AT_ERROR(
+ "expected " + std::to_string(inputs.size()) + " inputs, but got " +
+ std::to_string(tuple.size()));
}
Stack result;
result.reserve(tuple.size() + reserve_extra_space);
inline py::object invokeScriptMethodFromPython(
script::Method& method,
- tuple_slice args, py::kwargs kwargs) {
- auto stack = createStackForSchema(method.getSchema(), std::move(args), std::move(kwargs));
+ tuple_slice args,
+ py::kwargs kwargs) {
+ auto stack = createStackForSchema(
+ method.getSchema(), std::move(args), std::move(kwargs));
{
AutoNoGIL no_gil_guard;
method.run(stack);
return createPyObjectForStack(std::move(stack));
}
-}} // namespace torch::jit
+} // namespace jit
+} // namespace torch
#include <torch/csrc/autograd/grad_mode.h>
-namespace torch { namespace jit { namespace python {
+namespace torch {
+namespace jit {
+namespace python {
using namespace torch::autograd;
using namespace at;
// Alphabet used to describe structure of inputs/outputs (D for desc)
namespace D {
-static constexpr char ListOpen = '[';
-static constexpr char ListClose = ']';
-static constexpr char TupleOpen = '(';
-static constexpr char TupleClose = ')';
-static constexpr char Variable = 'v';
+static constexpr char ListOpen = '[';
+static constexpr char ListClose = ']';
+static constexpr char TupleOpen = '(';
+static constexpr char TupleClose = ')';
+static constexpr char Variable = 'v';
} // namespace D
namespace {
-template<typename T>
+template <typename T>
py::object cast_handle_sequence(std::vector<py::handle> objs) {
auto num_objs = objs.size();
- T sequence { num_objs };
+ T sequence{num_objs};
for (size_t i = 0; i < num_objs; ++i)
sequence[i] = py::reinterpret_borrow<py::object>(objs[i]);
return sequence;
}
void flatten_rec(PyObject* obj, ParsedArgs& args) {
- auto & structure = args.desc.structure;
+ auto& structure = args.desc.structure;
if (PyTuple_Check(obj)) {
structure.push_back(D::TupleOpen);
for (auto item : py::reinterpret_borrow<py::tuple>(obj))
args.desc.metadata.emplace_back(var);
args.desc.structure.push_back(D::Variable);
} else {
- std::string msg = "Only tuples, lists and Variables supported as JIT inputs, but got ";
+ std::string msg =
+ "Only tuples, lists and Variables supported as JIT inputs, but got ";
msg += THPUtils_typename(obj);
throw std::runtime_error(msg);
}
namespace {
-template<typename T>
+template <typename T>
py::object cast_sequence(std::vector<py::object> objs) {
auto num_objs = objs.size();
- T sequence { num_objs };
+ T sequence{num_objs};
for (size_t i = 0; i < num_objs; ++i)
sequence[i] = std::move(objs[i]);
return sequence;
}
-py::object unflatten_rec(ArrayRef<Variable>::iterator& var_it,
- ArrayRef<Variable>::iterator& var_it_end,
- std::string::const_iterator& desc_it) {
+py::object unflatten_rec(
+ ArrayRef<Variable>::iterator& var_it,
+ ArrayRef<Variable>::iterator& var_it_end,
+ std::string::const_iterator& desc_it) {
char type = *desc_it++;
if (type == D::TupleOpen) {
std::vector<py::object> objs;
return output.release().ptr();
}
-}}} // namespace torch::jit::python
+} // namespace python
+} // namespace jit
+} // namespace torch
#pragma once
-#include <torch/csrc/jit/pybind.h>
#include <torch/csrc/autograd/variable.h>
+#include <torch/csrc/jit/pybind.h>
#include <torch/csrc/utils/hash.h>
#include <ATen/ATen.h>
+#include <functional>
#include <tuple>
#include <vector>
-#include <functional>
-namespace torch { namespace jit { namespace python {
+namespace torch {
+namespace jit {
+namespace python {
struct IODescriptor {
struct VariableMetadata {
VariableMetadata(const autograd::Variable& var)
- : sizes(var.sizes().vec())
- , type(var.type().scalarType())
- , device(var.device())
- , requires_grad(var.requires_grad()) {}
+ : sizes(var.sizes().vec()),
+ type(var.type().scalarType()),
+ device(var.device()),
+ requires_grad(var.requires_grad()) {}
bool operator==(const VariableMetadata& o) const {
- return std::tie( device, requires_grad, type, sizes) ==
- std::tie(o.device, o.requires_grad, o.type, o.sizes);
+ return std::tie(device, requires_grad, type, sizes) ==
+ std::tie(o.device, o.requires_grad, o.type, o.sizes);
}
static size_t hash(const VariableMetadata& m) {
};
bool operator==(const IODescriptor& o) const {
- return std::tie( structure, metadata, grad_enabled) ==
- std::tie(o.structure, o.metadata, o.grad_enabled);
+ return std::tie(structure, metadata, grad_enabled) ==
+ std::tie(o.structure, o.metadata, o.grad_enabled);
}
static size_t hash(const IODescriptor& o) {
void extend(const autograd::variable_list& list) {
metadata.reserve(metadata.size() + list.size());
- for (auto & var : list)
+ for (auto& var : list)
metadata.emplace_back(var);
}
bool grad_enabled = false;
};
-static inline std::ostream& operator<<(std::ostream& out, const IODescriptor::VariableMetadata& meta) {
+static inline std::ostream& operator<<(
+ std::ostream& out,
+ const IODescriptor::VariableMetadata& meta) {
at::Device meta_device = meta.device;
- auto & t = at::getNonVariableType(meta_device.is_cpu() ? at::Backend::CPU : at::Backend::CUDA, meta.type);
+ auto& t = at::getNonVariableType(
+ meta_device.is_cpu() ? at::Backend::CPU : at::Backend::CUDA, meta.type);
out << t << "(requires_grad=" << meta.requires_grad;
if (meta_device.is_cuda()) {
out << ", device=" << meta_device.index();
}
out << ") {";
- for(size_t i = 0; i < meta.sizes.size(); ++i) {
- if(i > 0)
+ for (size_t i = 0; i < meta.sizes.size(); ++i) {
+ if (i > 0)
out << ", ";
out << meta.sizes[i];
}
return out;
}
-static inline std::ostream& operator<<(std::ostream & out, const IODescriptor & desc) {
+static inline std::ostream& operator<<(
+ std::ostream& out,
+ const IODescriptor& desc) {
out << desc.structure << "\n";
out << " with grad_enabled=" << desc.grad_enabled << "\n";
- for(size_t i = 0; i < desc.metadata.size(); ++i) {
+ for (size_t i = 0; i < desc.metadata.size(); ++i) {
out << " with v" << i << " having type " << desc.metadata[i] << "\n";
}
return out;
IODescriptor desc;
void extend(const autograd::variable_list& list) {
- if (list.empty()) return;
+ if (list.empty())
+ return;
vars.reserve(vars.size() + list.size());
- for (auto & var : list)
+ for (auto& var : list)
vars.emplace_back(var);
desc.extend(list);
}
};
-
ParsedArgs flatten(py::handle obj);
-PyObject* unflatten(at::ArrayRef<autograd::Variable> vars,
- const IODescriptor& structure);
+PyObject* unflatten(
+ at::ArrayRef<autograd::Variable> vars,
+ const IODescriptor& structure);
-}}} // namespace torch::jit::python
+} // namespace python
+} // namespace jit
+} // namespace torch
-#include <torch/csrc/python_headers.h>
#include <torch/csrc/jit/interpreter.h>
+#include <torch/csrc/python_headers.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/autograd/variable.h>
-#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/custom_operator.h>
#include <torch/csrc/jit/graph_executor.h>
#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/pybind_utils.h>
#include <typeinfo>
+#include <torch/csrc/Exceptions.h>
#include <torch/csrc/autograd/python_engine.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/jit/pybind.h>
#include <torch/csrc/utils/auto_gil.h>
-#include <torch/csrc/Exceptions.h>
namespace py = pybind11;
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
namespace {
AutoGIL gil;
const PythonOp* op = static_cast<const PythonOp*>(op_);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
- const py::function func = py::reinterpret_borrow<const py::function>(py::handle(const_cast<PythonOp*>(op)->pyobj.get()));
+ const py::function func = py::reinterpret_borrow<const py::function>(
+ py::handle(const_cast<PythonOp*>(op)->pyobj.get()));
size_t num_inputs = 0;
- for(auto arg_type : op->cconv) {
- if(arg_type == 'd')
+ for (auto arg_type : op->cconv) {
+ if (arg_type == 'd')
num_inputs++;
}
JIT_ASSERT(op->outputs().size() == 1);
- return [=](Stack & stack) {
+ return [=](Stack& stack) {
AutoGIL gil;
py::tuple py_inputs(op->cconv.size());
size_t i = 0;
for (auto arg_type : op->cconv) {
if (arg_type == 'c') {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
- py_inputs[i] = py::reinterpret_borrow<const py::object>(const_cast<PythonOp*>(op)->scalar_args[next_scalar++].get());
+ py_inputs[i] = py::reinterpret_borrow<const py::object>(
+ const_cast<PythonOp*>(op)->scalar_args[next_scalar++].get());
} else if (arg_type == 'd') {
- py_inputs[i] = toPyObject(std::move(peek(stack, next_tensor, num_inputs)));
+ py_inputs[i] =
+ toPyObject(std::move(peek(stack, next_tensor, num_inputs)));
next_tensor++;
}
i++;
try {
py::object py_output(func(*py_inputs));
stack.push_back(returnToIValue(op->output()->type(), py_output));
- } catch (py::error_already_set & e) {
+ } catch (py::error_already_set& e) {
throw std::runtime_error(e.what());
}
return 0;
};
}
+RegisterOperators reg({Operator(prim::PythonOp, createPythonOperation)});
-RegisterOperators reg({
- Operator(prim::PythonOp, createPythonOperation)
-});
-
-}}} // torch::jit::anon
+} // namespace
+} // namespace jit
+} // namespace torch
#include <torch/csrc/python_headers.h>
+#include <torch/csrc/jit/argument_spec.h>
+#include <torch/csrc/jit/export.h>
#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/passes/python_print.h>
+#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/pybind.h>
#include <torch/csrc/jit/python_tracer.h>
-#include <torch/csrc/utils/pybind.h>
-#include <torch/csrc/jit/export.h>
-#include <torch/csrc/jit/passes/shape_analysis.h>
-#include <torch/csrc/jit/passes/python_print.h>
-#include <torch/csrc/jit/argument_spec.h>
#include <torch/csrc/utils/auto_gil.h>
+#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_strings.h>
-
#include <iostream>
#include <sstream>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
using c10::Type;
return py::str(v);
}
-std::ostream& printPyObject(std::ostream & out, const THPObjectPtr& obj) {
+std::ostream& printPyObject(std::ostream& out, const THPObjectPtr& obj) {
AutoGIL gil;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto pyobj = py::handle(const_cast<PyObject*>(obj.get()));
}
}
-// execute a Python function, used for Ops we can't optimize but that we want to optimize around
+// execute a Python function, used for Ops we can't optimize but that we want to
+// optimize around
struct ConcretePythonOp : public PythonOp {
- ConcretePythonOp(Graph * graph)
- : PythonOp(graph) {}
- std::string name() const override {
- AutoGIL gil;
- if(auto autograd = autogradFunction()) {
- return getPythonName(autograd->get());
- } else {
- return getPythonName(pyobj.get());
- }
- }
- void cloneFrom(Node * other_) override {
- Node::cloneFrom(other_);
- auto other = other_->cast<PythonOp>();
- this->cconv = other->cconv;
- Py_INCREF(other->pyobj.get());
- this->pyobj = THPObjectPtr(other->pyobj.get());
- for(auto & sa : other->scalar_args) {
- Py_INCREF(sa.get());
- this->scalar_args.emplace_back(sa.get());
- }
- }
- Node * allocNewInstance(Graph * g) override {
- return new ConcretePythonOp(g);
- }
- // recover the autograd.Function instance, if this PythonOp's function
- // was originally SomeFunction.apply
- // used in ONNX for discovering symbolics
- c10::optional<THPObjectPtr> autogradFunction() const override {
- AutoGIL gil;
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
- py::handle obj = const_cast<PyObject*>(pyobj.get());
-
- auto r = py::getattr(obj, "__self__", py::none());
- if(r.is_none())
- return c10::nullopt;
+ ConcretePythonOp(Graph* graph) : PythonOp(graph) {}
+ std::string name() const override {
+ AutoGIL gil;
+ if (auto autograd = autogradFunction()) {
+ return getPythonName(autograd->get());
+ } else {
+ return getPythonName(pyobj.get());
+ }
+ }
+ void cloneFrom(Node* other_) override {
+ Node::cloneFrom(other_);
+ auto other = other_->cast<PythonOp>();
+ this->cconv = other->cconv;
+ Py_INCREF(other->pyobj.get());
+ this->pyobj = THPObjectPtr(other->pyobj.get());
+ for (auto& sa : other->scalar_args) {
+ Py_INCREF(sa.get());
+ this->scalar_args.emplace_back(sa.get());
+ }
+ }
+ Node* allocNewInstance(Graph* g) override {
+ return new ConcretePythonOp(g);
+ }
+ // recover the autograd.Function instance, if this PythonOp's function
+ // was originally SomeFunction.apply
+ // used in ONNX for discovering symbolics
+ c10::optional<THPObjectPtr> autogradFunction() const override {
+ AutoGIL gil;
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+ py::handle obj = const_cast<PyObject*>(pyobj.get());
- auto apply = py::getattr(r, "apply", py::none());
- if(apply.is_none())
- return c10::nullopt;
+ auto r = py::getattr(obj, "__self__", py::none());
+ if (r.is_none())
+ return c10::nullopt;
- auto c = PyObject_RichCompareBool(apply.ptr(), obj.ptr(), Py_NE);
- if(PyErr_Occurred())
- throw py::error_already_set();
- if(c)
- return c10::nullopt;
+ auto apply = py::getattr(r, "apply", py::none());
+ if (apply.is_none())
+ return c10::nullopt;
- return THPObjectPtr(r.release().ptr());
- }
+ auto c = PyObject_RichCompareBool(apply.ptr(), obj.ptr(), Py_NE);
+ if (PyErr_Occurred())
+ throw py::error_already_set();
+ if (c)
+ return c10::nullopt;
- void writeScalars(std::ostream& out) const override {
- out << "(";
- int i = 0;
- for (auto& scalar : scalar_args) {
- if (i++ > 0)
- out << ", ";
- printPyObject(out, scalar);
- }
- out << ")";
- }
+ return THPObjectPtr(r.release().ptr());
+ }
+ void writeScalars(std::ostream& out) const override {
+ out << "(";
+ int i = 0;
+ for (auto& scalar : scalar_args) {
+ if (i++ > 0)
+ out << ", ";
+ printPyObject(out, scalar);
+ }
+ out << ")";
+ }
};
PythonOp* pythonAllocPythonOp(Graph* g) {
return new ConcretePythonOp(g);
}
-void initPythonIRBindings(PyObject * module_) {
+void initPythonIRBindings(PyObject* module_) {
setAllocPythonOp(pythonAllocPythonOp);
auto m = py::handle(module_).cast<py::module>();
- #define GS(name) \
- def(#name,&Graph :: name)
- py::class_<Graph,std::shared_ptr<Graph>>(m,"Graph")
- .def(py::init<>())
- .def("__repr__",[](Graph & g) {
- std::stringstream ss;
- ss << g;
- return ss.str();
- })
- .def("propagate_shapes", [](std::shared_ptr<Graph> g, std::vector<at::Tensor> inputs, bool with_grad) {
- setInputTypes(*g, ArgumentSpec(with_grad, fmap<IValue>(inputs), inputs.size()));
- PropagateInputShapes(g);
- })
- .def("_export_onnx", [](const std::shared_ptr<Graph> g, const std::vector<at::Tensor>& initializers,
- int64_t onnx_opset_version, bool defer_weight_export,
- ::torch::onnx::OperatorExportTypes operator_export_type) {
- std::string graph;
- RawDataExportMap export_map;
- std::tie(graph, export_map) = export_onnx(
- g, initializers, onnx_opset_version, defer_weight_export, operator_export_type);
- std::unordered_map<std::string, py::bytes> python_serialized_export_map;
- for (auto& kv : export_map) {
- auto t = kv.second;
- size_t copy_bytes = t.type().elementSizeInBytes() * t.numel();
- // TODO: this is an unecessary copy. In theory we can directly return
- // the map from identifier to Tensor, but we need some API in Python
- // to get raw `bytes` containing the raw tensor data.
- python_serialized_export_map[kv.first] = py::bytes(static_cast<const char*>(t.data_ptr()), copy_bytes);
- }
- return std::make_tuple(py::bytes(graph), python_serialized_export_map);
- }, py::arg("initializers"),
- py::arg("onnx_opset_version")=0,
- py::arg("defer_weight_export")=false,
- py::arg("operator_export_type")=::torch::onnx::OperatorExportTypes::ONNX)
- .def("_pretty_print_onnx", [](const std::shared_ptr<Graph> g,
- const std::vector<at::Tensor>& initializers,
- int64_t onnx_opset_version, bool defer_weight_export,
- ::torch::onnx::OperatorExportTypes operator_export_type,
- bool google_printer) {
- return pretty_print_onnx(
- g, initializers, onnx_opset_version, defer_weight_export, operator_export_type,
- google_printer);
- }, py::arg("initializers"),
- py::arg("onnx_opset_version")=0,
- py::arg("defer_weight_export")=false,
- py::arg("operator_export_type")=::torch::onnx::OperatorExportTypes::ONNX,
- py::arg("google_printer")=false)
- .def("inputs",[](Graph &g) {
- return py::make_iterator(g.inputs().begin(), g.inputs().end());
- })
- .def("outputs",[](Graph &g) {
- return py::make_iterator(g.outputs().begin(), g.outputs().end());
- })
- // TODO: Iterator invalidation might make this hazardous
- .def("nodes",[](Graph &g) {
- return py::make_iterator(g.nodes().begin(), g.nodes().end());
- })
- .def("addInput",[](Graph &g) { return g.addInput(); })
- .def("copy",[](Graph &g) {
- return g.copy();
- })
- .GS(eraseInput)
- .GS(registerOutput)
- .def("create",[](Graph & g, const char * str) {
- return g.create(Symbol::fromQualString(str));
- })
- .def("create",[](Graph & g, const char * str, size_t noutputs) {
- return g.create(Symbol::fromQualString(str), noutputs);
- })
- .def("create",[](Graph & g, const char * str, const std::vector<Value*> & inputs) {
- return g.create(Symbol::fromQualString(str),inputs);
- })
- .def("create",[](Graph & g, const char * str, const std::vector<Value*> & inputs, size_t noutputs) {
- return g.create(Symbol::fromQualString(str),inputs, noutputs);
- })
- .def("param_node", [](Graph &g) {
- return g.block()->param_node();
- })
- .def("return_node", [](Graph &g) {
- return g.block()->return_node();
- })
- .def("pretty_print", [](Graph &g) {
- std::ostringstream oss;
- g.prettyPrint(oss);
- return oss.str();
- })
- .GS(createFusionGroup)
- .def("createClone",[](Graph & g, Node * n, py::object fn) {
- return g.createClone(n, [&](Value * e) {
- return fn(e).cast<Value*>();
- });
- })
- .GS(appendNode)
- .GS(prependNode)
- .GS(lint)
- .GS(insertNode)
- ;
- #undef GS
+#define GS(name) def(#name, &Graph ::name)
+ py::class_<Graph, std::shared_ptr<Graph>>(m, "Graph")
+ .def(py::init<>())
+ .def(
+ "__repr__",
+ [](Graph& g) {
+ std::stringstream ss;
+ ss << g;
+ return ss.str();
+ })
+ .def(
+ "propagate_shapes",
+ [](std::shared_ptr<Graph> g,
+ std::vector<at::Tensor> inputs,
+ bool with_grad) {
+ setInputTypes(
+ *g,
+ ArgumentSpec(with_grad, fmap<IValue>(inputs), inputs.size()));
+ PropagateInputShapes(g);
+ })
+ .def(
+ "_export_onnx",
+ [](const std::shared_ptr<Graph> g,
+ const std::vector<at::Tensor>& initializers,
+ int64_t onnx_opset_version,
+ bool defer_weight_export,
+ ::torch::onnx::OperatorExportTypes operator_export_type) {
+ std::string graph;
+ RawDataExportMap export_map;
+ std::tie(graph, export_map) = export_onnx(
+ g,
+ initializers,
+ onnx_opset_version,
+ defer_weight_export,
+ operator_export_type);
+ std::unordered_map<std::string, py::bytes>
+ python_serialized_export_map;
+ for (auto& kv : export_map) {
+ auto t = kv.second;
+ size_t copy_bytes = t.type().elementSizeInBytes() * t.numel();
+ // TODO: this is an unecessary copy. In theory we can directly
+ // return the map from identifier to Tensor, but we need some API
+ // in Python to get raw `bytes` containing the raw tensor data.
+ python_serialized_export_map[kv.first] =
+ py::bytes(static_cast<const char*>(t.data_ptr()), copy_bytes);
+ }
+ return std::make_tuple(
+ py::bytes(graph), python_serialized_export_map);
+ },
+ py::arg("initializers"),
+ py::arg("onnx_opset_version") = 0,
+ py::arg("defer_weight_export") = false,
+ py::arg("operator_export_type") =
+ ::torch::onnx::OperatorExportTypes::ONNX)
+ .def(
+ "_pretty_print_onnx",
+ [](const std::shared_ptr<Graph> g,
+ const std::vector<at::Tensor>& initializers,
+ int64_t onnx_opset_version,
+ bool defer_weight_export,
+ ::torch::onnx::OperatorExportTypes operator_export_type,
+ bool google_printer) {
+ return pretty_print_onnx(
+ g,
+ initializers,
+ onnx_opset_version,
+ defer_weight_export,
+ operator_export_type,
+ google_printer);
+ },
+ py::arg("initializers"),
+ py::arg("onnx_opset_version") = 0,
+ py::arg("defer_weight_export") = false,
+ py::arg("operator_export_type") =
+ ::torch::onnx::OperatorExportTypes::ONNX,
+ py::arg("google_printer") = false)
+ .def(
+ "inputs",
+ [](Graph& g) {
+ return py::make_iterator(g.inputs().begin(), g.inputs().end());
+ })
+ .def(
+ "outputs",
+ [](Graph& g) {
+ return py::make_iterator(g.outputs().begin(), g.outputs().end());
+ })
+ // TODO: Iterator invalidation might make this hazardous
+ .def(
+ "nodes",
+ [](Graph& g) {
+ return py::make_iterator(g.nodes().begin(), g.nodes().end());
+ })
+ .def("addInput", [](Graph& g) { return g.addInput(); })
+ .def("copy", [](Graph& g) { return g.copy(); })
+ .GS(eraseInput)
+ .GS(registerOutput)
+ .def(
+ "create",
+ [](Graph& g, const char* str) {
+ return g.create(Symbol::fromQualString(str));
+ })
+ .def(
+ "create",
+ [](Graph& g, const char* str, size_t noutputs) {
+ return g.create(Symbol::fromQualString(str), noutputs);
+ })
+ .def(
+ "create",
+ [](Graph& g, const char* str, const std::vector<Value*>& inputs) {
+ return g.create(Symbol::fromQualString(str), inputs);
+ })
+ .def(
+ "create",
+ [](Graph& g,
+ const char* str,
+ const std::vector<Value*>& inputs,
+ size_t noutputs) {
+ return g.create(Symbol::fromQualString(str), inputs, noutputs);
+ })
+ .def("param_node", [](Graph& g) { return g.block()->param_node(); })
+ .def("return_node", [](Graph& g) { return g.block()->return_node(); })
+ .def(
+ "pretty_print",
+ [](Graph& g) {
+ std::ostringstream oss;
+ g.prettyPrint(oss);
+ return oss.str();
+ })
+ .GS(createFusionGroup)
+ .def(
+ "createClone",
+ [](Graph& g, Node* n, py::object fn) {
+ return g.createClone(
+ n, [&](Value* e) { return fn(e).cast<Value*>(); });
+ })
+ .GS(appendNode)
+ .GS(prependNode)
+ .GS(lint)
+ .GS(insertNode);
+#undef GS
- #define VS(name) \
- def(#name,&Value :: name)
- py::class_<Value,std::unique_ptr<Value, py::nodelete>>(m,"Value")
- .def("__repr__",[](Value & n) {
- std::stringstream ss;
- ss << n.uniqueName() << " defined in (" << *n.node() << ")";
- return ss.str();
- })
- .VS(type)
- .VS(setType)
- .VS(inferTypeFrom)
- // skip owningGraph because it returns a raw pointer to a otherwise
- // std::shared_ptr stored graph object, and would cause a double free
- .VS(unique)
- .VS(uniqueName)
- .VS(setUniqueName)
- .VS(offset)
- .VS(uses)
- .VS(replaceAllUsesWith)
- .def("node",[](Value &v) { return v.node(); })
- .def("setTypeAs", [](Value * node, Value * other) {
- node->setType(other->type());
- return node;
- })
- .VS(copyMetadata)
- .VS(isTensor)
- ;
+#define VS(name) def(#name, &Value ::name)
+ py::class_<Value, std::unique_ptr<Value, py::nodelete>>(m, "Value")
+ .def(
+ "__repr__",
+ [](Value& n) {
+ std::stringstream ss;
+ ss << n.uniqueName() << " defined in (" << *n.node() << ")";
+ return ss.str();
+ })
+ .VS(type)
+ .VS(setType)
+ .VS(inferTypeFrom)
+ // skip owningGraph because it returns a raw pointer to a otherwise
+ // std::shared_ptr stored graph object, and would cause a double free
+ .VS(unique)
+ .VS(uniqueName)
+ .VS(setUniqueName)
+ .VS(offset)
+ .VS(uses)
+ .VS(replaceAllUsesWith)
+ .def("node", [](Value& v) { return v.node(); })
+ .def(
+ "setTypeAs",
+ [](Value* node, Value* other) {
+ node->setType(other->type());
+ return node;
+ })
+ .VS(copyMetadata)
+ .VS(isTensor);
- #undef VS
+#undef VS
py::class_<Block, std::unique_ptr<Block, py::nodelete>>(m, "Block")
- .def("nodes",[](Block &b) {
- return py::make_iterator(b.nodes().begin(), b.nodes().end());
- });
+ .def("nodes", [](Block& b) {
+ return py::make_iterator(b.nodes().begin(), b.nodes().end());
+ });
- #define NS(name) \
- def(#name,&Node :: name)
- py::class_<Node,std::unique_ptr<Node, py::nodelete>>(m,"Node")
- .def("__repr__",[](Node & n) {
- std::stringstream ss;
- ss << n;
- return ss.str();
- })
- .def("getSourceLocation", [](Node & n) -> py::object {
- std::stringstream ss;
- if (auto sl = n.getSourceLocation()) {
- sl->highlight(ss);
- return py::str(ss.str());
- } else {
- return py::none();
- }
- })
- .def("hasMultipleOutputs",[](Node&n) {
- return n.outputs().size() > 1;
- })
- .def("outputsSize",[](Node &n) {
- return n.outputs().size();
- })
- .NS(kind)
- .def("inputs",[](Node &n) {
- return py::make_iterator(n.inputs().begin(), n.inputs().end());
- })
- .def("outputs",[](Node &n) {
- return py::make_iterator(n.outputs().begin(), n.outputs().end());
- })
- .def("output", [](Node &n) {
- return n.output();
- })
- .NS(addInput)
- .NS(replaceInput)
- .NS(replaceInputWith)
- .NS(replaceAllUsesWith)
- .NS(insertBefore)
- .NS(insertAfter)
- .NS(moveAfter)
- .NS(moveBefore)
- .NS(removeInput)
- .NS(removeAllInputs)
- .NS(destroy)
- .NS(hasUses)
- .NS(eraseOutput)
- .NS(addOutput)
- .NS(scopeName)
- .NS(isNondeterministic)
- .def("blocks", [](Node& n) {
- return py::make_iterator(n.blocks().begin(), n.blocks().end());
- })
- .NS(addBlock)
+#define NS(name) def(#name, &Node ::name)
+ py::class_<Node, std::unique_ptr<Node, py::nodelete>>(m, "Node")
+ .def(
+ "__repr__",
+ [](Node& n) {
+ std::stringstream ss;
+ ss << n;
+ return ss.str();
+ })
+ .def(
+ "getSourceLocation",
+ [](Node& n) -> py::object {
+ std::stringstream ss;
+ if (auto sl = n.getSourceLocation()) {
+ sl->highlight(ss);
+ return py::str(ss.str());
+ } else {
+ return py::none();
+ }
+ })
+ .def("hasMultipleOutputs", [](Node& n) { return n.outputs().size() > 1; })
+ .def("outputsSize", [](Node& n) { return n.outputs().size(); })
+ .NS(kind)
+ .def(
+ "inputs",
+ [](Node& n) {
+ return py::make_iterator(n.inputs().begin(), n.inputs().end());
+ })
+ .def(
+ "outputs",
+ [](Node& n) {
+ return py::make_iterator(n.outputs().begin(), n.outputs().end());
+ })
+ .def("output", [](Node& n) { return n.output(); })
+ .NS(addInput)
+ .NS(replaceInput)
+ .NS(replaceInputWith)
+ .NS(replaceAllUsesWith)
+ .NS(insertBefore)
+ .NS(insertAfter)
+ .NS(moveAfter)
+ .NS(moveBefore)
+ .NS(removeInput)
+ .NS(removeAllInputs)
+ .NS(destroy)
+ .NS(hasUses)
+ .NS(eraseOutput)
+ .NS(addOutput)
+ .NS(scopeName)
+ .NS(isNondeterministic)
+ .def(
+ "blocks",
+ [](Node& n) {
+ return py::make_iterator(n.blocks().begin(), n.blocks().end());
+ })
+ .NS(addBlock)
-#define AS(name) def(#name,&Attributes<Node> :: name)
- // methods from Attributes
- .AS(copyAttributes)
- .AS(hasAttributes)
+#define AS(name) def(#name, &Attributes<Node>::name)
+ // methods from Attributes
+ .AS(copyAttributes)
+ .AS(hasAttributes)
#undef AS
-#define AS(name) def(#name,&Attributes<Node> :: name ## S)
- // The default method names take Symbol, but the string conversion for
- // Symbol you to qualify with attr::. This is not very user friendly
- // for attributes, so expose the string variants instead.
- .AS(hasAttribute)
- .AS(kindOf)
- .AS(removeAttribute)
- .AS(attributeNames)
+#define AS(name) def(#name, &Attributes<Node>::name##S)
+ // The default method names take Symbol, but the string conversion for
+ // Symbol you to qualify with attr::. This is not very user friendly
+ // for attributes, so expose the string variants instead.
+ .AS(hasAttribute)
+ .AS(kindOf)
+ .AS(removeAttribute)
+ .AS(attributeNames)
#undef AS
-#define CREATE_ACCESSOR(Kind,method) \
- def(#method "_",[](Node & n, const char * name, Kind##Attr::ValueType v) { \
- return n . method ## _(Symbol::attr(name), std::move(v)); \
- }) \
- .def(#method, [](Node & n, const char * name) { \
- return n.method(Symbol::attr(name)); \
- })
- .CREATE_ACCESSOR(Float,f)
- .CREATE_ACCESSOR(Floats,fs)
- .CREATE_ACCESSOR(String,s)
- .CREATE_ACCESSOR(Strings,ss)
- .CREATE_ACCESSOR(Int,i)
- .CREATE_ACCESSOR(Ints,is)
- .CREATE_ACCESSOR(Graph,g)
- .CREATE_ACCESSOR(Graphs,gs)
+#define CREATE_ACCESSOR(Kind, method) \
+ def(#method "_", \
+ [](Node& n, const char* name, Kind##Attr::ValueType v) { \
+ return n.method##_(Symbol::attr(name), std::move(v)); \
+ }) \
+ .def(#method, [](Node& n, const char* name) { \
+ return n.method(Symbol::attr(name)); \
+ })
+ .CREATE_ACCESSOR(Float, f)
+ .CREATE_ACCESSOR(Floats, fs)
+ .CREATE_ACCESSOR(String, s)
+ .CREATE_ACCESSOR(Strings, ss)
+ .CREATE_ACCESSOR(Int, i)
+ .CREATE_ACCESSOR(Ints, is)
+ .CREATE_ACCESSOR(Graph, g)
+ .CREATE_ACCESSOR(Graphs, gs)
#undef CREATE_ACCESSOR
- // Tensor (t_) -- manually written to unwrap the variable into a tensor.
- .def("t_",[](Node & n, const char * name, torch::autograd::Variable v) {
- return n.t_(Symbol::attr(name), std::move(v.data()));
- })
- .def("t", [](Node & n, const char * name) {
- return torch::autograd::make_variable(n.t(Symbol::attr(name)), /*requires_grad=*/false);
- })
- // Tensors (ts_) -- manually written to unwrap variables into tensors.
- .def("ts_",[](Node & n, const char * name, std::vector<torch::autograd::Variable> vs) {
- std::vector<at::Tensor> tensors;
- tensors.reserve(vs.size());
- for (auto& variable : vs) {
- tensors.push_back(std::move(variable.data()));
- }
- return n.ts_(Symbol::attr(name), std::move(tensors));
- })
- .def("ts", [](Node & n, const char * name) {
- auto tensors = n.ts(Symbol::attr(name));
- std::vector<torch::autograd::Variable> variables;
- variables.reserve(tensors.size());
- for (auto& tensor : tensors) {
- variables.push_back(torch::autograd::make_variable(
- std::move(tensor), /*requires_grad=*/false));
- }
- return variables;
- })
- .def("z_",[](Node & n, const char * name, at::Tensor v) {
- return n.t_(Symbol::attr(name), autograd::Variable(v.view({})).data());
- })
- .def("z",[](Node & n, const char * name) {
- return n.t(Symbol::attr(name));
- })
- .def("zs_",[](Node & n, const char * name, TensorsAttr::ValueType v) {
- for (auto& i : v) {
- i = autograd::Variable(i.view({})).data();
+ // Tensor (t_) -- manually written to unwrap the variable into a tensor.
+ .def(
+ "t_",
+ [](Node& n, const char* name, torch::autograd::Variable v) {
+ return n.t_(Symbol::attr(name), std::move(v.data()));
+ })
+ .def(
+ "t",
+ [](Node& n, const char* name) {
+ return torch::autograd::make_variable(
+ n.t(Symbol::attr(name)), /*requires_grad=*/false);
+ })
+ // Tensors (ts_) -- manually written to unwrap variables into tensors.
+ .def(
+ "ts_",
+ [](Node& n,
+ const char* name,
+ std::vector<torch::autograd::Variable> vs) {
+ std::vector<at::Tensor> tensors;
+ tensors.reserve(vs.size());
+ for (auto& variable : vs) {
+ tensors.push_back(std::move(variable.data()));
+ }
+ return n.ts_(Symbol::attr(name), std::move(tensors));
+ })
+ .def(
+ "ts",
+ [](Node& n, const char* name) {
+ auto tensors = n.ts(Symbol::attr(name));
+ std::vector<torch::autograd::Variable> variables;
+ variables.reserve(tensors.size());
+ for (auto& tensor : tensors) {
+ variables.push_back(torch::autograd::make_variable(
+ std::move(tensor), /*requires_grad=*/false));
+ }
+ return variables;
+ })
+ .def(
+ "z_",
+ [](Node& n, const char* name, at::Tensor v) {
+ return n.t_(
+ Symbol::attr(name), autograd::Variable(v.view({})).data());
+ })
+ .def(
+ "z",
+ [](Node& n, const char* name) { return n.t(Symbol::attr(name)); })
+ .def(
+ "zs_",
+ [](Node& n, const char* name, TensorsAttr::ValueType v) {
+ for (auto& i : v) {
+ i = autograd::Variable(i.view({})).data();
+ }
+ return n.ts_(Symbol::attr(name), std::move(v));
+ })
+ .def(
+ "zs",
+ [](Node& n, const char* name) { return n.ts(Symbol::attr(name)); })
+ .def(
+ "pyobj",
+ [](Node& n) {
+ return py::handle(n.expect<PythonOp>()->pyobj.get())
+ .cast<py::object>();
+ })
+ .def("cconv", [](Node& n) { return n.expect<PythonOp>()->cconv; })
+ .def("pyname", [](Node& n) { return n.expect<PythonOp>()->name(); })
+ .def("scalar_args", [](Node& n) {
+ auto op = n.expect<PythonOp>();
+ auto scalars = py::list();
+ auto append = scalars.attr("append");
+ for (auto& arg : op->scalar_args) {
+ append(py::handle(arg.get()));
}
- return n.ts_(Symbol::attr(name), std::move(v));
- })
- .def("zs",[](Node & n, const char * name) {
- return n.ts(Symbol::attr(name));
- })
- .def("pyobj",[](Node & n) {
- return py::handle(n.expect<PythonOp>()->pyobj.get()).cast<py::object>();
- })
- .def("cconv",[](Node & n) {
- return n.expect<PythonOp>()->cconv;
- })
- .def("pyname",[](Node & n) {
- return n.expect<PythonOp>()->name();
- })
- .def("scalar_args",[](Node & n) {
- auto op = n.expect<PythonOp>();
- auto scalars = py::list();
- auto append = scalars.attr("append");
- for(auto & arg : op->scalar_args) {
- append(py::handle(arg.get()));
- }
- return scalars;
- })
- ;
+ return scalars;
+ });
using ::c10::Type;
- py::class_<Type,std::shared_ptr<Type>>(m,"Type")
- .def("__repr__",[](Type & t) {
- return t.python_str();
- })
- .def("str",[](Type & t) {
- std::ostringstream s;
- s << t;
- return s.str();
- })
- .def("kind",[](const Type& t) {
- return typeKindToString(t.kind());
- })
- .def("sizes",[](Type& t) {
- return t.expect<CompleteTensorType>()->sizes();
- })
- .def("strides",[](Type& t) {
- return t.expect<CompleteTensorType>()->strides();
- })
- .def("contiguous",[](Type& t) {
- return std::static_pointer_cast<Type>(t.expect<CompleteTensorType>()->contiguous());
- })
- .def("scalarType",[](Type& t) {
- return toString(t.expect<TensorType>()->scalarType());
- })
- .def("__eq__", [](std::shared_ptr<Type>& self, std::shared_ptr<Type>& other) {
- return *self == *other;
- })
- .def("isSubtypeOf", [](std::shared_ptr<Type>& self, std::shared_ptr<Type> other) {
- return self->isSubtypeOf(other);
- });
+ py::class_<Type, std::shared_ptr<Type>>(m, "Type")
+ .def("__repr__", [](Type& t) { return t.python_str(); })
+ .def(
+ "str",
+ [](Type& t) {
+ std::ostringstream s;
+ s << t;
+ return s.str();
+ })
+ .def("kind", [](const Type& t) { return typeKindToString(t.kind()); })
+ .def(
+ "sizes",
+ [](Type& t) { return t.expect<CompleteTensorType>()->sizes(); })
+ .def(
+ "strides",
+ [](Type& t) { return t.expect<CompleteTensorType>()->strides(); })
+ .def(
+ "contiguous",
+ [](Type& t) {
+ return std::static_pointer_cast<Type>(
+ t.expect<CompleteTensorType>()->contiguous());
+ })
+ .def(
+ "scalarType",
+ [](Type& t) {
+ return toString(t.expect<TensorType>()->scalarType());
+ })
+ .def(
+ "__eq__",
+ [](std::shared_ptr<Type>& self, std::shared_ptr<Type>& other) {
+ return *self == *other;
+ })
+ .def(
+ "isSubtypeOf",
+ [](std::shared_ptr<Type>& self, std::shared_ptr<Type> other) {
+ return self->isSubtypeOf(other);
+ });
py::class_<NumberType, Type, std::shared_ptr<NumberType>>(m, "NumberType")
- .def_static("get", &NumberType::get);
+ .def_static("get", &NumberType::get);
py::class_<IntType, Type, std::shared_ptr<IntType>>(m, "IntType")
- .def_static("get", &IntType::get);
+ .def_static("get", &IntType::get);
py::class_<FloatType, Type, std::shared_ptr<FloatType>>(m, "FloatType")
- .def_static("get", &FloatType::get);
+ .def_static("get", &FloatType::get);
py::class_<DynamicType, Type, std::shared_ptr<DynamicType>>(m, "DynamicType")
- .def_static("get", &DynamicType::get);
+ .def_static("get", &DynamicType::get);
py::class_<BoolType, Type, std::shared_ptr<BoolType>>(m, "BoolType")
- .def_static("get", &BoolType::get);
+ .def_static("get", &BoolType::get);
py::class_<TupleType, Type, std::shared_ptr<TupleType>>(m, "TupleType")
- .def(py::init([](std::vector<TypePtr> a){ return TupleType::create(a); }))
- .def("elements", [](TupleType &self){
- std::vector<TypePtr> types;
- for (const auto& type : self.elements()) {
- types.push_back(type);
- }
- return types;
- });
+ .def(
+ py::init([](std::vector<TypePtr> a) { return TupleType::create(a); }))
+ .def("elements", [](TupleType& self) {
+ std::vector<TypePtr> types;
+ for (const auto& type : self.elements()) {
+ types.push_back(type);
+ }
+ return types;
+ });
py::class_<ListType, Type, std::shared_ptr<ListType>>(m, "ListType")
- .def_static("ofInts", &ListType::ofInts)
- .def_static("ofTensors", &ListType::ofTensors)
- .def("getElementType", &ListType::getElementType);
+ .def_static("ofInts", &ListType::ofInts)
+ .def_static("ofTensors", &ListType::ofTensors)
+ .def("getElementType", &ListType::getElementType);
- py::class_<Use>(m,"Use")
- .def_readonly("user",&Use::user)
- .def_readonly("offset",&Use::offset);
+ py::class_<Use>(m, "Use")
+ .def_readonly("user", &Use::user)
+ .def_readonly("offset", &Use::offset);
}
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/ir.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
void initPythonIRBindings(PyObject* module);
-}}
+}
+} // namespace torch
#include <torch/csrc/python_headers.h>
-#include <torch/csrc/jit/python_tracer.h>
-#include <torch/csrc/jit/tracer.h>
#include <torch/csrc/jit/export.h>
-#include <torch/csrc/jit/pybind.h>
-#include <torch/csrc/utils/python_strings.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/lower_tuples.h>
+#include <torch/csrc/jit/pybind.h>
+#include <torch/csrc/jit/python_tracer.h>
+#include <torch/csrc/jit/tracer.h>
+#include <torch/csrc/utils/python_strings.h>
#include <c10/util/Exception.h>
using namespace torch::jit;
using namespace torch::jit::tracer;
-
-namespace torch { namespace jit { namespace tracer {
-
+namespace torch {
+namespace jit {
+namespace tracer {
// Python interpreter retrieval routine adapted from
// https://stackoverflow.com/a/8706144
std::string getPythonInterpreterStackTrace() {
std::stringstream stack_trace;
AutoGIL gil;
- PyFrameObject *frame = PyEval_GetFrame();
+ PyFrameObject* frame = PyEval_GetFrame();
while (nullptr != frame) {
int line = PyCode_Addr2Line(frame->f_code, frame->f_lasti);
std::string filename = THPUtils_unpackString(frame->f_code->co_filename);
const c10::optional<size_t>& num_real_inputs) {
size_t num_func_inputs = num_real_inputs.value_or(trace_inputs.size());
auto enter_info = tracer::enter(std::move(trace_inputs));
- getTracingState()->lookup_var_name_fn = [var_name_lookup_fn](const Variable& var) -> std::string {
+ getTracingState()->lookup_var_name_fn =
+ [var_name_lookup_fn](const Variable& var) -> std::string {
AutoGIL ag;
return py::cast<std::string>(var_name_lookup_fn(var));
};
getTracingState()->force_outplace = force_outplace;
try {
-
py::tuple py_inputs(num_func_inputs);
- for(size_t i = 0; i < num_func_inputs; ++i) {
+ for (size_t i = 0; i < num_func_inputs; ++i) {
py_inputs[i] = py::cast(enter_info.second[i]);
}
auto out = func(*py_inputs);
if (out.ptr() == Py_None) {
- AT_ERROR("The traced function didn't return any values! Side-effects are not "
- "captured in traces, so it would be a no-op.");
+ AT_ERROR(
+ "The traced function didn't return any values! Side-effects are not "
+ "captured in traces, so it would be a no-op.");
}
tracer::exit({toIValue(out)});
auto graph = enter_info.first->graph;
}
}
-Node* preRecordPythonTrace(THPObjectPtr pyobj,
- const std::string& arg_types,
- at::ArrayRef<Variable> inputs,
- pyobj_list scalar_args) {
+Node* preRecordPythonTrace(
+ THPObjectPtr pyobj,
+ const std::string& arg_types,
+ at::ArrayRef<Variable> inputs,
+ pyobj_list scalar_args) {
THPObjectPtr apply(PyObject_GetAttrString(pyobj.get(), "apply"));
- if(!apply) {
+ if (!apply) {
throw python_error();
}
- auto & graph = getTracingState()->graph;
+ auto& graph = getTracingState()->graph;
Node* n = graph->createPythonOp(
std::move(apply), arg_types, std::move(scalar_args));
recordSourceLocation(n);
- for (const Variable & input : inputs) {
+ for (const Variable& input : inputs) {
n->addInput(getValueTrace(input));
}
}
void pythonRecordSourceLocation(Node* n) {
- auto sl = std::make_shared<StringSourceLocation>(getPythonInterpreterStackTrace());
+ auto sl =
+ std::make_shared<StringSourceLocation>(getPythonInterpreterStackTrace());
n->setSourceLocation(sl);
}
setRecordSourceLocation(pythonRecordSourceLocation);
auto m = py::handle(module).cast<py::module>();
- py::class_<TracingState,std::shared_ptr<TracingState>>(m, "TracingState", py::dynamic_attr())
- // NB: no constructor; you have to get it from C++ code
- .def("__repr__", [](const TracingState& s) {
- std::ostringstream ss;
- ss << "<TracingState " << (const void*)&s << ">";
- return ss.str();
- })
- .def("__str__", [](const TracingState& s) -> std::string {
- std::ostringstream ss;
- ss << *s.graph;
- return ss.str();
- })
- .def("push_scope", [](TracingState& s, const std::string& scope_name) {
- s.graph->push_scope(scope_name);
- })
- .def("pop_scope", [](TracingState& s) {
- s.graph->pop_scope();
- })
- .def("set_graph", [](TracingState& s, std::shared_ptr<Graph> g) {
- s.graph = g;
- })
- .def("graph", [](TracingState& s) {
- return s.graph;
- });
-
- m.def("_tracer_warn_use_python", []() {
- tracer::setWarn(pythonWarn);
- });
+ py::class_<TracingState, std::shared_ptr<TracingState>>(
+ m, "TracingState", py::dynamic_attr())
+ // NB: no constructor; you have to get it from C++ code
+ .def(
+ "__repr__",
+ [](const TracingState& s) {
+ std::ostringstream ss;
+ ss << "<TracingState " << (const void*)&s << ">";
+ return ss.str();
+ })
+ .def(
+ "__str__",
+ [](const TracingState& s) -> std::string {
+ std::ostringstream ss;
+ ss << *s.graph;
+ return ss.str();
+ })
+ .def(
+ "push_scope",
+ [](TracingState& s, const std::string& scope_name) {
+ s.graph->push_scope(scope_name);
+ })
+ .def("pop_scope", [](TracingState& s) { s.graph->pop_scope(); })
+ .def(
+ "set_graph",
+ [](TracingState& s, std::shared_ptr<Graph> g) { s.graph = g; })
+ .def("graph", [](TracingState& s) { return s.graph; });
+
+ m.def("_tracer_warn_use_python", []() { tracer::setWarn(pythonWarn); });
m.def("_tracer_enter", [](py::args trace_inputs) {
return tracer::enter(toStack(trace_inputs));
});
m.def("_tracer_exit", [](py::tuple var_outputs) {
tracer::exit(toStack(var_outputs));
});
- m.def("_tracer_abandon", []() {
- tracer::abandon();
- });
- m.def("_get_tracing_state", []() {
- return getTracingState();
- });
+ m.def("_tracer_abandon", []() { tracer::abandon(); });
+ m.def("_get_tracing_state", []() { return getTracingState(); });
m.def("_set_tracing_state", [](std::shared_ptr<TracingState> state) {
return setTracingState(state);
});
m.def("_tracer_set_get_unique_name_fn", [](py::function func) {
const auto& tracing_state = getTracingState();
JIT_ASSERT(tracing_state);
- tracing_state->lookup_var_name_fn = [func](const Variable& var) -> std::string {
+ tracing_state->lookup_var_name_fn =
+ [func](const Variable& var) -> std::string {
AutoGIL ag;
return py::cast<std::string>(func(var));
};
});
}
-}}} // namespace torch::jit::tracing
+} // namespace tracer
+} // namespace jit
+} // namespace torch
#pragma once
-#include <torch/csrc/python_headers.h>
#include <torch/csrc/jit/tracer.h>
+#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/pybind.h>
#include <memory>
#include <string>
-namespace torch { namespace jit { namespace tracer {
-void initPythonTracerBindings(PyObject *module);
-
+namespace torch {
+namespace jit {
+namespace tracer {
+void initPythonTracerBindings(PyObject* module);
std::string getPythonInterpreterStackTrace();
Node* preRecordPythonTrace(
bool force_outplace,
const c10::optional<size_t>& num_real_inputs = c10::nullopt);
} // namespace tracer
-}} // namespace torch::jit
+} // namespace jit
+} // namespace torch
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/autograd/variable.h>
+#include <torch/csrc/jit/custom_operator.h>
#include <torch/csrc/jit/fuser/interface.h>
#include <torch/csrc/jit/graph_executor.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/operator.h>
-#include <torch/csrc/jit/custom_operator.h>
#include <torch/csrc/jit/script/jit_exception.h>
#include <ATen/ExpandUtils.h>
// and if the dest is an int the source must be integral type
void checkImplicitTensorToNum(at::Tensor t, bool toInt) {
if (autograd::as_variable_ref(t).requires_grad()) {
- throw std::runtime_error("Cannot input a tensor that requires grad as a scalar argument");
+ throw std::runtime_error(
+ "Cannot input a tensor that requires grad as a scalar argument");
}
if (t.sizes().size() != 0) {
- throw std::runtime_error("Cannot input a tensor of dimension other than 0 as a scalar argument");
+ throw std::runtime_error(
+ "Cannot input a tensor of dimension other than 0 as a scalar argument");
}
- if (toInt && !isIntegralType(autograd::as_variable_ref(t).data().type().scalarType())) {
+ if (toInt &&
+ !isIntegralType(
+ autograd::as_variable_ref(t).data().type().scalarType())) {
std::stringstream ss;
- ss << "Cannot input a tensor of type " << t.type().scalarType() << " as an integral argument";
+ ss << "Cannot input a tensor of type " << t.type().scalarType()
+ << " as an integral argument";
throw std::runtime_error(ss.str());
}
}
Operator(
"prim::Float(Tensor a) -> float",
[](const Node* node) -> Operation {
- return [](Stack& stack) {
- at::Tensor a;
- pop(stack, a);
- push(stack, a.item<double>());
- return 0;
- };
+ return [](Stack& stack) {
+ at::Tensor a;
+ pop(stack, a);
+ push(stack, a.item<double>());
+ return 0;
+ };
}),
Operator(
"prim::ImplicitTensorToNum(Tensor a) -> Scalar",
[](const Node* node) -> Operation {
- if(node->output()->type() == IntType::get()) {
+ if (node->output()->type() == IntType::get()) {
return [](Stack& stack) {
at::Tensor a;
pop(stack, a);
- checkImplicitTensorToNum(a, /*to int*/true);
+ checkImplicitTensorToNum(a, /*to int*/ true);
push(stack, a.item<int64_t>());
return 0;
};
return [](Stack& stack) {
at::Tensor a;
pop(stack, a);
- checkImplicitTensorToNum(a, /*to int*/false);
+ checkImplicitTensorToNum(a, /*to int*/ false);
push(stack, a.item<double>());
return 0;
};
return [](Stack& stack) {
bool b;
pop(stack, b);
- push(
- stack,
- autograd::make_variable(at::scalar_to_tensor(b)));
+ push(stack, autograd::make_variable(at::scalar_to_tensor(b)));
return 0;
};
}),
};
}),
Operator(
- prim::None,
- [](const Node* node) {
- return [](Stack& stack) {
- stack.emplace_back(IValue());
- return 0;
- };
- }),
+ prim::None,
+ [](const Node* node) {
+ return [](Stack& stack) {
+ stack.emplace_back(IValue());
+ return 0;
+ };
+ }),
Operator(
prim::Print,
[](const Node* node) {
std::vector<int64_t> size;
size.reserve(8);
for (size_t i = 0; i < num_inputs; ++i) {
- size = at::infer_size(size, peek(stack, i, num_inputs).toIntList()->elements());
+ size = at::infer_size(
+ size, peek(stack, i, num_inputs).toIntList()->elements());
}
drop(stack, num_inputs);
push(stack, std::move(size));
return [raw_dim, chunks](Stack& stack) {
Shared<IntList> sizes_l;
pop(stack, sizes_l);
- const auto & shape = sizes_l->elements();
+ const auto& shape = sizes_l->elements();
std::vector<int64_t> regular_shape = shape;
std::vector<int64_t> last_shape = shape;
int64_t dim = at::maybe_wrap_dim(raw_dim, shape.size());
- AT_CHECK(dim < regular_shape.size(), "Dimension out of range for chunk");
+ AT_CHECK(
+ dim < regular_shape.size(), "Dimension out of range for chunk");
int64_t split_size = (regular_shape[dim] + chunks - 1) / chunks;
regular_shape[dim] = split_size;
if (shape[dim] % chunks == 0) {
last_shape[dim] = split_size;
} else {
- int64_t num_splits = std::max<int64_t>((shape[dim] + split_size - 1) / split_size, 1);
- last_shape[dim] = split_size - (split_size * num_splits - shape[dim]);
+ int64_t num_splits = std::max<int64_t>(
+ (shape[dim] + split_size - 1) / split_size, 1);
+ last_shape[dim] =
+ split_size - (split_size * num_splits - shape[dim]);
JIT_ASSERT(last_shape[dim] >= 0);
}
push(stack, std::move(regular_shape));
};
}),
Operator(
- FunctionSchema("aten::warn", {Argument("message", StringType::get()), Argument("stacklevel", IntType::get(), c10::nullopt, 2, true)}, {}),
+ FunctionSchema(
+ "aten::warn",
+ {Argument("message", StringType::get()),
+ Argument("stacklevel", IntType::get(), c10::nullopt, 2, true)},
+ {}),
[](const Node* node) {
return [](Stack& stack) {
drop(stack, 1);
size_t num_elems = node->outputs().size();
return [=](Stack& stack) {
auto t = pop(stack).toTuple();
- const auto & elems = t->elements();
+ const auto& elems = t->elements();
if (elems.size() != num_elems) {
- AT_ERROR("Expected a tuple of ", num_elems, " elements, but got ", elems.size());
+ AT_ERROR(
+ "Expected a tuple of ",
+ num_elems,
+ " elements, but got ",
+ elems.size());
}
stack.insert(stack.end(), elems.begin(), elems.end());
return 0;
int64_t end_ind = node->i(attr::end);
return [=](Stack& stack) {
auto t = pop(stack).toTuple();
- const auto & elems = t->elements();
+ const auto& elems = t->elements();
std::vector<IValue> output_elems;
for (int64_t i = beg_ind; i < end_ind; ++i) {
output_elems.emplace_back(elems.at(i));
};
}),
Operator(
- prim::TupleIndex,
- [](const Node* node) {
- auto index = node->i(attr::index);
- return [=](Stack& stack) {
- auto tup = pop(stack).toTuple();
- const auto & elems = tup->elements();
- // index is normalized to be positive at compile time
- stack.emplace_back(elems.at(index));
- return 0;
- };
- }),
+ prim::TupleIndex,
+ [](const Node* node) {
+ auto index = node->i(attr::index);
+ return [=](Stack& stack) {
+ auto tup = pop(stack).toTuple();
+ const auto& elems = tup->elements();
+ // index is normalized to be positive at compile time
+ stack.emplace_back(elems.at(index));
+ return 0;
+ };
+ }),
Operator(
prim::TupleConstruct,
[](const Node* node) {
size_t num_inputs = node->inputs().size();
return [=](Stack& stack) {
- std::vector<IValue> elems {
- std::make_move_iterator(stack.end() - num_inputs),
- std::make_move_iterator(stack.end())
- };
+ std::vector<IValue> elems{
+ std::make_move_iterator(stack.end() - num_inputs),
+ std::make_move_iterator(stack.end())};
drop(stack, num_inputs);
push(stack, Tuple::create(std::move(elems)));
return 0;
[](const Node* node) {
int64_t chunks = node->i(attr::chunks);
int64_t dim = node->i(attr::dim);
- auto outputs_used = fmap(node->outputs(), [](const Value *v) { return v->uses().size() > 0; });
+ auto outputs_used = fmap(node->outputs(), [](const Value* v) {
+ return v->uses().size() > 0;
+ });
return [=](Stack& stack) {
autograd::profiler::RecordFunction record("chunk");
at::Tensor t;
pop(stack, t);
auto result = at::chunk(t, chunks, dim);
- stack.insert(stack.end(), std::make_move_iterator(result.begin()),
- std::make_move_iterator(result.end()));
+ stack.insert(
+ stack.end(),
+ std::make_move_iterator(result.begin()),
+ std::make_move_iterator(result.end()));
// NB: Chunk can sometimes return a smaller number of outputs.
int64_t num_results = result.size();
if (num_results != chunks) {
if (num_results > chunks) {
- JIT_ASSERTM(num_results == chunks,
- "Expected chunk to return ", chunks, " outputs, but got ", num_results);
+ JIT_ASSERTM(
+ num_results == chunks,
+ "Expected chunk to return ",
+ chunks,
+ " outputs, but got ",
+ num_results);
}
for (int64_t i = num_results; i < chunks; ++i) {
- AT_CHECK(!outputs_used[i],
- "Expected chunk to return at least ", chunks, " outputs, but got only ", num_results);
- // We know that the output is unused, so it's ok to push anything on the stack.
+ AT_CHECK(
+ !outputs_used[i],
+ "Expected chunk to return at least ",
+ chunks,
+ " outputs, but got only ",
+ num_results);
+ // We know that the output is unused, so it's ok to push
+ // anything on the stack.
stack.emplace_back();
}
}
if (lt->getElementType() == IntType::get()) {
return [=](Stack& stack) {
auto ilist = pop(stack);
- const auto & list = ilist.toIntList()->elements();
- AT_CHECK(list.size() == num_outputs,
- "Expected ", num_outputs, " elements in a list but found ", list.size());
+ const auto& list = ilist.toIntList()->elements();
+ AT_CHECK(
+ list.size() == num_outputs,
+ "Expected ",
+ num_outputs,
+ " elements in a list but found ",
+ list.size());
stack.insert(stack.end(), list.begin(), list.end());
return 0;
};
} else if (lt->getElementType() == FloatType::get()) {
return [=](Stack& stack) {
auto ilist = pop(stack);
- const auto & list = ilist.toDoubleList()->elements();
- AT_CHECK(list.size() == num_outputs,
- "Expected ", num_outputs, " elements in a list but found ", list.size());
+ const auto& list = ilist.toDoubleList()->elements();
+ AT_CHECK(
+ list.size() == num_outputs,
+ "Expected ",
+ num_outputs,
+ " elements in a list but found ",
+ list.size());
stack.insert(stack.end(), list.begin(), list.end());
return 0;
};
} else if (lt->getElementType() == DynamicType::get()) {
return [=](Stack& stack) {
auto ilist = pop(stack);
- const auto & list = ilist.toTensorList()->elements();
- AT_CHECK(list.size() == num_outputs,
- "Expected ", num_outputs, " elements in a list but found ", list.size());
+ const auto& list = ilist.toTensorList()->elements();
+ AT_CHECK(
+ list.size() == num_outputs,
+ "Expected ",
+ num_outputs,
+ " elements in a list but found ",
+ list.size());
stack.insert(stack.end(), list.begin(), list.end());
return 0;
};
[](const Node* node) -> Operation {
const auto num_inputs = node->inputs().size();
ListTypePtr lt = node->output()->type()->expect<ListType>();
- if(IntType::get() == lt->getElementType()) {
+ if (IntType::get() == lt->getElementType()) {
return [=](Stack& stack) {
auto inputs = peekSlice(stack, 0, num_inputs, num_inputs);
- std::vector<int64_t> vals = fmap(inputs, [](const IValue& v) {
- return v.toInt();
- });
+ std::vector<int64_t> vals =
+ fmap(inputs, [](const IValue& v) { return v.toInt(); });
drop(stack, num_inputs);
push(stack, std::move(vals));
return 0;
};
- } else if(FloatType::get() == lt->getElementType()) {
+ } else if (FloatType::get() == lt->getElementType()) {
return [=](Stack& stack) {
auto inputs = peekSlice(stack, 0, num_inputs, num_inputs);
- std::vector<double> vals = fmap(inputs, [](const IValue& v) {
- return v.toDouble();
- });
+ std::vector<double> vals =
+ fmap(inputs, [](const IValue& v) { return v.toDouble(); });
drop(stack, num_inputs);
push(stack, std::move(vals));
return 0;
};
}
}),
- Operator("aten::_unwrap_optional(t? optional) -> t",
- [](const Node* node) -> Operation {
- return [=](Stack& stack) {
- auto val = pop(stack);
- JIT_ASSERTM(!val.isNone(), "Unwrapping null optional");
- push(stack, val);
- return 0;
- };
- }),
+ Operator(
+ "aten::_unwrap_optional(t? optional) -> t",
+ [](const Node* node) -> Operation {
+ return [=](Stack& stack) {
+ auto val = pop(stack);
+ JIT_ASSERTM(!val.isNone(), "Unwrapping null optional");
+ push(stack, val);
+ return 0;
+ };
+ }),
Operator(
prim::fork,
[](const Node* node) {
#define DEFINE_GENERIC_OP(aten_op, int_op, float_op, int_result, float_result) \
Operator( \
#aten_op "(int a, int b) -> " #int_result, \
- [](const Node* node) { \
+ [](const Node* node) { \
return [=](Stack& stack) { \
int64_t a, b; \
pop(stack, a, b); \
return 0; \
}; \
}), \
- Operator( \
- #aten_op "(float a, float b) -> " #float_result, [](const Node* node) { \
- return [=](Stack& stack) { \
- double a, b; \
- pop(stack, a, b); \
- push(stack, float_op); \
- return 0; \
- }; \
- }),
+ Operator( \
+ #aten_op "(float a, float b) -> " #float_result, \
+ [](const Node* node) { \
+ return [=](Stack& stack) { \
+ double a, b; \
+ pop(stack, a, b); \
+ push(stack, float_op); \
+ return 0; \
+ }; \
+ })
#define DEFINE_INT_FLOAT_OP(aten_op, op, result) \
Operator( \
- #aten_op "(int a, float b) -> " #result, [](const Node* node) { \
+ #aten_op "(int a, float b) -> " #result, \
+ [](const Node* node) { \
return [=](Stack& stack) { \
int64_t a; \
double b; \
return 0; \
}; \
}), \
- Operator( \
- #aten_op "(float a, int b) -> " #result, [](const Node* node) { \
+ Operator(#aten_op "(float a, int b) -> " #result, [](const Node* node) { \
return [=](Stack& stack) { \
double a; \
int64_t b; \
push(stack, op); \
return 0; \
}; \
- }),
+ })
-
-#define DEFINE_INT_OP(aten_op, op) \
+#define DEFINE_INT_OP(aten_op, op) \
Operator(#aten_op "(int a, int b) -> int", [](const Node* node) { \
- return [=](Stack& stack) { \
- int64_t a, b; \
- pop(stack, a, b); \
- push(stack, op); /* NOLINT(hicpp-signed-bitwise) */ \
- return 0; \
- }; \
- }),
-
-#define DEFINE_BINARY_OP(aten_op, op) \
- DEFINE_GENERIC_OP(aten_op, op, op, int, float) \
- DEFINE_INT_FLOAT_OP(aten_op, op, float)
-#define DEFINE_COMPARISON_OP(aten_op, op) \
- DEFINE_GENERIC_OP(aten_op, op, op, bool, bool) \
- DEFINE_INT_FLOAT_OP(aten_op, op, bool)
-#define DEFINE_BOOL_OP(aten_op, op) \
+ return [=](Stack& stack) { \
+ int64_t a, b; \
+ pop(stack, a, b); \
+ push(stack, op); /* NOLINT(hicpp-signed-bitwise) */ \
+ return 0; \
+ }; \
+ })
+
+#define DEFINE_BINARY_OP(aten_op, op) \
+ DEFINE_GENERIC_OP(aten_op, op, op, int, float), \
+ DEFINE_INT_FLOAT_OP(aten_op, op, float)
+#define DEFINE_COMPARISON_OP(aten_op, op) \
+ DEFINE_GENERIC_OP(aten_op, op, op, bool, bool), \
+ DEFINE_INT_FLOAT_OP(aten_op, op, bool)
+#define DEFINE_BOOL_OP(aten_op, op) \
Operator(#aten_op "(bool a, bool b) -> bool", [](const Node* node) { \
- return [=](Stack& stack) { \
- bool a, b; \
- pop(stack, a, b); \
- push(stack, op); \
- return 0; \
- }; \
- }),
+ return [=](Stack& stack) { \
+ bool a, b; \
+ pop(stack, a, b); \
+ push(stack, op); \
+ return 0; \
+ }; \
+ })
// Convert an python index (which may be negative) into an index usable for a
// C++ container
RegisterOperators reg2({
-#define DEFINE_STRING_OP(op_name, string_op, result) \
-Operator( \
- #op_name "(str a, str b) ->" #result, \
- [](const Node* node) { \
- return [=](Stack& stack) { \
- auto b = pop(stack).toStringRef(); \
- auto a = pop(stack).toStringRef(); \
- push(stack, string_op); \
- return 0; \
- }; \
- }),
-
- DEFINE_STRING_OP(aten::eq, a == b, bool)
- DEFINE_STRING_OP(aten::ne, a != b, bool)
- DEFINE_STRING_OP(aten::add, a + b, str)
+#define DEFINE_STRING_OP(op_name, string_op, result) \
+ Operator(#op_name "(str a, str b) ->" #result, [](const Node* node) { \
+ return [=](Stack& stack) { \
+ auto b = pop(stack).toStringRef(); \
+ auto a = pop(stack).toStringRef(); \
+ push(stack, string_op); \
+ return 0; \
+ }; \
+ })
+
+ DEFINE_STRING_OP(aten::eq, a == b, bool),
+ DEFINE_STRING_OP(aten::ne, a != b, bool),
+ DEFINE_STRING_OP(aten::add, a + b, str),
#undef DEFINE_STRING_OP
// tensor length op (size of 1st dimension)
Operator(
- "aten::len(Tensor t) -> int",
- [](Stack& stack) {
- at::Tensor t = pop(stack).toTensor();
- if (t.dim() == 0) {
- AT_ERROR("len() of a 0-d tensor");
- }
- push(stack, t.sizes()[0]);
- return 0;
- }
- ),
+ "aten::len(Tensor t) -> int",
+ [](Stack& stack) {
+ at::Tensor t = pop(stack).toTensor();
+ if (t.dim() == 0) {
+ AT_ERROR("len() of a 0-d tensor");
+ }
+ push(stack, t.sizes()[0]);
+ return 0;
+ }),
Operator(
"aten::append(Tensor[](a!) self, Tensor(c) el) -> Tensor[](a!)",
listAppend<Shared<TensorList>, at::Tensor>),
- Operator("aten::select(Tensor[](a) list, int idx) -> Tensor(*)", listSelect<Shared<TensorList>>),
- Operator("aten::_set_item(Tensor[](a!) l, int idx, Tensor el) -> Tensor[](a!)", listSetItem<Shared<TensorList>, at::Tensor>),
-
- // Mutable ops for lists containing immutable types.
-#define CREATE_IMMUTABLE_LIST_OPS(decl_type, c_type) \
- Operator("aten::select(" decl_type "[] a, int b) -> " decl_type, listSelect<Shared<c_type>>), \
- Operator( \
- "aten::append(" decl_type "[](a!) self, " decl_type " el) -> " decl_type "[](a!)", \
- listAppend<Shared<c_type>, c_type::ElemType>), \
- Operator("aten::_set_item(" decl_type "[](a!) l, int idx, " decl_type " el) -> " decl_type"[](a!)", listSetItem<Shared<c_type>, c_type::ElemType>), \
-
- CREATE_IMMUTABLE_LIST_OPS("int", IntList)
- CREATE_IMMUTABLE_LIST_OPS("float", DoubleList)
- CREATE_IMMUTABLE_LIST_OPS("t", GenericList)
-
-#define CREATE_LIST_OPS(decl_type, c_type) \
- Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>), \
- Operator("aten::add(" decl_type "[] a, " decl_type "[] b) -> " decl_type "[]", listAdd<Shared<c_type>, c_type::ElemType>), \
- Operator( \
- "aten::slice(" decl_type "[] l, int start, int end=9223372036854775807, int step=1) -> " decl_type "[]", \
- listSlice<Shared<c_type>, c_type::ElemType>), \
-
-
- CREATE_LIST_OPS("int", IntList)
- CREATE_LIST_OPS("float", DoubleList)
- CREATE_LIST_OPS("Tensor", TensorList)
- CREATE_LIST_OPS("t", GenericList)
+ Operator(
+ "aten::select(Tensor[](a) list, int idx) -> Tensor(*)",
+ listSelect<Shared<TensorList>>),
+ Operator(
+ "aten::_set_item(Tensor[](a!) l, int idx, Tensor el) -> Tensor[](a!)",
+ listSetItem<Shared<TensorList>, at::Tensor>),
+
+// Mutable ops for lists containing immutable types.
+#define CREATE_IMMUTABLE_LIST_OPS(decl_type, c_type) \
+ Operator( \
+ "aten::select(" decl_type "[] a, int b) -> " decl_type, \
+ listSelect<Shared<c_type>>), \
+ Operator( \
+ "aten::append(" decl_type "[](a!) self, " decl_type \
+ " el) -> " decl_type "[](a!)", \
+ listAppend<Shared<c_type>, c_type::ElemType>), \
+ Operator( \
+ "aten::_set_item(" decl_type "[](a!) l, int idx, " decl_type \
+ " el) -> " decl_type "[](a!)", \
+ listSetItem<Shared<c_type>, c_type::ElemType>)
+
+ CREATE_IMMUTABLE_LIST_OPS("int", IntList),
+ CREATE_IMMUTABLE_LIST_OPS("float", DoubleList),
+ CREATE_IMMUTABLE_LIST_OPS("t", GenericList),
+
+#define CREATE_LIST_OPS(decl_type, c_type) \
+ Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>), \
+ Operator( \
+ "aten::add(" decl_type "[] a, " decl_type "[] b) -> " decl_type \
+ "[]", \
+ listAdd<Shared<c_type>, c_type::ElemType>), \
+ Operator( \
+ "aten::slice(" decl_type \
+ "[] l, int start, int end=9223372036854775807, int step=1) -> " decl_type \
+ "[]", \
+ listSlice<Shared<c_type>, c_type::ElemType>)
+
+ CREATE_LIST_OPS("int", IntList),
+ CREATE_LIST_OPS("float", DoubleList),
+ CREATE_LIST_OPS("Tensor", TensorList),
+ CREATE_LIST_OPS("t", GenericList),
#undef CREATE_LIST_OPS
-
Operator("aten::eq(int[] a, int[] b) -> bool", listEq<Shared<IntList>>),
- Operator("aten::eq(float[] a, float[] b) -> bool", listEq<Shared<DoubleList>>),
- Operator("aten::eq(Tensor[] a, Tensor[] b) -> bool", listEq<Shared<TensorList>>),
+ Operator(
+ "aten::eq(float[] a, float[] b) -> bool",
+ listEq<Shared<DoubleList>>),
+ Operator(
+ "aten::eq(Tensor[] a, Tensor[] b) -> bool",
+ listEq<Shared<TensorList>>),
Operator("aten::ne(int[] a, int[] b) -> bool", listNe<Shared<IntList>>),
- Operator("aten::ne(float[] a, float[] b) -> bool", listNe<Shared<DoubleList>>),
- Operator("aten::ne(Tensor[] a, Tensor[] b) -> bool", listNe<Shared<TensorList>>),
-
-
-#define CREATE_COPY_OP(other_type, c_type) \
- Operator( \
- "aten::copy_(Tensor(a!) self, " #other_type \
- " other) -> Tensor(a!)", \
- [](const Node* node) { \
- return [=](Stack& stack) { \
- at::Tensor t; \
- c_type other; \
- pop(stack, t, other); \
- std::move(t) = other; /* NOLINT(bugprone-use-after-move) */ \
+ Operator(
+ "aten::ne(float[] a, float[] b) -> bool",
+ listNe<Shared<DoubleList>>),
+ Operator(
+ "aten::ne(Tensor[] a, Tensor[] b) -> bool",
+ listNe<Shared<TensorList>>),
+
+#define CREATE_COPY_OP(other_type, c_type) \
+ Operator( \
+ "aten::copy_(Tensor(a!) self, " #other_type " other) -> Tensor(a!)", \
+ [](const Node* node) { \
+ return [=](Stack& stack) { \
+ at::Tensor t; \
+ c_type other; \
+ pop(stack, t, other); \
+ std::move(t) = other; /* NOLINT(bugprone-use-after-move) */ \
push(stack, std::move(t)); /* NOLINT(bugprone-use-after-move) */ \
- return 0; \
- }; \
- }),
+ return 0; \
+ }; \
+ })
- CREATE_COPY_OP(Tensor, at::Tensor)
- CREATE_COPY_OP(int, int64_t)
- CREATE_COPY_OP(float, double)
+ CREATE_COPY_OP(Tensor, at::Tensor),
+ CREATE_COPY_OP(int, int64_t),
+ CREATE_COPY_OP(float, double),
#undef CREATE_COPY_OP
- DEFINE_BINARY_OP(aten::add, a + b)
- DEFINE_BINARY_OP(aten::sub, a - b)
- DEFINE_BINARY_OP(aten::mul, a * b)
- DEFINE_BINARY_OP(aten::pow, static_cast<decltype(a)>(pow(a, b)))
-
- // Pass in two ops for handling int and float separately as % in C++ only works for int
- // The modulus calculation is different between C++ and Python (on negative), we preserve
- // the python behavior as it's more common and match python syntax, hence the conversion.
- DEFINE_GENERIC_OP(aten::remainder, (b + (a % b)) % b, fmod((b + fmod(a, b)), b), int, float)
- DEFINE_INT_FLOAT_OP(aten::remainder, fmod((b + fmod(a, b)), b), float)
-
+ DEFINE_BINARY_OP(aten::add, a + b),
+ DEFINE_BINARY_OP(aten::sub, a - b),
+ DEFINE_BINARY_OP(aten::mul, a* b),
+ DEFINE_BINARY_OP(aten::pow, static_cast<decltype(a)>(pow(a, b))),
+
+ // Pass in two ops for handling int and float separately as % in C++ only
+ // works for int The modulus calculation is different between C++ and Python
+ // (on negative), we preserve the python behavior as it's more common and
+ // match python syntax, hence the conversion.
+ DEFINE_GENERIC_OP(
+ aten::remainder,
+ (b + (a % b)) % b,
+ fmod((b + fmod(a, b)), b),
+ int,
+ float),
+ DEFINE_INT_FLOAT_OP(aten::remainder, fmod((b + fmod(a, b)), b), float),
// in c++ int division rounds to the integer closer to 0, in python floordiv
// rounds to lower integer
- DEFINE_GENERIC_OP(aten::floordiv,
- static_cast<int64_t>(std::floor(static_cast<double>(a) / static_cast<double>(b))),
- std::floor(a / b), int, float)
- DEFINE_INT_FLOAT_OP(aten::floordiv, std::floor(a / b), float)
-
- //only used in loop unrolling, not exposed to end users
- DEFINE_INT_OP(aten::__round_to_zero_floordiv, a / b)
-
- DEFINE_INT_OP(aten::__and__, a & b)
- DEFINE_INT_OP(aten::__or__, a | b)
- DEFINE_INT_OP(aten::__xor__, a ^ b)
+ DEFINE_GENERIC_OP(
+ aten::floordiv,
+ static_cast<int64_t>(
+ std::floor(static_cast<double>(a) / static_cast<double>(b))),
+ std::floor(a / b),
+ int,
+ float),
+ DEFINE_INT_FLOAT_OP(aten::floordiv, std::floor(a / b), float),
+
+ // only used in loop unrolling, not exposed to end users
+ DEFINE_INT_OP(aten::__round_to_zero_floordiv, a / b),
+
+ DEFINE_INT_OP(aten::__and__, a& b),
+ DEFINE_INT_OP(aten::__or__, a | b),
+ DEFINE_INT_OP(aten::__xor__, a ^ b),
// NB: This is the python truediv operation
- Operator("aten::div(int a, int b) -> float",
+ Operator(
+ "aten::div(int a, int b) -> float",
[](const Node* node) {
return [=](Stack& stack) {
int64_t a, b;
return 0;
};
}),
- Operator("aten::div(float a, float b) -> float",
+ Operator(
+ "aten::div(float a, float b) -> float",
[](const Node* node) {
return [=](Stack& stack) {
double a, b;
};
}),
- Operator("aten::floor(float a) -> int",
+ Operator(
+ "aten::floor(float a) -> int",
[](const Node* node) {
return [=](Stack& stack) {
double a;
};
}),
- DEFINE_COMPARISON_OP(aten::ne, a != b)
- DEFINE_COMPARISON_OP(aten::eq, a == b)
- DEFINE_COMPARISON_OP(aten::lt, a < b)
- DEFINE_COMPARISON_OP(aten::gt, a > b)
- DEFINE_COMPARISON_OP(aten::le, a <= b)
- DEFINE_COMPARISON_OP(aten::ge, a >= b)
+ DEFINE_COMPARISON_OP(aten::ne, a != b),
+ DEFINE_COMPARISON_OP(aten::eq, a == b),
+ DEFINE_COMPARISON_OP(aten::lt, a < b),
+ DEFINE_COMPARISON_OP(aten::gt, a > b),
+ DEFINE_COMPARISON_OP(aten::le, a <= b),
+ DEFINE_COMPARISON_OP(aten::ge, a >= b),
- DEFINE_BOOL_OP(aten::__and__, a && b)
- DEFINE_BOOL_OP(aten::__or__, a || b)
- DEFINE_BOOL_OP(aten::__xor__, a != b)
+ DEFINE_BOOL_OP(aten::__and__, a&& b),
+ DEFINE_BOOL_OP(aten::__or__, a || b),
+ DEFINE_BOOL_OP(aten::__xor__, a != b),
Operator(
"aten::neg(int self) -> int",
pop(stack, t);
std::vector<int64_t> elems;
elems.reserve(t.size(0));
- for(int i = 0; i < t.size(0); i++){
+ for (int i = 0; i < t.size(0); i++) {
elems.push_back(*t[i].data<int32_t>());
}
push(stack, jit::IntList::create(elems));
pop(stack, l);
auto t = torch::empty(
{static_cast<int64_t>(l.size())}, at::dtype(at::kInt));
- for(size_t i = 0; i < l.size(); i++){
+ for (size_t i = 0; i < l.size(); i++) {
t[i] = l[i];
}
push(stack, t);
}),
});
-
// checking one of size & scale_factor is set
// if scale_factor is a double list check that it's len == dim
// reference: _check_size_scale_factor in torch/nn/functional.py
-void _check_size_factor(size_t dim, const IValue& size, const IValue& scale_factor) {
+void _check_size_factor(
+ size_t dim,
+ const IValue& size,
+ const IValue& scale_factor) {
if (size.isNone() && scale_factor.isNone()) {
throw std::runtime_error("either size or scale_factor should be defined");
}
if (!size.isNone() && !scale_factor.isNone()) {
- throw std::runtime_error("only one of size or scale_factor should be defined");
+ throw std::runtime_error(
+ "only one of size or scale_factor should be defined");
}
if (scale_factor.isDoubleList()) {
auto scale_len = scale_factor.toDoubleListRef().size();
if (scale_len != dim) {
std::stringstream str;
str << "scale_factor shape must match input shape. Input is " << dim
- << "D, scale_factor size is " << scale_len;
- throw std::runtime_error("only one of size or scale_factor should be defined");
+ << "D, scale_factor size is " << scale_len;
+ throw std::runtime_error(
+ "only one of size or scale_factor should be defined");
}
}
}
// reference: _output_size in torch/nn/functional.py
// size can be none, int or intlist
// scale_factors can be none, float, or floatlist
-std::vector<int64_t> _output_size(const at::Tensor& input, size_t dim, const IValue& size, const IValue& scale_factors) {
+std::vector<int64_t> _output_size(
+ const at::Tensor& input,
+ size_t dim,
+ const IValue& size,
+ const IValue& scale_factors) {
if (!size.isNone()) {
if (size.isInt()) {
std::vector<int64_t> repeated(dim, size.toInt());
c10::optional<bool> align_corners) {
if ((mode == "nearest" || mode == "area")) {
if (align_corners != c10::nullopt) {
- throw std::runtime_error("align_corners option can only be set with the "
- "interpolating modes: linear | bilinear | bicubic | trilinear");
+ throw std::runtime_error(
+ "align_corners option can only be set with the "
+ "interpolating modes: linear | bilinear | bicubic | trilinear");
}
} else {
if (align_corners == c10::nullopt) {
- AT_WARN("Default upsampling behavior when mode=", mode, " is changed "
- "to align_corners=False since 0.4.0. Please specify align_corners=True "
- "if the old behavior is desired. See the documentation of nn.Upsample for details");
+ AT_WARN(
+ "Default upsampling behavior when mode=",
+ mode,
+ " is changed "
+ "to align_corners=False since 0.4.0. Please specify align_corners=True "
+ "if the old behavior is desired. See the documentation of nn.Upsample for details");
align_corners = false;
}
}
auto input_dim = input.dim();
if (input_dim == 3 && mode == "nearest")
- return at::upsample_nearest1d(input, _output_size(input, 1, size, scale_factors));
+ return at::upsample_nearest1d(
+ input, _output_size(input, 1, size, scale_factors));
if (input_dim == 4 && mode == "nearest")
- return at::upsample_nearest2d(input, _output_size(input, 2, size, scale_factors));
+ return at::upsample_nearest2d(
+ input, _output_size(input, 2, size, scale_factors));
if (input_dim == 5 && mode == "nearest")
- return at::upsample_nearest3d(input, _output_size(input, 3, size, scale_factors));
+ return at::upsample_nearest3d(
+ input, _output_size(input, 3, size, scale_factors));
if (input_dim == 3 && mode == "area")
- return at::adaptive_avg_pool1d(input, _output_size(input, 1, size, scale_factors));
+ return at::adaptive_avg_pool1d(
+ input, _output_size(input, 1, size, scale_factors));
if (input_dim == 4 && mode == "area")
- return at::adaptive_avg_pool2d(input, _output_size(input, 2, size, scale_factors));
+ return at::adaptive_avg_pool2d(
+ input, _output_size(input, 2, size, scale_factors));
if (input_dim == 5 && mode == "area")
- return at::adaptive_avg_pool3d(input, _output_size(input, 3, size, scale_factors));
+ return at::adaptive_avg_pool3d(
+ input, _output_size(input, 3, size, scale_factors));
if (input_dim == 3 && mode == "linear")
- return at::upsample_linear1d(input, _output_size(input, 1, size, scale_factors), *align_corners);
+ return at::upsample_linear1d(
+ input, _output_size(input, 1, size, scale_factors), *align_corners);
if (input_dim == 3 && mode == "bilinear")
throw std::runtime_error("Got 3D input, but bilinear mode needs 4D input");
if (input_dim == 3 && mode == "bicubic")
if (input_dim == 4 && mode == "linear")
throw std::runtime_error("Got 4D input, but linear mode needs 3D input");
if (input_dim == 4 && mode == "bilinear")
- return at::upsample_bilinear2d(input, _output_size(input, 2, size, scale_factors), *align_corners);
+ return at::upsample_bilinear2d(
+ input, _output_size(input, 2, size, scale_factors), *align_corners);
if (input_dim == 4 && mode == "bicubic")
- return at::upsample_bicubic2d(input, _output_size(input, 2, size, scale_factors), *align_corners);
+ return at::upsample_bicubic2d(
+ input, _output_size(input, 2, size, scale_factors), *align_corners);
if (input_dim == 4 && mode == "trilinear")
throw std::runtime_error("Got 4D input, but trilinear mode needs 5D input");
if (input_dim == 5 && mode == "linear")
if (input_dim == 5 && mode == "bicubic")
throw std::runtime_error("Got 5D input, but bicubic mode needs 4D input");
if (input_dim == 5 && mode == "trilinear")
- return at::upsample_trilinear3d(input, _output_size(input, 3, size, scale_factors), *align_corners);
-
- AT_ERROR("Input Error: Only 3D, 4D and 5D input Tensors supported",
- " (got ", input_dim, "D) for the modes: nearest | linear | bilinear | trilinear",
- " (got ", mode, ") ");
+ return at::upsample_trilinear3d(
+ input, _output_size(input, 3, size, scale_factors), *align_corners);
+
+ AT_ERROR(
+ "Input Error: Only 3D, 4D and 5D input Tensors supported",
+ " (got ",
+ input_dim,
+ "D) for the modes: nearest | linear | bilinear | trilinear",
+ " (got ",
+ mode,
+ ") ");
}
Operation interpolate_op(const Node* n) {
std::string mode;
IValue align_corners;
pop(stack, input, size, scale_factors, mode, align_corners);
- at::Tensor res = interpolate(input, size, scale_factors, mode, align_corners.toOptional<bool>());
+ at::Tensor res = interpolate(
+ input, size, scale_factors, mode, align_corners.toOptional<bool>());
push(stack, res);
return 0;
};
return IValue();
} else {
std::stringstream ss;
- ss << "Expecting optional int or int list arg for scale factor, got" << int_ivalue;
+ ss << "Expecting optional int or int list arg for scale factor, got"
+ << int_ivalue;
throw std::runtime_error(ss.str());
}
return scale_factor_double;
IValue size;
IValue scale_factor_int;
pop(stack, input, size, scale_factor_int);
- IValue scale_factor_double = convert_scale_factor_to_double(scale_factor_int);
- at::Tensor res = interpolate(input, size, scale_factor_double, "nearest", c10::nullopt);
+ IValue scale_factor_double =
+ convert_scale_factor_to_double(scale_factor_int);
+ at::Tensor res =
+ interpolate(input, size, scale_factor_double, "nearest", c10::nullopt);
push(stack, res);
return 0;
};
std::string mode;
IValue align_corners;
pop(stack, input, size, scale_factor_int, mode, align_corners);
- IValue scale_factor_double = convert_scale_factor_to_double(scale_factor_int);
- at::Tensor res = interpolate(input, size, scale_factor_double, mode, align_corners.toOptional<bool>());
+ IValue scale_factor_double =
+ convert_scale_factor_to_double(scale_factor_int);
+ at::Tensor res = interpolate(
+ input,
+ size,
+ scale_factor_double,
+ mode,
+ align_corners.toOptional<bool>());
push(stack, res);
return 0;
};
IValue size;
IValue scale_factor_int;
pop(stack, input, size, scale_factor_int);
- IValue scale_factor_double = convert_scale_factor_to_double(scale_factor_int);
- at::Tensor res = interpolate(input, size, scale_factor_double, "bilinear", true);
+ IValue scale_factor_double =
+ convert_scale_factor_to_double(scale_factor_int);
+ at::Tensor res =
+ interpolate(input, size, scale_factor_double, "bilinear", true);
push(stack, res);
return 0;
};
}
-
RegisterOperators reg3({
- Operator(
- "aten::__interpolate(Tensor input, int? size = None, float[]? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
- interpolate_op),
- Operator(
- "aten::__interpolate(Tensor input, int[]? size = None, float[]? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
- interpolate_op),
- Operator(
- "aten::__interpolate(Tensor input, int? size = None, float? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
- interpolate_op),
- Operator(
- "aten::__interpolate(Tensor input, int[]? size = None, float? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
- interpolate_op),
-
- Operator(
- "aten::__upsample_nearest(Tensor input, int? size = None, int? scale_factor = None) -> Tensor",
- upsample_nearest_op),
- Operator(
- "aten::__upsample_nearest(Tensor input, int[]? size = None, int? scale_factor = None) -> Tensor",
- upsample_nearest_op),
-
- Operator(
- "aten::__upsample(Tensor input, int? size = None, int? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
- upsample_op),
- Operator(
- "aten::__upsample(Tensor input, int[]? size = None, int? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
- upsample_op),
-
-
- Operator(
- "aten::__upsample_bilinear(Tensor input, int? size = None, int? scale_factor = None) -> Tensor",
- upsample_bilinear_op),
- Operator(
- "aten::__upsample_bilinear(Tensor input, int[]? size = None, int? scale_factor = None) -> Tensor",
- upsample_bilinear_op),
- Operator(
- "aten::__upsample_bilinear(Tensor input, int? size = None, int[]? scale_factor = None) -> Tensor",
- upsample_bilinear_op),
- Operator(
- "aten::__upsample_bilinear(Tensor input, int[]? size = None, int[]? scale_factor = None) -> Tensor",
- upsample_bilinear_op),
+ Operator(
+ "aten::__interpolate(Tensor input, int? size = None, float[]? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
+ interpolate_op),
+ Operator(
+ "aten::__interpolate(Tensor input, int[]? size = None, float[]? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
+ interpolate_op),
+ Operator(
+ "aten::__interpolate(Tensor input, int? size = None, float? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
+ interpolate_op),
+ Operator(
+ "aten::__interpolate(Tensor input, int[]? size = None, float? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
+ interpolate_op),
-});
+ Operator(
+ "aten::__upsample_nearest(Tensor input, int? size = None, int? scale_factor = None) -> Tensor",
+ upsample_nearest_op),
+ Operator(
+ "aten::__upsample_nearest(Tensor input, int[]? size = None, int? scale_factor = None) -> Tensor",
+ upsample_nearest_op),
+
+ Operator(
+ "aten::__upsample(Tensor input, int? size = None, int? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
+ upsample_op),
+ Operator(
+ "aten::__upsample(Tensor input, int[]? size = None, int? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
+ upsample_op),
+
+ Operator(
+ "aten::__upsample_bilinear(Tensor input, int? size = None, int? scale_factor = None) -> Tensor",
+ upsample_bilinear_op),
+ Operator(
+ "aten::__upsample_bilinear(Tensor input, int[]? size = None, int? scale_factor = None) -> Tensor",
+ upsample_bilinear_op),
+ Operator(
+ "aten::__upsample_bilinear(Tensor input, int? size = None, int[]? scale_factor = None) -> Tensor",
+ upsample_bilinear_op),
+ Operator(
+ "aten::__upsample_bilinear(Tensor input, int[]? size = None, int[]? scale_factor = None) -> Tensor",
+ upsample_bilinear_op),
+});
at::Tensor leaky_relu(const at::Tensor& tensor, double scalar) {
return at::leaky_relu(tensor, scalar);
static auto reg4 =
torch::jit::RegisterOperators()
- .op("_test::leaky_relu(Tensor self, float v=0.01) -> Tensor", &leaky_relu)
+ .op("_test::leaky_relu(Tensor self, float v=0.01) -> Tensor",
+ &leaky_relu)
.op("_test::cat(Tensor[] inputs) -> Tensor", &cat);
-}}} // torch::jit::anon
+} // namespace
+} // namespace jit
+} // namespace torch
+#include <ATen/ExpandUtils.h>
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/jit/custom_operator.h>
#include <torch/csrc/jit/operator.h>
+
#include <torch/csrc/api/include/torch/utils.h>
-#include <ATen/ExpandUtils.h>
-#include <sstream>
#include <regex>
+#include <sstream>
namespace torch {
namespace jit {
}),
Operator(
"aten::Size(int[] sizes) -> int[]",
- [](Stack& stack) {
- return 0;
- }),
+ [](Stack& stack) { return 0; }),
Operator(
"aten::size(Tensor self) -> int[]",
[](Stack& stack) {
auto args = last(stack, num_inputs - 1);
std::stringstream ss;
- for(size_t begin = 0, used_args = 0; true; ++used_args) {
+ for (size_t begin = 0, used_args = 0; true; ++used_args) {
size_t loc = format.find("{}", begin);
- if(loc == std::string::npos) {
+ if (loc == std::string::npos) {
ss << format.substr(begin);
break;
}
ss << format.substr(begin, loc - begin);
- if(used_args >= args.size()) {
+ if (used_args >= args.size()) {
AT_ERROR("Too few arguments for format string: ", format);
}
ss << args[used_args];
};
}),
Operator(
- "aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor",
- [](const Node* node) {
- return [](Stack& stack) {
- at::Tensor weight;
- at::Tensor input;
- double max_norm;
- double norm_type;
- pop(stack, weight, input, max_norm, norm_type);
+ "aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor",
+ [](const Node* node) {
+ return [](Stack& stack) {
+ at::Tensor weight;
+ at::Tensor input;
+ double max_norm;
+ double norm_type;
+ pop(stack, weight, input, max_norm, norm_type);
- // TODO: remove when script supports setting grad mode
- torch::NoGradGuard no_grad;
+ // TODO: remove when script supports setting grad mode
+ torch::NoGradGuard no_grad;
- at::Tensor result = at::embedding_renorm_(weight, input, max_norm, norm_type);
- push(stack, result);
+ at::Tensor result =
+ at::embedding_renorm_(weight, input, max_norm, norm_type);
+ push(stack, result);
- return 0;
- };
- }),
+ return 0;
+ };
+ }),
Operator(
- "aten::_assert_int_or_pair(int[] vals, str name, str message) -> Tensor",
- [](const Node* node) {
- return [](Stack& stack) {
- // Everything is a list at the point this is used, so don't do anything
- drop(stack, 3);
- return 0;
- };
- }),
+ "aten::_assert_int_or_pair(int[] vals, str name, str message) -> Tensor",
+ [](const Node* node) {
+ return [](Stack& stack) {
+ // Everything is a list at the point this is used, so don't do
+ // anything
+ drop(stack, 3);
+ return 0;
+ };
+ }),
});
}
#pragma once
#include <functional>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
class ResourceGuard {
std::function<void()> _destructor;
bool _released;
-public:
+ public:
ResourceGuard(std::function<void()> destructor)
- : _destructor(std::move(destructor))
- , _released(false) {}
+ : _destructor(std::move(destructor)), _released(false) {}
~ResourceGuard() {
- if (!_released) _destructor();
+ if (!_released)
+ _destructor();
}
void release() {
}
};
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/ir.h>
-
-#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/operator.h>
+#include <algorithm>
#include <iostream>
-#include <unordered_map>
-#include <unordered_set>
#include <set>
-#include <stack>
#include <sstream>
-#include <algorithm>
+#include <stack>
#include <string>
+#include <unordered_map>
+#include <unordered_set>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
ScopePtr Scope::push(Symbol name) {
return c10::make_intrusive<Scope>(intrusive_from_this(), name);
return out;
}
-}} // namespace torch::jit
+} // namespace jit
+} // namespace torch
#pragma once
-#include <torch/csrc/jit/interned_strings.h>
-#include <torch/csrc/jit/assertions.h>
-#include <torch/csrc/WindowsTorchApiMacro.h>
#include <c10/macros/Macros.h>
+#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/interned_strings.h>
#include <memory>
using ScopePtr = c10::intrusive_ptr<Scope>;
struct TORCH_API Scope : public c10::intrusive_ptr_target {
-private:
+ private:
ScopePtr parent_;
Symbol name_;
ScopePtr intrusive_from_this() {
// to account for this ownership
return c10::intrusive_ptr<Scope>::reclaim(this);
}
-public:
+
+ public:
Scope() {
name_ = Symbol::scope("");
}
return name_;
}
- std::string namesFromRoot(const std::string& separator="/") const;
+ std::string namesFromRoot(const std::string& separator = "/") const;
};
} // namespace jit
-#include <torch/csrc/jit/script/builtin_functions.h>
#include <torch/csrc/api/include/torch/jit.h>
#include <torch/csrc/jit/code_template.h>
+#include <torch/csrc/jit/script/builtin_functions.h>
-namespace torch { namespace jit { namespace script {
+namespace torch {
+namespace jit {
+namespace script {
auto scalar_operators_source = CodeTemplate(
-R"SCRIPT(
+ R"SCRIPT(
def mul(a : ${Scalar}, b : Tensor) -> Tensor:
return b * a
def add(a : ${Scalar}, b : Tensor) -> Tensor:
)SCRIPT");
auto _ntuple_ops = CodeTemplate(
-R"SCRIPT(
+ R"SCRIPT(
def _${name}(x: BroadcastingList${Length}[${Scalar}]) -> List[${Scalar}]:
return x
)SCRIPT");
struct BuiltinFunctionRegistry {
-
const std::vector<Method*>& getAllBuiltinFunctionsFor(Symbol name) {
const static std::vector<Method*> empty;
// when initializing the builtin function library, we will re-enter
// getAllBuiltinFunctionsFor since it is called in the compiler to
- // lookup builtins and initializing the builtin functions calls the compiler.
- // To avoid deadlocking, we use a recursive mutex (same thread can re-lock,
- // the mutex without waiting), and report no loaded builtins during init.
+ // lookup builtins and initializing the builtin functions calls the
+ // compiler. To avoid deadlocking, we use a recursive mutex (same thread can
+ // re-lock, the mutex without waiting), and report no loaded builtins during
+ // init.
std::lock_guard<std::recursive_mutex> guard(mutex);
- if(state == INTIIALIZING) {
+ if (state == INTIIALIZING) {
return empty;
} else if (state == UNINITIALIZED) {
state = INTIIALIZING;
}
JIT_ASSERT(state == INITIALIZED);
auto it = builtins_by_name.find(name);
- if(it == builtins_by_name.end())
+ if (it == builtins_by_name.end())
return empty;
return it->second;
}
-private:
+
+ private:
void loadSource(const std::string& source) {
auto module = std::make_shared<script::Module>();
defineMethodsInModule(
module, source, script::nativeResolver, /*self=*/nullptr);
modules.push_back(module);
for (auto& method : module->get_methods()) {
- builtins_by_name[Symbol::fromQualString("aten::" + method.key())].push_back(
- method->get());
+ builtins_by_name[Symbol::fromQualString("aten::" + method.key())]
+ .push_back(method->get());
}
}
void loadBuiltinFunctions() {
- for(auto scalar : {"float", "int"}) {
+ for (auto scalar : {"float", "int"}) {
TemplateEnv env;
env.s("Scalar", scalar);
loadSource(scalar_operators_source.format(env));
using str_pair = std::pair<std::string, std::string>;
const std::vector<str_pair> name_len = {
- str_pair("single", "1"),
- str_pair("pair", "2"),
- str_pair("triple", "3"),
- str_pair("quadruple", "4"),
+ str_pair("single", "1"),
+ str_pair("pair", "2"),
+ str_pair("triple", "3"),
+ str_pair("quadruple", "4"),
};
- for(auto scalar: {"float", "int"}) {
- for (auto pair: name_len) {
+ for (auto scalar : {"float", "int"}) {
+ for (auto pair : name_len) {
TemplateEnv env;
env.s("Scalar", scalar);
env.s("name", pair.first);
}
}
}
- enum {UNINITIALIZED, INTIIALIZING, INITIALIZED} state = UNINITIALIZED;
+ enum { UNINITIALIZED, INTIIALIZING, INITIALIZED } state = UNINITIALIZED;
std::recursive_mutex mutex;
std::vector<std::shared_ptr<Module>> modules;
std::unordered_map<Symbol, std::vector<Method*>> builtins_by_name;
return registry.getAllBuiltinFunctionsFor(name);
}
-}}}
+} // namespace script
+} // namespace jit
+} // namespace torch
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/script/module.h>
-namespace torch { namespace jit { namespace script {
-
+namespace torch {
+namespace jit {
+namespace script {
TORCH_API const std::vector<Method*>& getAllBuiltinFunctionsFor(Symbol name);
-
-
-}}}
+}
+} // namespace jit
+} // namespace torch
-#include <torch/csrc/jit/script/compiler.h>
-#include <torch/csrc/jit/script/schema_matching.h>
-#include <torch/csrc/jit/script/final_returns.h>
-#include <torch/csrc/jit/passes/lower_tuples.h>
-#include <torch/csrc/jit/script/type_parser.h>
-#include <torch/csrc/jit/passes/constant_pooling.h>
-#include <torch/csrc/jit/operator.h>
+#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/hooks_for_testing.h>
#include <torch/csrc/jit/interpreter.h>
#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/operator.h>
+#include <torch/csrc/jit/passes/constant_pooling.h>
+#include <torch/csrc/jit/passes/lower_tuples.h>
+#include <torch/csrc/jit/script/compiler.h>
+#include <torch/csrc/jit/script/final_returns.h>
#include <torch/csrc/jit/script/parser.h>
-#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/script/schema_matching.h>
+#include <torch/csrc/jit/script/type_parser.h>
#include <torch/csrc/utils/object_ptr.h>
-#include <torch/csrc/jit/operator.h>
-#include <torch/csrc/jit/hooks_for_testing.h>
#include <torch/csrc/jit/constants.h>
using ListAttributeMap = std::unordered_map<std::string, std::vector<Const>>;
static Value* asSimple(const SugaredValuePtr& value) {
- if(SimpleValue* sv = dynamic_cast<SimpleValue*>(value.get())) {
+ if (SimpleValue* sv = dynamic_cast<SimpleValue*>(value.get())) {
return sv->getValue();
}
return nullptr;
}
// Auxiliary data structure for desugaring variable binding into our always
-// explicitly scoped language as we descend down
-// nested control structures in the frontend (which themselves don't introduce
-// scopes)
+// explicitly scoped language as we descend down nested control structures in
+// the frontend (which themselves don't introduce scopes)
//
// The algorithm is roughly as follows:
// 1) While emitting a block within a control operator, add inputs and outputs
// the IR API, but for now we choose to pessimisitically create inputs and
// delete unnecessary ones later with replaceAllusesWith().
struct Environment {
- Environment(Method & method, Resolver resolver, Block* b, std::shared_ptr<Environment> next = nullptr)
- : method(method), resolver(std::move(resolver)), b(b), next(std::move(next)) {}
+ Environment(
+ Method& method,
+ Resolver resolver,
+ Block* b,
+ std::shared_ptr<Environment> next = nullptr)
+ : method(method),
+ resolver(std::move(resolver)),
+ b(b),
+ next(std::move(next)) {}
- Method & method;
+ Method& method;
Resolver resolver;
std::vector<std::string> captured_inputs;
std::unordered_map<std::string, std::string> error_messages;
// set type error in the lowest environment. if the variable is used after an
// error has been set, then we will use the more informative error message
- void setVariableTypeError(const std::string& name, const std::string &msg) {
+ void setVariableTypeError(const std::string& name, const std::string& msg) {
auto runner = this;
while (runner->next) {
runner = runner->next.get();
SugaredValuePtr findInAnyFrame(const std::string& name) {
for (auto runner = this; runner; runner = runner->next.get()) {
- if(auto r = runner->findInThisFrame(name)) {
+ if (auto r = runner->findInThisFrame(name)) {
return r;
}
}
// this ensures consistency of the order of loop-carried dependencies
// even when the use in the loop is in a different order
size_t insert_pos = 0;
- while (insert_pos < captured_inputs.size() && name > captured_inputs[insert_pos]) {
+ while (insert_pos < captured_inputs.size() &&
+ name > captured_inputs[insert_pos]) {
insert_pos++;
}
captured_inputs.insert(captured_inputs.begin() + insert_pos, name);
// Create the input
const size_t loop_carried_block_inputs_offset = 1;
- Value* new_input = b->insertInput(loop_carried_block_inputs_offset + insert_pos)
- ->setType(orig->type());
+ Value* new_input =
+ b->insertInput(loop_carried_block_inputs_offset + insert_pos)
+ ->setType(orig->type());
// Associate this name with this value
auto sv = std::make_shared<SimpleValue>(new_input);
return sv;
}
- SugaredValuePtr createCapturedInputIfNeeded(const SourceRange& loc, const std::string& ident) {
+ SugaredValuePtr createCapturedInputIfNeeded(
+ const SourceRange& loc,
+ const std::string& ident) {
auto in_frame = findInThisFrame(ident);
if (in_frame) {
return in_frame;
}
// recursively handles the case where parent blocks are also loops
- auto from_parent = next ? next->createCapturedInputIfNeeded(loc, ident) : nullptr;
+ auto from_parent =
+ next ? next->createCapturedInputIfNeeded(loc, ident) : nullptr;
// recursively create the captured input if it is the loop block
if (from_parent && getBlockOwningKind() == prim::Loop) {
setSugaredVar(loc, name, std::make_shared<SimpleValue>(value));
}
- void setSugaredVar(const SourceRange& loc, const std::string& name, SugaredValuePtr value) {
+ void setSugaredVar(
+ const SourceRange& loc,
+ const std::string& name,
+ SugaredValuePtr value) {
Value* as_simple_value = asSimple(value);
if (as_simple_value && !as_simple_value->hasUniqueName() &&
meaningfulName(name) &&
- // note: if the value wasn't defined in this block, we might be giving a name
- // only used inside this block to a value outside of this. this is not
- // normally helpful for debugging and causes import/export jitter.
+ // note: if the value wasn't defined in this block, we might be giving a
+ // name only used inside this block to a value outside of this. this is
+ // not normally helpful for debugging and causes import/export jitter.
as_simple_value->node()->owningBlock() == block()) {
as_simple_value->setUniqueName(name);
}
// a = ..
// requires 'a' to be first-class in the graph since its value depends on
// control flow
- if(auto parent = findInParentFrame(name)) {
- if(!as_simple_value) {
- throw ErrorReport(loc) << "Cannot re-assign '" << name << "' to a value of type " << value->kind() <<
- " because " << name << " is not a first-class value. Only reassignments to first-class values are allowed";
+ if (auto parent = findInParentFrame(name)) {
+ if (!as_simple_value) {
+ throw ErrorReport(loc)
+ << "Cannot re-assign '" << name << "' to a value of type "
+ << value->kind() << " because " << name
+ << " is not a first-class value. Only reassignments to first-class values are allowed";
}
Value* simple_parent = asSimple(parent);
- if(!simple_parent) {
- throw ErrorReport(loc) << "Cannot re-assign '" << name << "' because it has type " << value->kind() <<
- " and " << name << " is not a first-class value. Only reassignments to first-class values are allowed";
+ if (!simple_parent) {
+ throw ErrorReport(loc)
+ << "Cannot re-assign '" << name << "' because it has type "
+ << value->kind() << " and " << name
+ << " is not a first-class value. Only reassignments to first-class values are allowed";
}
if (!as_simple_value->type()->isSubtypeOf(
unshapedType(simple_parent->type()))) {
value_table[name] = std::move(value);
}
- SugaredValuePtr getSugaredVar(const Ident& ident, bool required=true) {
+ SugaredValuePtr getSugaredVar(const Ident& ident, bool required = true) {
return getSugaredVar(ident.name(), ident.range());
}
Value* getVar(const Ident& ident) {
return getSugaredVar(ident)->asValue(ident.range(), method);
}
- SugaredValuePtr getSugaredVar(const std::string& ident, const SourceRange& range, bool required=true) {
+ SugaredValuePtr getSugaredVar(
+ const std::string& ident,
+ const SourceRange& range,
+ bool required = true) {
auto retval = createCapturedInputIfNeeded(range, ident);
- if(!retval) {
+ if (!retval) {
static std::unordered_map<std::string, SugaredValuePtr> globals = {
- {"print", std::make_shared<PrintValue>()},
- {"float", std::make_shared<CastValue>(FloatType::get(), prim::Float)},
- {"int", std::make_shared<CastValue>(IntType::get(), prim::Int)},
- {"bool", std::make_shared<CastValue>(BoolType::get(), prim::Bool)},
- {"getattr", std::make_shared<GetAttrValue>()},
- {"isinstance", std::make_shared<IsInstanceValue>()},
- // todo(zach): remove when we can correctly export torch.full via ONNX
- // or we have implicit conversion that can convert numbers to tensors
- {"_to_tensor", std::make_shared<CastValue>(DynamicType::get(), prim::NumToTensor)},
- {"len", std::make_shared<BuiltinFunction>(aten::len, at::nullopt)},
+ {"print", std::make_shared<PrintValue>()},
+ {"float", std::make_shared<CastValue>(FloatType::get(), prim::Float)},
+ {"int", std::make_shared<CastValue>(IntType::get(), prim::Int)},
+ {"bool", std::make_shared<CastValue>(BoolType::get(), prim::Bool)},
+ {"getattr", std::make_shared<GetAttrValue>()},
+ {"isinstance", std::make_shared<IsInstanceValue>()},
+ // todo(zach): remove when we can correctly export torch.full via ONNX
+ // or we have implicit conversion that can convert numbers to tensors
+ {"_to_tensor",
+ std::make_shared<CastValue>(DynamicType::get(), prim::NumToTensor)},
+ {"len", std::make_shared<BuiltinFunction>(aten::len, at::nullopt)},
};
auto it = globals.find(ident);
- if(it != globals.end())
+ if (it != globals.end())
retval = it->second;
}
- if(!retval) {
+ if (!retval) {
retval = resolver(ident, method, range);
}
// captured_inputs: lcd0, lcd1, ...
JIT_ASSERT(b->inputs().size() == b->outputs().size());
JIT_ASSERT(b->inputs().size() == captured_inputs.size() + 1);
- for(size_t i = b->inputs().size() - 1; i > 0; i--) {
+ for (size_t i = b->inputs().size() - 1; i > 0; i--) {
// nothing changed along this loop
- if(b->inputs()[i] == b->outputs()[i]) {
+ if (b->inputs()[i] == b->outputs()[i]) {
auto name = captured_inputs[i - 1];
Value* orig = findInParentFrame(name)->asValue(loc, method);
b->inputs()[i]->replaceAllUsesWith(orig);
}
std::vector<std::string> definedVariables() {
std::vector<std::string> result;
- for(auto & kv : value_table) {
+ for (auto& kv : value_table) {
result.push_back(kv.first);
}
return result;
}
-private:
+
+ private:
ValueTable value_table;
};
-template<class T>
-static Value* materializeConstant(T val, Graph& graph,
- const SourceRange& r, std::unordered_map<T, Value*>& map) {
+template <class T>
+static Value* materializeConstant(
+ T val,
+ Graph& graph,
+ const SourceRange& r,
+ std::unordered_map<T, Value*>& map) {
auto existing_constant = map.find(val);
if (existing_constant != map.end()) {
return existing_constant->second;
}
static Value* ensureInt(const SourceRange& range, Value* v) {
- if(!v->type()->isSubtypeOf(IntType::get())) {
- throw ErrorReport(range) << "expected a int but found a "
- << v->type()->str();
+ if (!v->type()->isSubtypeOf(IntType::get())) {
+ throw ErrorReport(range)
+ << "expected a int but found a " << v->type()->str();
}
return v;
}
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) {
- return std::make_shared<SimpleValue>(emitBuiltinCall(
- loc, *m.graph(), symbol, self, inputs, attributes, true));
+ return std::make_shared<SimpleValue>(
+ emitBuiltinCall(loc, *m.graph(), symbol, self, inputs, attributes, true));
}
inline bool isSupportedListElementType(const TypePtr& type) {
Resolver resolver_,
const SugaredValuePtr& self,
Method& method) // method being constructed
- : method(method)
- , graph(method.graph())
- , resolver(std::move(resolver_))
- , environment_stack(nullptr) {
+ : method(method),
+ graph(method.graph()),
+ resolver(std::move(resolver_)),
+ environment_stack(nullptr) {
JIT_ASSERT(resolver);
pushFrame(graph->block(), /*starts_def=*/true);
- // Type annotations exclude explicitly typing the "self" parameter, so in the
- // case that this is a method with self we expect one fewer parameter annotation
- // than the number of parameters this Def takes.
+ // Type annotations exclude explicitly typing the "self" parameter, so in
+ // the case that this is a method with self we expect one fewer parameter
+ // annotation than the number of parameters this Def takes.
if (self && def.decl().params().size() == 0) {
- throw ErrorReport(def.decl().params().range()) << "methods must have a self argument";
+ throw ErrorReport(def.decl().params().range())
+ << "methods must have a self argument";
}
method.setSchema(emitDef(def, self, graph->block()));
runCleanupPasses(graph);
}
-private:
+ private:
Method& method;
std::shared_ptr<Graph> graph;
Resolver resolver;
std::shared_ptr<Environment> environment_stack;
std::vector<DefContext> def_stack_;
- void pushFrame(Block * b, bool starts_def=false) {
+ void pushFrame(Block* b, bool starts_def = false) {
if (starts_def) {
def_stack_.emplace_back();
}
- environment_stack = std::make_shared<Environment>(method, resolver, b, environment_stack);
+ environment_stack =
+ std::make_shared<Environment>(method, resolver, b, environment_stack);
}
- std::shared_ptr<Environment> popFrame(bool ends_def=false) {
+ std::shared_ptr<Environment> popFrame(bool ends_def = false) {
auto old_frame = environment_stack;
environment_stack = environment_stack->next;
- if(ends_def) {
+ if (ends_def) {
def_stack_.pop_back();
}
return old_frame;
ConstantPooling(to_clean);
}
- FunctionSchema emitDef(const Def& def, const SugaredValuePtr& self, Block* block) {
+ FunctionSchema emitDef(
+ const Def& def,
+ const SugaredValuePtr& self,
+ Block* block) {
auto schema = extractSchemaFromDef(def, self);
if (schema.returns().size() == 1) {
def_stack_.back().declared_return_type_ = schema.returns().at(0).type();
}
- std::vector<Argument> arguments = emitFormalArguments(def, self, schema, block);
-
+ std::vector<Argument> arguments =
+ emitFormalArguments(def, self, schema, block);
// body
auto stmts_list = moveAllReturnsToEnd(def.statements());
return {def.name().name(), std::move(arguments), std::move(returns)};
}
- std::vector<IValue> evaluateDefaults(const SourceRange& r, const std::vector<Expr>& default_types, const std::vector<Expr>& default_exprs) {
+ std::vector<IValue> evaluateDefaults(
+ const SourceRange& r,
+ const std::vector<Expr>& default_types,
+ const std::vector<Expr>& default_exprs) {
std::vector<IValue> default_values;
if (default_exprs.empty())
return default_values;
// To evaluate the default expressions, we create a graph with no inputs,
// and whose returns are the default values we need.
- // We then run constant prop on this graph and check the results are constant.
- // This approach avoids having to have separate handling of default arguments
- // from standard expressions by piecing together existing machinery for
- // graph generation, constant propgation, and constant extraction.
+ // We then run constant prop on this graph and check the results are
+ // constant. This approach avoids having to have separate handling of
+ // default arguments from standard expressions by piecing together existing
+ // machinery for graph generation, constant propgation, and constant
+ // extraction.
auto tuple_type = Subscript::create(
r,
Var::create(r, Ident::create(r, "Tuple")),
List<Expr>::create(r, default_types));
- auto blank_decl =
- Decl::create(r, List<Param>::create(r, {}), Maybe<Expr>::create(r, tuple_type));
+ auto blank_decl = Decl::create(
+ r, List<Param>::create(r, {}), Maybe<Expr>::create(r, tuple_type));
- auto tuple_expr = TupleLiteral::create(r, List<Expr>::create(r, default_exprs));
+ auto tuple_expr =
+ TupleLiteral::create(r, List<Expr>::create(r, default_exprs));
auto ret = Return::create(r, tuple_expr);
auto def = Def::create(
r,
return stack.at(0).toTuple()->elements();
}
- std::vector<Argument> parseArgsFromDecl(const Decl& decl, const SugaredValuePtr& self) {
+ std::vector<Argument> parseArgsFromDecl(
+ const Decl& decl,
+ const SugaredValuePtr& self) {
auto params_begin = decl.params().begin();
auto params_end = decl.params().end();
if (self)
default_exprs.emplace_back(def.get());
}
}
- auto default_values = evaluateDefaults(decl.range(), default_types, default_exprs);
+ auto default_values =
+ evaluateDefaults(decl.range(), default_types, default_exprs);
auto defaults_it = default_values.begin();
for (auto it = params_begin; it != params_end; ++it) {
TypePtr type;
c10::optional<int32_t> N;
- //BroadcastList list can only appear at the argument level
+ // BroadcastList list can only appear at the argument level
if (auto maybe_broad_list = parseBroadcastList(decl_arg.type())) {
type = maybe_broad_list->first;
N = maybe_broad_list->second;
std::vector<Argument> parseReturnFromDecl(const Decl& decl) {
// we represent no annoation on a return type as having no values in the
// schema's return() list
- // in emitReturn we take the actual return value to be the value of the return
- // statement if no one was provided here
- if(!decl.return_type().present())
+ // in emitReturn we take the actual return value to be the value of the
+ // return statement if no one was provided here
+ if (!decl.return_type().present())
return {};
if (parseBroadcastList(decl.return_type().get()))
- throw ErrorReport(decl.return_type().range()) << "Broadcastable lists cannot appear as a return type";
+ throw ErrorReport(decl.return_type().range())
+ << "Broadcastable lists cannot appear as a return type";
auto parsed_type = parseTypeFromExpr(decl.return_type().get());
return {Argument(
"",
/*default_value =*/c10::nullopt,
/*kwarg_only =*/false)};
}
- FunctionSchema extractSchemaFromDef(const Def &def, const SugaredValuePtr& self) {
- auto name = def.name().name();
- std::vector<Argument> args = parseArgsFromDecl(def.decl(), self);
- std::vector<Argument> returns = parseReturnFromDecl(def.decl());
- return FunctionSchema(name, std::move(args), std::move(returns), false, false);
+ FunctionSchema extractSchemaFromDef(
+ const Def& def,
+ const SugaredValuePtr& self) {
+ auto name = def.name().name();
+ std::vector<Argument> args = parseArgsFromDecl(def.decl(), self);
+ std::vector<Argument> returns = parseReturnFromDecl(def.decl());
+ return FunctionSchema(
+ name, std::move(args), std::move(returns), false, false);
}
- std::vector<Argument> emitFormalArguments(const Def& def, const SugaredValuePtr& self, const FunctionSchema& schema, Block* block) {
+ std::vector<Argument> emitFormalArguments(
+ const Def& def,
+ const SugaredValuePtr& self,
+ const FunctionSchema& schema,
+ Block* block) {
std::vector<Argument> arguments; // for schema
// inputs
auto it = def.decl().params().begin();
auto end = def.decl().params().end();
- auto expected_annotation_size = self ? def.decl().params().size() - 1 : def.decl().params().size();
+ auto expected_annotation_size =
+ self ? def.decl().params().size() - 1 : def.decl().params().size();
if (schema.arguments().size() != expected_annotation_size) {
- throw ErrorReport(def.decl().params().range()) << "Number of type annotations for"
- << " function parameters (" << schema.arguments().size() << ")"
- << " does not match the number of parameters on the function ("
- << expected_annotation_size << ")!";
+ throw ErrorReport(def.decl().params().range())
+ << "Number of type annotations for"
+ << " function parameters (" << schema.arguments().size() << ")"
+ << " does not match the number of parameters on the function ("
+ << expected_annotation_size << ")!";
}
- if(self) {
+ if (self) {
JIT_ASSERT(it != end);
environment_stack->setSugaredVar(def.range(), (*it).ident().name(), self);
++it;
}
size_t arg_annotation_idx = 0;
- for(;it != end; ++it) {
+ for (; it != end; ++it) {
auto& name = (*it).ident().name();
// Add the input to the graph
- Value *new_input = block->addInput();
+ Value* new_input = block->addInput();
if (meaningfulName(name)) {
new_input->setUniqueName(name);
}
return arguments;
}
- Argument emitOutput(const SourceRange& range, const FunctionSchema& schema, Block* block) {
+ Argument emitOutput(
+ const SourceRange& range,
+ const FunctionSchema& schema,
+ Block* block) {
// rewrites ensure there is always a return statement in program
JIT_ASSERT(def_stack_.back().merged_return_type_);
// outputs
return emitStatements(statements.begin(), statements.end());
}
std::pair<std::shared_ptr<Graph>, Value*> lambdaLift(Block* block) {
- auto subgraph = std::make_shared<Graph>();
- // note: type is set later on pack_context and context when we know it
- Node* pack_context = graph->insertNode(graph->create(prim::TupleConstruct, {}, 1));
- Value* context = subgraph->addInput("context");
- // cannot use createTupleUnpack because the type is not known yet
- Node* unpack_context = subgraph->insertNode(subgraph->create(prim::TupleUnpack, {context}, 0));
-
- std::unordered_map<Value*, Value*> captures;
- auto env = [&](Value* v) -> Value* {
- auto it = captures.find(v);
- if (it != captures.end()) {
- return it->second;
- }
- pack_context->addInput(v);
- Value* r = unpack_context->addOutput()->copyMetadata(v);
- captures[v] = r;
- return r;
- };
- subgraph->block()->cloneFrom(block, env);
- auto context_type = TupleType::create(
- fmap(pack_context->inputs(), [](Value* v) { return v->type(); }));
- pack_context->output()->setType(context_type);
- context->setType(context_type);
- return std::make_pair(std::move(subgraph), pack_context->output());
+ auto subgraph = std::make_shared<Graph>();
+ // note: type is set later on pack_context and context when we know it
+ Node* pack_context =
+ graph->insertNode(graph->create(prim::TupleConstruct, {}, 1));
+ Value* context = subgraph->addInput("context");
+ // cannot use createTupleUnpack because the type is not known yet
+ Node* unpack_context =
+ subgraph->insertNode(subgraph->create(prim::TupleUnpack, {context}, 0));
+
+ std::unordered_map<Value*, Value*> captures;
+ auto env = [&](Value* v) -> Value* {
+ auto it = captures.find(v);
+ if (it != captures.end()) {
+ return it->second;
+ }
+ pack_context->addInput(v);
+ Value* r = unpack_context->addOutput()->copyMetadata(v);
+ captures[v] = r;
+ return r;
+ };
+ subgraph->block()->cloneFrom(block, env);
+ auto context_type = TupleType::create(
+ fmap(pack_context->inputs(), [](Value* v) { return v->type(); }));
+ pack_context->output()->setType(context_type);
+ context->setType(context_type);
+ return std::make_pair(std::move(subgraph), pack_context->output());
}
// XXX - right now closures are used _only_ for defining gradients internally
// There are several unfinished aspects that make them unusable generally
- // 1. We do not have a type, ivalue, operator to represent prim::Function, so closure_node has type None
+ // 1. We do not have a type, ivalue, operator to represent prim::Function, so
+ // closure_node has type None
// and any graphs that contain it cannot be run
- // 2. There is no export logic for it yet, so it cannot be exported/python_printed
- // 3. There is nothing preventing the assignment of already existing variables inside the closures
+ // 2. There is no export logic for it yet, so it cannot be
+ // exported/python_printed
+ // 3. There is nothing preventing the assignment of already existing variables
+ // inside the closures
// the changes to those variables will just get forgotten.
// 4. There is no parsing support in frontend.py, this is intentional since it
// prevents people from accidentally using this feature.
void emitClosure(const Def& def) {
Node* closure_node = graph->insertNode(graph->create(prim::Function, 1));
- closure_node->output()->setType(NoneType::get()); //it is not a real thing yet, so just say the type is none.
+ closure_node->output()->setType(
+ NoneType::get()); // it is not a real thing yet, so just say the type is
+ // none.
Block* block = closure_node->addBlock();
{
WithInsertPoint guard(block);
pushFrame(block, /*starts_def=*/true);
- emitDef(def, nullptr, block); //ignore schema return, we just wont use it for now since we never create a Method for the closure
+ emitDef(
+ def,
+ nullptr,
+ block); // ignore schema return, we just wont use it for now since we
+ // never create a Method for the closure
popFrame(/*ends_def=*/true);
}
std::shared_ptr<Graph> subgraph;
runCleanupPasses(subgraph);
closure_node->eraseBlock(0);
closure_node->g_(attr::Subgraph, std::move(subgraph));
- auto tup = graph->insertNode(graph->createTuple({closure_node->output(), context}))->output();
+ auto tup =
+ graph->insertNode(graph->createTuple({closure_node->output(), context}))
+ ->output();
environment_stack->setVar(def.name().range(), def.name().name(), tup);
}
TypePtr result_type = def_stack_.back().declared_return_type_;
// result type is annotated, every return must convert to that type
if (result_type) {
- // this guard skips implicit conversion from None -> Tensor for the return type.
- // otherwise forgetting a return a function returning a tensor will cause a None to be
- // converted to a tensor.
- if (!(result_type->isSubtypeOf(DynamicType::get()) && result->type()->isSubtypeOf(NoneType::get()))) {
+ // this guard skips implicit conversion from None -> Tensor for the return
+ // type. otherwise forgetting a return a function returning a tensor will
+ // cause a None to be converted to a tensor.
+ if (!(result_type->isSubtypeOf(DynamicType::get()) &&
+ result->type()->isSubtypeOf(NoneType::get()))) {
result = tryConvertToType(
- stmt.range(), *graph, result_type, result, /*allow_conversions=*/true);
+ stmt.range(),
+ *graph,
+ result_type,
+ result,
+ /*allow_conversions=*/true);
}
if (!result->type()->isSubtypeOf(result_type)) {
- throw ErrorReport(stmt.range()) << "Return value was annotated as having type " << result_type->python_str()
- << " but is actually of type " << result->type()->python_str();
+ throw ErrorReport(stmt.range())
+ << "Return value was annotated as having type "
+ << result_type->python_str() << " but is actually of type "
+ << result->type()->python_str();
}
} else {
result_type = def_stack_.back().merged_return_type_;
if (!result_type) {
result_type = result->type();
}
- if(!unifyTypes(result_type, result->type())) {
+ if (!unifyTypes(result_type, result->type())) {
throw ErrorReport(stmt.range())
<< "Previous return statement returned a value of type "
<< result_type->python_str()
environment_stack->setVar(stmt.range(), "$return", result);
}
- void emitStatements(List<Stmt>::const_iterator begin, List<Stmt>::const_iterator end) {
+ void emitStatements(
+ List<Stmt>::const_iterator begin,
+ List<Stmt>::const_iterator end) {
for (; begin != end; ++begin) {
auto stmt = *begin;
switch (stmt.kind()) {
case TK_GLOBAL:
for (auto ident : Global(stmt).names()) {
const auto& name = Ident(ident).name();
- environment_stack->setVar(ident.range(), name, graph->addInput(name));
+ environment_stack->setVar(
+ ident.range(), name, graph->addInput(name));
}
break;
case TK_EXPR_STMT: {
auto expr = ExprStmt(stmt).expr();
emitSugaredExpr(expr, 0);
- }
- break;
+ } break;
case TK_RAISE:
emitRaise(Raise(stmt).range());
break;
return popFrame();
}
- Node* create(Symbol kind, const SourceRange& loc, size_t n_outputs) {
- return graph
- ->create(kind, n_outputs)
- ->setSourceLocation(std::make_shared<SourceRange>(loc));
+ Node* create(Symbol kind, const SourceRange& loc, size_t n_outputs) {
+ return graph->create(kind, n_outputs)
+ ->setSourceLocation(std::make_shared<SourceRange>(loc));
}
Value* emitTernaryIf(const TernaryIf& expr) {
Value* cond_value = emitCond(expr.cond());
- auto true_expr = [&] {
- return emitExpr(expr.true_expr());
- };
- auto false_expr = [&] {
- return emitExpr(expr.false_expr());
- };
+ auto true_expr = [&] { return emitExpr(expr.true_expr()); };
+ auto false_expr = [&] { return emitExpr(expr.false_expr()); };
return emitIfExpr(expr.range(), cond_value, true_expr, false_expr);
}
Value* emitShortCircuitIf(
const SourceRange& loc,
- const TreeRef & first_expr,
- const TreeRef & second_expr,
+ const TreeRef& first_expr,
+ const TreeRef& second_expr,
bool is_or) {
- Value * first_value = emitCond(Expr(first_expr));
+ Value* first_value = emitCond(Expr(first_expr));
- auto get_first_expr = [first_value] {
- return first_value;
- };
- auto get_second_expr = [&] {
- return emitCond(Expr(second_expr));
- };
+ auto get_first_expr = [first_value] { return first_value; };
+ auto get_second_expr = [&] { return emitCond(Expr(second_expr)); };
// if this is an OR, eval second expression if first expr is False.
// If this is an AND, eval second expression if first expr is True
}
}
- Value* emitIfExpr(const SourceRange& range, Value * cond_value,
- std::function<Value*()> true_expr, std::function<Value*()> false_expr) {
+ Value* emitIfExpr(
+ const SourceRange& range,
+ Value* cond_value,
+ std::function<Value*()> true_expr,
+ std::function<Value*()> false_expr) {
Node* n = graph->insertNode(create(prim::If, range, 0));
n->addInput(cond_value);
// a =
// ... = a # OK, a is defined along all paths
-
- //ordered set, because we want deterministic graph output
+ // ordered set, because we want deterministic graph output
std::set<std::string> mutated_variables;
- for(auto & v : save_true->definedVariables()) {
- if(save_false->findInAnyFrame(v)) {
+ for (auto& v : save_true->definedVariables()) {
+ if (save_false->findInAnyFrame(v)) {
mutated_variables.insert(v);
}
}
- for(auto & v : save_false->definedVariables()) {
- if(save_true->findInAnyFrame(v)) {
+ for (auto& v : save_false->definedVariables()) {
+ if (save_true->findInAnyFrame(v)) {
mutated_variables.insert(v);
}
}
auto fv = save_false->getVar(x, stmt.range());
auto unified = unifyTypes(tv->type(), fv->type());
- // attempt to unify the types. we allow variables to be set to different types
- // in each branch as long as that variable is not already in scope,
+ // attempt to unify the types. we allow variables to be set to different
+ // types in each branch as long as that variable is not already in scope,
// or if that variable does not get used later. here, we save the error
// so that the error message will be more informative in the case that is
// used later. When a is accessed in (a + 1), the error will get printed
//
if (!unified) {
ErrorReport error(stmt);
- error << "Type mismatch: " << x << " is set to type " << tv->type()->str() << " in the true branch"
- << " and type " << fv->type()->str() << " in the false branch";
- if (save_true->findInParentFrame(x) || save_false->findInParentFrame(x)) {
+ error << "Type mismatch: " << x << " is set to type "
+ << tv->type()->str() << " in the true branch"
+ << " and type " << fv->type()->str() << " in the false branch";
+ if (save_true->findInParentFrame(x) ||
+ save_false->findInParentFrame(x)) {
throw error;
} else {
// error gets saved in the lowest environment because all
- // variables are scoped to the function. doesn't matter if this accessed
- // through save_true or save_false
+ // variables are scoped to the function. doesn't matter if this
+ // accessed through save_true or save_false
save_true->setVariableTypeError(x, error.what());
continue;
}
}
true_block->registerOutput(tv);
false_block->registerOutput(fv);
- environment_stack->setVar(stmt.range(), x, n->addOutput()->setType(*unified));
+ environment_stack->setVar(
+ stmt.range(), x, n->addOutput()->setType(*unified));
}
}
void emitIf(const If& stmt) {
- // NOTE: emitIf checks on If stmt condition to see if the cond AST kind == is/is not,
- // for such cases we do meta programming and disable emitting the corresponding branches
+ // NOTE: emitIf checks on If stmt condition to see if the cond AST kind ==
+ // is/is not, for such cases we do meta programming and disable emitting the
+ // corresponding branches
Expr cond = stmt.cond();
if (cond.kind() != TK_IS && cond.kind() != TK_ISNOT) {
emitIfElseBlocks(cond_value, stmt);
return;
}
- // meta programming on AST for is/is not cases and emit branches base on the possible output of cond
+ // meta programming on AST for is/is not cases and emit branches base on the
+ // possible output of cond
auto cond_op = BinOp(cond);
SugaredValuePtr lhs_val = emitSugaredExpr(cond_op.lhs(), 1);
SugaredValuePtr rhs_val = emitSugaredExpr(cond_op.rhs(), 1);
- List<Stmt> always_none_branch = cond.kind() == TK_IS? stmt.trueBranch(): stmt.falseBranch();
- List<Stmt> never_none_branch = cond.kind() == TK_IS? stmt.falseBranch(): stmt.trueBranch();
+ List<Stmt> always_none_branch =
+ cond.kind() == TK_IS ? stmt.trueBranch() : stmt.falseBranch();
+ List<Stmt> never_none_branch =
+ cond.kind() == TK_IS ? stmt.falseBranch() : stmt.trueBranch();
- auto lhs_none= lhs_val->isNone();
- auto rhs_none= rhs_val->isNone();
+ auto lhs_none = lhs_val->isNone();
+ auto rhs_none = rhs_val->isNone();
// Dispatch logic (A: ALWAYS, N: NEVER, M: MAYBE):
//
if (lhs_none == ALWAYS && rhs_none == ALWAYS) {
// None is/is not None: only emit the always_none_branch
emitStatements(always_none_branch);
- } else if ((lhs_none == ALWAYS && rhs_none == NEVER) ||
- (lhs_none == NEVER && rhs_none == ALWAYS)){
+ } else if (
+ (lhs_none == ALWAYS && rhs_none == NEVER) ||
+ (lhs_none == NEVER && rhs_none == ALWAYS)) {
// lhs_val/rhs_val with A/M: only emit never_none_branch
emitStatements(never_none_branch);
- }
- else {
+ } else {
// all other cases for lhs_val and rhs_val
// emit the whole If stmt as usual, finish emitCond first
auto lhs_range = cond_op.lhs().get()->range();
*method.graph(),
kind,
c10::nullopt,
- {lhs_val->asValue(lhs_range, method), rhs_val->asValue(rhs_range, method)},
+ {lhs_val->asValue(lhs_range, method),
+ rhs_val->asValue(rhs_range, method)},
{},
/*required=*/true);
emitIfElseBlocks(cond_value, stmt);
-
}
-
}
// *********************** Loop Operators ************************************
// the format of the Loop instruction is:
// loop_carried_outputs* = Loop(max_trip_count, start_condition,
- // loop_carried_inputs*)
- // block0(loop_counter, loop_carried_block*) {
- // <body>
- // -> (continue_condition,
- // loop_carried_block_outputs*)
- // }
+ // loop_carried_inputs*)
+ // block0(loop_counter, loop_carried_block*) {
+ // <body>
+ // -> (continue_condition, loop_carried_block_outputs*)
+ // }
// all loop_carried_... lists are the same length and represent the value of
// loop-carried variables whose definitions are updated as the loop executes
// in a way that ensure single static assignment.
max_trip_count_val = ensureInt(
max_trip_count->range(), emitExpr(max_trip_count.value()));
} else {
- max_trip_count_val =
- materializeConstant(std::numeric_limits<int64_t>::max(), *graph, range, integral_constants);
+ max_trip_count_val = materializeConstant(
+ std::numeric_limits<int64_t>::max(),
+ *graph,
+ range,
+ integral_constants);
}
if (cond) {
cond_val = emitCond(cond.value());
n->addInput(max_trip_count_val);
n->addInput(cond_val);
auto* body_block = n->addBlock();
- Value* trip_count = body_block->addInput()->setType(IntType::get()); // Iteration num
+ Value* trip_count =
+ body_block->addInput()->setType(IntType::get()); // Iteration num
{
pushFrame(body_block);
if (itr_ident) {
- environment_stack->setVar(itr_ident->range(), itr_ident->name(), trip_count);
+ environment_stack->setVar(
+ itr_ident->range(), itr_ident->name(), trip_count);
}
WithInsertPoint guard(body_block);
emitStatements(body);
body_frame->deleteExtraInputs(range);
// register node inputs/outputs for the true loop carried deps,
- for(size_t i = 0; i < body_frame->captured_inputs.size(); ++i) {
+ for (size_t i = 0; i < body_frame->captured_inputs.size(); ++i) {
auto x = body_frame->captured_inputs[i];
n->addInput(outer_frame->getVar(x, range));
// body_block->inputs(): loop_counter, lcd0, lcd1, ...
auto typ = body_block->inputs()[i + 1]->type();
outer_frame->setVar(range, x, n->addOutput()->setType(typ));
}
-
}
}
- void emitForRange(const SourceRange& range, const Ident& target, const List<Expr>& args, const List<Stmt>& body) {
+ void emitForRange(
+ const SourceRange& range,
+ const Ident& target,
+ const List<Expr>& args,
+ const List<Stmt>& body) {
// TODO: start, stop, step loop
if (args.size() != 1) {
throw ErrorReport(range)
<< "List of iterables is not supported currently.";
}
if (targets.size() != 1) {
- throw ErrorReport(stmt) << "Iteration variable unpacking is not supported";
+ throw ErrorReport(stmt)
+ << "Iteration variable unpacking is not supported";
}
if (targets[0].kind() != TK_VAR) {
- throw ErrorReport(targets[0]) << "unexpected expression in variable initialization of for loop";
+ throw ErrorReport(targets[0])
+ << "unexpected expression in variable initialization of for loop";
}
auto target = Var(targets[0]).name();
if (range_iterator.callee().kind() == TK_VAR) {
Var var = Var(range_iterator.callee());
if (var.name().name() == "range") {
- return emitForRange(stmt.range(), target, range_iterator.inputs(), body);
+ return emitForRange(
+ stmt.range(), target, range_iterator.inputs(), body);
}
}
}
- // it isn't a range(<expr>) loop, treat it as a sugared value that maybe can be
- // unrolled
+ // it isn't a range(<expr>) loop, treat it as a sugared value that maybe can
+ // be unrolled
auto sv = emitSugaredExpr(itrs[0], 1);
auto instances = sv->asTuple(stmt.range(), method);
const std::string& target_name = target.name();
pushFrame(environment_stack->block());
- for(const auto& inst : instances) {
+ for (const auto& inst : instances) {
environment_stack->setSugaredVar(itrs[0].range(), target_name, inst);
emitStatements(body);
}
- for (const auto & n : environment_stack->definedVariables()) {
+ for (const auto& n : environment_stack->definedVariables()) {
if (environment_stack->findInParentFrame(n)) {
- environment_stack->next->setVar(stmt.range(), n, environment_stack->getVar(n, stmt.range()));
+ environment_stack->next->setVar(
+ stmt.range(), n, environment_stack->getVar(n, stmt.range()));
}
}
popFrame();
emitLoopCommon(stmt.range(), {}, {cond}, stmt.body(), {});
}
-
// Currently we do not support assigning exceptions to variables,
// a = Exception("hi")
// raise a
/* true_block =*/n->addBlock();
auto* false_block = n->addBlock();
- //if assert test is false throw exception
+ // if assert test is false throw exception
pushFrame(false_block);
WithInsertPoint guard(false_block);
emitRaise(stmt.range());
popFrame();
}
-
// Validate that the `lhs` Expr's in an assignment statement are valid. That
// is:
//
// 1) All lhs Expr's are either Var or Starred nodes
// 2) There is at most one Starred node in the lhs Expr
- // 3) A Starred node can only appear when there is another non-Starred lhs Expr
- // Concretely this means that `*abc = func()` is illegal. Unpacking all
- // outputs into a tuple is covered by `abc = func()`.
+ // 3) A Starred node can only appear when there is another non-Starred lhs
+ // Expr. Concretely this means that `*abc = func()` is illegal. Unpacking
+ // all outputs into a tuple is covered by `abc = func()`.
bool calcNumStarredUnpack(const List<Expr>& lhs, const SourceRange& r) {
size_t num_normal_assign = 0;
size_t num_starred = 0;
if (num_starred > 0 && num_normal_assign == 0) {
throw ErrorReport(r) << "A Starred expression may only appear on the "
- << "lhs within the presence of another non-starred"
- << " expression.";
+ << "lhs within the presence of another non-starred"
+ << " expression.";
}
return num_starred;
// If the RHS is a tensor, return the corresponding ATen in-place op
// If it's a list of scalars, then return the corresponding list augment op
Symbol getAugOp(const AugAssign& stmt, bool isTensor) {
- switch (stmt.aug_op()) {
- case '+':
- return isTensor ? aten::add_ : aten::add;
- case '-':
- return isTensor ? aten::sub_ : aten::sub;
- case '/':
- return isTensor ? aten::div_ : aten::div;
- case '*':
- return isTensor ? aten::mul_ : aten::mul;
- default:
- throw ErrorReport(stmt) << "Unknown augmented assignment: "
- << kindToString(stmt.aug_op());
- }
+ switch (stmt.aug_op()) {
+ case '+':
+ return isTensor ? aten::add_ : aten::add;
+ case '-':
+ return isTensor ? aten::sub_ : aten::sub;
+ case '/':
+ return isTensor ? aten::div_ : aten::div;
+ case '*':
+ return isTensor ? aten::mul_ : aten::mul;
+ default:
+ throw ErrorReport(stmt)
+ << "Unknown augmented assignment: " << kindToString(stmt.aug_op());
+ }
}
// Emit nodes for augmented assignments like `+=`
// in place op, and throw error for other unsupported types
void emitAugAssignmentToSelectVar(const AugAssign& stmt) {
const auto lhs = Select(stmt.lhs());
- const auto lhsSugaredVar = environment_stack->getSugaredVar(Var(lhs.value()).name());
- const auto lhsValue = lhsSugaredVar->attr(lhs.range(), method, lhs.selector().name())->asValue(lhs.range(), method);
+ const auto lhsSugaredVar =
+ environment_stack->getSugaredVar(Var(lhs.value()).name());
+ const auto lhsValue =
+ lhsSugaredVar->attr(lhs.range(), method, lhs.selector().name())
+ ->asValue(lhs.range(), method);
if (lhsValue->type()->isSubtypeOf(DynamicType::get())) {
// for module parameter/buffer assignment, only consider tensor types,
// emit the corresponding in-place op
/*required=*/true);
} else {
- throw ErrorReport(stmt.lhs())
- << "left-hand side of augmented assignment to module "
- << "parameters/buffers can only be tensor types";
+ throw ErrorReport(stmt.lhs())
+ << "left-hand side of augmented assignment to module "
+ << "parameters/buffers can only be tensor types";
}
}
} else {
// Special case: we tried to do "advanced indexing". Lower this expr
// into `index` and `index_put_` ops
- const auto indices = graph->insertNode(
- graph->createList(DynamicType::get(), tensorIndices))->output();
+ const auto indices = graph
+ ->insertNode(graph->createList(
+ DynamicType::get(), tensorIndices))
+ ->output();
const auto indexed =
graph->insert(aten::index, {slicedArg, indices}, {}, stmt.range());
const auto augmented = emitBuiltinCall(
const SourceRange& stmtRange,
const Subscript& lhs,
const Expr& rhs) {
- emitSubscriptAssign(
- stmtRange, lhs, NamedValue(rhs.range(), emitExpr(rhs)));
+ emitSubscriptAssign(stmtRange, lhs, NamedValue(rhs.range(), emitExpr(rhs)));
}
void emitSubscriptAssign(
} else {
// Special case: we tried to do "advanced indexing" with a tensor.
// Dispatch to `aten::index_put_`.
- const auto indices = graph->insertNode(
- graph->createList(DynamicType::get(), tensorIndices))->output();
+ const auto indices = graph
+ ->insertNode(graph->createList(
+ DynamicType::get(), tensorIndices))
+ ->output();
graph->insert(
aten::index_put_, {slicedArg, indices, rhs}, {}, stmtRange);
}
- // Otherwise, this is a list. Dispatch to aten::_set_item to both select and
- // assign
+ // Otherwise, this is a list. Dispatch to aten::_set_item to both select
+ // and assign
} else {
const auto subscript = lhs.subscript_exprs();
if (subscript.size() != 1 || subscript[0].kind() == TK_SLICE_EXPR) {
void emitTupleAssign(const TupleLiteral& tl, const Expr& rhs) {
size_t n_binders = tl.inputs().size();
bool starred_unpack = calcNumStarredUnpack(tl.inputs(), tl.range());
- if(starred_unpack)
+ if (starred_unpack)
n_binders--;
auto output = emitSugaredExpr(rhs, n_binders);
auto outputs = output->asTuple(
rhs.range(),
method,
starred_unpack ? c10::nullopt : c10::optional<size_t>{n_binders});
- if(outputs.size() < n_binders) {
+ if (outputs.size() < n_binders) {
throw ErrorReport(tl)
- << "need " << (starred_unpack ? "at least " : "")
- << n_binders << " values to unpack but found only "
- << outputs.size();
+ << "need " << (starred_unpack ? "at least " : "") << n_binders
+ << " values to unpack but found only " << outputs.size();
}
- if(outputs.size() > n_binders && !starred_unpack) {
- throw ErrorReport(tl)
- << "too many values to unpack: need " << n_binders << " but found "
- << outputs.size();
+ if (outputs.size() > n_binders && !starred_unpack) {
+ throw ErrorReport(tl) << "too many values to unpack: need " << n_binders
+ << " but found " << outputs.size();
}
int i = 0;
for (auto assignee : tl.inputs()) {
i++;
break;
case TK_VAR:
- environment_stack->setSugaredVar(assignee.range(), Var(assignee).name().name(), outputs.at(i));
+ environment_stack->setSugaredVar(
+ assignee.range(), Var(assignee).name().name(), outputs.at(i));
i++;
break;
case TK_STARRED: {
auto var = Starred(assignee).expr();
if (var.kind() != TK_VAR) {
- throw ErrorReport(var) << "Cannot pack a tuple into a non-variable.";
+ throw ErrorReport(var)
+ << "Cannot pack a tuple into a non-variable.";
}
size_t n_matched = outputs.size() - n_binders;
ArrayRef<std::shared_ptr<SugaredValue>> outputs_ref = outputs;
- auto values = fmap(outputs_ref.slice(i, n_matched), [&](const std::shared_ptr<SugaredValue>& v) {
- return v->asValue(assignee.range(), method);
- });
+ auto values = fmap(
+ outputs_ref.slice(i, n_matched),
+ [&](const std::shared_ptr<SugaredValue>& v) {
+ return v->asValue(assignee.range(), method);
+ });
auto tup = graph->insertNode(graph->createTuple(values))->output();
- environment_stack->setVar(
- var.range(), Var(var).name().name(), tup);
+ environment_stack->setVar(var.range(), Var(var).name().name(), tup);
i += n_matched;
} break;
default:
- throw ErrorReport(assignee) << "unexpected expression on the left-hand side";
+ throw ErrorReport(assignee)
+ << "unexpected expression on the left-hand side";
}
}
}
void emitAssignment(const Assign& stmt) {
- switch(stmt.lhs().kind()) {
+ switch (stmt.lhs().kind()) {
case TK_VAR: {
auto v = Var(stmt.lhs());
environment_stack->setSugaredVar(
emitSubscriptAssign(stmt.range(), Subscript(stmt.lhs()), stmt.rhs());
break;
default:
- throw ErrorReport(stmt.lhs()) << "unexpected expression on left-hand side of assignment.";
+ throw ErrorReport(stmt.lhs())
+ << "unexpected expression on left-hand side of assignment.";
}
}
}
}
-
-
std::vector<NamedValue> getNamedValues(
const TreeList& trees,
bool maybe_unpack) {
std::vector<NamedValue> values;
for (const auto& tree : trees) {
- if(maybe_unpack && tree->kind() == TK_STARRED) {
+ if (maybe_unpack && tree->kind() == TK_STARRED) {
auto starred = Starred(tree);
- auto entries = emitSugaredExpr(starred.expr(), 1)->asTuple(starred.range(), method);
- for(const auto& entry : entries) {
+ auto entries = emitSugaredExpr(starred.expr(), 1)
+ ->asTuple(starred.range(), method);
+ for (const auto& entry : entries) {
values.emplace_back(
tree->range(), entry->asValue(starred.range(), method));
}
return getNamedValues(trees.tree()->trees(), maybe_unpack);
}
- std::vector<Value*> getValues(
- const TreeList& trees,
- bool maybe_unpack) {
+ std::vector<Value*> getValues(const TreeList& trees, bool maybe_unpack) {
return toValues(*graph, getNamedValues(trees, maybe_unpack));
}
- std::vector<Value*> getValues(
- const List<Expr>& trees,
- bool maybe_unpack) {
+ std::vector<Value*> getValues(const List<Expr>& trees, bool maybe_unpack) {
return getValues(trees.tree()->trees(), maybe_unpack);
}
std::vector<NamedValue> emitAttributes(const List<Attribute>& attributes) {
return fmap(attributes, [&](const Attribute& attr) {
- return NamedValue(attr.range(), attr.name().name(), emitExpr(attr.value()));
+ return NamedValue(
+ attr.range(), attr.name().name(), emitExpr(attr.value()));
});
}
void checkApplyExpr(Apply& apply, SourceRange& loc) {
if (apply.inputs().size() != 2) {
- throw ErrorReport(loc)
- << Var(apply.callee()).name().name()
- << " expected exactly two arguments but found "
- << apply.inputs().size();
+ throw ErrorReport(loc) << Var(apply.callee()).name().name()
+ << " expected exactly two arguments but found "
+ << apply.inputs().size();
}
if (apply.attributes().size() > 0) {
throw ErrorReport(loc)
- << Var(apply.callee()).name().name()
- << " takes no keyword arguments";
+ << Var(apply.callee()).name().name() << " takes no keyword arguments";
}
}
- std::shared_ptr<SugaredValue> emitApplyExpr(Apply &apply, size_t n_binders) {
+ std::shared_ptr<SugaredValue> emitApplyExpr(Apply& apply, size_t n_binders) {
auto sv = emitSugaredExpr(apply.callee(), 1);
auto loc = apply.callee().range();
if (auto fork_value = dynamic_cast<ForkValue*>(sv.get())) {
<< " but found " << expr->type()->python_str();
}
return std::make_shared<SimpleValue>(expr);
- } else if(auto getattr = dynamic_cast<GetAttrValue*>(sv.get())) {
+ } else if (auto getattr = dynamic_cast<GetAttrValue*>(sv.get())) {
checkApplyExpr(apply, loc);
auto obj = emitSugaredExpr(apply.inputs()[0], 1);
auto selector = apply.inputs()[1];
if (selector.kind() != TK_STRINGLITERAL) {
- throw ErrorReport(loc) << "getattr's second argument must be a string literal";
+ throw ErrorReport(loc)
+ << "getattr's second argument must be a string literal";
}
const std::string& name = StringLiteral(selector).text();
return obj->attr(apply.range(), method, name);
} else if (auto isinstance = dynamic_cast<IsInstanceValue*>(sv.get())) {
- // NOTE: for `isinstance` builtin call in JIT, we only check the static types
- // on the inputs to evaluate, and insert the corresponding constant node
- std::function<bool(Expr, Expr)> isInstanceCheck = [&](Expr obj, Expr classinfo) {
+ // NOTE: for `isinstance` builtin call in JIT, we only check the static
+ // types on the inputs to evaluate, and insert the corresponding constant
+ // node
+ std::function<bool(Expr, Expr)> isInstanceCheck = [&](Expr obj,
+ Expr classinfo) {
if (classinfo.kind() == TK_TUPLE_LITERAL) {
// handle the case for recursive tuple classinfo
// return true if obj is an instance of any of the types
- for (Expr e: TupleLiteral(classinfo).inputs()) {
+ for (Expr e : TupleLiteral(classinfo).inputs()) {
if (isInstanceCheck(obj, e)) {
return true;
}
}
auto type_name = parseBaseTypeName(classinfo);
if (!type_name) {
- throw ErrorReport(classinfo.range()) << "type must be a type identifier";
+ throw ErrorReport(classinfo.range())
+ << "type must be a type identifier";
}
auto val = emitExpr(obj);
- // Special casing for list and tuple since isintance(x, list) and isinstance(x, tuple)
- // does not accept List[int] / Tuple[int] like subscript type annotation in python
+ // Special casing for list and tuple since isintance(x, list) and
+ // isinstance(x, tuple) does not accept List[int] / Tuple[int] like
+ // subscript type annotation in python
if (*type_name == "list" && val->type()->cast<ListType>()) {
return true;
} else if (*type_name == "tuple" && val->type()->cast<TupleType>()) {
return true;
} else if (val->type()->cast<OptionalType>()) {
throw ErrorReport(loc)
- << "Optional isinstance check is not supported, consider use is/isnot None instead";
+ << "Optional isinstance check is not supported, consider use is/isnot None instead";
} else {
TypePtr type = parseTypeFromExpr(classinfo);
if (val->type()->isSubtypeOf(type)) {
return false;
};
checkApplyExpr(apply, loc);
- bool is_instance_val = isInstanceCheck(apply.inputs()[0], apply.inputs()[1]);
- return std::make_shared<SimpleValue>(graph->insertConstant(is_instance_val, loc));
+ bool is_instance_val =
+ isInstanceCheck(apply.inputs()[0], apply.inputs()[1]);
+ return std::make_shared<SimpleValue>(
+ graph->insertConstant(is_instance_val, loc));
} else {
auto inputs = getNamedValues(apply.inputs(), true);
auto attributes = emitAttributes(apply.attributes());
} else if (kind == aten::ge) {
return aten::le;
}
- throw std::runtime_error("reverseComparision: unsupported NodeKind. File a bug");
+ throw std::runtime_error(
+ "reverseComparision: unsupported NodeKind. File a bug");
}
// any expression that can produce a SugaredValue is handled here
// or a = torch.jit.annotate(List[int], [])
// the caller is responsible for checking that the result matches type_hint
// emitSugaredExpr is free to ignore it.
- std::shared_ptr<SugaredValue> emitSugaredExpr(const Expr& tree, size_t n_binders, const TypePtr& type_hint=nullptr) {
- switch(tree.kind()) {
+ std::shared_ptr<SugaredValue> emitSugaredExpr(
+ const Expr& tree,
+ size_t n_binders,
+ const TypePtr& type_hint = nullptr) {
+ switch (tree.kind()) {
case TK_VAR:
return environment_stack->getSugaredVar(Var(tree).name());
case '.': {
}
}
- Value * emitNegate(const TreeRef& tree) {
+ Value* emitNegate(const TreeRef& tree) {
const auto& inputs = tree->trees();
auto named_values = getNamedValues(inputs, /*maybe_unpack=*/false);
auto neg_val = emitBuiltinCall(
- tree->range(),
- *method.graph(),
- aten::neg,
- c10::nullopt,
- named_values,
- {},
- /*required=*/true);
+ tree->range(),
+ *method.graph(),
+ aten::neg,
+ c10::nullopt,
+ named_values,
+ {},
+ /*required=*/true);
// constant fold the input if possible
auto maybe_constant_input = toIValue(neg_val->node()->input());
// This function extract a new graph from its original subgraph
std::shared_ptr<SugaredValue> emitForkExpr(
SourceRange loc,
- const std::shared_ptr<SugaredValue> &forked,
+ const std::shared_ptr<SugaredValue>& forked,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes) {
// Build the fork node without inputs
- auto fork_node = method.graph()->insertNode(method.graph()->create(prim::fork, 1))
- ->setSourceLocation(std::make_shared<SourceRange>(loc));
+ auto fork_node =
+ method.graph()
+ ->insertNode(method.graph()->create(prim::fork, 1))
+ ->setSourceLocation(std::make_shared<SourceRange>(loc));
auto body_block = fork_node->addBlock();
// Build a template of the graph to be executed
- Value *node_output;
+ Value* node_output;
{
WithInsertPoint guard(body_block);
auto fn_sugared_output = forked->call(loc, method, inputs, attributes, 1);
auto fn_simple_output = fn_sugared_output->asValue(loc, method);
body_block->registerOutput(fn_simple_output);
- node_output = fork_node->output()->setType(FutureType::create(fn_simple_output->type()));
+ node_output = fork_node->output()->setType(
+ FutureType::create(fn_simple_output->type()));
}
// Fork a new graph from its orignal owning graph
auto kind = getNodeKind(tree->kind(), inputs.size());
auto named_values = getNamedValues(inputs, /*maybe_unpack=*/false);
return emitBuiltinCall(
- tree->range(),
- *method.graph(),
- kind,
- c10::nullopt,
- named_values,
- {},
- /*required=*/true);
+ tree->range(),
+ *method.graph(),
+ kind,
+ c10::nullopt,
+ named_values,
+ {},
+ /*required=*/true);
}
case TK_UNARY_MINUS: {
return emitNegate(tree);
case TK_OR: {
const auto& inputs = tree->trees();
return emitShortCircuitIf(
- tree->range(),
- inputs[0],
- inputs[1],
- tree->kind() == TK_OR);
+ tree->range(), inputs[0], inputs[1], tree->kind() == TK_OR);
}
case TK_STARRED: {
- throw ErrorReport(tree) << "Unexpected starred expansion. File a bug report.";
+ throw ErrorReport(tree)
+ << "Unexpected starred expansion. File a bug report.";
}
case TK_CONST: {
return emitConst(Const(tree));
<< *elem_type << " but found " << *v->type() << " instead";
}
}
- Value* result = graph->insertNode(graph->createList(elem_type, values))
- ->output();
+ Value* result =
+ graph->insertNode(graph->createList(elem_type, values))->output();
return result;
} break;
case TK_TUPLE_LITERAL: {
Value* emitConst(const Const& c) {
if (c.isFloatingPoint())
- return materializeConstant(c.asFloatingPoint(), *graph, c.range(), fp_constants);
+ return materializeConstant(
+ c.asFloatingPoint(), *graph, c.range(), fp_constants);
else
- return materializeConstant(c.asIntegral(), *graph, c.range(), integral_constants);
+ return materializeConstant(
+ c.asIntegral(), *graph, c.range(), integral_constants);
}
Value* emitStringLiteral(const StringLiteral& c) {
int64_t dim,
Value* index) {
return emitBuiltinCall(
- loc, *graph, aten::select, c10::nullopt,
- {input, graph->insertConstant(dim, loc), index}, {}, true);
+ loc,
+ *graph,
+ aten::select,
+ c10::nullopt,
+ {input, graph->insertConstant(dim, loc), index},
+ {},
+ true);
}
- // Desugars slice indexing: tensor[begin:end] -> tensor.slice(dim, begin, end, 1)
+ // Desugars slice indexing: tensor[begin:end] -> tensor.slice(dim, begin, end,
+ // 1)
Value* emitSlice(
const SourceRange& loc,
Value* input,
}
if (input->type()->cast<TupleType>()) {
if (has_end) {
- return emitTupleSlice(loc, args[0], args[1], /*end*/args[2]);
+ return emitTupleSlice(loc, args[0], args[1], /*end*/ args[2]);
} else {
return emitTupleSlice(loc, args[0], args[1], c10::nullopt);
}
}
NamedValue step = NamedValue(loc, "step", graph->insertConstant(1, loc));
- return emitBuiltinCall(loc, *graph, aten::slice, c10::nullopt, args, {step}, true);
+ return emitBuiltinCall(
+ loc, *graph, aten::slice, c10::nullopt, args, {step}, true);
}
Value* emitIndex(
const SourceRange& loc,
Value* input,
at::ArrayRef<Value*> indices) {
- auto* index = graph->insertNode(
- graph->createList(DynamicType::get(), indices))->output();
- return emitBuiltinCall(loc, *graph, aten::index, c10::nullopt, {input, index}, {}, true);
+ auto* index =
+ graph->insertNode(graph->createList(DynamicType::get(), indices))
+ ->output();
+ return emitBuiltinCall(
+ loc, *graph, aten::index, c10::nullopt, {input, index}, {}, true);
}
// Emits multidimensional slicing with int and slice indices.
// Returns:
// - Value*: the input after it has been indexed by int and slice indices.
- // - vector<Value*>: A list of tensor Value* indices that have not been applied yet.
- // Should be NULL at indices where sliceable (post-slicing) isn't indexed by a tensor.
+ // - vector<Value*>: A list of tensor Value* indices that have not been
+ // applied yet.
+ // Should be NULL at indices where sliceable (post-slicing) isn't indexed by
+ // a tensor.
std::pair<Value*, std::vector<Value*>> emitIntAndSliceIndexing(
const SourceRange& loc,
Value* sliceable,
dim++;
};
- for (const auto & subscript_expr : subscript_exprs) {
+ for (const auto& subscript_expr : subscript_exprs) {
if (subscript_expr.kind() == TK_SLICE_EXPR) {
sliceable = emitSlice(loc, sliceable, dim, SliceExpr(subscript_expr));
++dim;
continue;
}
throw ErrorReport(loc)
- << "Unsupported operation: indexing tensor with unsupported index type "
- << index->type()->str() << ". Only ints, slices, and tensors are supported.";
+ << "Unsupported operation: indexing tensor with unsupported index type "
+ << index->type()->str()
+ << ". Only ints, slices, and tensors are supported.";
}
// at::index takes in a TensorList where some tensors can be undefined.
// Convert NULL tensorIndices to undefined tensors to pass to at::index.
// enough dimensions to index".
//
// The strategy is to slice and select the tensor for int and slices first
- // in one pass and then apply at::index on the result of the slicing/selecting.
- // Call the tensor after we've applied slice / select the `sliced`.
- // tensor_indices should have the same size as sliced.dim():
+ // in one pass and then apply at::index on the result of the
+ // slicing/selecting. Call the tensor after we've applied slice / select the
+ // `sliced`. tensor_indices should have the same size as sliced.dim():
// - tensor_indices[i] = NULL if we should not index `sliced` at dim i
// - tensor_indices[i] = t if we should index `sliced` at dim i with tensor t.
Value* emitMultidimSlicing(
const List<Expr>& subscript_exprs) {
if (!sliceable->type()->isSubtypeOf(DynamicType::get())) {
throw ErrorReport(loc)
- << "Unsupported operation: attempted to use multidimensional "
- << "indexing on a non-tensor type.";
+ << "Unsupported operation: attempted to use multidimensional "
+ << "indexing on a non-tensor type.";
}
std::vector<Value*> tensor_indices;
return emitSlice(loc, sliceable, maybe_dim, slice_exp);
}
- int64_t getTupleIndexVal(const SourceRange& loc,
- const TupleTypePtr& tuple_type,
- Value * idx_val,
+ int64_t getTupleIndexVal(
+ const SourceRange& loc,
+ const TupleTypePtr& tuple_type,
+ Value* idx_val,
bool allow_out_of_bounds) {
- int64_t index;
+ int64_t index;
at::optional<IValue> ivalue = toIValue(idx_val);
if (ivalue && ivalue->isInt()) {
index = ivalue->to<int64_t>();
} else {
- throw ErrorReport(loc)
- << "tuple indices must be integer constants";
+ throw ErrorReport(loc) << "tuple indices must be integer constants";
}
- // set index to be positive to simplify logic in runtime
+ // set index to be positive to simplify logic in runtime
int64_t adj_index = index;
int64_t tuple_len = tuple_type->elements().size();
if (index < 0) {
adj_index = tuple_len + index;
}
if (!allow_out_of_bounds && (adj_index >= tuple_len || adj_index < 0)) {
- throw ErrorReport(loc)
- << "Tuple index out of range. Tuple is length " << tuple_len
- << " and index is " << index;
+ throw ErrorReport(loc) << "Tuple index out of range. Tuple is length "
+ << tuple_len << " and index is " << index;
}
return adj_index;
}
- Value* emitTupleIndex(const SourceRange& loc,
- Value * tuple_val,
- Value * idx_val) {
+ Value* emitTupleIndex(
+ const SourceRange& loc,
+ Value* tuple_val,
+ Value* idx_val) {
auto tuple_typ = tuple_val->type()->cast<TupleType>();
- auto adj_index = getTupleIndexVal(loc, tuple_typ, idx_val, /*allow_out_of_bounds*/false);
- return graph->insertNode(
- graph->createTupleIndex(tuple_val, adj_index))->output();
+ auto adj_index = getTupleIndexVal(
+ loc, tuple_typ, idx_val, /*allow_out_of_bounds*/ false);
+ return graph->insertNode(graph->createTupleIndex(tuple_val, adj_index))
+ ->output();
}
- Value* emitTupleSlice(const SourceRange& loc,
+ Value* emitTupleSlice(
+ const SourceRange& loc,
const NamedValue& tuple_val,
const NamedValue& beg_val,
const at::optional<NamedValue>& end_val) {
auto tuple_type = tuple_val.value(*graph)->type()->expect<TupleType>();
- int64_t beg = getTupleIndexVal(loc, tuple_type, beg_val.value(*graph), /*allow_out_of_bounds*/true);
+ int64_t beg = getTupleIndexVal(
+ loc, tuple_type, beg_val.value(*graph), /*allow_out_of_bounds*/ true);
int64_t end;
int64_t tuple_len = tuple_type->elements().size();
if (end_val) {
end = std::min(std::max((int64_t)0, end), tuple_len);
beg = std::min(std::max((int64_t)0, beg), tuple_len);
- return graph->insertNode(
- graph->createTupleSlice(tuple_val.value(*graph), beg, end))->output();
+ return graph
+ ->insertNode(graph->createTupleSlice(tuple_val.value(*graph), beg, end))
+ ->output();
}
Value* emitSubscript(const Subscript& subscript) {
// if it's a list, emit a regular index selection op
auto* idx = emitExpr(subscript_exprs[0]);
return emitBuiltinCall(
- loc, *graph, aten::select, c10::nullopt, {gatherable, idx}, {}, true);
+ loc, *graph, aten::select, c10::nullopt, {gatherable, idx}, {}, true);
} else if (gatherable->type()->isSubtypeOf(DynamicType::get())) {
return emitMultidimSlicing(loc, gatherable, subscript_exprs);
} else if (auto tuple_type = gatherable->type()->cast<TupleType>()) {
return emitTupleIndex(loc, gatherable, idx);
} else {
throw ErrorReport(loc)
- << "Indexing only supported on lists, tensors, and tuples.";
+ << "Indexing only supported on lists, tensors, and tuples.";
}
}
};
-void defineMethodsInModule(const std::shared_ptr<Module>& m, const std::vector<Def>& definitions, const std::vector<Resolver>& resolvers, const SugaredValuePtr& self) {
+void defineMethodsInModule(
+ const std::shared_ptr<Module>& m,
+ const std::vector<Def>& definitions,
+ const std::vector<Resolver>& resolvers,
+ const SugaredValuePtr& self) {
JIT_ASSERT(definitions.size() == resolvers.size());
auto resolver_it = resolvers.begin();
std::vector<Method*> methods;
std::unordered_map<std::string, Method*> function_table;
- for(const Def& def : definitions) {
+ for (const Def& def : definitions) {
const std::string& name = def.name().name();
auto resolver = *resolver_it++;
JIT_ASSERT(resolver);
- if(!self) {
- // if self is defined, then these are methods and do not go into the global namespace
- // otherwise, they get defined together so we add them to the function table
- // so the methods can see each other
+ if (!self) {
+ // if self is defined, then these are methods and do not go into the
+ // global namespace otherwise, they get defined together so we add them to
+ // the function table so the methods can see each other
resolver = [resolver, &function_table](
const std::string& name,
Method& m,
}
auto creator = [def, resolver, self](Method& method) {
JIT_ASSERT(resolver);
- to_ir(def, resolver, self, method);
+ to_ir(def, resolver, self, method);
};
Method& method = m->create_method(name, creator);
function_table[name] = &method;
methods.push_back(&method);
}
- for(Method* method : methods) {
+ for (Method* method : methods) {
method->ensure_defined();
}
didFinishEmitModule(m);
}
-void defineMethodsInModule(const std::shared_ptr<Module>& m, const std::string& source, const Resolver& resolver, const SugaredValuePtr& self) {
+void defineMethodsInModule(
+ const std::shared_ptr<Module>& m,
+ const std::string& source,
+ const Resolver& resolver,
+ const SugaredValuePtr& self) {
Parser p(source);
std::vector<Def> definitions;
std::vector<Resolver> resolvers;
defineMethodsInModule(m, definitions, resolvers, self);
}
-
} // namespace script
} // namespace jit
} // namespace torch
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/script/error_report.h>
-#include <torch/csrc/jit/script/tree_views.h>
#include <torch/csrc/jit/script/module.h>
#include <torch/csrc/jit/script/sugared_value.h>
+#include <torch/csrc/jit/script/tree_views.h>
namespace torch {
namespace jit {
namespace script {
-using Resolver = std::function<std::shared_ptr<SugaredValue>(const std::string& name, Method& m, const SourceRange& loc)>;
+using Resolver = std::function<std::shared_ptr<
+ SugaredValue>(const std::string& name, Method& m, const SourceRange& loc)>;
-inline std::shared_ptr<SugaredValue> nativeResolver(const std::string& name, Method& m, const SourceRange& loc){
+inline std::shared_ptr<SugaredValue> nativeResolver(
+ const std::string& name,
+ Method& m,
+ const SourceRange& loc) {
if (name == "torch") {
return std::make_shared<BuiltinModule>("aten");
}
}
TORCH_API void defineMethodsInModule(
- const std::shared_ptr<Module>& m,
- const std::vector<Def>& definitions,
- const std::vector<Resolver>& resolvers, /* determines how we handle free variables in each definition*/
- const std::shared_ptr<SugaredValue>& self /* if non-null, the first argument to each def, is bound to this value */
+ const std::shared_ptr<Module>& m,
+ const std::vector<Def>& definitions,
+ const std::vector<Resolver>& resolvers, /* determines how we handle free
+ variables in each definition*/
+ const std::shared_ptr<SugaredValue>&
+ self /* if non-null, the first argument to each def, is bound to this
+ value */
);
// same as above but parse the definitions from source
-TORCH_API void defineMethodsInModule(const std::shared_ptr<Module>& m, const std::string& source, const Resolver& resolver, const std::shared_ptr<SugaredValue>& self);
+TORCH_API void defineMethodsInModule(
+ const std::shared_ptr<Module>& m,
+ const std::string& source,
+ const Resolver& resolver,
+ const std::shared_ptr<SugaredValue>& self);
} // namespace script
} // namespace jit
explicit ErrorReport(const SourceRange& r)
: context(std::make_shared<SourceRange>(r)) {}
explicit ErrorReport(std::shared_ptr<SourceLocation> loc)
- : context(std::move(loc)) {}
+ : context(std::move(loc)) {}
explicit ErrorReport(const TreeRef& tree) : ErrorReport(tree->range()) {}
explicit ErrorReport(const Token& tok) : ErrorReport(tok.range) {}
const char* what() const noexcept override {
-#include <torch/csrc/jit/script/final_returns.h>
#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/script/final_returns.h>
namespace torch {
namespace jit {
void checkNoReturn(const TreeRef& ref) {
if (ref->kind() == TK_RETURN)
throw ErrorReport(ref) << "return is not allowed from a loop.";
- for(const TreeRef& child : ref->trees()) {
+ for (const TreeRef& child : ref->trees()) {
checkNoReturn(child);
}
}
// transform stmts so that its last action is to return or report that it
// never returns.
// return_none - if true, add an implicit `return None` to the end of the block
-// this handles the case where the return is implicit at the end of the function.
-ReturnInfo makeReturnsFinal(const SourceRange& range, at::ArrayRef<TreeRef> stmts, bool return_none);
+// this handles the case where the return is implicit at the end of the
+// function.
+ReturnInfo makeReturnsFinal(
+ const SourceRange& range,
+ at::ArrayRef<TreeRef> stmts,
+ bool return_none);
ReturnInfo makeReturnsFinal(const List<Stmt>& stmts, bool return_none) {
return makeReturnsFinal(stmts.range(), stmts.get()->trees(), return_none);
}
-ReturnInfo makeReturnsFinal(const SourceRange& range, at::ArrayRef<TreeRef> stmts, bool return_none) {
+ReturnInfo makeReturnsFinal(
+ const SourceRange& range,
+ at::ArrayRef<TreeRef> stmts,
+ bool return_none) {
std::vector<TreeRef> changed;
changed.reserve(stmts.size());
- for(size_t i = 0; i < stmts.size(); ++i) {
+ for (size_t i = 0; i < stmts.size(); ++i) {
const TreeRef& stmt = stmts[i];
- switch(stmt->kind()) {
+ switch (stmt->kind()) {
case TK_IF: {
auto if_stmt = If(stmt);
auto true_final = makeReturnsFinal(if_stmt.trueBranch(), false);
// (3) early return an if statement without an else block:
if (true_final.returns_ && if_stmt.falseBranch().size() == 0) {
- auto rest_final = makeReturnsFinal(range, stmts.slice(i + 1), return_none);
+ auto rest_final =
+ makeReturnsFinal(range, stmts.slice(i + 1), return_none);
if (!rest_final.returns_) {
throw ErrorReport(if_stmt)
- << "This if statement performs an early return, but the block of code that follows it does not return."
- << " Early returns are only allowed when the block following them also returns.";
+ << "This if statement performs an early return, but the block of code that follows it does not return."
+ << " Early returns are only allowed when the block following them also returns.";
}
- changed.emplace_back(if_stmt.withNewBranches(true_final.stmts_, rest_final.stmts_));
+ changed.emplace_back(
+ if_stmt.withNewBranches(true_final.stmts_, rest_final.stmts_));
return {true, List<Stmt>::unsafeCreate(range, std::move(changed))};
}
}
// (2) all branches return
if (true_final.returns_ && false_final.returns_) {
- changed.emplace_back(if_stmt.withNewBranches(true_final.stmts_, false_final.stmts_));
+ changed.emplace_back(
+ if_stmt.withNewBranches(true_final.stmts_, false_final.stmts_));
return {true, List<Stmt>::unsafeCreate(range, std::move(changed))};
}
throw ErrorReport(if_stmt)
- << "This if statement contains some paths that return and some paths that do not. "
- << "If statements must either entirely return or never return.";
+ << "This if statement contains some paths that return and some paths that do not. "
+ << "If statements must either entirely return or never return.";
} break;
case TK_WHILE:
case TK_FOR:
}
if (return_none) {
// add an implicit return none node
- changed.emplace_back(Return::create(range, Expr(Compound::create(TK_NONE, range, {}))));
+ changed.emplace_back(
+ Return::create(range, Expr(Compound::create(TK_NONE, range, {}))));
}
// we reach the end of the block, no returns have happened
// unless we just inserted a return_none implicit return.
#include <memory>
#include <string>
+#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/script/error_report.h>
#include <torch/csrc/jit/script/tree_views.h>
-#include <torch/csrc/WindowsTorchApiMacro.h>
namespace torch {
namespace jit {
// In particular we allow:
// 1. If statements where neither <true> nor <false> branch returns.
// 2. If statements where both <true> and <false> always return.
-// 3. An 'early return' if statement where <true> always returns <false> is empty, and <rest>
-// always returns.
+// 3. An 'early return' if statement where <true> always returns <false> is
+// empty, and <rest> always returns.
//
// We do not allow returns from loops in any case.
//
// 2. Both branches return, so we recursively transform the program such that
// <true> and <false>'s final action is to return. We then delete <rest>
// because the code is dead. The remaining program preserves the inductive
-// property that its last action is to return since both branches end in a return.
-// 3. In this case we know that <true> and <rest> always returns, and <false> is empty.
+// property that its last action is to return since both branches end in a
+// return.
+// 3. In this case we know that <true> and <rest> always returns, and <false> is
+// empty.
// We transform the graph to:
// if <cond>:
// <true>
TORCH_API List<Stmt> moveAllReturnsToEnd(const List<Stmt>& stmts);
-}
+} // namespace script
} // namespace jit
} // namespace torch
#include <torch/csrc/jit/script/compiler.h>
#include <torch/csrc/jit/script/schema_matching.h>
-#include <torch/csrc/jit/python_tracer.h>
-#include <torch/csrc/jit/pybind_utils.h>
#include <torch/csrc/jit/constants.h>
-#include <torch/csrc/jit/passes/to_batch.h>
#include <torch/csrc/jit/function_schema.h>
-#include <torch/csrc/jit/script/parser.h>
-#include <torch/csrc/jit/import_method.h>
#include <torch/csrc/jit/hooks_for_testing.h>
+#include <torch/csrc/jit/import_method.h>
#include <torch/csrc/jit/passes/python_print.h>
+#include <torch/csrc/jit/passes/to_batch.h>
+#include <torch/csrc/jit/pybind_utils.h>
+#include <torch/csrc/jit/python_tracer.h>
+#include <torch/csrc/jit/script/parser.h>
#include <torch/csrc/api/include/torch/ordered_dict.h>
#include <ATen/ATen.h>
+#include <pybind11/functional.h>
#include <cstddef>
#include <memory>
#include <sstream>
#include <tuple>
#include <utility>
#include <vector>
-#include <pybind11/functional.h>
-
namespace torch {
namespace jit {
bool is_submodule = false);
struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
- PythonValue(py::object self)
- : self(std::move(self)) {}
+ PythonValue(py::object self) : self(std::move(self)) {}
FunctionSchema getSchema(const size_t n_args, const size_t n_binders) {
auto annotations = py::module::import("torch.jit.annotations");
if (!signature.is_none()) {
std::vector<TypePtr> arg_types;
TypePtr ret_type;
- std::tie(arg_types, ret_type) = py::cast<std::pair<std::vector<TypePtr>, TypePtr>>(signature);
+ std::tie(arg_types, ret_type) =
+ py::cast<std::pair<std::vector<TypePtr>, TypePtr>>(signature);
args.reserve(arg_types.size());
size_t idx = 0; // Fake argument names by putting in the index
- for (auto &arg_type : arg_types) {
- args.push_back(Argument(std::to_string(idx++), std::move(arg_type), {}, {}, false));
+ for (auto& arg_type : arg_types) {
+ args.push_back(Argument(
+ std::to_string(idx++), std::move(arg_type), {}, {}, false));
}
rets.push_back(Argument("0", std::move(ret_type), {}, {}, false));
} else {
// Construct the default signature: all arguments and returns will be
// DynamicType
args.reserve(actual_n_args);
- for (size_t i=0; i < actual_n_args; ++i) {
- args.push_back(Argument(std::to_string(i), DynamicType::get(), {}, {}, false));
+ for (size_t i = 0; i < actual_n_args; ++i) {
+ args.push_back(
+ Argument(std::to_string(i), DynamicType::get(), {}, {}, false));
}
TypePtr ret_type = DynamicType::get();
- if(n_binders != 1) {
+ if (n_binders != 1) {
std::vector<TypePtr> tuple_values(n_binders, ret_type);
ret_type = TupleType::create(std::move(tuple_values));
}
}
// call it like a function, e.g. `outputs = this(inputs)`
- std::shared_ptr<SugaredValue> call(const SourceRange& loc, Method & m, at::ArrayRef<NamedValue> inputs_, at::ArrayRef<NamedValue> attributes, size_t n_binders) override {
+ std::shared_ptr<SugaredValue> call(
+ const SourceRange& loc,
+ Method& m,
+ at::ArrayRef<NamedValue> inputs_,
+ at::ArrayRef<NamedValue> attributes,
+ size_t n_binders) override {
auto inputs = toValues(*m.graph(), inputs_);
auto schema = getSchema(inputs.size(), n_binders);
std::stringstream failure_messages;
- c10::optional<MatchedSchema> matched_schema =
- tryMatchSchema(schema, loc, *m.graph(), c10::nullopt, inputs_, attributes, failure_messages, /*conv_tensor_to_num*/true);
+ c10::optional<MatchedSchema> matched_schema = tryMatchSchema(
+ schema,
+ loc,
+ *m.graph(),
+ c10::nullopt,
+ inputs_,
+ attributes,
+ failure_messages,
+ /*conv_tensor_to_num*/ true);
if (!matched_schema)
throw ErrorReport(loc) << failure_messages.str();
py::object func = self;
std::string cconv(inputs.size(), 'd');
Node* new_node = m.graph()->insertNode(m.graph()->createPythonOp(
- THPObjectPtr(func.release().ptr()), cconv, {}));
+ THPObjectPtr(func.release().ptr()), cconv, {}));
new_node->setSourceLocation(std::make_shared<SourceRange>(loc));
- for(auto &i : matched_schema->inputs)
+ for (auto& i : matched_schema->inputs)
new_node->addInput(i);
JIT_ASSERT(matched_schema->return_types.size() == 1);
- Value* output = new_node->addOutput()->setType(matched_schema->return_types.at(0));
+ Value* output =
+ new_node->addOutput()->setType(matched_schema->return_types.at(0));
return std::make_shared<SimpleValue>(output);
}
return ss.str();
}
-protected:
-
+ protected:
py::object getattr(const SourceRange& loc, const std::string& name) {
try {
return py::getattr(self, name.c_str());
};
struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue {
- explicit ConstantPythonTupleValue(py::object tup) : PythonValue(std::move(tup)) {}
+ explicit ConstantPythonTupleValue(py::object tup)
+ : PythonValue(std::move(tup)) {}
std::vector<std::shared_ptr<SugaredValue>> asTuple(
const SourceRange& loc,
Method& m,
return result;
}
- Value* asValue(
- const SourceRange& loc,
- Method& m) override {
+ Value* asValue(const SourceRange& loc, Method& m) override {
std::vector<Value*> values;
for (const auto& sugared_item : asTuple(loc, m)) {
values.push_back(sugared_item->asValue(loc, m));
// anticipating we will eventually need to replace Module with a py::object
// holding the actual nn.Module class.
-
struct ModuleValue : public SugaredValue {
- ModuleValue(std::shared_ptr<Module> module)
- : module(std::move(module)) {}
+ ModuleValue(std::shared_ptr<Module> module) : module(std::move(module)) {}
std::string kind() const override {
return "module";
}
// select an attribute on it, e.g. `this.field`
- std::shared_ptr<SugaredValue> attr(const SourceRange& loc, Method & m, const std::string& field) override {
+ std::shared_ptr<SugaredValue> attr(
+ const SourceRange& loc,
+ Method& m,
+ const std::string& field) override {
// workaround to make self.training work
// it adds a buffer 'training' to the model if one doesn't exist
// and then loads that parameter, casting it to bool
if (!v) {
py::object py_module = py::cast(module);
bool training = py::cast<bool>(py::getattr(py_module, "training"));
- auto t = autograd::make_variable(at::full({}, training ? 1 : 0, at::kLong));
+ auto t =
+ autograd::make_variable(at::full({}, training ? 1 : 0, at::kLong));
module->register_parameter("training", std::move(t), true);
v = module->find_parameter(field);
}
return std::make_shared<SimpleValue>(the_bool);
}
- if(NamedModule* v = module->find_module(field)) {
+ if (NamedModule* v = module->find_module(field)) {
return std::make_shared<ModuleValue>(v->module);
- } else if(Method* v = module->find_method(field)) {
+ } else if (Method* v = module->find_method(field)) {
return std::make_shared<MethodValue>(module, *v);
- } else if(NamedParameter* v = module->find_parameter(field)) {
+ } else if (NamedParameter* v = module->find_parameter(field)) {
return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
}
// This can also be a call to a non-script module, or a plain
// python method. If so return this as a python value.
py::object py_module = py::cast(module);
- if(py::object attr = py::getattr(py_module, field.c_str(), py::none())) {
+ if (py::object attr = py::getattr(py_module, field.c_str(), py::none())) {
if (py::isinstance<py::function>(attr) ||
py::isinstance(attr, py::module::import("torch.nn").attr("Module")) ||
py_module.attr("_constants_set").contains(field.c_str())) {
return toSugaredValue(attr, m, loc, true);
} else {
- throw ErrorReport(loc) << "attribute '" << field << "' of type '" << typeString(attr) << "' is not usable in a script method (did you forget to add it __constants__?)";
+ throw ErrorReport(loc)
+ << "attribute '" << field << "' of type '" << typeString(attr)
+ << "' is not usable in a script method (did you forget to add it __constants__?)";
}
}
throw ErrorReport(loc) << "module has no attribute '" << field << "'";
}
// call module.forward
- std::shared_ptr<SugaredValue> call(const SourceRange& loc, Method & caller, at::ArrayRef<NamedValue> inputs, at::ArrayRef<NamedValue> attributes, size_t n_binders) override {
- return attr(loc, caller, "forward")->call(loc, caller, inputs, attributes, n_binders);
+ std::shared_ptr<SugaredValue> call(
+ const SourceRange& loc,
+ Method& caller,
+ at::ArrayRef<NamedValue> inputs,
+ at::ArrayRef<NamedValue> attributes,
+ size_t n_binders) override {
+ return attr(loc, caller, "forward")
+ ->call(loc, caller, inputs, attributes, n_binders);
}
std::vector<std::shared_ptr<SugaredValue>> asTuple(
Method& m,
const c10::optional<size_t>& size_hint = {}) override {
py::object py_module = py::cast(module);
- if(!py::isinstance(py_module, py::module::import("torch.jit").attr("_ConstModuleList")))
+ if (!py::isinstance(
+ py_module,
+ py::module::import("torch.jit").attr("_ConstModuleList")))
return SugaredValue::asTuple(loc, m, size_hint);
std::vector<std::shared_ptr<SugaredValue>> result;
- for(py::handle module : py_module) {
+ for (py::handle module : py_module) {
py::object obj = py::reinterpret_borrow<py::object>(module);
result.push_back(toSugaredValue(
obj,
const auto v = static_cast<int64_t>(dtype->scalar_type);
return toSimple(g.insertConstant(v, loc));
} else if (py::isinstance<py::tuple>(obj)) {
- return std::make_shared<ConstantPythonTupleValue>(obj);
+ return std::make_shared<ConstantPythonTupleValue>(obj);
}
}
return std::make_shared<PythonModuleValue>(obj);
} else if (obj.ptr() == py::module::import("torch.jit").attr("_fork").ptr()) {
return std::make_shared<ForkValue>();
- } else if (obj.ptr() == py::module::import("torch.jit").attr("annotate").ptr()) {
+ } else if (
+ obj.ptr() == py::module::import("torch.jit").attr("annotate").ptr()) {
return std::make_shared<AnnotateValue>();
}
- py::object builtin_name = py::module::import("torch.jit").attr("_find_builtin")(obj);
+ py::object builtin_name =
+ py::module::import("torch.jit").attr("_find_builtin")(obj);
if (!builtin_name.is_none()) {
return std::make_shared<BuiltinFunction>(
Symbol::fromQualString(py::str(builtin_name)), c10::nullopt);
return py::cast(autograd::as_variable_ref(outputs[0]));
} else {
py::tuple tuple(outputs.size());
- for(size_t i = 0; i < outputs.size(); i++) {
+ for (size_t i = 0; i < outputs.size(); i++) {
tuple[i] = py::cast(autograd::as_variable_ref(outputs[i]));
}
return tuple;
}
}
-static void gatherParametersAndBuffers(std::vector<at::Tensor*> & values, const Module & m) {
- for(auto & param : m.get_parameters()) {
+static void gatherParametersAndBuffers(
+ std::vector<at::Tensor*>& values,
+ const Module& m) {
+ for (auto& param : m.get_parameters()) {
values.push_back(param->slot());
}
- for(const auto & sub : m.get_modules()) {
+ for (const auto& sub : m.get_modules()) {
gatherParametersAndBuffers(values, *sub->module);
}
}
};
}
-}
+} // namespace
FunctionSchema getSchemaWithNameAndDefaults(
const SourceRange& range,
} else {
value = toIValue(it->second, arg.type());
}
- new_args.emplace_back(arg.name(), arg.type(), arg.N(), value, arg.kwarg_only());
+ new_args.emplace_back(
+ arg.name(), arg.type(), arg.N(), value, arg.kwarg_only());
} catch (py::cast_error& e) {
throw ErrorReport(range)
<< "Expected a default value of type " << arg.type()->str()
// public.
py::class_<Module, std::shared_ptr<Module>>(m, "ScriptModule")
.def(py::init<>())
- .def("save", [](std::shared_ptr<Module> m, const std::string& filename) {
- m->save(filename);
- })
- .def("save_to_buffer", [](std::shared_ptr<Module> m) {
- std::ostringstream buf;
- m->save(buf);
- return py::bytes(buf.str());
- })
+ .def(
+ "save",
+ [](std::shared_ptr<Module> m, const std::string& filename) {
+ m->save(filename);
+ })
+ .def(
+ "save_to_buffer",
+ [](std::shared_ptr<Module> m) {
+ std::ostringstream buf;
+ m->save(buf);
+ return py::bytes(buf.str());
+ })
.def("_set_optimized", &Module::set_optimized)
.def(
"_define",
[](std::shared_ptr<Module> m,
const std::string& script,
- ResolutionCallback rcb, bool has_self) {
+ ResolutionCallback rcb,
+ bool has_self) {
auto self = has_self ? std::make_shared<ModuleValue>(m) : nullptr;
defineMethodsInModule(m, script, pythonResolver(rcb), self);
})
- .def("_create_methods", [](std::shared_ptr<Module> m,
- const std::vector<Def>& defs,
- const std::vector<ResolutionCallback>& rcbs,
- const std::vector<FunctionDefaults>& defaults) {
- std::vector<Resolver> resolvers;
- resolvers.reserve(rcbs.size());
- for(auto & callback : rcbs) {
- resolvers.push_back(pythonResolver(callback));
- }
- defineMethodsInModule(
- m,
- defs,
- resolvers,
- std::make_shared<ModuleValue>(m));
-
- // Stitch in default arguments for each Def if provided
- auto defaults_it = defaults.begin();
- auto defs_it = defs.begin();
- while (defs_it != defs.end()) {
- auto& method = m->get_method((*defs_it).name().name());
- method.setSchema(getSchemaWithNameAndDefaults(
- defs_it->range(), method.getSchema(), at::nullopt, *defaults_it));
- ++defs_it;
- ++defaults_it;
- }
- didFinishEmitModule(m);
- })
- .def("_get_method",
- [](Module& self, const std::string& name) -> const Method& {
- return self.get_method(name);
- }, py::return_value_policy::reference_internal)
+ .def(
+ "_create_methods",
+ [](std::shared_ptr<Module> m,
+ const std::vector<Def>& defs,
+ const std::vector<ResolutionCallback>& rcbs,
+ const std::vector<FunctionDefaults>& defaults) {
+ std::vector<Resolver> resolvers;
+ resolvers.reserve(rcbs.size());
+ for (auto& callback : rcbs) {
+ resolvers.push_back(pythonResolver(callback));
+ }
+ defineMethodsInModule(
+ m, defs, resolvers, std::make_shared<ModuleValue>(m));
+
+ // Stitch in default arguments for each Def if provided
+ auto defaults_it = defaults.begin();
+ auto defs_it = defs.begin();
+ while (defs_it != defs.end()) {
+ auto& method = m->get_method((*defs_it).name().name());
+ method.setSchema(getSchemaWithNameAndDefaults(
+ defs_it->range(),
+ method.getSchema(),
+ at::nullopt,
+ *defaults_it));
+ ++defs_it;
+ ++defaults_it;
+ }
+ didFinishEmitModule(m);
+ })
+ .def(
+ "_get_method",
+ [](Module& self, const std::string& name) -> const Method& {
+ return self.get_method(name);
+ },
+ py::return_value_policy::reference_internal)
.def("_register_parameter", &Module::register_parameter)
.def("_register_module", &Module::register_module)
.def("_set_parameter", &Module::set_parameter)
.def("_get_parameter", &Module::get_parameter)
.def("_get_module", &Module::get_module)
- .def("_get_modules", [](Module& self) -> py::tuple {
- auto & modules = self.get_modules();
- py::tuple result(modules.size());
- for(size_t i = 0; i < modules.size(); ++i) {
- auto & item = modules[i];
- result[i] = std::make_pair(item.key(), item.value().module);
- }
- return result;
- })
- .def("_get_parameters", [](Module& self) -> py::tuple {
- auto & parameters = self.get_parameters();
- py::tuple result(parameters.size());
- for(size_t i = 0; i < parameters.size(); ++i) {
- auto & p = parameters[i];
- py::tuple r(3);
- result[i] = std::make_tuple(
- p.key(),
- autograd::as_variable_ref(*p->slot()),
- p->is_buffer);
-
- }
- return result;
- })
- .def("_has_parameter", [](Module& self, const std::string& name) {
- if(auto r = self.find_parameter(name)) {
- return !r->is_buffer;
- }
- return false;
- })
- .def("_has_buffer", [](Module& self, const std::string& name) {
- if(auto r = self.find_parameter(name)) {
- return r->is_buffer;
- }
- return false;
- })
- .def("_has_module", [](Module& self, const std::string& name) {
- return bool(self.find_module(name));
- })
- .def("_has_method", [](Module& self, const std::string& name) {
- return bool(self.find_method(name));
- })
- .def("_method_names", [](Module& self) {
- using Item = torch::OrderedDict<std::string, std::unique_ptr<Method>>::Item;
- return fmap(self.get_methods(), [](const Item & item) {
- return (*item)->name();
- });
- })
- .def("_create_method_from_graph", [](
- Module& self,
- const std::string& name,
- std::shared_ptr<Graph> graph
- ){
- self.create_method(name, std::move(graph), {});
- })
- .def("_create_method_from_trace", [](
- std::shared_ptr<Module> self,
- const std::string& name,
- py::function func,
- py::tuple input_tuple,
- py::function var_lookup_fn,
- bool force_outplace) {
- // prereq: Module's buffers and parameters are unique
- // this was ensured in python before calling this function
- std::vector<at::Tensor*> parameters;
- gatherParametersAndBuffers(parameters, *self);
- Stack inputs = toStack(input_tuple);
- for(at::Tensor* param : parameters) {
- inputs.emplace_back(*param);
- }
- auto graph = tracer::createGraphByTracing(
- func, inputs, var_lookup_fn, force_outplace, input_tuple.size());
- self->create_method(name, std::move(graph), std::move(parameters));
- didFinishEmitModule(self);
- })
- .def("graph_for", [](py::args args, py::kwargs kwargs) {
- // [pybind11 varargs] note: old version of pybind11 have a bug that leaks memory
- // when py::args is mixed with positional arguments
- // https://github.com/pybind/pybind11/pull/1216
- // we work around this by not mixing positional arguments with varargs
- Module& self = py::cast<Module&>(args[0]);
- if (self.find_method("forward")) {
- Method & m = self.get_method("forward");
- return m.graph_for(
- createStackForSchema(m.getSchema(), tuple_slice(std::move(args), 1), kwargs));
- }
- throw std::runtime_error("Attempted to call graph_for on a Module without a compiled forward()");
- })
- .def("get_debug_state", [](Module& self) {
- if (self.find_method("forward")) {
- Method & m = self.get_method("forward");
- return m.getDebugState();
- }
- throw std::runtime_error("Attempted to call get_debug_state on a Module without a compiled forward()");
- })
- .def("debug_disable_autodiff_subgraph_inlining", [](Module& self) {
- if (self.find_method("forward")) {
- Method & m = self.get_method("forward");
- m.debugDisableAutodiffSubgraphInlining();
- }
- })
- .def("forward", [](py::args args, py::kwargs kwargs) {
- // We implement this in C++ to avoid incurring the pybind11 dispatch
- // overhead twice: once to call into the method lookup for "forward"
- // and once to actually invoke the method.
- //
- // There is a thin wrapper on top of this method in the C++ version of
- // ScriptModule.
-
- // see: [pybind11 varargs]
- Module& self = py::cast<Module&>(args[0]);
- return invokeScriptMethodFromPython(self.get_method("forward"), tuple_slice(std::move(args), 1), std::move(kwargs));
- })
- .def("_python_print", [](Module& self) {
- std::ostringstream ss;
- std::vector<at::Tensor> tensors;
- PythonPrint(ss, self, tensors, true);
- return std::make_pair(ss.str(), tensors);
- })
- .def_property_readonly("code", [](Module& self) {
- std::ostringstream ss;
- std::vector<at::Tensor> tensors;
- PythonPrint(ss, self, tensors, false);
- return ss.str();
- })
+ .def(
+ "_get_modules",
+ [](Module& self) -> py::tuple {
+ auto& modules = self.get_modules();
+ py::tuple result(modules.size());
+ for (size_t i = 0; i < modules.size(); ++i) {
+ auto& item = modules[i];
+ result[i] = std::make_pair(item.key(), item.value().module);
+ }
+ return result;
+ })
+ .def(
+ "_get_parameters",
+ [](Module& self) -> py::tuple {
+ auto& parameters = self.get_parameters();
+ py::tuple result(parameters.size());
+ for (size_t i = 0; i < parameters.size(); ++i) {
+ auto& p = parameters[i];
+ py::tuple r(3);
+ result[i] = std::make_tuple(
+ p.key(), autograd::as_variable_ref(*p->slot()), p->is_buffer);
+ }
+ return result;
+ })
+ .def(
+ "_has_parameter",
+ [](Module& self, const std::string& name) {
+ if (auto r = self.find_parameter(name)) {
+ return !r->is_buffer;
+ }
+ return false;
+ })
+ .def(
+ "_has_buffer",
+ [](Module& self, const std::string& name) {
+ if (auto r = self.find_parameter(name)) {
+ return r->is_buffer;
+ }
+ return false;
+ })
+ .def(
+ "_has_module",
+ [](Module& self, const std::string& name) {
+ return bool(self.find_module(name));
+ })
+ .def(
+ "_has_method",
+ [](Module& self, const std::string& name) {
+ return bool(self.find_method(name));
+ })
+ .def(
+ "_method_names",
+ [](Module& self) {
+ using Item =
+ torch::OrderedDict<std::string, std::unique_ptr<Method>>::Item;
+ return fmap(self.get_methods(), [](const Item& item) {
+ return (*item)->name();
+ });
+ })
+ .def(
+ "_create_method_from_graph",
+ [](Module& self,
+ const std::string& name,
+ std::shared_ptr<Graph> graph) {
+ self.create_method(name, std::move(graph), {});
+ })
+ .def(
+ "_create_method_from_trace",
+ [](std::shared_ptr<Module> self,
+ const std::string& name,
+ py::function func,
+ py::tuple input_tuple,
+ py::function var_lookup_fn,
+ bool force_outplace) {
+ // prereq: Module's buffers and parameters are unique
+ // this was ensured in python before calling this function
+ std::vector<at::Tensor*> parameters;
+ gatherParametersAndBuffers(parameters, *self);
+ Stack inputs = toStack(input_tuple);
+ for (at::Tensor* param : parameters) {
+ inputs.emplace_back(*param);
+ }
+ auto graph = tracer::createGraphByTracing(
+ func,
+ inputs,
+ var_lookup_fn,
+ force_outplace,
+ input_tuple.size());
+ self->create_method(name, std::move(graph), std::move(parameters));
+ didFinishEmitModule(self);
+ })
+ .def(
+ "graph_for",
+ [](py::args args, py::kwargs kwargs) {
+ // [pybind11 varargs] note: old version of pybind11 have a bug that
+ // leaks memory when py::args is mixed with positional arguments
+ // https://github.com/pybind/pybind11/pull/1216
+ // we work around this by not mixing positional arguments with
+ // varargs
+ Module& self = py::cast<Module&>(args[0]);
+ if (self.find_method("forward")) {
+ Method& m = self.get_method("forward");
+ return m.graph_for(createStackForSchema(
+ m.getSchema(), tuple_slice(std::move(args), 1), kwargs));
+ }
+ throw std::runtime_error(
+ "Attempted to call graph_for on a Module without a compiled forward()");
+ })
+ .def(
+ "get_debug_state",
+ [](Module& self) {
+ if (self.find_method("forward")) {
+ Method& m = self.get_method("forward");
+ return m.getDebugState();
+ }
+ throw std::runtime_error(
+ "Attempted to call get_debug_state on a Module without a compiled forward()");
+ })
+ .def(
+ "debug_disable_autodiff_subgraph_inlining",
+ [](Module& self) {
+ if (self.find_method("forward")) {
+ Method& m = self.get_method("forward");
+ m.debugDisableAutodiffSubgraphInlining();
+ }
+ })
+ .def(
+ "forward",
+ [](py::args args, py::kwargs kwargs) {
+ // We implement this in C++ to avoid incurring the pybind11 dispatch
+ // overhead twice: once to call into the method lookup for "forward"
+ // and once to actually invoke the method.
+ //
+ // There is a thin wrapper on top of this method in the C++ version
+ // of ScriptModule.
+
+ // see: [pybind11 varargs]
+ Module& self = py::cast<Module&>(args[0]);
+ return invokeScriptMethodFromPython(
+ self.get_method("forward"),
+ tuple_slice(std::move(args), 1),
+ std::move(kwargs));
+ })
+ .def(
+ "_python_print",
+ [](Module& self) {
+ std::ostringstream ss;
+ std::vector<at::Tensor> tensors;
+ PythonPrint(ss, self, tensors, true);
+ return std::make_pair(ss.str(), tensors);
+ })
+ .def_property_readonly(
+ "code",
+ [](Module& self) {
+ std::ostringstream ss;
+ std::vector<at::Tensor> tensors;
+ PythonPrint(ss, self, tensors, false);
+ return ss.str();
+ })
.def("apply", &Module::apply)
.def("_copy_into", &Module::copy_into);
py::class_<Method>(m, "ScriptMethod", py::dynamic_attr())
- .def("graph", [&](Method& self) {
- return self.graph();
- })
- .def("__call__", [](py::args args, py::kwargs kwargs) {
- // see: [pybind11 varargs]
- Method& method = py::cast<Method&>(args[0]);
- return invokeScriptMethodFromPython(method, tuple_slice(std::move(args), 1), std::move(kwargs));
- })
- .def_property_readonly("graph", [](Method& m) {
- return m.graph();
- })
- .def("propagate_shapes", &Method::propagate_shapes)
- .def("propagate_and_assign_input_and_output_shapes", &Method::propagate_and_assign_input_and_output_shapes)
- .def("params", &Method::params)
- .def("graph_for", [](py::args args, py::kwargs kwargs) {
- // see: [pybind11 varargs]
- Method& self = py::cast<Method&>(args[0]);
- return self.graph_for(createStackForSchema(self.getSchema(), tuple_slice(std::move(args), 1), kwargs));
- })
- .def("debug_disable_autodiff_subgraph_inlining", &Method::debugDisableAutodiffSubgraphInlining)
- .def("schema", &Method::getSchema)
- .def("pretty_print_schema", &Method::pretty_print_schema)
- .def("python_print", [](Method &m) {
- std::ostringstream oss;
- std::vector<at::Tensor> constants;
- PythonPrint(oss, m, constants, true);
- return std::make_pair(oss.str(), std::move(constants));
- });
-
- m.def("_jit_script_compile", [](std::shared_ptr<Module> mod, const Def &def, ResolutionCallback rcb, FunctionDefaults defaults) {
- auto def_f = def.withName("forward");
- defineMethodsInModule(mod, {def_f}, {pythonResolver(rcb)}, nullptr);
- auto& method = mod->get_method("forward");
- method.setSchema(getSchemaWithNameAndDefaults(
- def.range(), method.getSchema(), def.name().name(), defaults));
- didFinishEmitModule(mod);
- return mod;
- });
+ .def("graph", [&](Method& self) { return self.graph(); })
+ .def(
+ "__call__",
+ [](py::args args, py::kwargs kwargs) {
+ // see: [pybind11 varargs]
+ Method& method = py::cast<Method&>(args[0]);
+ return invokeScriptMethodFromPython(
+ method, tuple_slice(std::move(args), 1), std::move(kwargs));
+ })
+ .def_property_readonly("graph", [](Method& m) { return m.graph(); })
+ .def("propagate_shapes", &Method::propagate_shapes)
+ .def(
+ "propagate_and_assign_input_and_output_shapes",
+ &Method::propagate_and_assign_input_and_output_shapes)
+ .def("params", &Method::params)
+ .def(
+ "graph_for",
+ [](py::args args, py::kwargs kwargs) {
+ // see: [pybind11 varargs]
+ Method& self = py::cast<Method&>(args[0]);
+ return self.graph_for(createStackForSchema(
+ self.getSchema(), tuple_slice(std::move(args), 1), kwargs));
+ })
+ .def(
+ "debug_disable_autodiff_subgraph_inlining",
+ &Method::debugDisableAutodiffSubgraphInlining)
+ .def("schema", &Method::getSchema)
+ .def("pretty_print_schema", &Method::pretty_print_schema)
+ .def("python_print", [](Method& m) {
+ std::ostringstream oss;
+ std::vector<at::Tensor> constants;
+ PythonPrint(oss, m, constants, true);
+ return std::make_pair(oss.str(), std::move(constants));
+ });
+
+ m.def(
+ "_jit_script_compile",
+ [](std::shared_ptr<Module> mod,
+ const Def& def,
+ ResolutionCallback rcb,
+ FunctionDefaults defaults) {
+ auto def_f = def.withName("forward");
+ defineMethodsInModule(mod, {def_f}, {pythonResolver(rcb)}, nullptr);
+ auto& method = mod->get_method("forward");
+ method.setSchema(getSchemaWithNameAndDefaults(
+ def.range(), method.getSchema(), def.name().name(), defaults));
+ didFinishEmitModule(mod);
+ return mod;
+ });
m.def("parse_type_comment", [](const std::string& comment) {
Parser p(comment);
});
m.def("merge_type_from_type_comment", &mergeTypesFromTypeComment);
- m.def("import_ir_module", [](ModuleLookup module_lookup, const std::string& filename,
- py::object map_location) {
- c10::optional<at::Device> optional_device;
- if (!map_location.is(py::none())) {
- AT_ASSERT(THPDevice_Check(map_location.ptr()));
- optional_device = reinterpret_cast<THPDevice*>(map_location.ptr())->device;
- }
- import_ir_module(module_lookup, filename, optional_device);
- });
- m.def("import_ir_module_from_buffer", [](ModuleLookup module_lookup,
- const std::string& buffer, py::object map_location) {
- std::istringstream in(buffer);
- c10::optional<at::Device> optional_device;
- if (!map_location.is(py::none())) {
- AT_ASSERT(THPDevice_Check(map_location.ptr()));
- optional_device = reinterpret_cast<THPDevice*>(map_location.ptr())->device;
- }
- import_ir_module(module_lookup, in, optional_device);
- });
+ m.def(
+ "import_ir_module",
+ [](ModuleLookup module_lookup,
+ const std::string& filename,
+ py::object map_location) {
+ c10::optional<at::Device> optional_device;
+ if (!map_location.is(py::none())) {
+ AT_ASSERT(THPDevice_Check(map_location.ptr()));
+ optional_device =
+ reinterpret_cast<THPDevice*>(map_location.ptr())->device;
+ }
+ import_ir_module(module_lookup, filename, optional_device);
+ });
+ m.def(
+ "import_ir_module_from_buffer",
+ [](ModuleLookup module_lookup,
+ const std::string& buffer,
+ py::object map_location) {
+ std::istringstream in(buffer);
+ c10::optional<at::Device> optional_device;
+ if (!map_location.is(py::none())) {
+ AT_ASSERT(THPDevice_Check(map_location.ptr()));
+ optional_device =
+ reinterpret_cast<THPDevice*>(map_location.ptr())->device;
+ }
+ import_ir_module(module_lookup, in, optional_device);
+ });
m.def("_jit_import_methods", import_methods);
m.def("_jit_set_emit_module_hook", setEmitModuleHook);
}
namespace torch {
namespace jit {
-struct JITException
- : public std::runtime_error {
+struct JITException : public std::runtime_error {
JITException() = default;
- explicit JITException(const std::string& msg)
- : std::runtime_error(msg) {}
+ explicit JITException(const std::string& msg) : std::runtime_error(msg) {}
};
} // namespace jit
#include <c10/util/Exception.h>
+#include <mutex>
#include <string>
#include <unordered_map>
-#include <mutex>
namespace torch {
namespace jit {
namespace script {
static const std::unordered_map<int, int> binary_prec = {
- {TK_IF, 1},
- {TK_AND, 2},
- {TK_OR, 2},
+ {TK_IF, 1},
+ {TK_AND, 2},
+ {TK_OR, 2},
// reserve a level for unary not
- {'<', 4},
- {'>', 4},
- {TK_IS, 4},
- {TK_ISNOT, 4},
- {TK_EQ, 4},
- {TK_LE, 4},
- {TK_GE, 4},
- {TK_NE, 4},
- {'|', 5},
- {'^', 6},
- {'&', 7},
- {'+', 8},
- {'-', 8},
- {'*', 9},
- {'/', 9},
- {TK_FLOOR_DIV, 9},
- {'%', 9},
- {'@', 9},
- {TK_POW, 10},
+ {'<', 4},
+ {'>', 4},
+ {TK_IS, 4},
+ {TK_ISNOT, 4},
+ {TK_EQ, 4},
+ {TK_LE, 4},
+ {TK_GE, 4},
+ {TK_NE, 4},
+ {'|', 5},
+ {'^', 6},
+ {'&', 7},
+ {'+', 8},
+ {'-', 8},
+ {'*', 9},
+ {'/', 9},
+ {TK_FLOOR_DIV, 9},
+ {'%', 9},
+ {'@', 9},
+ {TK_POW, 10},
};
static const std::unordered_map<int, int> unary_prec = {
- {TK_NOT, 3},
- {'-', 9},
- {'*', 9},
+ {TK_NOT, 3},
+ {'-', 9},
+ {'*', 9},
};
bool SharedParserData::isUnary(int kind, int* prec) {
for (char tok : std::string(valid_single_char_tokens))
str_to_kind[std::string(1, tok)] = tok;
#define DEFINE_CASE(tok, _, str) \
- if (std::string(str) != "") str_to_kind[str] = tok;
+ if (std::string(str) != "") \
+ str_to_kind[str] = tok;
TC_FORALL_TOKEN_KINDS(DEFINE_CASE)
#undef DEFINE_CASE
});
#pragma once
+#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/source_range.h>
+#include <torch/csrc/utils/memory.h>
#include <algorithm>
+#include <clocale>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
-#include <torch/csrc/jit/assertions.h>
-#include <torch/csrc/jit/source_range.h>
-#include <torch/csrc/utils/memory.h>
-#include <clocale>
namespace torch {
namespace jit {
_(TK_DOTS, "dots", "...") \
_(TK_PASS, "pass", "pass")
-
static const char* valid_single_char_tokens = "+-*/%@()[]:,={}><.?!&^|";
enum TokenKind {
head->insert(str.c_str(), *c);
}
-#define ADD_CASE(tok, _, tokstring) \
+#define ADD_CASE(tok, _, tokstring) \
if (*(tokstring) != '\0') { \
- head->insert((tokstring), (tok)); \
+ head->insert((tokstring), (tok)); \
}
TC_FORALL_TOKEN_KINDS(ADD_CASE)
#undef ADD_CASE
}
#ifdef _WIN32
- static double strtod_c(const char * str, char** end) {
+ static double strtod_c(const char* str, char** end) {
/// NOLINTNEXTLINE(hicpp-signed-bitwise)
static _locale_t loc = _create_locale(LC_ALL, "C");
return _strtod_l(str, end, loc);
}
#else
- static double strtod_c(const char * str, char** end) {
+ static double strtod_c(const char* str, char** end) {
/// NOLINTNEXTLINE(hicpp-signed-bitwise)
static locale_t loc = newlocale(LC_ALL_MASK, "C", nullptr);
return strtod_l(str, end, loc);
}
bool isCharCount(char c, const std::string& str, size_t start, int len) {
- //count checks from [start, start + len)
- return start + len <= str.size() && std::count(str.begin() + start, str.begin() + start + len, c) == len;
+ // count checks from [start, start + len)
+ return start + len <= str.size() &&
+ std::count(str.begin() + start, str.begin() + start + len, c) == len;
}
// python concatenates all adjacent strings "a" "b" == "ab"
return false;
int quote_len = isCharCount(quote, str, start, 3) ? 3 : 1;
- //end is now set past the opening quotation marks
+ // end is now set past the opening quotation marks
size_t end = start + quote_len;
- while(end < str.size() && !isCharCount(quote, str, end, quote_len)) {
+ while (end < str.size() && !isCharCount(quote, str, end, quote_len)) {
if (str[end] == '\n' && quote_len != 3) {
return false;
}
- //handle escaped characters. advances past escaped quotation marks,
- //escaped newlines and escaped backslashes
- //multi-char escapes like \x1A are handled fine here because the
- //remainder of the escape are valid string characters anyway
+ // handle escaped characters. advances past escaped quotation marks,
+ // escaped newlines and escaped backslashes
+ // multi-char escapes like \x1A are handled fine here because the
+ // remainder of the escape are valid string characters anyway
if (str[end] == '\\') {
end++;
}
end++;
}
- //set length equal to the complete string including quotations
+ // set length equal to the complete string including quotations
*len = end - start + quote_len;
- //if end finished without going past the last character of the string than
- //there is a match
+ // if end finished without going past the last character of the string than
+ // there is a match
return end < str.size();
}
return match_string == type_string;
}
// find the longest match of str.substring(pos) against a token, return true
- // if successful
- // filling in kind, start,and len
+ // if successful filling in kind, start,and len
bool match(
const std::string& str,
size_t pos,
str, pos + 1, continuation, !continuation, kind, start, len);
}
}
- // we handle white space before EOF because in the case we have something like
- // the following where we need to generate the dedent token
- // if foo:
+ // we handle white space before EOF because in the case we have something
+ // like the following where we need to generate the dedent token if foo:
// ...
// else:
// pass
// identifier 'max'
if (cur) {
size_t child_offset = 0;
- for (size_t e = cur->child_chars.size(); child_offset < e; ++child_offset) {
+ for (size_t e = cur->child_chars.size(); child_offset < e;
+ ++child_offset) {
if (cur->child_chars[child_offset] == str[pos + i])
- break;
+ break;
}
cur = (child_offset == cur->child_chars.size())
- ? nullptr
- : cur->child_tries[child_offset].get();
+ ? nullptr
+ : cur->child_tries[child_offset].get();
if (cur && cur->kind != 0) {
matched = true;
[[noreturn]] void reportError(const std::string& what) {
reportError(what, cur());
- }[[noreturn]] void reportError(const std::string& what, const Token& t) {
+ }
+ [[noreturn]] void reportError(const std::string& what, const Token& t) {
std::stringstream ss;
ss << what << ":\n";
t.range.highlight(ss);
<< "' here:\n";
t.range.highlight(ss);
throw std::runtime_error(ss.str());
- }[[noreturn]] void expected(const std::string& what) {
+ }
+ [[noreturn]] void expected(const std::string& what) {
expected(what, cur());
}
// Check that the current token has a given kind, return the current token,
#include <torch/csrc/jit/assertions.h>
-#include <torch/csrc/jit/script/module.h>
-#include <torch/csrc/jit/script/compiler.h>
-#include <torch/csrc/jit/script/schema_matching.h>
-#include <torch/csrc/jit/script/error_report.h>
#include <torch/csrc/jit/export.h>
#include <torch/csrc/jit/operator.h>
+#include <torch/csrc/jit/script/compiler.h>
+#include <torch/csrc/jit/script/error_report.h>
+#include <torch/csrc/jit/script/module.h>
+#include <torch/csrc/jit/script/schema_matching.h>
-namespace torch { namespace jit { namespace script {
-
+namespace torch {
+namespace jit {
+namespace script {
struct RecursiveMethodCallError : public std::exception {};
void placeholderCreator(Method&) {
try {
callee.ensure_defined();
} catch (RecursiveMethodCallError&) {
- throw ErrorReport(loc) << " method '" << callee.name()
+ throw ErrorReport(loc)
+ << " method '" << callee.name()
<< "' is called recursively involving this call site. Recursive calls are not supported";
}
auto fn = callee.graph();
auto matched_schema = tryMatchSchema(
- callee.getSchema(),
- loc, graph, std::move(self), args, kwargs, failure_messages, conv_tensors_to_nums);
- if(!matched_schema)
+ callee.getSchema(),
+ loc,
+ graph,
+ std::move(self),
+ args,
+ kwargs,
+ failure_messages,
+ conv_tensors_to_nums);
+ if (!matched_schema)
return nullptr;
// parameters to callee method (which become parameters to _this_ method
// if they were not already)
- for(at::Tensor* member : callee.params()) {
- if(!caller) {
- throw ErrorReport(loc) << " attempting to call a method with parameters from a raw graph. File a bug report";
+ for (at::Tensor* member : callee.params()) {
+ if (!caller) {
+ throw ErrorReport(loc)
+ << " attempting to call a method with parameters from a raw graph. File a bug report";
}
matched_schema->inputs.push_back(caller->get_or_add_parameter(member));
}
return inlineCallTo(graph, *callee.graph(), matched_schema->inputs).at(0);
}
-Value* Method::emit_call_to(const SourceRange& loc, Method & callee, ArrayRef<NamedValue> args, ArrayRef<NamedValue> kwargs) {
+Value* Method::emit_call_to(
+ const SourceRange& loc,
+ Method& callee,
+ ArrayRef<NamedValue> args,
+ ArrayRef<NamedValue> kwargs) {
JIT_ASSERT(!executor);
std::stringstream failure_messages;
if (auto result = try_emit_call_to(
}
void Method::ensure_defined() {
- if(method_creator) {
+ if (method_creator) {
auto creator = method_creator;
method_creator = placeholderCreator;
creator(*this);
}
}
-}}}
+} // namespace script
+} // namespace jit
+} // namespace torch
#pragma once
-#include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/graph_executor.h>
#include <torch/csrc/autograd/variable.h>
-#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/argument_spec.h>
-#include <torch/csrc/jit/function_schema.h>
#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/function_schema.h>
+#include <torch/csrc/jit/graph_executor.h>
+#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/named_value.h>
+#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/source_range.h>
+#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/api/include/torch/ordered_dict.h>
#include <torch/csrc/utils/memory.h>
-#include <torch/csrc/WindowsTorchApiMacro.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Optional.h>
#include <functional>
#include <memory>
#include <mutex>
+#include <ostream>
#include <string>
#include <unordered_map>
#include <vector>
-#include <ostream>
// This file contains classes which assist in desugaring Python style
// modules and their methods into flattened graphs which don't have any
// function calls.
-namespace torch { namespace jit { namespace script {
+namespace torch {
+namespace jit {
+namespace script {
// A method in a module, e.g. f in:
//
struct Module;
struct Method {
- Method(Module* owner, std::string name, bool optimize,
- std::shared_ptr<Graph> graph,
- std::vector<at::Tensor*> initial_members,
- std::function<void(Method&)> method_creator)
- : owner_(owner)
- , name_(std::move(name))
- , graph_(std::move(graph))
- , optimize(optimize)
- , member_inputs(std::move(initial_members))
- , method_creator(std::move(method_creator)) {
+ Method(
+ Module* owner,
+ std::string name,
+ bool optimize,
+ std::shared_ptr<Graph> graph,
+ std::vector<at::Tensor*> initial_members,
+ std::function<void(Method&)> method_creator)
+ : owner_(owner),
+ name_(std::move(name)),
+ graph_(std::move(graph)),
+ optimize(optimize),
+ member_inputs(std::move(initial_members)),
+ method_creator(std::move(method_creator)) {
JIT_ASSERT(graph_->inputs().size() >= member_inputs.size());
int i = graph_->inputs().size() - member_inputs.size();
- for(at::Tensor* member : member_inputs) {
+ for (at::Tensor* member : member_inputs) {
member_input_index[member] = i++;
}
}
- void run(Stack & stack) {
- for(at::Tensor* tp : member_inputs) {
+ void run(Stack& stack) {
+ for (at::Tensor* tp : member_inputs) {
stack.emplace_back(*tp);
}
get_executor().run(stack);
}
std::shared_ptr<Graph> graph_for(Stack inputs) {
- for(at::Tensor* tp : member_inputs) {
+ for (at::Tensor* tp : member_inputs) {
inputs.emplace_back(*tp);
}
return get_executor().graphFor(inputs);
return graph_;
}
- TORCH_API const std::string & name() const {
+ TORCH_API const std::string& name() const {
return name_;
}
// emit a function call by inlining the callees Graph into this one
// adding any extra parameters necessary to do this call
- // defined here to keep details of member_input handling confined to this class
- Value* emit_call_to(const SourceRange& loc, Method & callee, ArrayRef<NamedValue> args, ArrayRef<NamedValue> kwargs);
+ // defined here to keep details of member_input handling confined to this
+ // class
+ Value* emit_call_to(
+ const SourceRange& loc,
+ Method& callee,
+ ArrayRef<NamedValue> args,
+ ArrayRef<NamedValue> kwargs);
// if this isn't yet defined, run its method_creator function
TORCH_API void ensure_defined();
-
size_t num_inputs() const {
return graph()->inputs().size() - member_inputs.size();
}
- TORCH_API Value * get_or_add_parameter(at::Tensor* slot) {
+ TORCH_API Value* get_or_add_parameter(at::Tensor* slot) {
auto it = member_input_index.find(slot);
- if(it != member_input_index.end()) {
+ if (it != member_input_index.end()) {
return graph()->inputs().at(it->second);
}
// add it as a new parameter
return graph()->addInput();
}
- std::shared_ptr<Graph> propagate_shapes(std::vector<at::Tensor> inputs, bool with_grad=false) {
+ std::shared_ptr<Graph> propagate_shapes(
+ std::vector<at::Tensor> inputs,
+ bool with_grad = false) {
auto retval = graph_->copy();
Stack stack;
stack.reserve(inputs.size() + member_inputs.size());
- for (at::Tensor & i : inputs) {
+ for (at::Tensor& i : inputs) {
stack.emplace_back(std::move(i));
}
for (at::Tensor* inp : member_inputs) {
return retval;
}
- std::shared_ptr<Graph> propagate_and_assign_input_and_output_shapes(std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, bool with_grad=false, bool propagate=true) {
+ std::shared_ptr<Graph> propagate_and_assign_input_and_output_shapes(
+ std::vector<at::Tensor> inputs,
+ std::vector<at::Tensor> outputs,
+ bool with_grad = false,
+ bool propagate = true) {
auto retval = graph_->copy();
for (auto inp : member_inputs) {
inputs.push_back(*inp);
}
if (propagate) {
- setInputTypes(*retval, ArgumentSpec(with_grad, fmap<IValue>(inputs), inputs.size()));
+ setInputTypes(
+ *retval,
+ ArgumentSpec(with_grad, fmap<IValue>(inputs), inputs.size()));
PropagateInputShapes(retval);
}
JIT_ASSERT(retval->inputs().size() == inputs.size());
- for (size_t i=0; i < retval->inputs().size(); ++i) {
+ for (size_t i = 0; i < retval->inputs().size(); ++i) {
auto scalar_type = inputs[i].type().scalarType();
auto sizes = inputs[i].sizes();
- auto type = torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
+ auto type =
+ torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
retval->inputs()[i]->setType(type);
}
at::ArrayRef<Value*> output_values = retval->outputs();
// patch this to still work if we are returning a tuple of multiple values
if (output_values.at(0)->type()->kind() == TupleType::Kind) {
- JIT_ASSERT(output_values.at(0)->node()->kind()== prim::TupleConstruct);
+ JIT_ASSERT(output_values.at(0)->node()->kind() == prim::TupleConstruct);
output_values = output_values.at(0)->node()->inputs();
}
JIT_ASSERT(output_values.size() == outputs.size());
- for (size_t i=0; i < retval->outputs().size(); ++i) {
+ for (size_t i = 0; i < retval->outputs().size(); ++i) {
auto scalar_type = outputs[i].type().scalarType();
auto sizes = outputs[i].sizes();
- auto type = torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
+ auto type =
+ torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
output_values[i]->setType(type);
}
return retval;
}
TORCH_API const FunctionSchema& getSchema() const {
- if(schema == nullptr) {
+ if (schema == nullptr) {
schema = make_unique<FunctionSchema>(defaultSchemaFor(*this));
}
return *schema;
graph()->outputs().size() == 1,
"Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs");
}
-private:
+ private:
static FunctionSchema defaultSchemaFor(const Method& method) {
std::vector<Argument> args;
std::vector<Argument> returns;
Graph& g = *method.graph();
size_t num_inputs = method.num_inputs();
- for(size_t i = 0; i < num_inputs; ++i) {
+ for (size_t i = 0; i < num_inputs; ++i) {
const Value* v = g.inputs().at(i);
- std::string name = v->hasUniqueName() ? v->uniqueNameBase() : ("argument_" + std::to_string(i));
+ std::string name = v->hasUniqueName() ? v->uniqueNameBase()
+ : ("argument_" + std::to_string(i));
args.emplace_back(std::move(name), unshapedType(g.inputs()[i]->type()));
}
- for(size_t i = 0; i < g.outputs().size(); ++i) {
+ for (size_t i = 0; i < g.outputs().size(); ++i) {
returns.emplace_back("", unshapedType(g.outputs()[i]->type()));
}
- return { method.name(), std::move(args), std::move(returns) };
+ return {method.name(), std::move(args), std::move(returns)};
}
GraphExecutor& get_executor() {
// Do we have more inputs than the schema accepts?
AT_CHECK(
inputs.size() <= schema.arguments().size(),
- "Expected at most ", schema.arguments().size(),
- " argument(s) for operator '", schema.name(), "', but received ",
- inputs.size(), " argument(s). Declaration: ", schema);
+ "Expected at most ",
+ schema.arguments().size(),
+ " argument(s) for operator '",
+ schema.name(),
+ "', but received ",
+ inputs.size(),
+ " argument(s). Declaration: ",
+ schema);
for (size_t pos = 0; pos < schema.arguments().size(); ++pos) {
const auto& argument = schema.arguments()[pos];
// and should be replaced with a function isSubvalueOf(ivalue, type)
// That asks if the specific value is a valid instance of type.
const TypePtr inputType = incompleteInferTypeFrom(inputs[pos]);
- AT_CHECK(inputType->isSubtypeOf(argument.type()),
- "Expected value of type ", *argument.type(),
- " for argument '", argument.name(),
- "' in position ", pos,
- ", but instead got value of type ", *inputType,
- ". Declaration: ", schema);
+ AT_CHECK(
+ inputType->isSubtypeOf(argument.type()),
+ "Expected value of type ",
+ *argument.type(),
+ " for argument '",
+ argument.name(),
+ "' in position ",
+ pos,
+ ", but instead got value of type ",
+ *inputType,
+ ". Declaration: ",
+ schema);
} else if (argument.default_value()) {
inputs.push_back(*argument.default_value());
} else {
- AT_ERROR(schema.name(), "() is missing value for argument '",
- argument.name(), "'. Declaration: ", schema);
+ AT_ERROR(
+ schema.name(),
+ "() is missing value for argument '",
+ argument.name(),
+ "'. Declaration: ",
+ schema);
}
}
}
-
// Methods are uniqued onwed by a single module. This raw pointer allows
// looking up the module.
Module* owner_;
std::once_flag executor_init;
- // an optional function that actually creates the method when emit_call_to(this,...)
- // is first called.
- // this is used by the compiler so that it can construct methods out of order
+ // an optional function that actually creates the method when
+ // emit_call_to(this,...) is first called. this is used by the compiler so
+ // that it can construct methods out of order
std::function<void(Method&)> method_creator;
// if absent, then we generate a default schema based on the graph
struct NamedParameter {
NamedParameter(std::string name, at::Tensor tensor, bool is_buffer)
- : name(std::move(name))
- , is_buffer(is_buffer)
- , parameter(torch::make_unique<at::Tensor>(std::move(tensor))) {}
+ : name(std::move(name)),
+ is_buffer(is_buffer),
+ parameter(torch::make_unique<at::Tensor>(std::move(tensor))) {}
const std::string name;
bool is_buffer; // buffers are part of the module state but
- // are not modified by optimizers during SGD
+ // are not modified by optimizers during SGD
at::Tensor* slot() const {
return parameter.get();
}
-private:
+
+ private:
// the extra level of indirection allows Methods to safely store pointers
// to the slots where parameters are kept while also allow parameters
// to be reassigned
struct Module {
TH_DISALLOW_COPY_AND_ASSIGN(Module);
Module()
- : modules("Module")
- , parameters("Parameter")
- , methods("Method")
- , optimize(true) {}
+ : modules("Module"),
+ parameters("Parameter"),
+ methods("Method"),
+ optimize(true) {}
// note this doesn't change the flags of existing methods just ones
// added afterward.
return get_method("forward")(std::move(inputs));
}
- void register_parameter(const std::string & name, autograd::Variable v, bool is_buffer) {
- if(auto p = parameters.find(name)){
+ void register_parameter(
+ const std::string& name,
+ autograd::Variable v,
+ bool is_buffer) {
+ if (auto p = parameters.find(name)) {
*p->slot() = v;
p->is_buffer = is_buffer;
return;
}
parameters.insert(name, NamedParameter(name, std::move(v), is_buffer));
}
- void register_module(const std::string& name, std::shared_ptr<Module> module) {
+ void register_module(
+ const std::string& name,
+ std::shared_ptr<Module> module) {
modules.insert(name, {name, std::move(module)});
}
- Method& create_method(const std::string & name, std::shared_ptr<Graph> graph, std::vector<at::Tensor*> member_inputs) {
+ Method& create_method(
+ const std::string& name,
+ std::shared_ptr<Graph> graph,
+ std::vector<at::Tensor*> member_inputs) {
JIT_ASSERT(graph);
- std::unique_ptr<Method> method(new Method(this, name, optimize, std::move(graph), std::move(member_inputs), nullptr));
+ std::unique_ptr<Method> method(new Method(
+ this,
+ name,
+ optimize,
+ std::move(graph),
+ std::move(member_inputs),
+ nullptr));
return *methods.insert(name, std::move(method));
}
- Method& create_method(const std::string & name, std::function<void(Method&)> creator) {
- std::unique_ptr<Method> method(new Method(this, name, optimize, std::make_shared<Graph>(), {}, std::move(creator)));
+ Method& create_method(
+ const std::string& name,
+ std::function<void(Method&)> creator) {
+ std::unique_ptr<Method> method(new Method(
+ this,
+ name,
+ optimize,
+ std::make_shared<Graph>(),
+ {},
+ std::move(creator)));
return *methods.insert(name, std::move(method));
}
- at::Tensor* parameter_slot(const std::string & name) const {
+ at::Tensor* parameter_slot(const std::string& name) const {
return parameters[name].slot();
}
- void set_parameter(const std::string & name, at::Tensor v) {
+ void set_parameter(const std::string& name, at::Tensor v) {
*parameter_slot(name) = std::move(v);
}
const torch::OrderedDict<std::string, NamedModule>& get_modules() const {
return modules;
}
- const torch::OrderedDict<std::string, NamedParameter>& get_parameters() const {
+ const torch::OrderedDict<std::string, NamedParameter>& get_parameters()
+ const {
return parameters;
}
- const torch::OrderedDict<std::string, std::unique_ptr<Method>>& get_methods() const {
+ const torch::OrderedDict<std::string, std::unique_ptr<Method>>& get_methods()
+ const {
return methods;
}
return nullptr;
}
void apply(std::function<void(Module&)> fn) {
- for (auto &submod : get_modules()) {
+ for (auto& submod : get_modules()) {
submod.value().module->apply(fn);
}
fn(*this);
void save(const std::string& filename);
- void copy_into(std::function<std::shared_ptr<Module>(
- std::vector<std::string>)> module_lookup,
- // parameter_remap is needed when a parent module uses a parameter of a submodule
+ void copy_into(
+ std::function<std::shared_ptr<Module>(std::vector<std::string>)>
+ module_lookup,
+ // parameter_remap is needed when a parent module uses a parameter of a
+ // submodule
std::unordered_map<at::Tensor*, at::Tensor*>& parameter_remap,
std::vector<std::string> names = {}) const {
auto curr = module_lookup(names);
- for (auto &kv : parameters) {
- curr->register_parameter(kv.key(), *kv.value().slot(), kv.value().is_buffer);
+ for (auto& kv : parameters) {
+ curr->register_parameter(
+ kv.key(), *kv.value().slot(), kv.value().is_buffer);
parameter_remap[kv.value().slot()] = curr->parameter_slot(kv.key());
}
- for (auto &kv : modules) {
+ for (auto& kv : modules) {
names.push_back(kv.key());
// Submodules must be translated first, otherwise parameter_remap entries
// will not be filled in for methods of this module.
kv.value().module->copy_into(module_lookup, parameter_remap, names);
names.pop_back();
}
- for (auto &kv : methods) {
+ for (auto& kv : methods) {
std::vector<at::Tensor*> params;
- for (auto &p : kv.value()->params()) {
+ for (auto& p : kv.value()->params()) {
params.push_back(parameter_remap.at(p));
}
curr->create_method(kv.key(), kv.value()->graph()->copy(), params);
// unit, and not a method), then nullptr can be passed as caller.
Method* caller,
bool conv_tensors_to_nums);
-}}}
+} // namespace script
+} // namespace jit
+} // namespace torch
#pragma once
-#include <torch/csrc/jit/script/lexer.h>
-#include <torch/csrc/jit/script/error_report.h>
#include <c10/util/Optional.h>
+#include <torch/csrc/jit/script/error_report.h>
+#include <torch/csrc/jit/script/lexer.h>
namespace torch {
namespace jit {
namespace script {
inline bool isCharCount(char c, const std::string& str, size_t start, int len) {
- //count checks from [start, start + len)
- return start + len <= str.size() && std::count(str.begin() + start, str.begin() + start + len, c) == len;
+ // count checks from [start, start + len)
+ return start + len <= str.size() &&
+ std::count(str.begin() + start, str.begin() + start + len, c) == len;
}
inline static bool isOctal(char c) {
if (pos + 3 >= str.size())
return c10::nullopt;
size_t c = 0;
- for(size_t i = 1, b = 64; i < 4; ++i, b /= 8) {
+ for (size_t i = 1, b = 64; i < 4; ++i, b /= 8) {
int d = str[pos + i];
if (d < '0' || d > '7')
return c10::nullopt;
c += b * (d - '0');
}
- if(c >= 256)
+ if (c >= 256)
return c10::nullopt;
return c;
}
-inline std::string parseStringLiteral(const SourceRange& range, const std::string &str) {
+inline std::string parseStringLiteral(
+ const SourceRange& range,
+ const std::string& str) {
int quote_len = isCharCount(str[0], str, 0, 3) ? 3 : 1;
auto ret_str = str.substr(quote_len, str.size() - quote_len * 2);
size_t pos = ret_str.find('\\');
- while(pos != std::string::npos) {
- //invariant: pos has to escape a character because it is a valid string
+ while (pos != std::string::npos) {
+ // invariant: pos has to escape a character because it is a valid string
char c = ret_str[pos + 1];
size_t to_erase = 2;
switch (ret_str[pos + 1]) {
c = '\t';
break;
case 'h':
- throw ErrorReport(range)
- << "unsupported hex specifier";
+ throw ErrorReport(range) << "unsupported hex specifier";
default:
// \0NN
if (auto v = parseOctal(str, pos + 1)) {
to_erase = 4;
c = *v;
} else {
- throw ErrorReport(range)
- << " ill formed octal specifier";
+ throw ErrorReport(range) << " ill formed octal specifier";
}
}
ret_str.replace(pos, to_erase, /* num copies */ 1, c);
+#include <c10/util/Optional.h>
#include <torch/csrc/jit/script/lexer.h>
-#include <torch/csrc/jit/script/tree.h>
+#include <torch/csrc/jit/script/parse_string_literal.h>
#include <torch/csrc/jit/script/parser.h>
+#include <torch/csrc/jit/script/tree.h>
#include <torch/csrc/jit/script/tree_views.h>
-#include <c10/util/Optional.h>
-#include <torch/csrc/jit/script/parse_string_literal.h>
namespace torch {
namespace jit {
namespace script {
-Decl mergeTypesFromTypeComment(const Decl& decl, const Decl& type_annotation_decl, bool is_method) {
+Decl mergeTypesFromTypeComment(
+ const Decl& decl,
+ const Decl& type_annotation_decl,
+ bool is_method) {
auto expected_num_annotations = decl.params().size();
if (is_method) {
// `self` argument
expected_num_annotations -= 1;
}
if (expected_num_annotations != type_annotation_decl.params().size()) {
- throw ErrorReport(type_annotation_decl.range()) << "Number of type annotations ("
- << type_annotation_decl.params().size() << ") did not match the number of "
- << "function parameters (" << expected_num_annotations << ")";
+ throw ErrorReport(type_annotation_decl.range())
+ << "Number of type annotations ("
+ << type_annotation_decl.params().size()
+ << ") did not match the number of "
+ << "function parameters (" << expected_num_annotations << ")";
}
auto old = decl.params();
auto _new = type_annotation_decl.params();
for (; i < decl.params().size(); ++i, ++j) {
new_params.emplace_back(old[i].withType(_new[j].type()));
}
- return Decl::create(decl.range(), List<Param>::create(decl.range(), new_params), type_annotation_decl.return_type());
+ return Decl::create(
+ decl.range(),
+ List<Param>::create(decl.range(), new_params),
+ type_annotation_decl.return_type());
}
struct ParserImpl {
}
static bool followsTuple(int kind) {
- switch(kind) {
+ switch (kind) {
case TK_PLUS_EQ:
case TK_MINUS_EQ:
case TK_TIMES_EQ:
// exp | expr, | expr, expr, ...
Expr parseExpOrExpTuple() {
auto prefix = parseExp();
- if(L.cur().kind == ',') {
- std::vector<Expr> exprs = { prefix };
- while(L.nextIf(',')) {
+ if (L.cur().kind == ',') {
+ std::vector<Expr> exprs = {prefix};
+ while (L.nextIf(',')) {
if (followsTuple(L.cur().kind))
break;
exprs.push_back(parseExp());
} break;
}
}
- TreeRef
- parseTrinary(TreeRef true_branch, const SourceRange& range, int binary_prec) {
+ TreeRef parseTrinary(
+ TreeRef true_branch,
+ const SourceRange& range,
+ int binary_prec) {
auto cond = parseExp();
L.expect(TK_ELSE);
auto false_branch = parseExp(binary_prec);
// precedence strictly greater than 'precedence'
// precedence == 0 will parse _all_ expressions
// this is the core loop of 'top-down precedence parsing'
- Expr parseExp() { return parseExp(0); }
+ Expr parseExp() {
+ return parseExp(0);
+ }
Expr parseExp(int precedence) {
TreeRef prefix = nullptr;
int unary_prec;
auto kind = L.cur().kind;
auto pos = L.cur().range;
L.next();
- auto unary_kind = kind == '*' ? TK_STARRED :
- kind == '-' ? TK_UNARY_MINUS :
- kind;
+ auto unary_kind =
+ kind == '*' ? TK_STARRED : kind == '-' ? TK_UNARY_MINUS : kind;
auto subexp = parseExp(unary_prec);
// fold '-' into constant numbers, so that attributes can accept
// things like -1
- if(unary_kind == TK_UNARY_MINUS && subexp.kind() == TK_CONST) {
+ if (unary_kind == TK_UNARY_MINUS && subexp.kind() == TK_CONST) {
prefix = Const::create(subexp.range(), "-" + Const(subexp).text());
} else {
prefix = c(unary_kind, pos, {subexp});
}
return Expr(prefix);
}
- template<typename T>
+ template <typename T>
List<T> parseList(int begin, int sep, int end, T (ParserImpl::*parse)()) {
auto r = L.cur().range;
if (begin != TK_NOTHING)
StringLiteral parseConcatenatedStringLiterals() {
auto range = L.cur().range;
std::stringstream ss;
- while(L.cur().kind == TK_STRINGLITERAL) {
+ while (L.cur().kind == TK_STRINGLITERAL) {
auto literal_range = L.cur().range;
ss << parseStringLiteral(literal_range, L.next().text());
}
auto ident = parseIdent();
L.expect('=');
auto v = parseAttributeValue();
- attributes.push_back(Attribute::create(ident.range(), Ident(ident), v));
+ attributes.push_back(
+ Attribute::create(ident.range(), Ident(ident), v));
} else {
inputs.push_back(parseExp());
}
if (L.cur().kind != ',' && L.cur().kind != ']') {
second = parseExp();
}
- auto maybe_first = first ? Maybe<Expr>::create(range, Expr(first)) : Maybe<Expr>::create(range);
- auto maybe_second = second ? Maybe<Expr>::create(range, Expr(second)) : Maybe<Expr>::create(range);
+ auto maybe_first = first ? Maybe<Expr>::create(range, Expr(first))
+ : Maybe<Expr>::create(range);
+ auto maybe_second = second ? Maybe<Expr>::create(range, Expr(second))
+ : Maybe<Expr>::create(range);
return SliceExpr::create(range, maybe_first, maybe_second);
} else {
return Expr(first);
TreeRef parseSubscript(const TreeRef& value) {
const auto range = L.cur().range;
- auto subscript_exprs = parseList('[', ',', ']', &ParserImpl::parseSubscriptExp);
+ auto subscript_exprs =
+ parseList('[', ',', ']', &ParserImpl::parseSubscriptExp);
return Subscript::create(range, Expr(value), subscript_exprs);
}
} else {
def = Maybe<Expr>::create(L.cur().range);
}
- return Param::create(type->range(), Ident(ident), Expr(type), Maybe<Expr>(def));
+ return Param::create(
+ type->range(), Ident(ident), Expr(type), Maybe<Expr>(def));
}
Param parseBareTypeAnnotation() {
auto type = parseExp();
- return Param::create(type.range(), Ident::create(type.range(), ""), type, Maybe<Expr>::create(type.range()));
+ return Param::create(
+ type.range(),
+ Ident::create(type.range(), ""),
+ type,
+ Maybe<Expr>::create(type.range()));
}
Decl parseTypeComment() {
auto range = L.cur().range;
L.expect(TK_TYPE_COMMENT);
- auto param_types = parseList('(', ',', ')', &ParserImpl::parseBareTypeAnnotation);
+ auto param_types =
+ parseList('(', ',', ')', &ParserImpl::parseBareTypeAnnotation);
TreeRef return_type;
if (L.nextIf(TK_ARROW)) {
auto return_type_range = L.cur().range;
throw ErrorReport(lhs.range())
<< " augmented assignment can only have one LHS expression";
}
- return AugAssign::create(
- lhs.range(), lhs, AugAssignKind(op), Expr(rhs));
+ return AugAssign::create(lhs.range(), lhs, AugAssignKind(op), Expr(rhs));
}
}
return parseFor();
case TK_GLOBAL: {
auto range = L.next().range;
- auto idents = parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseIdent);
+ auto idents =
+ parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseIdent);
L.expect(TK_NEWLINE);
return Global::create(range, idents);
}
auto range = L.next().range;
auto cond = parseExp();
Maybe<Expr> maybe_first = Maybe<Expr>::create(range);
- if (L.nextIf(',')) {
+ if (L.nextIf(',')) {
auto msg = parseExp();
maybe_first = Maybe<Expr>::create(range, Expr(msg));
}
}
return list;
}
- TreeRef parseIf(bool expect_if=true) {
+ TreeRef parseIf(bool expect_if = true) {
auto r = L.cur().range;
if (expect_if)
L.expect(TK_IF);
auto range = L.cur().range;
false_branch = makeList(range, {parseIf(false)});
}
- return If::create(r, Expr(cond), List<Stmt>(true_branch), List<Stmt>(false_branch));
+ return If::create(
+ r, Expr(cond), List<Stmt>(true_branch), List<Stmt>(false_branch));
}
TreeRef parseWhile() {
auto r = L.cur().range;
TreeRef parseFor() {
auto r = L.cur().range;
L.expect(TK_FOR);
- auto targets = parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseExp);
+ auto targets =
+ parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseExp);
L.expect(TK_IN);
auto itrs = parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseExp);
L.expect(':');
return For::create(r, targets, itrs, body);
}
- TreeRef parseStatements(bool expect_indent=true) {
+ TreeRef parseStatements(bool expect_indent = true) {
auto r = L.cur().range;
if (expect_indent) {
L.expect(TK_INDENT);
TreeList stmts;
do {
stmts.push_back(parseStmt());
- } while(!L.nextIf(TK_DEDENT));
+ } while (!L.nextIf(TK_DEDENT));
return c(TK_LIST, r, std::move(stmts));
}
TreeRef return_type;
Maybe<Expr> return_annotation = parseReturnAnnotation();
L.expect(':');
- return Decl::create(paramlist.range(), List<Param>(paramlist), return_annotation);
+ return Decl::create(
+ paramlist.range(), List<Param>(paramlist), return_annotation);
}
TreeRef parseFunction(bool is_method) {
}
auto stmts_list = parseStatements(false);
- return Def::create(name.range(), Ident(name), Decl(decl),
- List<Stmt>(stmts_list));
+ return Def::create(
+ name.range(), Ident(name), Decl(decl), List<Stmt>(stmts_list));
}
Lexer& lexer() {
return L;
SharedParserData& shared;
};
-Parser::Parser(const std::string& src)
-: pImpl(new ParserImpl(src)) {}
+Parser::Parser(const std::string& src) : pImpl(new ParserImpl(src)) {}
Parser::~Parser() = default;
#pragma once
-#include <memory>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/script/tree.h>
+#include <memory>
namespace torch {
namespace jit {
struct ParserImpl;
struct Lexer;
-TORCH_API Decl mergeTypesFromTypeComment(const Decl& decl, const Decl& type_annotation_decl, bool is_method);
+TORCH_API Decl mergeTypesFromTypeComment(
+ const Decl& decl,
+ const Decl& type_annotation_decl,
+ bool is_method);
struct TORCH_API Parser {
explicit Parser(const std::string& str);
Decl parseTypeComment();
Lexer& lexer();
~Parser();
-private:
+
+ private:
std::unique_ptr<ParserImpl> pImpl;
};
namespace py = pybind11;
-namespace torch { namespace jit { namespace script {
+namespace torch {
+namespace jit {
+namespace script {
struct SourceRangeFactory {
SourceRangeFactory(std::string source)
- : source_(std::make_shared<std::string>(std::move(source))) {
+ : source_(std::make_shared<std::string>(std::move(source))) {
size_t pos = 0;
do {
line_len_prefix_sum_.push_back(pos);
} while ((pos = source_->find('\n', pos)) != std::string::npos);
}
SourceRange create(int line, int start_col, int end_col) {
- // Python has a weird convention where col_offset points to the column *before*
- // the token starts.
+ // Python has a weird convention where col_offset points to the column
+ // *before* the token starts.
start_col++;
end_col++;
// Also, lines are counted from 1.
std::vector<size_t> line_len_prefix_sum_;
};
-template<typename T>
+template <typename T>
List<T> wrap_list(const SourceRange& fallback_pos, std::vector<T>&& vec) {
if (vec.empty())
return List<T>::create(fallback_pos, std::move(vec));
return List<T>::create(vec.front().range(), std::move(vec));
}
-template<typename T>
+template <typename T>
Maybe<T> wrap_maybe(const SourceRange& fallback_pos, T* val) {
- return val ? Maybe<T>::create(val->range(), *val) : Maybe<T>::create(fallback_pos);
+ return val ? Maybe<T>::create(val->range(), *val)
+ : Maybe<T>::create(fallback_pos);
}
-void initTreeViewBindings(PyObject *module) {
+void initTreeViewBindings(PyObject* module) {
auto _C = py::handle(module).cast<py::module>();
auto m = _C.def_submodule("_jit_tree_views");
py::class_<SourceRange>(m, "SourceRange")
- .def("highlight", [](const SourceRange& self) {
- std::ostringstream stream;
- self.highlight(stream);
- return stream.str();
- })
- .def_property_readonly("start", &SourceRange::start)
- .def_property_readonly("end", &SourceRange::end);
+ .def(
+ "highlight",
+ [](const SourceRange& self) {
+ std::ostringstream stream;
+ self.highlight(stream);
+ return stream.str();
+ })
+ .def_property_readonly("start", &SourceRange::start)
+ .def_property_readonly("end", &SourceRange::end);
py::class_<SourceRangeFactory>(m, "SourceRangeFactory")
- .def(py::init<std::string&&>())
- .def("make_range", &SourceRangeFactory::create)
- .def("make_raw_range", [](const SourceRangeFactory& self, size_t start, size_t end) {
- return SourceRange(self.source_, start, end);
- })
- .def_property_readonly("source", [](const SourceRangeFactory& self) {
- return *self.source_;
- });
+ .def(py::init<std::string&&>())
+ .def("make_range", &SourceRangeFactory::create)
+ .def(
+ "make_raw_range",
+ [](const SourceRangeFactory& self, size_t start, size_t end) {
+ return SourceRange(self.source_, start, end);
+ })
+ .def_property_readonly("source", [](const SourceRangeFactory& self) {
+ return *self.source_;
+ });
py::class_<TreeView>(m, "TreeView")
- .def("range", &TreeView::range)
- .def("__str__", [](const TreeView& tree) {
- std::ostringstream stream;
- stream << tree.get();
- return stream.str();
- });
+ .def("range", &TreeView::range)
+ .def("__str__", [](const TreeView& tree) {
+ std::ostringstream stream;
+ stream << tree.get();
+ return stream.str();
+ });
py::class_<Ident, TreeView>(m, "Ident")
.def(py::init(&Ident::create))
"name", [](const Ident& self) { return self.name(); });
py::class_<Param, TreeView>(m, "Param")
- .def(py::init([](const Expr& type, const Ident& name) {
- return Param::create(name.range(), name, type, Maybe<Expr>::create(name.range()));
- }));
+ .def(py::init([](const Expr& type, const Ident& name) {
+ return Param::create(
+ name.range(), name, type, Maybe<Expr>::create(name.range()));
+ }));
py::class_<Attribute, TreeView>(m, "Attribute")
- .def(py::init([](const Ident& name, const Expr& value) {
- return Attribute::create(name.range(), name, value);
- }));
+ .def(py::init([](const Ident& name, const Expr& value) {
+ return Attribute::create(name.range(), name, value);
+ }));
m.def("TrueLiteral", [](const SourceRange& range) {
return Expr(Compound::create(TK_TRUE, range, {}));
});
py::class_<Stmt, TreeView>(m, "Stmt"); // NOLINT(bugprone-unused-raii)
py::class_<Expr, TreeView>(m, "Expr"); // NOLINT(bugprone-unused-raii)
- py::class_<Def, TreeView>(m, "Def")
- .def(py::init([](const Ident& name,
- Decl decl,
- std::vector<Stmt> body) {
- const auto& r = name.range();
- return Def::create(r,
- name,
- decl,
- wrap_list(r, std::move(body)));
- }));
- py::class_<Decl, TreeView>(m, "Decl")
- .def(py::init([](const SourceRange& r,
- std::vector<Param> params,
- Expr *return_type) {
- return Decl::create(r, wrap_list(r, std::move(params)), wrap_maybe(r, return_type));
- }));
-
+ py::class_<Def, TreeView>(m, "Def").def(
+ py::init([](const Ident& name, Decl decl, std::vector<Stmt> body) {
+ const auto& r = name.range();
+ return Def::create(r, name, decl, wrap_list(r, std::move(body)));
+ }));
+ py::class_<Decl, TreeView>(m, "Decl").def(py::init(
+ [](const SourceRange& r, std::vector<Param> params, Expr* return_type) {
+ return Decl::create(
+ r, wrap_list(r, std::move(params)), wrap_maybe(r, return_type));
+ }));
py::class_<Assign, Stmt>(m, "Assign")
- .def(py::init([](const Expr& lhs, const Expr& rhs) {
- return Assign::create(lhs.range(), lhs, rhs);
- }));
+ .def(py::init([](const Expr& lhs, const Expr& rhs) {
+ return Assign::create(lhs.range(), lhs, rhs);
+ }));
py::class_<AugAssign, Stmt>(m, "AugAssign")
- .def(py::init([](const Expr& lhs, std::string kind_str, const Expr& rhs) {
- const auto& r = lhs.range();
- auto kind = AugAssignKind(Compound::create(stringToKind(kind_str), r, {}));
- return AugAssign::create(r, lhs, kind, rhs);
- }));
+ .def(py::init([](const Expr& lhs, std::string kind_str, const Expr& rhs) {
+ const auto& r = lhs.range();
+ auto kind =
+ AugAssignKind(Compound::create(stringToKind(kind_str), r, {}));
+ return AugAssign::create(r, lhs, kind, rhs);
+ }));
py::class_<Return, Stmt>(m, "Return")
- .def(py::init([](const SourceRange& range, Expr* value) {
- return Return::create(range, value ? *value : Expr(Compound::create(TK_NONE, range, {})));
- }));
+ .def(py::init([](const SourceRange& range, Expr* value) {
+ return Return::create(
+ range, value ? *value : Expr(Compound::create(TK_NONE, range, {})));
+ }));
py::class_<Raise, Stmt>(m, "Raise")
- .def(py::init([](const SourceRange& range, Expr *expr) {
- return Raise::create(range, wrap_maybe(range, expr));
- }));
+ .def(py::init([](const SourceRange& range, Expr* expr) {
+ return Raise::create(range, wrap_maybe(range, expr));
+ }));
py::class_<Assert, Stmt>(m, "Assert")
- .def(py::init([](const SourceRange& range, const Expr& test, Expr *msg) {
- return Assert::create(range, test, wrap_maybe(range, msg));
- }));
- py::class_<Pass, Stmt>(m, "Pass")
- .def(py::init([](const SourceRange& range) {
- return Pass::create(range);
- }));
- py::class_<If, Stmt>(m, "If")
- .def(py::init([](const SourceRange& range, const Expr& cond, std::vector<Stmt> true_branch, std::vector<Stmt> false_branch) {
- return If::create(range, cond,
- wrap_list(range, std::move(true_branch)),
- wrap_list(range, std::move(false_branch)));
- }));
+ .def(py::init([](const SourceRange& range, const Expr& test, Expr* msg) {
+ return Assert::create(range, test, wrap_maybe(range, msg));
+ }));
+ py::class_<Pass, Stmt>(m, "Pass").def(
+ py::init([](const SourceRange& range) { return Pass::create(range); }));
+ py::class_<If, Stmt>(m, "If").def(
+ py::init([](const SourceRange& range,
+ const Expr& cond,
+ std::vector<Stmt> true_branch,
+ std::vector<Stmt> false_branch) {
+ return If::create(
+ range,
+ cond,
+ wrap_list(range, std::move(true_branch)),
+ wrap_list(range, std::move(false_branch)));
+ }));
py::class_<While, Stmt>(m, "While")
- .def(py::init([](const SourceRange& range, const Expr& cond, std::vector<Stmt> body) {
- return While::create(range, cond, wrap_list(range, std::move(body)));
- }));
+ .def(py::init([](const SourceRange& range,
+ const Expr& cond,
+ std::vector<Stmt> body) {
+ return While::create(range, cond, wrap_list(range, std::move(body)));
+ }));
py::class_<For, Stmt>(m, "For").def(py::init([](const SourceRange range,
std::vector<Expr>& targets,
std::vector<Expr>& itrs,
wrap_list(range, std::move(itrs)),
wrap_list(range, std::move(body)));
}));
- py::class_<ExprStmt, Stmt>(m, "ExprStmt")
- .def(py::init([](const Expr& expr) {
- return ExprStmt::create(expr.range(), expr);
- }));
+ py::class_<ExprStmt, Stmt>(m, "ExprStmt").def(py::init([](const Expr& expr) {
+ return ExprStmt::create(expr.range(), expr);
+ }));
py::class_<Var, Expr>(m, "Var")
- .def(py::init([](const Ident& name) {
- return Var::create(name.range(), name);
- }))
- .def_property_readonly("name", [](const Var& var) { return var.name(); });
+ .def(py::init(
+ [](const Ident& name) { return Var::create(name.range(), name); }))
+ .def_property_readonly("name", [](const Var& var) { return var.name(); });
py::class_<BinOp, Expr>(m, "BinOp")
- .def(py::init([](std::string kind, const Expr& lhs, const Expr& rhs) {
- return BinOp::create(lhs.range(), stringToKind(kind), lhs, rhs);
- }));
- // NB: we take range here, because unary ops precede their exprs, so we need to include them
+ .def(py::init([](std::string kind, const Expr& lhs, const Expr& rhs) {
+ return BinOp::create(lhs.range(), stringToKind(kind), lhs, rhs);
+ }));
+ // NB: we take range here, because unary ops precede their exprs, so we need
+ // to include them
py::class_<UnaryOp, Expr>(m, "UnaryOp")
- .def(py::init([](const SourceRange& range, std::string kind, const Expr& expr) {
- auto resolved_kind = stringToKind(kind);
- resolved_kind = resolved_kind == '-' ? TK_UNARY_MINUS : resolved_kind;
- return UnaryOp::create(range, resolved_kind, expr);
- }));
+ .def(py::init(
+ [](const SourceRange& range, std::string kind, const Expr& expr) {
+ auto resolved_kind = stringToKind(kind);
+ resolved_kind =
+ resolved_kind == '-' ? TK_UNARY_MINUS : resolved_kind;
+ return UnaryOp::create(range, resolved_kind, expr);
+ }));
py::class_<Const, Expr>(m, "Const")
- .def(py::init([](const SourceRange& range, std::string value) {
- return Const::create(range, value);
- }));
+ .def(py::init([](const SourceRange& range, std::string value) {
+ return Const::create(range, value);
+ }));
py::class_<StringLiteral, Expr>(m, "StringLiteral")
- .def(py::init([](const SourceRange& range, std::string value) {
- return StringLiteral::create(range, value);
- }));
+ .def(py::init([](const SourceRange& range, std::string value) {
+ return StringLiteral::create(range, value);
+ }));
py::class_<Apply, Expr>(m, "Apply")
- .def(py::init([](const Expr& expr, std::vector<Expr> args, std::vector<Attribute> kwargs) {
- const auto& r = expr.range();
- return Apply::create(expr.range(), expr,
- wrap_list(r, std::move(args)), wrap_list(r, std::move(kwargs)));
- }));
+ .def(py::init([](const Expr& expr,
+ std::vector<Expr> args,
+ std::vector<Attribute> kwargs) {
+ const auto& r = expr.range();
+ return Apply::create(
+ expr.range(),
+ expr,
+ wrap_list(r, std::move(args)),
+ wrap_list(r, std::move(kwargs)));
+ }));
py::class_<Select, Expr>(m, "Select")
- .def(py::init([](const Expr& expr, const Ident& field) {
- const auto& r = expr.range();
- return Select::create(expr.range(), expr, field);
- }));
+ .def(py::init([](const Expr& expr, const Ident& field) {
+ const auto& r = expr.range();
+ return Select::create(expr.range(), expr, field);
+ }));
py::class_<TernaryIf, Expr>(m, "TernaryIf")
- .def(py::init([](const Expr& cond, const Expr& true_expr, const Expr& false_expr) {
- return TernaryIf::create(cond.range(), cond, true_expr, false_expr);
- }));
+ .def(py::init(
+ [](const Expr& cond, const Expr& true_expr, const Expr& false_expr) {
+ return TernaryIf::create(cond.range(), cond, true_expr, false_expr);
+ }));
py::class_<ListLiteral, Expr>(m, "ListLiteral")
- .def(py::init([](const SourceRange& range, std::vector<Expr> args) {
- return ListLiteral::create(range, wrap_list(range, std::move(args)));
- }));
+ .def(py::init([](const SourceRange& range, std::vector<Expr> args) {
+ return ListLiteral::create(range, wrap_list(range, std::move(args)));
+ }));
py::class_<TupleLiteral, Expr>(m, "TupleLiteral")
- .def(py::init([](const SourceRange& range, std::vector<Expr> args) {
- return TupleLiteral::create(range, wrap_list(range, std::move(args)));
- }));
+ .def(py::init([](const SourceRange& range, std::vector<Expr> args) {
+ return TupleLiteral::create(range, wrap_list(range, std::move(args)));
+ }));
py::class_<Subscript, Expr>(m, "Subscript")
- .def(py::init([](const Expr& base, std::vector<Expr> subscript_exprs) {
- return Subscript::create(base.range(), base, wrap_list(base.range(), std::move(subscript_exprs)));
- }));
+ .def(py::init([](const Expr& base, std::vector<Expr> subscript_exprs) {
+ return Subscript::create(
+ base.range(),
+ base,
+ wrap_list(base.range(), std::move(subscript_exprs)));
+ }));
py::class_<SliceExpr, Expr>(m, "SliceExpr")
- .def(py::init([](const SourceRange& range, Expr *lower, Expr *upper) {
- return SliceExpr::create(range, wrap_maybe(range, lower), wrap_maybe(range, upper));
- }));
+ .def(py::init([](const SourceRange& range, Expr* lower, Expr* upper) {
+ return SliceExpr::create(
+ range, wrap_maybe(range, lower), wrap_maybe(range, upper));
+ }));
py::class_<Starred, Expr>(m, "Starred")
- .def(py::init([](const SourceRange& range, Expr expr){
- return Starred::create(range, expr);
- }));
+ .def(py::init([](const SourceRange& range, Expr expr) {
+ return Starred::create(range, expr);
+ }));
}
-}}} // namespace torch::jit::script
+} // namespace script
+} // namespace jit
+} // namespace torch
#include <torch/csrc/python_headers.h>
-namespace torch { namespace jit { namespace script {
+namespace torch {
+namespace jit {
+namespace script {
-void initTreeViewBindings(PyObject *module);
-
-}}} // namespace torch::jit::script
+void initTreeViewBindings(PyObject* module);
+}
+} // namespace jit
+} // namespace torch
-#include <torch/csrc/jit/script/schema_matching.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/script/builtin_functions.h>
#include <torch/csrc/jit/script/error_report.h>
+#include <torch/csrc/jit/script/schema_matching.h>
namespace torch {
namespace jit {
inline bool convertibleToList(const TypePtr& type, const TypePtr& list_type_) {
auto list_type = list_type_->cast<ListType>();
- if(!list_type) {
+ if (!list_type) {
return false;
}
- if(type->isSubtypeOf(list_type_)) {
+ if (type->isSubtypeOf(list_type_)) {
return true;
}
- if(auto tuple = type->cast<TupleType>()) {
+ if (auto tuple = type->cast<TupleType>()) {
return std::all_of(
tuple->elements().begin(),
tuple->elements().end(),
return false;
}
-// applies implict conversion from value trying to turn it into type concrete_type
-// it succeeds if the return_value->isSubclassOf(concrete_type)
+// applies implict conversion from value trying to turn it into type
+// concrete_type it succeeds if the return_value->isSubclassOf(concrete_type)
Value* tryConvertToType(
const SourceRange& loc,
Graph& graph,
const TypePtr& concrete_type,
Value* value,
bool allow_conversions) {
-
if (auto value_tuple = value->type()->cast<TupleType>()) {
// Allow homogeneous tuples to be casted implicitly to lists of appropriate
// types
if (convertibleToList(value->type(), unwrapOptional(concrete_type))) {
auto unpacked = createTupleUnpack(value);
- auto elem_type = unwrapOptional(concrete_type)->expect<ListType>()->getElementType();
+ auto elem_type =
+ unwrapOptional(concrete_type)->expect<ListType>()->getElementType();
value = graph.insertNode(graph.createList(elem_type, unpacked))->output();
}
// inductively apply implicit conversions to tuples
}
}
- if (value->type()->isSubtypeOf(NoneType::get()) && !concrete_type->isSubtypeOf(NoneType::get())){
+ if (value->type()->isSubtypeOf(NoneType::get()) &&
+ !concrete_type->isSubtypeOf(NoneType::get())) {
if (concrete_type->isSubtypeOf(OptionalType::ofTensor())) {
// create undefined tensor when None pass to a optional[tensor] formal arg
value = graph.insertNode(graph.createUndefined())->output();
} else if (auto optional_type = concrete_type->cast<OptionalType>()) {
- value = graph.insertNode(graph.createNone(optional_type->getElementType()))->output();
+ value =
+ graph.insertNode(graph.createNone(optional_type->getElementType()))
+ ->output();
}
}
- //implicit conversions
- if(allow_conversions) {
- if(concrete_type->isSubtypeOf(NumberType::get())
- && value->type()->isSubtypeOf(DynamicType::get())) {
+ // implicit conversions
+ if (allow_conversions) {
+ if (concrete_type->isSubtypeOf(NumberType::get()) &&
+ value->type()->isSubtypeOf(DynamicType::get())) {
auto n = graph.createImplicitTensorToNum(concrete_type, value);
value = graph.insertNode(n)
- ->setSourceLocation(std::make_shared<SourceRange>(loc))
- ->output();
+ ->setSourceLocation(std::make_shared<SourceRange>(loc))
+ ->output();
}
if (value->type()->isSubtypeOf(StringType::get()) &&
- DeviceObjType::get()->isSubtypeOf(concrete_type)) {
- return graph.insert(aten::device, { value }, {}, loc);
+ DeviceObjType::get()->isSubtypeOf(concrete_type)) {
+ return graph.insert(aten::device, {value}, {}, loc);
}
}
const NamedValue& named_value,
const std::function<std::ostream&()>& err,
bool allow_conversions,
- TypeEnv & type_env) {
+ TypeEnv& type_env) {
Value* value = named_value.value(graph);
// some functions that take lists of integers or floats for fixed size arrays
// the single int/float is then repeated to the length of the list
if (isIntOrFloatUsedAsList(value, arg)) {
std::vector<Value*> repeated(*arg.N(), value);
- value = graph.insertNode(graph.createList(value->type(), repeated))->output();
+ value =
+ graph.insertNode(graph.createList(value->type(), repeated))->output();
}
const MatchTypeReturn matched_type =
value = tryConvertToType(loc, graph, concrete_type, value, allow_conversions);
- if(!value->type()->isSubtypeOf(concrete_type)) {
- err() << "expected a value of type " << concrete_type->str() << " for argument '" << arg.name() << "' but found "
+ if (!value->type()->isSubtypeOf(concrete_type)) {
+ err() << "expected a value of type " << concrete_type->str()
+ << " for argument '" << arg.name() << "' but found "
<< value->type()->str() << "\n"
<< named_value.locOr(loc);
return nullptr;
c10::optional<size_t> findInputWithName(
const std::string& name,
at::ArrayRef<NamedValue> kwargs) {
- for(size_t i = 0; i < kwargs.size(); ++i) {
- if(kwargs[i].name() == name)
+ for (size_t i = 0; i < kwargs.size(); ++i) {
+ if (kwargs[i].name() == name)
return i;
}
return c10::nullopt;
at::ArrayRef<NamedValue> varargs,
const std::function<std::ostream&()>& err,
bool convert_tensor_to_num,
- TypeEnv & type_env) {
+ TypeEnv& type_env) {
Argument elem_arg("<varargs>", elem_type);
std::vector<Value*> list_ctor;
- for(const auto& a : varargs) {
- Value* av = tryMatchArgument(elem_arg, graph, loc, a, err, convert_tensor_to_num, type_env);
- if(!av)
+ for (const auto& a : varargs) {
+ Value* av = tryMatchArgument(
+ elem_arg, graph, loc, a, err, convert_tensor_to_num, type_env);
+ if (!av)
return nullptr;
list_ctor.push_back(av);
}
self = c10::nullopt;
} else if (!arg.kwarg_only() && used_args < args.size()) {
// allow zeros(IntList sizes) to work with zeros(1, 2) or zeros(1)
- if (arg.type()->kind() == TypeKind::ListType && // the formal must be a list
- !arg.N() && // it must not be a broadcasting list like int[3], otherwise
- // a single int is a valid input
+ if (arg.type()->kind() ==
+ TypeKind::ListType && // the formal must be a list
+ !arg.N() && // it must not be a broadcasting list like int[3],
+ // otherwise a single int is a valid input
(schema_i + 1 == schema.arguments().size() ||
schema.arguments()[schema_i + 1]
.kwarg_only())) { // must be the last position argument
if (actual_type->kind() != TypeKind::ListType &&
!convertibleToList(
actual_type,
- unwrapOptional(arg.type()))) { // and the actual should not be a list already
- auto elem_type = unwrapOptional(arg.type())->expect<ListType>()->getElementType();
+ unwrapOptional(arg.type()))) { // and the actual should not be a
+ // list already
+ auto elem_type =
+ unwrapOptional(arg.type())->expect<ListType>()->getElementType();
Value* list = tryCreateList(
elem_type,
graph,
<< loc;
return c10::nullopt;
}
- Value* positional = tryMatchArgument(
- arg, graph, loc, *v, err, allow_conversions, type_env);
+ Value* positional =
+ tryMatchArgument(arg, graph, loc, *v, err, allow_conversions, type_env);
if (!positional)
return c10::nullopt;
positional_inputs.push_back(positional);
}
// check for unused self argument
- if(self != c10::nullopt) {
+ if (self != c10::nullopt) {
err() << "provided self argument not used in schema\n";
}
if (schema.is_vararg()) {
- for(;used_args < args.size(); ++used_args) {
+ for (; used_args < args.size(); ++used_args) {
positional_inputs.push_back(args[used_args].value(graph));
}
}
return MatchedSchema{std::move(positional_inputs), std::move(return_types)};
}
-
-// pack outputs of a function following python rules. If there is a single value return
-// a SimpleValue, otherwise pack all the values into a Tuple.
+// pack outputs of a function following python rules. If there is a single value
+// return a SimpleValue, otherwise pack all the values into a Tuple.
Value* packOutputs(Graph& g, at::ArrayRef<Value*> values) {
- if(values.size() == 1) {
+ if (values.size() == 1) {
return values[0];
}
return g.insertNode(g.createTuple(values))->output();
Graph& graph,
Symbol name) {
auto n = graph.insertNode(graph.create(name, matched_schema.inputs, 0))
- ->setSourceLocation(std::make_shared<SourceRange>(loc));
+ ->setSourceLocation(std::make_shared<SourceRange>(loc));
- for(auto & ret : matched_schema.return_types) {
+ for (auto& ret : matched_schema.return_types) {
n->addOutput()->setType(ret);
}
return packOutputs(graph, n->outputs());
}
-static std::string prefixLine(const std::string& str, const std::string& prefix) {
+static std::string prefixLine(
+ const std::string& str,
+ const std::string& prefix) {
std::stringstream ss;
bool was_newline = true;
- for(auto c : str) {
- if(was_newline)
+ for (auto c : str) {
+ if (was_newline)
ss << prefix;
ss.put(c);
was_newline = c == '\n';
// Search for operators matching the provided symbol name and input types.
// If one is found, emit a node to the graph for that operator.
Value* emitBuiltinCall(
- const SourceRange& loc,
- Graph& graph,
- Symbol name,
- const c10::optional<NamedValue>& self,
- at::ArrayRef<NamedValue> inputs,
- at::ArrayRef<NamedValue> attributes,
- // if true, emitBuiltinCall will throw an exception if this builtin does not exist,
- // otherwise it will return nullptr if the builtin is not found.
- bool required) {
-
-
+ const SourceRange& loc,
+ Graph& graph,
+ Symbol name,
+ const c10::optional<NamedValue>& self,
+ at::ArrayRef<NamedValue> inputs,
+ at::ArrayRef<NamedValue> attributes,
+ // if true, emitBuiltinCall will throw an exception if this builtin does not
+ // exist, otherwise it will return nullptr if the builtin is not found.
+ bool required) {
const auto& variants = getAllOperatorsFor(name);
const auto& builtin_functions = getAllBuiltinFunctionsFor(name);
std::stringstream failure_messages;
- //first we try to match the schema without any conversion
- //if no schema matches then insert ImplicitTensorToNum
+ // first we try to match the schema without any conversion
+ // if no schema matches then insert ImplicitTensorToNum
for (bool allow_conversions : {false, true}) {
// clear previous error messages
failure_messages.str("");
if (!required) {
return nullptr;
}
- if(variants.size() == 0) {
+ if (variants.size() == 0) {
throw ErrorReport(loc) << "unknown builtin op";
}
throw ErrorReport(loc) << "arguments for call are not valid:\n"
<< "for call at";
}
-
} // namespace script
} // namespace jit
} // namespace torch
#pragma once
-#include <torch/csrc/jit/type.h>
-#include <torch/csrc/jit/named_value.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/function_schema.h>
+#include <torch/csrc/jit/named_value.h>
+#include <torch/csrc/jit/type.h>
namespace torch {
namespace jit {
namespace script {
- // try to match a list if inputs and keyword 'attributes' to this schema,
- // if it works return the flat list of positional inputs to the call
- // if it returns nullopt, then failure_messages contains a good error report
- // set convert_tensor_to_num to true if ImplicitTensorToNums should be inserted to
- // match the schema
+// try to match a list if inputs and keyword 'attributes' to this schema,
+// if it works return the flat list of positional inputs to the call
+// if it returns nullopt, then failure_messages contains a good error report
+// set convert_tensor_to_num to true if ImplicitTensorToNums should be inserted
+// to match the schema
struct MatchedSchema {
std::vector<Value*> inputs;
};
TORCH_API c10::optional<MatchedSchema> tryMatchSchema(
- const FunctionSchema& schema,
- const SourceRange& loc,
- Graph& graph,
- c10::optional<NamedValue> self,
- at::ArrayRef<NamedValue> inputs,
- at::ArrayRef<NamedValue> attributes,
- std::ostream& failure_messages,
- bool allow_conversions);
+ const FunctionSchema& schema,
+ const SourceRange& loc,
+ Graph& graph,
+ c10::optional<NamedValue> self,
+ at::ArrayRef<NamedValue> inputs,
+ at::ArrayRef<NamedValue> attributes,
+ std::ostream& failure_messages,
+ bool allow_conversions);
TORCH_API Value* emitBuiltinCall(
- const SourceRange& loc,
- Graph& graph,
- Symbol name,
- const c10::optional<NamedValue>& self,
- at::ArrayRef<NamedValue> inputs,
- at::ArrayRef<NamedValue> attributes,
- // if true, emitBuiltinCall will throw an exception if this builtin does not exist,
- // otherwise it will return nullptr if the builtin is not found.
- bool required);
+ const SourceRange& loc,
+ Graph& graph,
+ Symbol name,
+ const c10::optional<NamedValue>& self,
+ at::ArrayRef<NamedValue> inputs,
+ at::ArrayRef<NamedValue> attributes,
+ // if true, emitBuiltinCall will throw an exception if this builtin does not
+ // exist, otherwise it will return nullptr if the builtin is not found.
+ bool required);
TORCH_API c10::optional<size_t> findInputWithName(
- const std::string& name,
- at::ArrayRef<NamedValue> kwargs);
+ const std::string& name,
+ at::ArrayRef<NamedValue> kwargs);
-// applies implict conversion from value trying to turn it into type concrete_type
-// it succeeds if the return_value->isSubclassOf(concrete_type)
+// applies implict conversion from value trying to turn it into type
+// concrete_type it succeeds if the return_value->isSubclassOf(concrete_type)
TORCH_API Value* tryConvertToType(
const SourceRange& loc,
Graph& graph,
Value* value,
bool allow_conversions);
-}
+} // namespace script
} // namespace jit
} // namespace torch
-#include <torch/csrc/jit/script/type_parser.h>
#include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/script/tree_views.h>
#include <torch/csrc/jit/script/sugared_value.h>
+#include <torch/csrc/jit/script/tree_views.h>
+#include <torch/csrc/jit/script/type_parser.h>
namespace torch {
namespace jit {
};
std::shared_ptr<SugaredValue> PrintValue::call(
- const SourceRange& loc,
- Method & m,
- at::ArrayRef<NamedValue> inputs,
- at::ArrayRef<NamedValue> attributes,
- size_t n_binders) {
- auto& g = *m.graph();
- if (!attributes.empty())
- throw ErrorReport(loc) << "print doesn't accept any keyword arguments";
+ const SourceRange& loc,
+ Method& m,
+ at::ArrayRef<NamedValue> inputs,
+ at::ArrayRef<NamedValue> attributes,
+ size_t n_binders) {
+ auto& g = *m.graph();
+ if (!attributes.empty())
+ throw ErrorReport(loc) << "print doesn't accept any keyword arguments";
- //temporary hack to allow print statements to work in python 2, where
- //print(a, b) is treated as a (a, b) tuple input.
+ // temporary hack to allow print statements to work in python 2, where
+ // print(a, b) is treated as a (a, b) tuple input.
- std::vector<Value*> lowered_inputs = toValues(*m.graph(), inputs);
- if(lowered_inputs.size() == 1 && lowered_inputs.at(0)->node()->kind() == prim::TupleConstruct) {
- auto input = lowered_inputs[0];
- for(size_t j = 0; j < input->node()->inputs().size(); ++j) {
- lowered_inputs.insert(lowered_inputs.begin() + 1 + j, input->node()->inputs().at(j));
- }
- lowered_inputs.erase(lowered_inputs.begin());
+ std::vector<Value*> lowered_inputs = toValues(*m.graph(), inputs);
+ if (lowered_inputs.size() == 1 &&
+ lowered_inputs.at(0)->node()->kind() == prim::TupleConstruct) {
+ auto input = lowered_inputs[0];
+ for (size_t j = 0; j < input->node()->inputs().size(); ++j) {
+ lowered_inputs.insert(
+ lowered_inputs.begin() + 1 + j, input->node()->inputs().at(j));
}
- g.insertNode(g.create(prim::Print, lowered_inputs, 0)
- ->setSourceLocation(std::make_shared<SourceRange>(loc)));
- return std::make_shared<NoneValue>();
+ lowered_inputs.erase(lowered_inputs.begin());
+ }
+ g.insertNode(g.create(prim::Print, lowered_inputs, 0)
+ ->setSourceLocation(std::make_shared<SourceRange>(loc)));
+ return std::make_shared<NoneValue>();
}
-static const std::unordered_map<std::string, std::string> &builtin_cast_methods() {
+static const std::unordered_map<std::string, std::string>&
+builtin_cast_methods() {
static std::unordered_map<std::string, std::string> builtin_cast_methods = {
- {"byte", "_cast_Byte"},
- {"char", "_cast_Char"},
- {"double", "_cast_Double"},
- {"float", "_cast_Float"},
- {"int", "_cast_Int"},
- {"long", "_cast_Long"},
- {"short", "_cast_Short"},
- {"half", "_cast_Half"}
- };
+ {"byte", "_cast_Byte"},
+ {"char", "_cast_Char"},
+ {"double", "_cast_Double"},
+ {"float", "_cast_Float"},
+ {"int", "_cast_Int"},
+ {"long", "_cast_Long"},
+ {"short", "_cast_Short"},
+ {"half", "_cast_Half"}};
return builtin_cast_methods;
}
// support syntax sugar for x.foo(y, z) by allowing x.foo to return a
// callable value that will resolve to foo(x, y, z) when called.
-std::shared_ptr<SugaredValue> SimpleValue::attr(const SourceRange& loc, Method & m, const std::string& field) {
+std::shared_ptr<SugaredValue> SimpleValue::attr(
+ const SourceRange& loc,
+ Method& m,
+ const std::string& field) {
// Allow method-style casts on Tensor types. e.g. x.int()
if (value->type()->isSubtypeOf(DynamicType::get())) {
if (builtin_cast_methods().count(field)) {
// functions that are just direct property lookups on tensor
// must be registered as prim::<name>(Tensor t) -> <return_type>
static const std::unordered_set<std::string> fields = {
- "dtype",
- "device",
- "shape",
- "is_cuda",
- "requires_grad",
+ "dtype",
+ "device",
+ "shape",
+ "is_cuda",
+ "requires_grad",
};
if (fields.count(field)) {
- auto r = m.graph()->insert(Symbol::fromQualString("prim::"+field), {value});
+ auto r =
+ m.graph()->insert(Symbol::fromQualString("prim::" + field), {value});
return std::make_shared<SimpleValue>(r);
}
}
const SourceRange& loc,
Method& m,
const c10::optional<size_t>& size_hint) {
- static const auto make_simple_value = [](Value* v) -> std::shared_ptr<SugaredValue> {
+ static const auto make_simple_value =
+ [](Value* v) -> std::shared_ptr<SugaredValue> {
return std::make_shared<SimpleValue>(v);
};
- if(value->type()->kind() == TypeKind::TupleType) {
+ if (value->type()->kind() == TypeKind::TupleType) {
auto outputs = createTupleUnpack(value);
return fmap(outputs, make_simple_value);
} else if (value->type()->kind() == TypeKind::ListType) {
if (!size_hint) {
- throw ErrorReport(loc) << "cannot statically infer the expected size of a list in this context";
+ throw ErrorReport(loc)
+ << "cannot statically infer the expected size of a list in this context";
}
auto graph = value->owningGraph();
- Node *unpack = graph->insertNode(graph->createListUnpack(value, *size_hint));
+ Node* unpack =
+ graph->insertNode(graph->createListUnpack(value, *size_hint));
return fmap(unpack->outputs(), make_simple_value);
}
- throw ErrorReport(loc) << value->type()->str() << " cannot be used as a tuple";
+ throw ErrorReport(loc) << value->type()->str()
+ << " cannot be used as a tuple";
}
} // namespace script
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/script/error_report.h>
-#include <torch/csrc/jit/script/tree_views.h>
#include <torch/csrc/jit/script/module.h>
+#include <torch/csrc/jit/script/tree_views.h>
namespace torch {
namespace jit {
// that separates their behavior from the AST -> IR converter itself.
// This allows us to keep dependencies on python minimal.
-enum NoneStatus {
- ALWAYS,
- MAYBE,
- NEVER
-};
+enum NoneStatus { ALWAYS, MAYBE, NEVER };
struct SugaredValue : public std::enable_shared_from_this<SugaredValue> {
// what is this node? for error reporting (e.g. Module, python function)
// what can we do with this thing?
// use it as a value e.g. `this + 4`
- virtual Value * asValue(const SourceRange& loc, Method & m) {
+ virtual Value* asValue(const SourceRange& loc, Method& m) {
throw ErrorReport(loc) << kind() << " cannot be used as a value";
}
// select an attribute on it, e.g. `this.field`
- virtual std::shared_ptr<SugaredValue> attr(const SourceRange& loc, Method & m, const std::string& field) {
+ virtual std::shared_ptr<SugaredValue> attr(
+ const SourceRange& loc,
+ Method& m,
+ const std::string& field) {
throw ErrorReport(loc) << "attribute lookup is not defined on " << kind();
}
virtual NoneStatus isNone() {
// call it like a function, e.g. `outputs = this(inputs)`
virtual std::shared_ptr<SugaredValue> call(
- const SourceRange& loc,
- Method & m,
- // note: names for args will be 'argument 0', 'argument 1', etc..
- at::ArrayRef<NamedValue> inputs_,
- at::ArrayRef<NamedValue> attributes,
- size_t n_binders) {
-// n_binders is always set to the number of variables an expression is
-// syntactically bound to:
-// a = foo() # 1 binder (note in this case the single binder might be a tuple)
-// a, * b = foo() # 1 binder
-// a, b = foo() # 2 binders
-// foo() # 0 binders
-//
-// In subexpressions, like bar() in foo(bar()), n_binders is always set to
-// 1. n_binders is used as a hint to subexpressions to determine how many
-// values they should return when that number is ambiguous statically. In
-// particular it is currently used to decide how many tensors a call to a
-// python function will return. It is only a hint, functions do not have to
-// check that n_binders match the number of things they are returning, the
-// assignment logic will do that anyway.
+ const SourceRange& loc,
+ Method& m,
+ // note: names for args will be 'argument 0', 'argument 1', etc..
+ at::ArrayRef<NamedValue> inputs_,
+ at::ArrayRef<NamedValue> attributes,
+ size_t n_binders) {
+ // n_binders is always set to the number of variables an expression is
+ // syntactically bound to:
+ // a = foo() # 1 binder (note in this case the single binder might be a
+ // tuple) a, * b = foo() # 1 binder a, b = foo() # 2 binders foo() # 0
+ // binders
+ //
+ // In subexpressions, like bar() in foo(bar()), n_binders is always set to
+ // 1. n_binders is used as a hint to subexpressions to determine how many
+ // values they should return when that number is ambiguous statically. In
+ // particular it is currently used to decide how many tensors a call to a
+ // python function will return. It is only a hint, functions do not have to
+ // check that n_binders match the number of things they are returning, the
+ // assignment logic will do that anyway.
throw ErrorReport(loc) << "cannot call a " << kind();
}
// most things in the environment are just simple value types
// and not special python syntax sugar types
struct TORCH_API SimpleValue : public SugaredValue {
- SimpleValue(Value * value)
- : value(value) {}
+ SimpleValue(Value* value) : value(value) {}
std::string kind() const override {
return "value";
}
- Value * asValue(const SourceRange& range, Method & m) override {
+ Value* asValue(const SourceRange& range, Method& m) override {
return value;
}
NoneStatus isNone() override {
const SourceRange& loc,
Method& m,
const c10::optional<size_t>& size_hint = {}) override;
- std::shared_ptr<SugaredValue> attr(const SourceRange& loc, Method & m, const std::string& field) override;
+ std::shared_ptr<SugaredValue> attr(
+ const SourceRange& loc,
+ Method& m,
+ const std::string& field) override;
Value* getValue() const {
return value;
}
-private:
+
+ private:
Value* value;
};
};
struct TORCH_API BuiltinModule : public SugaredValue {
- BuiltinModule(std::string name,
- c10::optional<int64_t> version = at::nullopt)
- : name(std::move(name))
- , version(std::move(version)) {}
+ BuiltinModule(std::string name, c10::optional<int64_t> version = at::nullopt)
+ : name(std::move(name)), version(std::move(version)) {}
std::string kind() const override {
return "builtin module";
}
- std::shared_ptr<SugaredValue> attr(const SourceRange& loc, Method & m, const std::string& field) override {
- return std::make_shared<BuiltinFunction>(Symbol::fromQualString(name+"::"+field), c10::nullopt);
+ std::shared_ptr<SugaredValue> attr(
+ const SourceRange& loc,
+ Method& m,
+ const std::string& field) override {
+ return std::make_shared<BuiltinFunction>(
+ Symbol::fromQualString(name + "::" + field), c10::nullopt);
}
-private:
+ private:
std::string name;
// when we add operator versioning, emit this op as it exising at 'version'
// if not set, use the latest version
// defines how a method obtained from a module behaves in script
struct MethodValue : public SugaredValue {
MethodValue(std::shared_ptr<Module> module, Method& method)
- : module(std::move(module)) //insurance that method stays alive
- , method(method) {}
+ : module(std::move(module)) // insurance that method stays alive
+ ,
+ method(method) {}
std::string kind() const override {
return "method";
}
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override {
- return std::make_shared<SimpleValue>(caller.emit_call_to(loc, method, inputs, attributes));
+ return std::make_shared<SimpleValue>(
+ caller.emit_call_to(loc, method, inputs, attributes));
}
private:
std::shared_ptr<Module> module;
Method& method;
-
};
struct TORCH_API PrintValue : public SugaredValue {
return "print";
}
std::shared_ptr<SugaredValue> call(
- const SourceRange& loc,
- Method & m,
- at::ArrayRef<NamedValue> inputs,
- at::ArrayRef<NamedValue> attributes,
- size_t n_binders) override;
+ const SourceRange& loc,
+ Method& m,
+ at::ArrayRef<NamedValue> inputs,
+ at::ArrayRef<NamedValue> attributes,
+ size_t n_binders) override;
};
// expressions like int(x)
// is a noop when the input is a subtype of 'type'
struct TORCH_API CastValue : public BuiltinFunction {
CastValue(TypePtr type, c10::Symbol method)
- : BuiltinFunction(method, c10::nullopt)
- , type_(std::move(type)) {}
+ : BuiltinFunction(method, c10::nullopt), type_(std::move(type)) {}
std::shared_ptr<SugaredValue> call(
- const SourceRange& loc,
- Method & m,
- at::ArrayRef<NamedValue> inputs,
- at::ArrayRef<NamedValue> attributes,
- size_t n_binders) override {
- if(inputs.size() == 1 && attributes.size() == 0) {
- auto v = inputs[0].value(*m.graph());
- if (v->type()->isSubtypeOf(type_)) {
- return std::make_shared<SimpleValue>(v);
- }
+ const SourceRange& loc,
+ Method& m,
+ at::ArrayRef<NamedValue> inputs,
+ at::ArrayRef<NamedValue> attributes,
+ size_t n_binders) override {
+ if (inputs.size() == 1 && attributes.size() == 0) {
+ auto v = inputs[0].value(*m.graph());
+ if (v->type()->isSubtypeOf(type_)) {
+ return std::make_shared<SimpleValue>(v);
}
- return BuiltinFunction::call(loc, m , inputs, attributes, n_binders);
+ }
+ return BuiltinFunction::call(loc, m, inputs, attributes, n_binders);
}
-private:
+
+ private:
TypePtr type_;
};
-
// These SugaredValues have special handling in the compiler because they
// change the normal evalution order of the expression they participate in.
// They are exposed here so that the python frontend can inject them
}
};
-static inline std::vector<Value*> toValues(Graph& g, at::ArrayRef<NamedValue> nvs) {
- return fmap(nvs, [&](const NamedValue& v) {
- return v.value(g);
- });
+static inline std::vector<Value*> toValues(
+ Graph& g,
+ at::ArrayRef<NamedValue> nvs) {
+ return fmap(nvs, [&](const NamedValue& v) { return v.value(g); });
}
-}
+} // namespace script
} // namespace jit
} // namespace torch
#pragma once
+#include <functional>
#include <memory>
#include <vector>
-#include <functional>
#include <torch/csrc/jit/script/lexer.h>
void matchNumSubtrees(int k, size_t expected_subtrees) {
return matchNumSubtreesD(k, "unknown", 0, expected_subtrees, false);
}
- void matchNumSubtreesD(int k, const char* filename, int lineno,
- size_t expected_subtrees, bool allow_more) {
+ void matchNumSubtreesD(
+ int k,
+ const char* filename,
+ int lineno,
+ size_t expected_subtrees,
+ bool allow_more) {
if (kind() != k) {
std::stringstream ss;
ss << filename << ":" << lineno << ": expecting kind '" << kindToString(k)
if (trees().size() < expected_subtrees ||
(!allow_more && trees().size() != expected_subtrees)) {
std::stringstream ss;
- ss << filename << ":" << lineno << ": expected at least " << expected_subtrees
- << " subtrees, but found only " << trees().size() << "\n";
+ ss << filename << ":" << lineno << ": expected at least "
+ << expected_subtrees << " subtrees, but found only " << trees().size()
+ << "\n";
range().highlight(ss);
throw std::runtime_error(ss.str());
}
}
struct Compound : public Tree {
- Compound(int kind, SourceRange range) : Tree(kind), range_(std::move(range)) {}
+ Compound(int kind, SourceRange range)
+ : Tree(kind), range_(std::move(range)) {}
Compound(int kind, const SourceRange& range_, TreeList&& trees_)
: Tree(kind),
range_(mergeRanges(range_, trees_)),
const TreeList& trees() const override {
return trees_;
}
- static TreeRef
- create(int kind, const SourceRange& range_, TreeList&& trees_) {
+ static TreeRef create(
+ int kind,
+ const SourceRange& range_,
+ TreeList&& trees_) {
return std::make_shared<Compound>(kind, range_, std::move(trees_));
}
bool isAtom() const override {
namespace jit {
namespace script {
+// clang-format off
// TreeView provides a statically-typed way to traverse the tree, which should
// be formed according to the grammar below.
//
// | Global(List<Ident> idents) TK_GLOBAL
// -- NB: the only type of Expr's allowed on lhs are Var
// Or a tuple containing Var with an optional terminating Starred
-// | Assign(Expr lhs, Expr rhs) TK_ASSIGN
+// | Assign(Expr lhs, Expr rhs) TK_ASSIGN
// | AugAssign(Expr lhs, AugAssignKind aug_op, Expr rhs) TK_AUG_ASSIGN
// | Return(List<Expr> values) TK_RETURN
// | ExprStmt(List<Expr> expr) TK_EXPR_STMT
// changes to the structure of Ident are always made right here rather
// than both in the parser and in this code.
// XXX: these structs should have no fields to prevent slicing when passing by value
+// clang-format on
struct TreeView {
explicit TreeView(TreeRef tree) : tree_(std::move(tree)) {}
TreeRef tree() const {
return tree_->kind();
}
-protected:
+ protected:
const TreeRef& subtree(size_t i) const {
return tree_->trees().at(i);
}
TreeRef tree_;
};
-template<typename T>
+template <typename T>
struct ListIterator {
ListIterator(TreeList::const_iterator it) : it(it) {}
- bool operator!=(const ListIterator& rhs) const { return it != rhs.it; }
- bool operator==(const ListIterator& rhs) const { return it == rhs.it; }
- T operator*() const { return T(*it); }
- ListIterator& operator+=(std::ptrdiff_t n) { it += n; return *this; }
- ListIterator& operator++() { ++it; return *this; }
- ListIterator& operator--() { --it; return *this; }
-
-private:
+ bool operator!=(const ListIterator& rhs) const {
+ return it != rhs.it;
+ }
+ bool operator==(const ListIterator& rhs) const {
+ return it == rhs.it;
+ }
+ T operator*() const {
+ return T(*it);
+ }
+ ListIterator& operator+=(std::ptrdiff_t n) {
+ it += n;
+ return *this;
+ }
+ ListIterator& operator++() {
+ ++it;
+ return *this;
+ }
+ ListIterator& operator--() {
+ --it;
+ return *this;
+ }
+
+ private:
TreeList::const_iterator it;
};
tree->match(TK_LIST);
// Iterate over list to temporarily instantiate Ts that will check the type
for (const T& elem : *this) {
- (void) elem; //silence unused warning
+ (void)elem; // silence unused warning
}
}
iterator begin() const {
return tree_->map([&](TreeRef v) { return fn(T(v)); });
}
static List create(const SourceRange& range, const std::vector<T>& subtrees) {
- TreeList type_erased_sub {subtrees.begin(), subtrees.end()};
+ TreeList type_erased_sub{subtrees.begin(), subtrees.end()};
return List(Compound::create(TK_LIST, range, std::move(type_erased_sub)));
}
static List unsafeCreate(const SourceRange& range, TreeList&& subtrees) {
case TK_DEF:
return;
default:
- throw ErrorReport(tree) << kindToString(tree->kind()) << " is not a valid Stmt";
+ throw ErrorReport(tree)
+ << kindToString(tree->kind()) << " is not a valid Stmt";
}
}
};
case '|':
return;
default:
- throw ErrorReport(tree) << kindToString(tree->kind()) << " is not a valid Expr";
+ throw ErrorReport(tree)
+ << kindToString(tree->kind()) << " is not a valid Expr";
}
}
};
Expr value() const {
return Expr(subtree(1));
}
- static Attribute create(const SourceRange& range, const Ident& name, const TreeRef& value) {
+ static Attribute create(
+ const SourceRange& range,
+ const Ident& name,
+ const TreeRef& value) {
return Attribute(Compound::create(TK_ATTRIBUTE, range, {name, value}));
}
};
-
struct Param : public TreeView {
explicit Param(const TreeRef& tree) : TreeView(tree) {
tree_->match(TK_PARAM);
}
- static Param create(const SourceRange& range, const Ident& ident, const Expr& type, const Maybe<Expr>& def) {
+ static Param create(
+ const SourceRange& range,
+ const Ident& ident,
+ const Expr& type,
+ const Maybe<Expr>& def) {
return Param(Compound::create(TK_PARAM, range, {ident, type, def}));
}
Ident ident() const {
Maybe<Expr> return_type() const {
return Maybe<Expr>(subtree(1));
}
- static Decl create(const SourceRange& range, const List<Param>& params, const Maybe<Expr>& return_type) {
+ static Decl create(
+ const SourceRange& range,
+ const List<Param>& params,
+ const Maybe<Expr>& return_type) {
return Decl(Compound::create(TK_DECL, range, {params, return_type}));
}
};
const Ident& name,
const Decl& decl,
const List<Stmt>& stmts) {
- return Def(Compound::create(
- TK_DEF, range, {name, decl, stmts}));
+ return Def(Compound::create(TK_DEF, range, {name, decl, stmts}));
}
};
-
////////////////////////////////////////////////////////////////////////////////
// Statements
////////////////////////////////////////////////////////////////////////////////
List<Stmt> falseBranch() const {
return List<Stmt>(subtree(2));
}
- If withNewBranches(const List<Stmt>& true_branch, const List<Stmt>& false_branch) const {
+ If withNewBranches(
+ const List<Stmt>& true_branch,
+ const List<Stmt>& false_branch) const {
return create(range(), cond(), true_branch, false_branch);
}
static If create(
const Expr& cond,
const List<Stmt>& true_branch,
const List<Stmt>& false_branch) {
- return If(Compound::create(TK_IF, range, {cond, true_branch, false_branch}));
+ return If(
+ Compound::create(TK_IF, range, {cond, true_branch, false_branch}));
}
};
List<Stmt> body() const {
return List<Stmt>(subtree(1));
}
- static While create(const SourceRange& range, const Expr& cond, const List<Stmt>& body) {
+ static While create(
+ const SourceRange& range,
+ const Expr& cond,
+ const List<Stmt>& body) {
return While(Compound::create(TK_WHILE, range, {cond, body}));
}
};
}
};
-
struct Assign : public Stmt {
explicit Assign(const TreeRef& tree) : Stmt(tree) {
tree_->match(TK_ASSIGN);
explicit Pass(const TreeRef& tree) : Stmt(tree) {
tree_->match(TK_PASS);
}
- static Pass create(
- const SourceRange& range) {
+ static Pass create(const SourceRange& range) {
return Pass(Compound::create(TK_PASS, range, {}));
}
};
-
struct ExprStmt : public Stmt {
explicit ExprStmt(const TreeRef& tree) : Stmt(tree) {
tree_->match(TK_EXPR_STMT);
}
};
-
////////////////////////////////////////////////////////////////////////////////
// Expressions
////////////////////////////////////////////////////////////////////////////////
case '|':
case TK_FLOOR_DIV:
if (tree->trees().size() != 2)
- throw ErrorReport(tree) << "BinOp expected 2 subtrees, found " << tree->trees().size();
+ throw ErrorReport(tree)
+ << "BinOp expected 2 subtrees, found " << tree->trees().size();
return;
default:
- throw ErrorReport(tree) << kindToString(tree->kind()) << " is not a valid BinOp";
+ throw ErrorReport(tree)
+ << kindToString(tree->kind()) << " is not a valid BinOp";
}
}
Expr lhs() const {
Expr rhs() const {
return Expr(subtree(1));
}
- static BinOp create(const SourceRange& range, int kind, const Expr& lhs, const Expr& rhs) {
+ static BinOp create(
+ const SourceRange& range,
+ int kind,
+ const Expr& lhs,
+ const Expr& rhs) {
return BinOp(Compound::create(kind, range, {lhs, rhs}));
}
};
case TK_UNARY_MINUS:
case TK_NOT:
if (tree->trees().size() != 1)
- throw ErrorReport(tree) << "UnaryOp expected 1 subtree, found " << tree->trees().size();
+ throw ErrorReport(tree)
+ << "UnaryOp expected 1 subtree, found " << tree->trees().size();
return;
default:
- throw ErrorReport(tree) << kindToString(tree->kind()) << " is not a valid UnaryOp";
+ throw ErrorReport(tree)
+ << kindToString(tree->kind()) << " is not a valid UnaryOp";
}
}
static UnaryOp create(const SourceRange& range, int kind, const Expr& expr) {
return std::stoll(subtree(0)->stringValue());
}
double asFloatingPoint() const {
- return SharedParserData::strtod_c(subtree(0)->stringValue().c_str(), nullptr);
+ return SharedParserData::strtod_c(
+ subtree(0)->stringValue().c_str(), nullptr);
}
const std::string& text() const {
return subtree(0)->stringValue();
const std::string& text() const {
return subtree(0)->stringValue();
}
- static StringLiteral create(const SourceRange& range, const std::string& value) {
- return StringLiteral(Compound::create(TK_STRINGLITERAL, range, {String::create(value)}));
+ static StringLiteral create(
+ const SourceRange& range,
+ const std::string& value) {
+ return StringLiteral(
+ Compound::create(TK_STRINGLITERAL, range, {String::create(value)}));
}
};
const Expr& callee,
const List<Expr>& inputs,
const List<Attribute>& attributes) {
- return Apply(Compound::create(TK_APPLY, range, {callee, inputs, attributes}));
+ return Apply(
+ Compound::create(TK_APPLY, range, {callee, inputs, attributes}));
}
};
Ident selector() const {
return Ident(subtree(1));
}
- static Select create(const SourceRange& range, const Expr& value, const Ident& selector) {
+ static Select create(
+ const SourceRange& range,
+ const Expr& value,
+ const Ident& selector) {
return Select(Compound::create('.', range, {value, selector}));
}
};
const Maybe<Expr>& end) {
return SliceExpr(Compound::create(TK_SLICE_EXPR, range, {start, end}));
}
-private:
+
+ private:
Expr createInt(int value) const {
return Expr(Const::create(range(), std::to_string(value)));
}
const SourceRange& range,
const Expr& value,
const List<Expr>& subscript_exprs) {
- return Subscript(Compound::create(TK_SUBSCRIPT, range, {value, subscript_exprs}));
+ return Subscript(
+ Compound::create(TK_SUBSCRIPT, range, {value, subscript_exprs}));
}
};
Expr false_expr() const {
return Expr(subtree(2));
}
- static TernaryIf create(const SourceRange& range,
- const Expr& cond,
- const Expr& true_expr,
- const Expr& false_expr) {
- return TernaryIf(Compound::create(TK_IF_EXPR, range, {cond, true_expr, false_expr}));
+ static TernaryIf create(
+ const SourceRange& range,
+ const Expr& cond,
+ const Expr& true_expr,
+ const Expr& false_expr) {
+ return TernaryIf(
+ Compound::create(TK_IF_EXPR, range, {cond, true_expr, false_expr}));
};
};
-
struct ListLiteral : public Expr {
explicit ListLiteral(const TreeRef& tree) : Expr(tree) {
tree_->match(TK_LIST_LITERAL);
List<Expr> inputs() const {
return subtree(0);
}
- static ListLiteral create(const SourceRange& range, const List<Expr>& inputs) {
+ static ListLiteral create(
+ const SourceRange& range,
+ const List<Expr>& inputs) {
return ListLiteral(Compound::create(TK_LIST_LITERAL, range, {inputs}));
}
};
List<Expr> inputs() const {
return subtree(0);
}
- static TupleLiteral create(const SourceRange& range, const List<Expr>& inputs) {
+ static TupleLiteral create(
+ const SourceRange& range,
+ const List<Expr>& inputs) {
return TupleLiteral(Compound::create(TK_TUPLE_LITERAL, range, {inputs}));
}
};
namespace std {
-template<typename T>
+template <typename T>
struct iterator_traits<torch::jit::script::ListIterator<T>>
- : std::iterator_traits<torch::jit::script::TreeList::const_iterator> {};
+ : std::iterator_traits<torch::jit::script::TreeList::const_iterator> {};
} // namespace std
-#include <torch/csrc/jit/script/type_parser.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/script/tree_views.h>
+#include <torch/csrc/jit/script/type_parser.h>
namespace torch {
namespace jit {
namespace script {
-const std::unordered_map<std::string, TypePtr> &ident_to_type_lut() {
+const std::unordered_map<std::string, TypePtr>& ident_to_type_lut() {
static std::unordered_map<std::string, TypePtr> map = {
- {"Tensor", DynamicType::get()},
- {"int", IntType::get()},
- {"float", FloatType::get()},
- {"bool", BoolType::get()},
- {"str", StringType::get()},
- {"Device", DeviceObjType::get()},
- // technically this is not a python type but we need it when
- // parsing serialized methods that use implicit converions to Scalar
- {"number", NumberType::get()},
- {"None", NoneType::get()},
+ {"Tensor", DynamicType::get()},
+ {"int", IntType::get()},
+ {"float", FloatType::get()},
+ {"bool", BoolType::get()},
+ {"str", StringType::get()},
+ {"Device", DeviceObjType::get()},
+ // technically this is not a python type but we need it when
+ // parsing serialized methods that use implicit converions to Scalar
+ {"number", NumberType::get()},
+ {"None", NoneType::get()},
};
return map;
}
-const std::unordered_map<std::string, std::function<TypePtr(Subscript)>> &subscript_to_type_fns() {
- static std::unordered_map<std::string, std::function<TypePtr(Subscript)>> map = {
- {"Tuple", [](Subscript subscript) -> TypePtr {
- std::vector<TypePtr> subscript_expr_types;
- for (auto expr : subscript.subscript_exprs()) {
- subscript_expr_types.push_back(parseTypeFromExpr(expr));
- }
- return TupleType::create(subscript_expr_types);
- }},
- {"List", [](Subscript subscript) -> TypePtr {
- if (subscript.subscript_exprs().size() != 1) {
- throw ErrorReport(subscript) << " expected exactly one element type but found " << subscript.subscript_exprs().size();
- }
- auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin());
- return ListType::create(elem_type);
- }},
- {"Optional", [](Subscript subscript) -> TypePtr {
- if (subscript.subscript_exprs().size() != 1) {
- throw ErrorReport(subscript) << " expected exactly one element type but found " << subscript.subscript_exprs().size();
- }
- auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin());
- return OptionalType::create(elem_type);
- }},
- {"Future", [](Subscript subscript) -> TypePtr {
- if (subscript.subscript_exprs().size() != 1) {
- throw ErrorReport(subscript) << " expected exactly one element type but found " << subscript.subscript_exprs().size();
- }
- auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin());
- return FutureType::create(elem_type);
- }},
- };
+const std::unordered_map<std::string, std::function<TypePtr(Subscript)>>&
+subscript_to_type_fns() {
+ static std::unordered_map<std::string, std::function<TypePtr(Subscript)>>
+ map = {
+ {"Tuple",
+ [](Subscript subscript) -> TypePtr {
+ std::vector<TypePtr> subscript_expr_types;
+ for (auto expr : subscript.subscript_exprs()) {
+ subscript_expr_types.push_back(parseTypeFromExpr(expr));
+ }
+ return TupleType::create(subscript_expr_types);
+ }},
+ {"List",
+ [](Subscript subscript) -> TypePtr {
+ if (subscript.subscript_exprs().size() != 1) {
+ throw ErrorReport(subscript)
+ << " expected exactly one element type but found "
+ << subscript.subscript_exprs().size();
+ }
+ auto elem_type =
+ parseTypeFromExpr(*subscript.subscript_exprs().begin());
+ return ListType::create(elem_type);
+ }},
+ {"Optional",
+ [](Subscript subscript) -> TypePtr {
+ if (subscript.subscript_exprs().size() != 1) {
+ throw ErrorReport(subscript)
+ << " expected exactly one element type but found "
+ << subscript.subscript_exprs().size();
+ }
+ auto elem_type =
+ parseTypeFromExpr(*subscript.subscript_exprs().begin());
+ return OptionalType::create(elem_type);
+ }},
+ {"Future",
+ [](Subscript subscript) -> TypePtr {
+ if (subscript.subscript_exprs().size() != 1) {
+ throw ErrorReport(subscript)
+ << " expected exactly one element type but found "
+ << subscript.subscript_exprs().size();
+ }
+ auto elem_type =
+ parseTypeFromExpr(*subscript.subscript_exprs().begin());
+ return FutureType::create(elem_type);
+ }},
+ };
return map;
}
return expr.kind() == TK_VAR && Var(expr).name().name() == "torch";
}
-
-
-c10::optional<std::pair<TypePtr, int32_t>> parseBroadcastList(const Expr& expr) {
+c10::optional<std::pair<TypePtr, int32_t>> parseBroadcastList(
+ const Expr& expr) {
if (expr.kind() != TK_SUBSCRIPT)
return c10::nullopt;
auto subscript = Subscript(expr);
auto subscript_exprs = subscript.subscript_exprs();
// handle the case where the BroadcastingList is wrapped in a Optional type
- if(var.name().name() == "Optional") {
+ if (var.name().name() == "Optional") {
auto broadcast_list = parseBroadcastList(subscript_exprs[0]);
if (broadcast_list) {
TypePtr opt_type = OptionalType::create(broadcast_list->first);
if (subscript_exprs.size() != 1)
throw ErrorReport(subscript.subscript_exprs().range())
- << "BroadcastingList/Optional[BroadcastingList] must be subscripted with a type";
+ << "BroadcastingList/Optional[BroadcastingList] must be subscripted with a type";
auto typ = subscript_exprs[0];
auto len = var.name().name().substr(strlen("BroadcastingList"));
if (typ.kind() != TK_VAR)
- throw ErrorReport(subscript.value().range()) << "Subscripted type must be a type identifier";
+ throw ErrorReport(subscript.value().range())
+ << "Subscripted type must be a type identifier";
auto value_name = Var(typ).name().name();
if (value_name != "float" && value_name != "int")
- throw ErrorReport(subscript.value().range()) << "Broadcastable lists only supported for int or float";
+ throw ErrorReport(subscript.value().range())
+ << "Broadcastable lists only supported for int or float";
auto elem_ptr = ident_to_type_lut().find(value_name);
JIT_ASSERT(elem_ptr != ident_to_type_lut().end());
auto subscript = Subscript(expr);
auto value_name = parseBaseTypeName(subscript.value());
if (!value_name) {
- throw ErrorReport(subscript.value().range()) << "Subscripted type must be a type identifier";
+ throw ErrorReport(subscript.value().range())
+ << "Subscripted type must be a type identifier";
}
if (!subscript_to_type_fns().count(*value_name)) {
- throw ErrorReport(subscript.range()) << "Unknown type constructor " << *value_name;
+ throw ErrorReport(subscript.range())
+ << "Unknown type constructor " << *value_name;
}
return subscript_to_type_fns().at(*value_name)(subscript);
} else if (auto name = parseBaseTypeName(expr)) {
}
throw ErrorReport(expr) << "Unknown type name " << *name;
}
- throw ErrorReport(expr.range()) << "Expression of type " << kindToString(expr.kind())
- << " cannot be used in a type expression";
+ throw ErrorReport(expr.range())
+ << "Expression of type " << kindToString(expr.kind())
+ << " cannot be used in a type expression";
}
} // namespace script
} // namespace jit
struct Expr;
TORCH_API c10::optional<std::string> parseBaseTypeName(const Expr& expr);
TORCH_API c10::TypePtr parseTypeFromExpr(const Expr& expr);
-TORCH_API c10::optional<std::pair<c10::TypePtr, int32_t>> parseBroadcastList(const Expr& expr);
-}
+TORCH_API c10::optional<std::pair<c10::TypePtr, int32_t>> parseBroadcastList(
+ const Expr& expr);
+} // namespace script
} // namespace jit
} // namespace torch
#include <stdexcept>
#include <string>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
// SourceLocation represents source code-level debug information for a node.
// It contains information about where a node got generated.
// In the case of tracing this will be a python stack trace.
// by a SourceRange object
struct SourceLocation {
virtual ~SourceLocation() = default;
- virtual void highlight(std::ostream & out) const = 0;
+ virtual void highlight(std::ostream& out) const = 0;
- std::string wrapException(const std::exception & e, const std::string & additional = "") {
+ std::string wrapException(
+ const std::exception& e,
+ const std::string& additional = "") {
std::stringstream msg;
msg << "\n" << e.what() << ":\n";
- if(!additional.empty()) {
+ if (!additional.empty()) {
msg << additional << ":\n";
}
highlight(msg);
return msg.str();
}
- void wrapAndRethrowException(const std::exception & e, const std::string & additional = "") {
+ void wrapAndRethrowException(
+ const std::exception& e,
+ const std::string& additional = "") {
throw std::runtime_error(wrapException(e, additional));
}
-
};
inline std::ostream& operator<<(std::ostream& out, const SourceLocation& sl) {
return out;
}
-
// normally a python stack trace
struct StringSourceLocation : public SourceLocation {
- StringSourceLocation(std::string context)
- : context(std::move(context)) {}
- void highlight(std::ostream & out) const override {
+ StringSourceLocation(std::string context) : context(std::move(context)) {}
+ void highlight(std::ostream& out) const override {
out << context;
}
-private:
+
+ private:
std::string context;
};
-}}
+} // namespace jit
+} // namespace torch
#pragma once
-#include <torch/csrc/jit/source_location.h>
#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/source_location.h>
+#include <algorithm>
+#include <memory>
namespace torch {
namespace jit {
// that
// range.
struct SourceRange : public SourceLocation {
- SourceRange(
- std::shared_ptr<std::string> file_,
- size_t start_,
- size_t end_)
+ SourceRange(std::shared_ptr<std::string> file_, size_t start_, size_t end_)
: file_(std::move(file_)), start_(start_), end_(end_) {}
const std::string text() const {
return file().substr(start(), end() - start());
JIT_ASSERT(begin_line == 0 || str[begin_line - 1] == '\n');
JIT_ASSERT(end_line == str.size() || str[end_line] == '\n');
- size_t begin_highlight = begin_line; // beginning of context, CONTEXT lines before the highlight line
- for(size_t i = 0; begin_highlight > 0; --begin_highlight) {
- if(str[begin_highlight - 1] == '\n')
+ size_t begin_highlight = begin_line; // beginning of context, CONTEXT lines
+ // before the highlight line
+ for (size_t i = 0; begin_highlight > 0; --begin_highlight) {
+ if (str[begin_highlight - 1] == '\n')
++i;
- if(i >= CONTEXT)
+ if (i >= CONTEXT)
break;
}
JIT_ASSERT(begin_highlight == 0 || str[begin_highlight - 1] == '\n');
- size_t end_highlight = end_line; // end of context, CONTEXT lines after the highlight line
- for(size_t i = 0; end_highlight < str.size(); ++end_highlight) {
- if(str[end_highlight] == '\n')
+ size_t end_highlight =
+ end_line; // end of context, CONTEXT lines after the highlight line
+ for (size_t i = 0; end_highlight < str.size(); ++end_highlight) {
+ if (str[end_highlight] == '\n')
++i;
- if(i >= CONTEXT)
+ if (i >= CONTEXT)
break;
}
JIT_ASSERT(end_highlight == str.size() || str[end_highlight] == '\n');
#include <torch/csrc/jit/ivalue.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
using Stack = std::vector<IValue>;
using Operation = std::function<int(Stack&)>;
// the stack and pushes its M inputs onto the stack
// before: <other stack items> I0, I1, ... IN <- stack.back()
// after: <other stack items> O0, O1, ... OM
-// operations are defined this way so that ownership of inputs can be transferred
-// to the operation and it can incrementally drop ownership of tensors
-// when they become unneeded. For large operations, like 'run an entire subgraph',
-// this functionality is very important for minimizing gpu memory usage
-// return value is the relative 'offset' to jump to for the next operation:
-// pc += 1 + offset
+// operations are defined this way so that ownership of inputs can be
+// transferred to the operation and it can incrementally drop ownership of
+// tensors when they become unneeded. For large operations, like 'run an entire
+// subgraph', this functionality is very important for minimizing gpu memory
+// usage return value is the relative 'offset' to jump to for the next
+// operation:
+// pc += 1 + offset
// so a return value of 0 goes to the next instruction
// treat the last N elements of the stack as a list, looking up
// element i
-static inline IValue & peek(Stack & stack, size_t i, size_t N) {
+static inline IValue& peek(Stack& stack, size_t i, size_t N) {
return *(stack.end() - N + i);
}
// treat the last N elements of the stack as a list, looking up the
// slice starting at index i and having length len
-static inline at::ArrayRef<IValue> peekSlice(const Stack & stack, size_t i, size_t len, size_t N) {
+static inline at::ArrayRef<IValue> peekSlice(
+ const Stack& stack,
+ size_t i,
+ size_t len,
+ size_t N) {
return at::ArrayRef<IValue>(stack).slice(stack.size() - N + i, len);
}
-static inline at::ArrayRef<IValue> last(const Stack & stack, size_t N) {
+static inline at::ArrayRef<IValue> last(const Stack& stack, size_t N) {
return peekSlice(stack, 0, N, N);
}
-static inline void drop(Stack & stack, size_t n) {
+static inline void drop(Stack& stack, size_t n) {
stack.erase(stack.end() - n, stack.end());
}
-static inline IValue pop(Stack & stack) {
+static inline IValue pop(Stack& stack) {
auto r = std::move(stack.back());
stack.pop_back();
return r;
// equivalent to:
// b = pop(stack).toTensor();
// a = pop(stack).toInt();
-template<typename... Types>
+template <typename... Types>
static inline void pop(Stack& stack, Types&... args) {
size_t i = 0;
constexpr size_t N = sizeof...(args);
int result[N] = {
- (args = std::move(peek(stack,i++, N)).template to<Types>(),0)...
- };
- (void) result;
+ (args = std::move(peek(stack, i++, N)).template to<Types>(), 0)...};
+ (void)result;
drop(stack, N);
}
-template<typename... Types>
+template <typename... Types>
static inline void push(Stack& stack, Types&&... args) {
constexpr size_t N = sizeof...(args);
- int result[N] = {
- (stack.emplace_back(std::forward<Types>(args)), 0)...
- };
- (void) result;
+ int result[N] = {(stack.emplace_back(std::forward<Types>(args)), 0)...};
+ (void)result;
}
// The packer here is carefully written not to make any unnecessary
// copies.
// pack takes the return values of aten functions pushes them onto the stack
-template<typename T>
-inline void pack(Stack & stack, T&& v) {
+template <typename T>
+inline void pack(Stack& stack, T&& v) {
stack.emplace_back(std::forward<T>(v));
}
-template<std::size_t remaining, typename... Args>
-struct TuplePacker
-{
+template <std::size_t remaining, typename... Args>
+struct TuplePacker {
// NB: *Not* a universal reference.
- static void execute(Stack & stack, std::tuple<Args...> && t)
- {
+ static void execute(Stack& stack, std::tuple<Args...>&& t) {
// NB: The move here does not "destroy" the entire tuple, that is
// not what std::move does; only the particular tuple index
// processed here gets stolen.
}
};
-template<typename... Args>
-struct TuplePacker<0, Args...>
-{
- static void execute(Stack & stack, std::tuple<Args...> && t) {};
+template <typename... Args>
+struct TuplePacker<0, Args...> {
+ static void execute(Stack& stack, std::tuple<Args...>&& t){};
};
-template<typename... Args>
-inline void pack(Stack & stack, std::tuple<Args...> && t) {
+template <typename... Args>
+inline void pack(Stack& stack, std::tuple<Args...>&& t) {
TuplePacker<sizeof...(Args), Args...>::execute(stack, std::move(t));
}
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/symbolic_script.h>
-
-
-namespace torch { namespace jit {
- namespace {
- std::mutex lock;
- const std::vector<std::string> functions = {
- R"(
+namespace torch {
+namespace jit {
+namespace {
+std::mutex lock;
+const std::vector<std::string> functions = {
+ R"(
def mul(self, other):
def backward(grad_output):
grad_self = (grad_output * other).sum_to_size(self.size())
return grad_self, None
return torch.adaptive_avg_pool2d(self, output_size), backward
- )"
- };
- std::unordered_map<std::string, GradientPair> schema_to_graphs;
-
- // This map is a workaround to cache compiled gradient_pairs. Ideally this graph
- // should be compiled only once and saved in Operator structure.
- // This should be done along with merging into native_functions.yaml.
- std::unordered_map<const FunctionSchema*, GradientPair> cached_gradient_pairs;
- } // anonymous namespace
-
- std::pair<std::shared_ptr<Graph>, Value*> extractClosure(Value* closure) {
- AT_CHECK(closure->node()->kind() == prim::TupleConstruct, "closure must be a literal tuple construct");
- Value* fn = closure->node()->inputs().at(0);
- Value* context = closure->node()->inputs().at(1);
-
- AT_CHECK(fn->node()->kind() == prim::Function, "closure tuple must contain a prim::Function");
- return std::make_pair(fn->node()->g(attr::Subgraph), context);
- }
-
- Argument originalReturnType(const TupleTypePtr& tup) {
- AT_CHECK(tup->elements().size() > 1);
- if(tup->elements().size() == 2)
- return Argument("", tup->elements().at(0));
- std::vector<TypePtr> types = tup->elements().vec();
- types.pop_back();
- return Argument("", TupleType::create(std::move(types)));
- }
+ )"};
+std::unordered_map<std::string, GradientPair> schema_to_graphs;
+
+// This map is a workaround to cache compiled gradient_pairs. Ideally this graph
+// should be compiled only once and saved in Operator structure.
+// This should be done along with merging into native_functions.yaml.
+std::unordered_map<const FunctionSchema*, GradientPair> cached_gradient_pairs;
+} // anonymous namespace
+
+std::pair<std::shared_ptr<Graph>, Value*> extractClosure(Value* closure) {
+ AT_CHECK(
+ closure->node()->kind() == prim::TupleConstruct,
+ "closure must be a literal tuple construct");
+ Value* fn = closure->node()->inputs().at(0);
+ Value* context = closure->node()->inputs().at(1);
+
+ AT_CHECK(
+ fn->node()->kind() == prim::Function,
+ "closure tuple must contain a prim::Function");
+ return std::make_pair(fn->node()->g(attr::Subgraph), context);
+}
+
+Argument originalReturnType(const TupleTypePtr& tup) {
+ AT_CHECK(tup->elements().size() > 1);
+ if (tup->elements().size() == 2)
+ return Argument("", tup->elements().at(0));
+ std::vector<TypePtr> types = tup->elements().vec();
+ types.pop_back();
+ return Argument("", TupleType::create(std::move(types)));
+}
+
+void loadModule(const std::shared_ptr<script::Module>& module) {
+ for (const auto& method_ : module->get_methods()) {
+ const auto& method = method_.value();
+ GradientPair pair;
+ pair.forward = method->graph();
+
+ // lookup the backward function
+ Node* forward_tuple = pair.forward->outputs().at(0)->node();
+
+ if (forward_tuple->kind() != prim::TupleConstruct) {
+ throw script::ErrorReport(forward_tuple->getSourceLocation())
+ << "gradient must return literal a tuple";
+ }
- void loadModule(const std::shared_ptr<script::Module>& module) {
- for(const auto& method_ : module->get_methods()) {
- const auto& method = method_.value();
- GradientPair pair;
- pair.forward = method->graph();
-
- // lookup the backward function
- Node* forward_tuple = pair.forward->outputs().at(0)->node();
-
- if (forward_tuple->kind() != prim::TupleConstruct) {
- throw script::ErrorReport(forward_tuple->getSourceLocation()) << "gradient must return literal a tuple";
- }
-
- Value* context;
- std::tie(pair.backward, context) = extractClosure(forward_tuple->inputs().back());
-
- // do surgery on the forward function to remove the closure tuple and replace it with the
- // context variable:
- // backward = (<lambda>, context_tuple)
- // return original, backward
- // -----
- // return original, context_tuple
- std::vector<Value*> new_inputs = forward_tuple->inputs().vec();
- new_inputs.back() = context;
- Value* new_tuple = pair.forward->appendNode(pair.forward->createTuple(new_inputs))->output();
- pair.forward->eraseOutput(0);
- pair.forward->registerOutput(new_tuple);
- forward_tuple->destroy();
-
- // derive schema from original function's schema:
- const FunctionSchema& loaded_schema = method->getSchema();
- FunctionSchema actual_schema(Symbol::aten(loaded_schema.name()),
+ Value* context;
+ std::tie(pair.backward, context) =
+ extractClosure(forward_tuple->inputs().back());
+
+ // do surgery on the forward function to remove the closure tuple and
+ // replace it with the context variable:
+ // backward = (<lambda>, context_tuple)
+ // return original, backward
+ // -----
+ // return original, context_tuple
+ std::vector<Value*> new_inputs = forward_tuple->inputs().vec();
+ new_inputs.back() = context;
+ Value* new_tuple =
+ pair.forward->appendNode(pair.forward->createTuple(new_inputs))
+ ->output();
+ pair.forward->eraseOutput(0);
+ pair.forward->registerOutput(new_tuple);
+ forward_tuple->destroy();
+
+ // derive schema from original function's schema:
+ const FunctionSchema& loaded_schema = method->getSchema();
+ FunctionSchema actual_schema(
+ Symbol::aten(loaded_schema.name()),
loaded_schema.arguments(),
- {originalReturnType(new_tuple->type()->expect<TupleType>())}
- );
- std::string key = canonicalSchemaString(actual_schema);
- schema_to_graphs[key] = std::move(pair);
- }
+ {originalReturnType(new_tuple->type()->expect<TupleType>())});
+ std::string key = canonicalSchemaString(actual_schema);
+ schema_to_graphs[key] = std::move(pair);
}
+}
- void loadFunctions() {
- for(const std::string& str : functions) {
- auto cu = std::make_shared<script::Module>();
- script::defineMethodsInModule(cu, str, script::nativeResolver, nullptr);
- loadModule(cu);
- }
+void loadFunctions() {
+ for (const std::string& str : functions) {
+ auto cu = std::make_shared<script::Module>();
+ script::defineMethodsInModule(cu, str, script::nativeResolver, nullptr);
+ loadModule(cu);
}
+}
- c10::optional<GradientPair> gradientInfoForSchema(const FunctionSchema& schema) {
- std::lock_guard<std::mutex> guard(lock);
- if (schema_to_graphs.size() == 0) {
- loadFunctions();
- }
- auto cache_it = cached_gradient_pairs.find(&schema);
- if (cache_it != cached_gradient_pairs.end()) {
- return cache_it->second;
- } else {
- auto schema_str = canonicalSchemaString(schema);
- auto sym_script_it = schema_to_graphs.find(schema_str);
- if (sym_script_it != schema_to_graphs.end()) {
- cached_gradient_pairs.emplace_hint(cache_it, &schema, sym_script_it->second);
- return sym_script_it->second;
- }
+c10::optional<GradientPair> gradientInfoForSchema(
+ const FunctionSchema& schema) {
+ std::lock_guard<std::mutex> guard(lock);
+ if (schema_to_graphs.size() == 0) {
+ loadFunctions();
+ }
+ auto cache_it = cached_gradient_pairs.find(&schema);
+ if (cache_it != cached_gradient_pairs.end()) {
+ return cache_it->second;
+ } else {
+ auto schema_str = canonicalSchemaString(schema);
+ auto sym_script_it = schema_to_graphs.find(schema_str);
+ if (sym_script_it != schema_to_graphs.end()) {
+ cached_gradient_pairs.emplace_hint(
+ cache_it, &schema, sym_script_it->second);
+ return sym_script_it->second;
}
- return c10::nullopt;
}
+ return c10::nullopt;
+}
- bool hasGradientInfoForSchema(const FunctionSchema& schema) {
- return gradientInfoForSchema(schema).has_value();
- }
+bool hasGradientInfoForSchema(const FunctionSchema& schema) {
+ return gradientInfoForSchema(schema).has_value();
+}
-}}
+} // namespace jit
+} // namespace torch
#pragma once
-// This file is temporary until native_functions.yaml and derivatives.yaml are merged.
-// Ideally this should all go into native_functions.yaml
+// This file is temporary until native_functions.yaml and derivatives.yaml are
+// merged. Ideally this should all go into native_functions.yaml
#include <c10/util/Optional.h>
+#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/script/compiler.h>
#include <torch/csrc/jit/script/module.h>
-#include <torch/csrc/jit/operator.h>
-namespace torch { namespace jit {
- struct GradientPair {
- std::shared_ptr<Graph> forward;
- std::shared_ptr<Graph> backward;
- };
+namespace torch {
+namespace jit {
+struct GradientPair {
+ std::shared_ptr<Graph> forward;
+ std::shared_ptr<Graph> backward;
+};
- TORCH_API c10::optional<GradientPair> gradientInfoForSchema(const FunctionSchema& schema);
- TORCH_API bool hasGradientInfoForSchema(const FunctionSchema& schema);
-}}
+TORCH_API c10::optional<GradientPair> gradientInfoForSchema(
+ const FunctionSchema& schema);
+TORCH_API bool hasGradientInfoForSchema(const FunctionSchema& schema);
+} // namespace jit
+} // namespace torch
#pragma once
-#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/constants.h>
+#include <torch/csrc/jit/ir.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
struct SymbolicVariable {
SymbolicVariable() : v(nullptr) {}
- /* implicit */ SymbolicVariable(Value * v) : v(v) {}
+ /* implicit */ SymbolicVariable(Value* v) : v(v) {}
// we allow implicit conversions to/from Value since
// this type truly just provides more methods for value
operator Value*() const {
return v;
}
- static SymbolicVariable asNewInput(Graph & g, std::string name = "") {
+ static SymbolicVariable asNewInput(Graph& g, std::string name = "") {
return g.addInput(std::move(name));
}
- static SymbolicVariable asNewInput(Graph & g, TypePtr type) {
+ static SymbolicVariable asNewInput(Graph& g, TypePtr type) {
return g.addInput()->setType(std::move(type));
}
const std::vector<int64_t>& sizes() const {
void addAsOutput() const {
v->owningGraph()->registerOutput(v);
}
- static std::vector<SymbolicVariable> create(Symbol kind, ArrayRef<SymbolicVariable> inputs,
- int num_outputs = 1,
- Node** created_node = nullptr,
- Graph * g = nullptr) {
- if(g == nullptr) {
- g = inputs.at(0).value()->owningGraph();
- }
- Node* n = g->insertNode(g->create(kind, num_outputs));
- size_t max_depth = 0;
- ScopePtr s;
- for(auto n : inputs) {
- size_t d = n.value()->node()->scope()->getDepth();
- if(d > max_depth) {
- max_depth = d;
- s = n.value()->node()->scope();
- }
+ static std::vector<SymbolicVariable> create(
+ Symbol kind,
+ ArrayRef<SymbolicVariable> inputs,
+ int num_outputs = 1,
+ Node** created_node = nullptr,
+ Graph* g = nullptr) {
+ if (g == nullptr) {
+ g = inputs.at(0).value()->owningGraph();
+ }
+ Node* n = g->insertNode(g->create(kind, num_outputs));
+ size_t max_depth = 0;
+ ScopePtr s;
+ for (auto n : inputs) {
+ size_t d = n.value()->node()->scope()->getDepth();
+ if (d > max_depth) {
+ max_depth = d;
+ s = n.value()->node()->scope();
}
- n->setScope(s);
+ }
+ n->setScope(s);
- for(auto i : inputs) {
- n->addInput(i.value());
- }
- if(created_node) {
- *created_node = n;
- }
- std::vector<SymbolicVariable> out;
- for(auto v : n->outputs()) {
- out.emplace_back(v);
- }
- return out;
+ for (auto i : inputs) {
+ n->addInput(i.value());
+ }
+ if (created_node) {
+ *created_node = n;
+ }
+ std::vector<SymbolicVariable> out;
+ for (auto v : n->outputs()) {
+ out.emplace_back(v);
+ }
+ return out;
}
static bool isConstInt(at::Scalar s, int32_t i) {
// int32_t is safely convertible to both double and int64_t
- if(s.isFloatingPoint()) {
- return (double) i == s.toDouble();
+ if (s.isFloatingPoint()) {
+ return (double)i == s.toDouble();
} else {
- return (int64_t) i == s.toLong();
+ return (int64_t)i == s.toLong();
}
}
SymbolicVariable operator*(const SymbolicVariable rhs) const {
return (*this) * insertConstant(rhs);
}
SymbolicVariable operator>(at::Scalar rhs) const {
- return create(aten::gt, {*this, insertConstant(rhs)})[0].typeLikeWithScalarType(*this, at::kByte);
+ return create(aten::gt, {*this, insertConstant(rhs)})[0]
+ .typeLikeWithScalarType(*this, at::kByte);
}
SymbolicVariable operator>(const SymbolicVariable rhs) const {
- return create(aten::gt, {*this, rhs})[0].typeLikeWithScalarType(*this, at::kByte);
+ return create(aten::gt, {*this, rhs})[0].typeLikeWithScalarType(
+ *this, at::kByte);
}
SymbolicVariable operator<(at::Scalar rhs) const {
- return create(aten::lt, {*this, insertConstant(rhs)})[0].typeLikeWithScalarType(*this, at::kByte);
+ return create(aten::lt, {*this, insertConstant(rhs)})[0]
+ .typeLikeWithScalarType(*this, at::kByte);
}
SymbolicVariable operator<(const SymbolicVariable rhs) const {
- return create(aten::lt, {*this, rhs})[0].typeLikeWithScalarType(*this, at::kByte);
+ return create(aten::lt, {*this, rhs})[0].typeLikeWithScalarType(
+ *this, at::kByte);
}
SymbolicVariable operator>=(at::Scalar rhs) const {
- return create(aten::ge, {*this, insertConstant(rhs)})[0].typeLikeWithScalarType(*this, at::kByte);
+ return create(aten::ge, {*this, insertConstant(rhs)})[0]
+ .typeLikeWithScalarType(*this, at::kByte);
}
SymbolicVariable operator>=(const SymbolicVariable rhs) const {
- return create(aten::ge, {*this, rhs})[0].typeLikeWithScalarType(*this, at::kByte);
+ return create(aten::ge, {*this, rhs})[0].typeLikeWithScalarType(
+ *this, at::kByte);
}
SymbolicVariable operator<=(at::Scalar rhs) const {
- return create(aten::le, {*this, insertConstant(rhs)})[0].typeLikeWithScalarType(*this, at::kByte);
+ return create(aten::le, {*this, insertConstant(rhs)})[0]
+ .typeLikeWithScalarType(*this, at::kByte);
}
SymbolicVariable operator<=(const SymbolicVariable rhs) const {
- return create(aten::le, {*this, rhs})[0].typeLikeWithScalarType(*this, at::kByte);
+ return create(aten::le, {*this, rhs})[0].typeLikeWithScalarType(
+ *this, at::kByte);
}
SymbolicVariable operator==(at::Scalar rhs) const {
- return create(aten::eq, {*this, insertConstant(rhs)})[0].typeLikeWithScalarType(*this, at::kByte);
+ return create(aten::eq, {*this, insertConstant(rhs)})[0]
+ .typeLikeWithScalarType(*this, at::kByte);
}
SymbolicVariable operator!=(at::Scalar rhs) const {
- return create(aten::ne, {*this, insertConstant(rhs)})[0].typeLikeWithScalarType(*this, at::kByte);
+ return create(aten::ne, {*this, insertConstant(rhs)})[0]
+ .typeLikeWithScalarType(*this, at::kByte);
}
SymbolicVariable operator+(const SymbolicVariable rhs) const {
- return create(aten::add, {*this, rhs, insertConstant(1)})[0].typeLike(*this);
+ return create(aten::add, {*this, rhs, insertConstant(1)})[0].typeLike(
+ *this);
}
SymbolicVariable operator+(at::Scalar rhs) const {
return (*this) + insertConstant(rhs);
return create(aten::neg, {*this})[0].typeLike(*this);
}
SymbolicVariable operator-(const SymbolicVariable rhs) const {
- return create(aten::sub, {*this, rhs, insertConstant(1)})[0].typeLike(*this);
+ return create(aten::sub, {*this, rhs, insertConstant(1)})[0].typeLike(
+ *this);
}
SymbolicVariable operator/(at::Scalar rhs) const {
return create(aten::div, {*this, insertConstant(rhs)})[0].typeLike(*this);
}
SymbolicVariable operator%(at::Scalar rhs) const {
- return create(aten::remainder, {*this, insertConstant(rhs)})[0].typeLike(*this);
+ return create(aten::remainder, {*this, insertConstant(rhs)})[0].typeLike(
+ *this);
}
Value* size() const {
return v->owningGraph()->insert(aten::size, {v});
}
- SymbolicVariable sumToSize(Value * size) const {
+ SymbolicVariable sumToSize(Value* size) const {
return create(prim::SumToSize, {*this, size})[0];
}
- SymbolicVariable expand(Value * size) const {
+ SymbolicVariable expand(Value* size) const {
return v->owningGraph()->insert(aten::expand, {v, size});
}
SymbolicVariable isnan() const {
- return create(aten::ne, {*this, *this})[0].typeLikeWithScalarType(*this, at::kByte);
+ return create(aten::ne, {*this, *this})[0].typeLikeWithScalarType(
+ *this, at::kByte);
}
SymbolicVariable mm(const SymbolicVariable rhs) const {
return create(t("mm"), {*this, rhs})[0];
return create(aten::tanh, {*this})[0].typeLike(*this);
}
std::vector<SymbolicVariable> chunk(int64_t chunks, int dim) const {
- Node *chunk;
+ Node* chunk;
auto outputs = create(prim::ConstantChunk, {value()}, chunks, &chunk);
chunk->i_(attr::chunks, chunks)->i_(attr::dim, dim);
return outputs;
}
SymbolicVariable type_as(const SymbolicVariable rhs) const {
- return create(aten::type_as, {*this, rhs})[0].typeLikeWithRhsScalarType(*this, rhs);
+ return create(aten::type_as, {*this, rhs})[0].typeLikeWithRhsScalarType(
+ *this, rhs);
}
SymbolicVariable narrow(int dim, int64_t start, int64_t length) const {
- return create(t("narrow"), { *this, insertConstant(dim), insertConstant(start), insertConstant(length) }, 1)[0];
+ return create(
+ t("narrow"),
+ {*this,
+ insertConstant(dim),
+ insertConstant(start),
+ insertConstant(length)},
+ 1)[0];
}
static SymbolicVariable cat(ArrayRef<SymbolicVariable> inputs, Value* dim) {
- Graph *g = dim->owningGraph();
- Value * input_list;
- if (inputs.size() == 1 && inputs[0].value()->type()->isSubtypeOf(ListType::ofTensors())) {
+ Graph* g = dim->owningGraph();
+ Value* input_list;
+ if (inputs.size() == 1 &&
+ inputs[0].value()->type()->isSubtypeOf(ListType::ofTensors())) {
input_list = inputs[0];
} else {
- auto value_inputs = fmap(inputs, [](const SymbolicVariable & v) { return v.value(); });
- input_list = g->insertNode(g->createList(DynamicType::get(), value_inputs))->output();
+ auto value_inputs =
+ fmap(inputs, [](const SymbolicVariable& v) { return v.value(); });
+ input_list =
+ g->insertNode(g->createList(DynamicType::get(), value_inputs))
+ ->output();
}
return create(aten::cat, {input_list, dim})[0];
}
return SymbolicVariable::cat(inputs, inputs[0].insertConstant(dim));
}
static SymbolicVariable stack(ArrayRef<SymbolicVariable> inputs, Value* dim) {
- Graph *g = dim->owningGraph();
- auto value_inputs = fmap(inputs, [](const SymbolicVariable & v) { return v.value(); });
- Value *input_list = g->insertNode(g->createList(DynamicType::get(), value_inputs))->output();
+ Graph* g = dim->owningGraph();
+ auto value_inputs =
+ fmap(inputs, [](const SymbolicVariable& v) { return v.value(); });
+ Value* input_list =
+ g->insertNode(g->createList(DynamicType::get(), value_inputs))
+ ->output();
return create(aten::stack, {input_list, dim})[0];
}
static SymbolicVariable stack(ArrayRef<SymbolicVariable> inputs, int dim) {
JIT_ASSERT(inputs.size() > 0);
return SymbolicVariable::stack(inputs, inputs[0].insertConstant(dim));
}
- static std::vector<SymbolicVariable> broadcast_tensors(ArrayRef<SymbolicVariable> inputs) {
+ static std::vector<SymbolicVariable> broadcast_tensors(
+ ArrayRef<SymbolicVariable> inputs) {
JIT_ASSERT(inputs.size() > 0);
- Graph *g = inputs[0].value()->owningGraph();
- auto value_inputs = fmap(inputs, [](const SymbolicVariable & v) { return v.value(); });
- Value * input_list = g->insertNode(g->createList(DynamicType::get(), value_inputs))->output();
- Value * output_list = g->insert(aten::broadcast_tensors, {input_list});
- Node * unpack = g->insertNode(g->create(prim::ListUnpack, {output_list}, inputs.size()));
+ Graph* g = inputs[0].value()->owningGraph();
+ auto value_inputs =
+ fmap(inputs, [](const SymbolicVariable& v) { return v.value(); });
+ Value* input_list =
+ g->insertNode(g->createList(DynamicType::get(), value_inputs))
+ ->output();
+ Value* output_list = g->insert(aten::broadcast_tensors, {input_list});
+ Node* unpack = g->insertNode(
+ g->create(prim::ListUnpack, {output_list}, inputs.size()));
return fmap<SymbolicVariable>(unpack->outputs());
}
static SymbolicVariable zeros_like(const SymbolicVariable input) {
return create(t("sum"), {*this})[0];
}
SymbolicVariable sum(int dim, bool keepdim) const {
- return create(t("sum"), {*this, insertConstant(at::IntList{dim}), insertConstant(keepdim)})[0];
+ return create(
+ t("sum"),
+ {*this, insertConstant(at::IntList{dim}), insertConstant(keepdim)})[0];
}
SymbolicVariable squeeze(Value* dim) const {
return create(t("squeeze"), {*this, dim})[0];
return reshape(insertConstant(std::move(sizes)));
}
SymbolicVariable addmm(SymbolicVariable mat1, SymbolicVariable mat2) const {
- return create(aten::addmm, {*this, mat1, mat2, insertConstant(1), insertConstant(1)})[0];
+ return create(
+ aten::addmm,
+ {*this, mat1, mat2, insertConstant(1), insertConstant(1)})[0];
}
- Value * value() const {
+ Value* value() const {
return v;
}
-private:
- Value * insertConstant(IValue value) const {
+
+ private:
+ Value* insertConstant(IValue value) const {
return v->owningGraph()->insertConstant(std::move(value));
}
SymbolicVariable typeLike(SymbolicVariable other) const {
SymbolicVariable typeLikeWithScalarType(
SymbolicVariable other,
at::ScalarType type) const {
- if (auto other_type = other.v->type()->cast<CompleteTensorType>()){
+ if (auto other_type = other.v->type()->cast<CompleteTensorType>()) {
auto new_type = other_type->toScalarType(type)->contiguous();
v->setType(new_type);
}
SymbolicVariable rhs) const {
auto other_type = other.v->type()->cast<CompleteTensorType>();
auto rhs_type = rhs.v->type()->cast<CompleteTensorType>();
- if (other_type && rhs_type){
- auto new_type = other_type->toScalarType(rhs_type->scalarType())->contiguous();
+ if (other_type && rhs_type) {
+ auto new_type =
+ other_type->toScalarType(rhs_type->scalarType())->contiguous();
v->setType(new_type);
}
return *this;
}
- static Symbol a(const char * s_) {
+ static Symbol a(const char* s_) {
return Symbol::attr(s_);
}
- static Symbol t(const char * s_) {
+ static Symbol t(const char* s_) {
return Symbol::aten(s_);
}
- Value * v;
+ Value* v;
};
// shorter method so that toVar(v) + toVar(c) is short.
-static inline SymbolicVariable toVar(Value * v) {
+static inline SymbolicVariable toVar(Value* v) {
return {v};
}
-template<typename T, typename = typename std::enable_if<std::is_arithmetic<T>::value>::type>
+template <
+ typename T,
+ typename = typename std::enable_if<std::is_arithmetic<T>::value>::type>
inline SymbolicVariable operator+(T lhs, SymbolicVariable rhs) {
return rhs + at::Scalar(lhs);
}
return (lhs + (-rhs));
}
-}}
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/tracer.h>
-#include <torch/csrc/jit/assertions.h>
-#include <torch/csrc/autograd/variable.h>
-#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/engine.h>
+#include <torch/csrc/autograd/function.h>
+#include <torch/csrc/autograd/variable.h>
+#include <torch/csrc/jit/assertions.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/remove_expands.h>
-#include <string>
-#include <sstream>
#include <memory>
+#include <sstream>
+#include <string>
-namespace torch { namespace jit { namespace tracer {
+namespace torch {
+namespace jit {
+namespace tracer {
////////////////////////////////////////////////////////////////////////////////
// Recording the traces
////////////////////////////////////////////////////////////////////////////////
namespace detail {
-template<typename T>
-void genericAddInput(Node *n, T value) {
- Value *v = n->owningGraph()->insertConstant(value);
+template <typename T>
+void genericAddInput(Node* n, T value) {
+ Value* v = n->owningGraph()->insertConstant(value);
recordSourceLocation(v->node());
n->addInput(v);
}
-template<typename T>
+template <typename T>
void badArgType(const T& v) {
AT_ERROR(
"Found an unsupported argument type in the JIT tracer: ",
} // namespace detail
-void setValueTrace(const IValue &v, Value *value) {
+void setValueTrace(const IValue& v, Value* value) {
if (v.isTensor()) {
auto var = v.toTensor();
JIT_ASSERT(var.defined());
} else if (v.isTensorList()) {
auto& outputs = v.toTensorList()->elements();
auto graph = getTracingState()->graph;
- Node * unpack_node = graph->appendNode(graph->create(prim::ListUnpack, {value}, outputs.size()));
+ Node* unpack_node = graph->appendNode(
+ graph->create(prim::ListUnpack, {value}, outputs.size()));
for (size_t i = 0; i < outputs.size(); ++i) {
setValueTrace(outputs[i], unpack_node->outputs()[i]);
}
} else if (v.isTuple()) {
auto& outputs = v.toTuple()->elements();
auto graph = getTracingState()->graph;
- Node * unpack_node = graph->appendNode(graph->create(prim::TupleUnpack, {value}, outputs.size()));
+ Node* unpack_node = graph->appendNode(
+ graph->create(prim::TupleUnpack, {value}, outputs.size()));
for (size_t i = 0; i < outputs.size(); ++i) {
setValueTrace(outputs[i], unpack_node->outputs()[i]);
}
}
}
-void addInputs(Node *n, const char * name, int64_t value) {
+void addInputs(Node* n, const char* name, int64_t value) {
using ArgumentStash = jit::tracer::ArgumentStash;
if (ArgumentStash::hasValue(name)) {
- Value * v = ArgumentStash::popValue(name);
+ Value* v = ArgumentStash::popValue(name);
n->addInput(v);
} else {
detail::genericAddInput(n, value);
}
}
-void addInputs(Node *n, const char * name, c10::optional<int64_t> value) {
- if(value) {
+void addInputs(Node* n, const char* name, c10::optional<int64_t> value) {
+ if (value) {
detail::genericAddInput(n, *value);
} else {
- Graph * g = n->owningGraph();
- Value* none =
- g->insertNode(g->createNone(IntType::get()))
- ->output();
+ Graph* g = n->owningGraph();
+ Value* none = g->insertNode(g->createNone(IntType::get()))->output();
n->addInput(none);
}
}
-void addInputs(Node *n, const char * name, bool value) { detail::genericAddInput(n, value); }
-void addInputs(Node *n, const char * name, double value) { detail::genericAddInput(n, value); }
-void addInputs(Node *n, const char * name, const at::Scalar& value) { detail::genericAddInput(n, value); }
-void addInputs(Node *n, const char * name, const c10::optional<at::Scalar>& value) {
- if(value) {
+void addInputs(Node* n, const char* name, bool value) {
+ detail::genericAddInput(n, value);
+}
+void addInputs(Node* n, const char* name, double value) {
+ detail::genericAddInput(n, value);
+}
+void addInputs(Node* n, const char* name, const at::Scalar& value) {
+ detail::genericAddInput(n, value);
+}
+void addInputs(
+ Node* n,
+ const char* name,
+ const c10::optional<at::Scalar>& value) {
+ if (value) {
detail::genericAddInput(n, *value);
} else {
- Graph * g = n->owningGraph();
- Value* none =
- g->insertNode(g->createNone(NumberType::get()))
- ->output();
+ Graph* g = n->owningGraph();
+ Value* none = g->insertNode(g->createNone(NumberType::get()))->output();
n->addInput(none);
}
}
-void addInputs(Node *n, const char * name, const std::string& value) { detail::genericAddInput(n, value); }
-void addInputs(Node *n, const char * name, const at::Tensor& value) { n->addInput(getValueTrace(value)); }
-void addInputs(Node *n, const char * name, const at::SparseTensorRef& value) { detail::badArgType(value); }
-void addInputs(Node *n, const char * name, at::Generator * value) {
+void addInputs(Node* n, const char* name, const std::string& value) {
+ detail::genericAddInput(n, value);
+}
+void addInputs(Node* n, const char* name, const at::Tensor& value) {
+ n->addInput(getValueTrace(value));
+}
+void addInputs(Node* n, const char* name, const at::SparseTensorRef& value) {
+ detail::badArgType(value);
+}
+void addInputs(Node* n, const char* name, at::Generator* value) {
if (value) {
detail::badArgType(value);
}
- Graph * g = n->owningGraph();
- Value * undef_gen = g->insertNode(g->createNone(GeneratorType::get()))->output();
+ Graph* g = n->owningGraph();
+ Value* undef_gen =
+ g->insertNode(g->createNone(GeneratorType::get()))->output();
n->addInput(undef_gen);
}
-void addInputs(Node *n, const char * name, at::Device value) {
+void addInputs(Node* n, const char* name, at::Device value) {
detail::genericAddInput(n, value);
}
-void addInputs(Node *n, const char * name, at::Layout value) {
+void addInputs(Node* n, const char* name, at::Layout value) {
detail::genericAddInput(n, static_cast<int64_t>(value));
}
-void addInputs(Node *n, const char * name, at::ScalarType value) {
+void addInputs(Node* n, const char* name, at::ScalarType value) {
detail::genericAddInput(n, static_cast<int64_t>(value));
}
-void addInputs(Node *n, const char * name, const c10::optional<at::ScalarType>& value) {
- if(value) {
+void addInputs(
+ Node* n,
+ const char* name,
+ const c10::optional<at::ScalarType>& value) {
+ if (value) {
detail::genericAddInput(n, static_cast<int64_t>(*value));
} else {
- Graph * g = n->owningGraph();
- Value* none =
- g->insertNode(g->createNone(IntType::get()))
- ->output();
+ Graph* g = n->owningGraph();
+ Value* none = g->insertNode(g->createNone(IntType::get()))->output();
n->addInput(none);
}
}
-void addInputs(Node *n, const char * name, at::TensorList value) {
- Graph *g = n->owningGraph();
- Node *list_node = g->appendNode(g->createList(DynamicType::get(), fmap(value, getValueTrace)));
+void addInputs(Node* n, const char* name, at::TensorList value) {
+ Graph* g = n->owningGraph();
+ Node* list_node = g->appendNode(
+ g->createList(DynamicType::get(), fmap(value, getValueTrace)));
n->addInput(list_node->output());
}
-void addInputs(Node* n, const char * name, const at::TensorOptions& options) {
- // [TensorOptions in script] - update this when you change how we schematize TensorOptions
+void addInputs(Node* n, const char* name, const at::TensorOptions& options) {
+ // [TensorOptions in script] - update this when you change how we schematize
+ // TensorOptions
addInputs(n, name, at::typeMetaToScalarType(options.dtype()));
addInputs(n, name, options.layout());
addInputs(n, name, options.device());
}
-void addInputs(Node *n, const char * name, at::IntList value) {
+void addInputs(Node* n, const char* name, at::IntList value) {
using ArgumentStash = jit::tracer::ArgumentStash;
- std::vector<Value*> info = ArgumentStash::hasIntList(name) ?
- ArgumentStash::popIntList(name) :
- ArgumentStash::IntListTrace(value.size());
+ std::vector<Value*> info = ArgumentStash::hasIntList(name)
+ ? ArgumentStash::popIntList(name)
+ : ArgumentStash::IntListTrace(value.size());
auto& g = getTracingState()->graph;
for (size_t i = 0; i < info.size(); ++i) {
- if (info[i] != nullptr) continue;
+ if (info[i] != nullptr)
+ continue;
info[i] = g->insertConstant(value[i]);
recordSourceLocation(info[i]->node());
}
for (jit::Value* v : info) {
if (*v->type() != *jit::IntType::get()) {
throw std::runtime_error(
- "Type mismatch in setposattr for IntList. Check that your program "
- "is valid without tracing, and please file a bug report if it is.");
+ "Type mismatch in setposattr for IntList. Check that your program "
+ "is valid without tracing, and please file a bug report if it is.");
}
}
- n->addInput(g->insertNode(g->createList(jit::IntType::get(), info))->output());
+ n->addInput(
+ g->insertNode(g->createList(jit::IntType::get(), info))->output());
}
-void addInputs(Node *n, const char * name, const ArrayRef<double>& value) {
+void addInputs(Node* n, const char* name, const ArrayRef<double>& value) {
AT_ERROR("Tracing float lists currently not supported!");
}
void addOutput(Node* node, const at::Tensor& output) {
- Value * value = node->addOutput();
+ Value* value = node->addOutput();
if (output.defined()) {
value->inferTypeFrom(output);
setValueTrace(autograd::as_variable_ref(output), value);
}
void addOutput(Node* node, const std::vector<at::Tensor>& outputs) {
- Value * value = node->addOutput()->setType(ListType::ofTensors());
- Graph * graph = node->owningGraph();
- Node * unpack_node = graph->appendNode(graph->create(prim::ListUnpack, {value}, outputs.size()));
+ Value* value = node->addOutput()->setType(ListType::ofTensors());
+ Graph* graph = node->owningGraph();
+ Node* unpack_node = graph->appendNode(
+ graph->create(prim::ListUnpack, {value}, outputs.size()));
for (size_t i = 0; i < outputs.size(); ++i) {
- Value * output_val = unpack_node->outputs()[i];
+ Value* output_val = unpack_node->outputs()[i];
output_val->inferTypeFrom(outputs[i]);
setValueTrace(outputs[i], output_val);
}
detail::tracing_state = std::move(state);
}
-TracingState::TracingState()
- : graph(new Graph()) {}
+TracingState::TracingState() : graph(new Graph()) {}
TracingState::~TracingState() = default;
autograd::Variable getSizeOf(const autograd::Variable& var, int64_t dim) {
- auto & tracing_state = getTracingState();
- auto & graph = tracing_state->graph;
+ auto& tracing_state = getTracingState();
+ auto& graph = tracing_state->graph;
- auto size_var = autograd::make_variable(scalar_to_tensor(at::Scalar(var.size(dim))));
+ auto size_var =
+ autograd::make_variable(scalar_to_tensor(at::Scalar(var.size(dim))));
auto* value = getValueTrace(var);
- WithInsertPoint ipoint { graph->block() };
+ WithInsertPoint ipoint{graph->block()};
auto dim_val = graph->insertConstant(dim);
recordSourceLocation(dim_val->node());
auto* node = graph->insertNode(graph->create(aten::size, {value, dim_val}));
////////////////////////////////////////////////////////////////////////////////
thread_local ArgumentStash ArgumentStash::stash;
-void ArgumentStash::stashIntListElem(const std::string& arg_name, size_t size, size_t idx, const Variable& var) {
+void ArgumentStash::stashIntListElem(
+ const std::string& arg_name,
+ size_t size,
+ size_t idx,
+ const Variable& var) {
// TODO: check type?
- if (!isTracing()) return;
- auto & list_trace = stash.intlists.emplace(arg_name, size).first->second;
+ if (!isTracing())
+ return;
+ auto& list_trace = stash.intlists.emplace(arg_name, size).first->second;
JIT_ASSERT(size == list_trace.size());
JIT_ASSERT(idx < list_trace.size());
JIT_ASSERT(list_trace[idx] == nullptr);
list_trace[idx] = prim;
}
-void ArgumentStash::stashValue(const std::string& arg_name, size_t idx, const Variable& var, const TypePtr& type) {
- if (!isTracing()) return;
+void ArgumentStash::stashValue(
+ const std::string& arg_name,
+ size_t idx,
+ const Variable& var,
+ const TypePtr& type) {
+ if (!isTracing())
+ return;
Value* ten = getValueTrace(var);
WithInsertPoint guard(ten->node()->next());
auto& g = *ten->owningGraph();
if (type == IntType::get()) {
- ten = g.insert(prim::Int, { ten });
+ ten = g.insert(prim::Int, {ten});
} else if (type == FloatType::get()) {
- ten = g.insert(prim::Float, { ten });
+ ten = g.insert(prim::Float, {ten});
}
stash.values.emplace(arg_name, ten);
////////////////////////////////////////////////////////////////////////////////
// no python present so we just do not record source information
void defaultRecordSourceLocation(Node* n) {}
-std::atomic<decltype(&defaultRecordSourceLocation)> record_source_location(defaultRecordSourceLocation);
+std::atomic<decltype(&defaultRecordSourceLocation)> record_source_location(
+ defaultRecordSourceLocation);
void recordSourceLocation(Node* n) {
return record_source_location.load()(n);
}
void defaultWarn(const std::string& str) {
AT_WARN(str);
}
-std::atomic<warn_fn_type> warn_callback { defaultWarn };
-
-const char * WARN_PYTHON_DATAFLOW =
- " might cause the trace to be incorrect. We can't record the data flow of "
- "Python values, so this value will be treated as a constant in the future. "
- "This means that the trace might not generalize to other inputs!";
-const char * WARN_CONSTRUCTOR =
- " results are registered as constants in the trace. You can safely ignore this "
- "warning if you use this function to create tensors out of constant variables "
- "that would be the same every time you call this function. In any other case, "
- "this might cause the trace to be incorrect.";
-const char * WARN_RESIZE =
- " can't be represented in the JIT at the moment, so we won't connect any uses of "
- "this value with its current trace. If you happen to use it again, it will show "
- "up as a constant in the graph.";
+std::atomic<warn_fn_type> warn_callback{defaultWarn};
+
+const char* WARN_PYTHON_DATAFLOW =
+ " might cause the trace to be incorrect. We can't record the data flow of "
+ "Python values, so this value will be treated as a constant in the future. "
+ "This means that the trace might not generalize to other inputs!";
+const char* WARN_CONSTRUCTOR =
+ " results are registered as constants in the trace. You can safely ignore this "
+ "warning if you use this function to create tensors out of constant variables "
+ "that would be the same every time you call this function. In any other case, "
+ "this might cause the trace to be incorrect.";
+const char* WARN_RESIZE =
+ " can't be represented in the JIT at the moment, so we won't connect any uses of "
+ "this value with its current trace. If you happen to use it again, it will show "
+ "up as a constant in the graph.";
// XXX: _kind can be a nullptr
-void _do_warn(const char * _reason, const char * _kind) {
- std::string reason { _reason };
- std::string kind { _kind ? _kind : "" };
+void _do_warn(const char* _reason, const char* _kind) {
+ std::string reason{_reason};
+ std::string kind{_kind ? _kind : ""};
std::ostringstream s;
s << reason << kind;
warn_callback.load()(s.str());
warn_callback.store(fn);
}
-}}}
+} // namespace tracer
+} // namespace jit
+} // namespace torch
#pragma once
+#include <ATen/Backtrace.h>
+#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/autograd/function_hook.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/jit/assertions.h>
#include <torch/csrc/jit/constants.h>
+#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/stack.h>
#include <torch/csrc/jit/tracing_state.h>
-#include <torch/csrc/jit/ir.h>
-#include <torch/csrc/utils/functional.h>
#include <torch/csrc/utils/functional.h>
#include <torch/csrc/utils/variadic.h>
-#include <torch/csrc/utils/variadic.h>
-#include <torch/csrc/WindowsTorchApiMacro.h>
-#include <ATen/Backtrace.h>
+#include <cstdint>
+#include <iostream>
#include <memory>
#include <mutex>
-#include <vector>
-#include <iostream>
-#include <cstdint>
#include <unordered_map>
+#include <vector>
-namespace torch { namespace jit { namespace tracer {
+namespace torch {
+namespace jit {
+namespace tracer {
using torch::autograd::Variable;
using variable_list = std::vector<Variable>;
TORCH_API void recordSourceLocation(Node* n);
TORCH_API void setRecordSourceLocation(void (*v)(Node*));
-// Having finished adding a new 'node' to the graph IR 'setValueTrace' associates
-// this node with an output variable, so that further operations involving this
-// variable know which node in the IR to reference.
+// Having finished adding a new 'node' to the graph IR 'setValueTrace'
+// associates this node with an output variable, so that further operations
+// involving this variable know which node in the IR to reference.
TORCH_API void setValueTrace(const IValue& v, Value* value);
inline void delValueTrace(const Variable& var) {
std::shared_ptr<tracer::TracingState> state = getTracingState();
tracer::setTracingState(nullptr);
- return [state]() {
- tracer::setTracingState(state);
- };
+ return [state]() { tracer::setTracingState(state); };
}
// Given a variable 'var', return the 'node' which represents the instruction
// return Addmm.apply(output, self, matrix, 0, 1, True)
//
// Here, mm fakes up a dummy variable with uninitialized data to do an inplace
-// update on, but subsequently ignores it because the alpha scaling factor is zero.
-// This is one of the cases where a Variable can be created inside of a trace, and
-// if we treat it as a constant, everything will work out.
+// update on, but subsequently ignores it because the alpha scaling factor is
+// zero. This is one of the cases where a Variable can be created inside of a
+// trace, and if we treat it as a constant, everything will work out.
inline Value* getValueTrace(const Variable& var) {
- auto &state = getTracingState();
+ auto& state = getTracingState();
if (!var.defined()) {
- Node *n = state->graph->createUndefined();
+ Node* n = state->graph->createUndefined();
return state->graph->appendNode(n)->output();
}
- auto & value_map = getTracingState()->value_map;
+ auto& value_map = getTracingState()->value_map;
auto it = value_map.find(var);
if (it == value_map.end()) {
- Value *constant = state->graph->insertConstant(var.data());
+ Value* constant = state->graph->insertConstant(var.data());
recordSourceLocation(constant->node());
constant->inferTypeFrom(var.data());
it = value_map.emplace_hint(it, var, constant);
// allow tracing of tuples passed to List[Tensor] or Tuple[Tensor...] arguments
// One might merge getValueTrace and getNestedValueTrace after checking that
// casting to IValue instead of Variable is OK
-inline Value* getNestedValueTrace(const IValue &v) {
- auto &state = getTracingState();
+inline Value* getNestedValueTrace(const IValue& v) {
+ auto& state = getTracingState();
if (v.isTensorList()) {
- return state->graph->insertNode(state->graph->createList(
- DynamicType::get(),
- fmap(v.toTensorListRef(), [](const IValue &val) {
- return getNestedValueTrace(val);
- })))->output();
+ return state->graph
+ ->insertNode(state->graph->createList(
+ DynamicType::get(),
+ fmap(
+ v.toTensorListRef(),
+ [](const IValue& val) { return getNestedValueTrace(val); })))
+ ->output();
} else if (v.isTuple()) {
- return state->graph->insertNode(state->graph->createTuple(
- fmap(v.toTuple()->elements(), [](const IValue &val) {
- return getNestedValueTrace(val);
- })))->output();
+ return state->graph
+ ->insertNode(state->graph->createTuple(fmap(
+ v.toTuple()->elements(),
+ [](const IValue& val) { return getNestedValueTrace(val); })))
+ ->output();
}
return getValueTrace(v.toTensor());
}
-
-inline Value* getOutputTrace(const std::shared_ptr<TracingState>& state, const Variable& var, size_t output_no) {
+inline Value* getOutputTrace(
+ const std::shared_ptr<TracingState>& state,
+ const Variable& var,
+ size_t output_no) {
if (!var.defined()) {
- Node *n = state->graph->createUndefined();
+ Node* n = state->graph->createUndefined();
return state->graph->appendNode(n)->output();
}
- auto & value_map = getTracingState()->value_map;
+ auto& value_map = getTracingState()->value_map;
auto it = value_map.find(var);
if (it == value_map.end()) {
std::ostringstream os;
auto state = std::make_shared<TracingState>();
setTracingState(state);
// XXX: this function mutates input
- const std::function<IValue(IValue, TypePtr, Value*)> add_input = [&](IValue input, TypePtr type, Value* value) -> IValue {
+ const std::function<IValue(IValue, TypePtr, Value*)> add_input =
+ [&](IValue input, TypePtr type, Value* value) -> IValue {
value->setType(type);
if (type->isSubtypeOf(DynamicType::get())) {
auto input_tensor = input.toTensor();
state->value_map[input_tensor] = value;
return input_tensor;
} else if (auto tuple_type = type->cast<TupleType>()) {
- auto unpack_node = state->graph->insertNode(state->graph->createTupleUnpack(value));
+ auto unpack_node =
+ state->graph->insertNode(state->graph->createTupleUnpack(value));
auto elem_values = unpack_node->outputs();
auto elem_types = tuple_type->elements();
Stack elems = input.toTuple()->elements();
size_t num_elems = elems.size();
- AT_ASSERT(elem_values.size() == num_elems && elem_types.size() == num_elems);
+ AT_ASSERT(
+ elem_values.size() == num_elems && elem_types.size() == num_elems);
for (size_t i = 0; i < num_elems; ++i) {
elems[i] = add_input(elems[i], elem_types[i], elem_values[i]);
}
return Tuple::create(std::move(elems));
} else {
- AT_ERROR("Only tensors or tuples of tensors can be inputs to traced functions");
+ AT_ERROR(
+ "Only tensors or tuples of tensors can be inputs to traced functions");
}
};
for (IValue& input : inputs) {
- input = add_input(input, incompleteInferTypeFrom(input), state->graph->addInput());
+ input = add_input(
+ input, incompleteInferTypeFrom(input), state->graph->addInput());
}
return std::make_pair(state, inputs);
}
// are the variables whose values will be computed upon subsequent
// invocations of the trace.
inline void exit(const Stack& outputs) {
- auto & state = getTracingState();
+ auto& state = getTracingState();
size_t i = 0;
- std::function<Value*(const IValue&)> reduce_ivalue = [&](const IValue& iv) -> Value* {
+ std::function<Value*(const IValue&)> reduce_ivalue =
+ [&](const IValue& iv) -> Value* {
if (iv.isTensor()) {
return getOutputTrace(state, iv.toTensor(), i);
} else if (iv.isTuple()) {
- const auto & elems = iv.toTuple()->elements();
+ const auto& elems = iv.toTuple()->elements();
auto tuple_node = state->graph->createTuple(fmap(elems, reduce_ivalue));
state->graph->appendNode(tuple_node);
return tuple_node->output();
} else {
- AT_ERROR("Only tensors or tuples of tensors can be output from traced functions");
+ AT_ERROR(
+ "Only tensors or tuples of tensors can be output from traced functions");
}
};
for (auto& output : outputs) {
// NB: those serve both as an intermediate steps in addInputs below,
// as well as the overloads that terminate template recursion
-TORCH_API void addInputs(Node *n, const char * name, int64_t value);
-TORCH_API void addInputs(Node *n, const char * name, c10::optional<int64_t> value);
-TORCH_API void addInputs(Node *n, const char * name, bool value);
-TORCH_API void addInputs(Node *n, const char * name, double value);
-TORCH_API void addInputs(Node *n, const char * name, const at::Scalar& value);
-TORCH_API void addInputs(Node *n, const char * name, const c10::optional<at::Scalar>& value);
-TORCH_API void addInputs(Node *n, const char * name, const at::Tensor& value);
-TORCH_API void addInputs(Node *n, const char * name, at::IntList value);
-TORCH_API void addInputs(Node *n, const char * name, at::TensorList value);
-TORCH_API void addInputs(Node *n, const char * name, const ArrayRef<double>& value);
-TORCH_API void addInputs(Node *n, const char * name, const std::string& value);
-TORCH_API void addInputs(Node *n, const char * name, const at::SparseTensorRef& value);
-TORCH_API void addInputs(Node *n, const char * name, const at::TensorOptions& value);
-TORCH_API void addInputs(Node *n, const char * name, at::Device value);
-TORCH_API void addInputs(Node *n, const char * name, at::Layout value);
-TORCH_API void addInputs(Node *n, const char * name, at::ScalarType value);
-TORCH_API void addInputs(Node *n, const char * name, const c10::optional<at::ScalarType>& value);
-TORCH_API void addInputs(Node *n, const char * name, at::Generator * value);
+TORCH_API void addInputs(Node* n, const char* name, int64_t value);
+TORCH_API void addInputs(
+ Node* n,
+ const char* name,
+ c10::optional<int64_t> value);
+TORCH_API void addInputs(Node* n, const char* name, bool value);
+TORCH_API void addInputs(Node* n, const char* name, double value);
+TORCH_API void addInputs(Node* n, const char* name, const at::Scalar& value);
+TORCH_API void addInputs(
+ Node* n,
+ const char* name,
+ const c10::optional<at::Scalar>& value);
+TORCH_API void addInputs(Node* n, const char* name, const at::Tensor& value);
+TORCH_API void addInputs(Node* n, const char* name, at::IntList value);
+TORCH_API void addInputs(Node* n, const char* name, at::TensorList value);
+TORCH_API void addInputs(
+ Node* n,
+ const char* name,
+ const ArrayRef<double>& value);
+TORCH_API void addInputs(Node* n, const char* name, const std::string& value);
+TORCH_API void addInputs(
+ Node* n,
+ const char* name,
+ const at::SparseTensorRef& value);
+TORCH_API void addInputs(
+ Node* n,
+ const char* name,
+ const at::TensorOptions& value);
+TORCH_API void addInputs(Node* n, const char* name, at::Device value);
+TORCH_API void addInputs(Node* n, const char* name, at::Layout value);
+TORCH_API void addInputs(Node* n, const char* name, at::ScalarType value);
+TORCH_API void addInputs(
+ Node* n,
+ const char* name,
+ const c10::optional<at::ScalarType>& value);
+TORCH_API void addInputs(Node* n, const char* name, at::Generator* value);
-template<size_t N>
-void addInputs(Node *n, const char * name, std::array<bool, N> value) {
- throw std::runtime_error("Found an unsupported argument type in the JIT tracer. File a bug report.");
+template <size_t N>
+void addInputs(Node* n, const char* name, std::array<bool, N> value) {
+ throw std::runtime_error(
+ "Found an unsupported argument type in the JIT tracer. File a bug report.");
}
-inline void ensureUniqueIfOutOfPlaced(const char * name, const at::Tensor& tensor) {
+inline void ensureUniqueIfOutOfPlaced(
+ const char* name,
+ const at::Tensor& tensor) {
auto& state = getTracingState();
if (state && state->force_outplace == false) {
// If we're not converting in-place ops to out-of-place, this check is
std::stringstream ss;
ss << "There are " << aliases
<< " live references to the data region being modified when tracing in-place operator "
- << name << ". This might cause the trace to be incorrect, because all other views "
+ << name
+ << ". This might cause the trace to be incorrect, because all other views "
<< "that also reference this data will not not reflect this change in the trace! "
<< "On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. "
<< "are outputs of torch.split), this might still be safe.";
}
}
-
template <
typename T,
typename = torch::enable_if_t<
TORCH_API void addOutput(Node* node, const at::Tensor& tensor);
TORCH_API void addOutput(Node* node, const std::vector<at::Tensor>& list);
-TORCH_API autograd::Variable getSizeOf(const autograd::Variable& var, int64_t dim);
+TORCH_API autograd::Variable getSizeOf(
+ const autograd::Variable& var,
+ int64_t dim);
-}}} // namespace torch::jit::tracer
+} // namespace tracer
+} // namespace jit
+} // namespace torch
#pragma once
+#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/autograd/function_hook.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/jit/assertions.h>
#include <torch/csrc/jit/stack.h>
#include <torch/csrc/jit/type.h>
#include <torch/csrc/utils/functional.h>
-#include <torch/csrc/utils/functional.h>
-#include <torch/csrc/utils/variadic.h>
#include <torch/csrc/utils/variadic.h>
-#include <torch/csrc/WindowsTorchApiMacro.h>
#include <ATen/Backtrace.h>
+#include <cstdint>
+#include <iostream>
#include <memory>
#include <mutex>
-#include <vector>
-#include <iostream>
-#include <cstdint>
#include <unordered_map>
+#include <vector>
-namespace torch { namespace jit { namespace tracer {
+namespace torch {
+namespace jit {
+namespace tracer {
using torch::autograd::Variable;
using variable_list = std::vector<Variable>;
-struct TORCH_API TracingState : public std::enable_shared_from_this<TracingState> {
+struct TORCH_API TracingState
+ : public std::enable_shared_from_this<TracingState> {
TracingState();
~TracingState();
}
};
- std::unordered_map<WeakTensor, Value*, WeakTensorHasher, WeakTensorEq> value_map;
+ std::unordered_map<WeakTensor, Value*, WeakTensorHasher, WeakTensorEq>
+ value_map;
std::shared_ptr<Graph> graph;
bool warn = true;
bool force_outplace = false;
std::function<std::string(const Variable& var)> lookup_var_name_fn =
- [](const Variable& var) {return "";};
+ [](const Variable& var) { return ""; };
};
-
// This is meant to be used as a thread local place, where we can store extra
// info that gets lost when we call into ATen from Python bindings. One example
// for when this happens is when we get an IntList argument with e.g. sizes for
// information. To prevent this, we temporarily stash it in here.
struct ArgumentStash {
struct IntListTrace : std::vector<Value*> {
- IntListTrace(int size)
- : std::vector<Value*>(size, nullptr) {}
+ IntListTrace(int size) : std::vector<Value*>(size, nullptr) {}
};
static bool empty() {
return stash.intlists.empty();
}
- TORCH_API static void stashIntListElem(const std::string& arg_name,
- size_t size,
- size_t idx,
- const Variable& var);
+ TORCH_API static void stashIntListElem(
+ const std::string& arg_name,
+ size_t size,
+ size_t idx,
+ const Variable& var);
static bool hasIntList(const std::string& arg_name) {
return stash.intlists.count(arg_name) > 0;
// Value stashing: Use these methods to stash arguments which correspond
// to regular Value*'s in the graph. i.e. they don't require special
// handling like in the case of IntLists
- TORCH_API static void stashValue(const std::string& arg_name,
- size_t idx,
- const Variable& var,
- const TypePtr& type=nullptr);
+ TORCH_API static void stashValue(
+ const std::string& arg_name,
+ size_t idx,
+ const Variable& var,
+ const TypePtr& type = nullptr);
static bool hasValue(const std::string& arg_name) {
return stash.values.count(arg_name) > 0;
return info;
}
-private:
+ private:
static thread_local ArgumentStash stash;
std::unordered_map<std::string, IntListTrace> intlists;
std::unordered_map<std::string, Value*> values;
};
-// Retrieve or set the current tracing state. Returns a nullptr if tracing is disabled.
+// Retrieve or set the current tracing state. Returns a nullptr if tracing is
+// disabled.
TORCH_API const std::shared_ptr<TracingState>& getTracingState();
TORCH_API void setTracingState(std::shared_ptr<TracingState> state);
}
using warn_fn_type = void (*)(const std::string& msg);
-TORCH_API extern const char * WARN_PYTHON_DATAFLOW;
-TORCH_API extern const char * WARN_CONSTRUCTOR;
-TORCH_API extern const char * WARN_RESIZE;
-TORCH_API void _do_warn(const char * _reason, const char * _kind);
-inline void warn(const char * _reason, const char * _kind=nullptr) {
+TORCH_API extern const char* WARN_PYTHON_DATAFLOW;
+TORCH_API extern const char* WARN_CONSTRUCTOR;
+TORCH_API extern const char* WARN_RESIZE;
+TORCH_API void _do_warn(const char* _reason, const char* _kind);
+inline void warn(const char* _reason, const char* _kind = nullptr) {
if (const auto& state = getTracingState()) {
- if (!state->warn) return;
+ if (!state->warn)
+ return;
_do_warn(_reason, _kind);
}
}
TORCH_API void setWarn(warn_fn_type fn);
struct TORCH_API NoWarn {
- NoWarn(): state(getTracingState()) {
+ NoWarn() : state(getTracingState()) {
if (state) {
prev = state->warn;
state->warn = false;
bool prev;
};
-}}} // namespace torch::jit::tracer
+} // namespace tracer
+} // namespace jit
+} // namespace torch
#include <ATen/core/jit_type.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
#define C10_USING(T) using ::c10::T;
- C10_FORALL_TYPES(C10_USING)
+C10_FORALL_TYPES(C10_USING)
#undef C10_USING
#define C10_USING(T) using ::c10::T##Ptr;
- C10_FORALL_TYPES(C10_USING)
+C10_FORALL_TYPES(C10_USING)
#undef C10_USING
using ::c10::Type;
-using ::c10::TypePtr;
using ::c10::TypeEnv;
+using ::c10::TypePtr;
using ::c10::getTypePtr;
-using ::c10::TypeKind;
using ::c10::MatchTypeReturn;
+using ::c10::TypeKind;
-}} // namespace torch::jit
+} // namespace jit
+} // namespace torch
#pragma once
#include <ATen/ATen.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
// a wrapper to mark places where we expect all the at::Tensors to be
// variables
struct variable_tensor_list : public std::vector<at::Tensor> {
variable_tensor_list() = default;
- template<class InputIt>
+ template <class InputIt>
variable_tensor_list(InputIt first, InputIt last)
- : std::vector<at::Tensor>(first, last) {}
- explicit variable_tensor_list(std::vector<at::Tensor> && tensor)
- : std::vector<at::Tensor>(std::move(tensor)) {}
+ : std::vector<at::Tensor>(first, last) {}
+ explicit variable_tensor_list(std::vector<at::Tensor>&& tensor)
+ : std::vector<at::Tensor>(std::move(tensor)) {}
};
-}}
+} // namespace jit
+} // namespace torch