return grad_self, grad_mat2
return torch.bmm(self, mat2), backward
+ def AD_mat_transpose(mat):
+ dim = mat.dim()
+ if dim == 1:
+ out = mat
+ elif dim == 2:
+ out = mat.t()
+ else:
+ dims = range(dim)
+ dims[-1] = dim - 2
+ dims[-2] = dim - 1
+ out = mat.permute(dims)
+ return out
+
+ def AD_matmul_size(mat1, mat2,
+ out_size: List[int]):
+ dim1 = mat1.dim()
+ dim2 = mat2.dim()
+ dim_out = len(out_size)
+ if dim1 == 0 or dim2 == 0:
+ out = mat1 * mat2
+ elif dim1 + dim2 == dim_out:
+ if dim2 == 1:
+ target_dim2 = 0
+ else:
+ target_dim2 = -2
+ out = torch.matmul(mat1.unsqueeze(dim1), mat2.unsqueeze(target_dim2))
+ elif dim_out == dim1 - dim2:
+ out = torch.matmul(mat1, mat2.unsqueeze(dim2)).squeeze(-1)
+ elif dim_out == dim2 - dim1:
+ out = torch.matmul(mat1.unsqueeze(-2), mat2).squeeze(-2)
+ else:
+ out = torch.matmul(mat1, mat2)
+ return out
+
+ def matmul(self, other):
+ def backward(grad_output):
+ self_size = self.size()
+ other_size = other.size()
+ grad_self = AD_matmul_size(grad_output, AD_mat_transpose(other), self_size)._grad_sum_to_size(self_size)
+ grad_other = AD_matmul_size(AD_mat_transpose(self), grad_output, other_size)._grad_sum_to_size(other_size)
+ return grad_self, grad_other
+
+ return torch.matmul(self, other), backward
)",
R"(
def _dim_arange(like,
return output, backward
+ def dropout(input,
+ p: float,
+ train: bool):
+ mask = torch.empty_like(input)
+ mask.bernoulli_(1 - p)
+ res = mask * input / (1.0 - p)
+
+ def backward(grad_output):
+ grad_input = grad_output * mask / (1.0 - p)
+ return grad_input, None, None
+ return res, backward
+
def embedding(weight,
indices,
padding_idx: int,
// This map is a workaround to cache compiled gradient_pairs. Ideally this graph
// should be compiled only once and saved in Operator structure.
// This should be done along with merging into native_functions.yaml.
-std::unordered_map<const FunctionSchema*, GradientPair> cached_gradient_pairs;
+std::unordered_map<const FunctionSchema *, GradientPair> cached_gradient_pairs;
} // anonymous namespace
-std::pair<std::shared_ptr<Graph>, Value*> extractClosure(Value* closure) {
- AT_CHECK(
- closure->node()->kind() == prim::TupleConstruct,
- "closure must be a literal tuple construct");
- Value* fn = closure->node()->inputs().at(0);
- Value* context = closure->node()->inputs().at(1);
+std::pair<std::shared_ptr<Graph>, Value *> extractClosure(Value *closure) {
+ AT_CHECK(closure->node()->kind() == prim::TupleConstruct,
+ "closure must be a literal tuple construct");
+ Value *fn = closure->node()->inputs().at(0);
+ Value *context = closure->node()->inputs().at(1);
- AT_CHECK(
- fn->node()->kind() == prim::Function,
- "closure tuple must contain a prim::Function");
+ AT_CHECK(fn->node()->kind() == prim::Function,
+ "closure tuple must contain a prim::Function");
return std::make_pair(fn->node()->g(attr::Subgraph), context);
}
-Argument originalReturnType(const TupleTypePtr& tup) {
+Argument originalReturnType(const TupleTypePtr &tup) {
AT_CHECK(tup->elements().size() > 1);
if (tup->elements().size() == 2)
return Argument("", tup->elements().at(0));
// overloaded functions of `func`.
// Remove the suffix before adding the schema string to map
// schema_to_graphs.
-std::string overloadedSchemaString(const FunctionSchema& schema) {
- const auto& schema_name = schema.name();
+std::string overloadedSchemaString(const FunctionSchema &schema) {
+ const auto &schema_name = schema.name();
auto pos = schema_name.find_last_of('_');
auto schema_name_suffix = schema_name.substr(pos + 1);
std::string schema_string = canonicalSchemaString(schema);
if (!schema_name_suffix.empty() &&
schema_name_suffix.find_first_not_of("0123456789") == std::string::npos) {
- schema_string.replace(
- schema_string.find(schema_name),
- schema_name.length(),
- schema_name.substr(0, pos));
+ schema_string.replace(schema_string.find(schema_name), schema_name.length(),
+ schema_name.substr(0, pos));
}
return schema_string;
}
-bool isHelperFunction(const std::string& method_name) {
+bool isHelperFunction(const std::string &method_name) {
std::string helper_prefix = "AD_";
return method_name.compare(0, helper_prefix.length(), helper_prefix) == 0;
}
-void loadModule(const std::shared_ptr<script::Module>& module) {
- for (const auto& method_ : module->get_methods()) {
+void loadModule(const std::shared_ptr<script::Module> &module) {
+ for (const auto &method_ : module->get_methods()) {
if (isHelperFunction(method_.key()))
continue;
- const auto& method = method_.value();
+ const auto &method = method_.value();
GradientPair pair;
pair.forward = method->graph();
// lookup the backward function
- Node* forward_tuple = pair.forward->outputs().at(0)->node();
+ Node *forward_tuple = pair.forward->outputs().at(0)->node();
if (forward_tuple->kind() != prim::TupleConstruct) {
throw script::ErrorReport(forward_tuple->getSourceLocation())
<< "gradient must return literal a tuple";
}
- Value* context;
+ Value *context;
std::tie(pair.backward, context) =
extractClosure(forward_tuple->inputs().back());
// return original, backward
// -----
// return original, context_tuple
- std::vector<Value*> new_inputs = forward_tuple->inputs().vec();
+ std::vector<Value *> new_inputs = forward_tuple->inputs().vec();
new_inputs.back() = context;
- Value* new_tuple =
+ Value *new_tuple =
pair.forward->appendNode(pair.forward->createTuple(new_inputs))
->output();
pair.forward->eraseOutput(0);
forward_tuple->destroy();
// derive schema from original function's schema:
- const FunctionSchema& loaded_schema = method->getSchema();
+ const FunctionSchema &loaded_schema = method->getSchema();
FunctionSchema actual_schema(
- Symbol::aten(loaded_schema.name()),
- loaded_schema.arguments(),
+ Symbol::aten(loaded_schema.name()), loaded_schema.arguments(),
{originalReturnType(new_tuple->type()->expect<TupleType>())});
// modify canonical string for function overloading
}
void loadFunctions() {
- for (const std::string& str : functions) {
+ for (const std::string &str : functions) {
auto cu = std::make_shared<script::Module>();
script::defineMethodsInModule(cu, str, script::nativeResolver, nullptr);
loadModule(cu);
}
}
-c10::optional<GradientPair> gradientInfoForSchema(
- const FunctionSchema& schema) {
+c10::optional<GradientPair>
+gradientInfoForSchema(const FunctionSchema &schema) {
std::lock_guard<std::mutex> guard(lock);
if (schema_to_graphs.size() == 0) {
loadFunctions();
auto sym_script_it = schema_to_graphs.find(schema_str);
if (sym_script_it != schema_to_graphs.end()) {
- cached_gradient_pairs.emplace_hint(
- cache_it, &schema, sym_script_it->second);
+ cached_gradient_pairs.emplace_hint(cache_it, &schema,
+ sym_script_it->second);
return sym_script_it->second;
}
}
return c10::nullopt;
}
-bool hasGradientInfoForSchema(const FunctionSchema& schema) {
+bool hasGradientInfoForSchema(const FunctionSchema &schema) {
return gradientInfoForSchema(schema).has_value();
}