convolution/matmul/dropout (#17523)
authorAiling Zhang <ailzhang@fb.com>
Thu, 28 Feb 2019 05:36:37 +0000 (21:36 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 28 Feb 2019 05:41:59 +0000 (21:41 -0800)
Summary:
* Add AD formula for _convolution & matmul & dropout
* add prim::range, fixes #17483
Example:
```
dim = 3
x = range(dim)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17523

Differential Revision: D14254002

Pulled By: ailzhang

fbshipit-source-id: ba60d77b047db347929b72beca2623fb26aec957

aten/src/ATen/core/interned_strings.h
aten/src/ATen/native/Dropout.cpp
torch/csrc/jit/autodiff.cpp
torch/csrc/jit/register_prim_ops.cpp
torch/csrc/jit/script/compiler.cpp
torch/csrc/jit/symbolic_script.cpp

index ce88112..89fda46 100644 (file)
@@ -74,6 +74,7 @@ namespace c10 {
   _(prim, MMBatchSide)             \
   _(prim, min)                     \
   _(prim, max)                     \
+  _(prim, range)                   \
   _(aten, _grad_sum_to_size)       \
   _(aten, _ncf_unsqueeze)          \
   _(aten, warn)                    \
index dedfbc4..52181f1 100644 (file)
@@ -5,6 +5,9 @@ namespace at { namespace native {
 
 namespace {
 
+template<bool inplace>
+using Ctype = typename std::conditional<inplace, Tensor&, Tensor>::type;
+
 Tensor make_feature_noise(const Tensor& input) {
   auto input_sizes = input.sizes();
   AT_CHECK(input.dim() >= 2, "Feature dropout requires at least 2 dimensions in the input");
@@ -36,8 +39,7 @@ Tensor multiply(const Tensor& input, const Tensor& noise) {
 }
 
 template<bool feature_dropout, bool alpha_dropout, bool inplace, typename T>
-typename std::conditional<inplace, Tensor&, Tensor>::type
-_dropout_impl(T& input, double p, bool train) {
+Ctype<inplace> _dropout_impl(T& input, double p, bool train) {
   AT_CHECK(p >= 0 && p <= 1, "dropout probability has to be between 0 and 1, but got ", p);
   if (p == 0 || !train) {
     return input;
@@ -66,10 +68,9 @@ _dropout_impl(T& input, double p, bool train) {
   }
 }
 
-#define ALIAS_SPECIALIZATION(ALIAS_NAME, IS_FEATURE, IS_ALPHA)                 \
-template <bool inplace, typename... Args>                                      \
-typename std::conditional<inplace, Tensor&, Tensor>::type                      \
-ALIAS_NAME(Args&&... args) {                                                   \
+#define ALIAS_SPECIALIZATION(ALIAS_NAME, IS_FEATURE, IS_ALPHA)                      \
+template <bool inplace, typename... Args>                                           \
+Ctype<inplace> ALIAS_NAME(Args&&... args) {                                         \
   return _dropout_impl<IS_FEATURE, IS_ALPHA, inplace>(std::forward<Args>(args)...); \
 }
 
index 2d4f41a..a92f982 100644 (file)
@@ -128,6 +128,12 @@ bool isDifferentiable(Node* n) {
   if (differentiable_ops.find(n))
     return true;
 
+  if (n->matches(
+          "aten::dropout(Tensor input, float p, bool train) -> Tensor")) {
+    auto train = n->get<bool>(attr::train).value();
+    return train;
+  }
+
   auto schema = n->maybeSchema();
   if (schema && hasGradientInfoForSchema(*schema)) {
     return true;
index acc94c4..c1f425d 100644 (file)
@@ -122,6 +122,18 @@ RegisterOperators reg({
           };
         }),
     Operator(
+        "prim::range(int n) -> int[]",
+        [](Stack& stack) {
+          int64_t n;
+          pop(stack, n);
+          std::vector<int64_t> elems(n);
+          for (int i = 0; i < n; i++) {
+            elems[i] = i;
+          }
+          push(stack, jit::IntList::create(elems));
+          return 0;
+        }),
+    Operator(
         "prim::Bool(Tensor a) -> bool",
         [](Stack& stack) {
           at::Tensor a;
@@ -1113,7 +1125,7 @@ int listSelect<Shared<BoolList>>(Stack& stack) {
   pop(stack, list, idx);
 
   auto element = getBoolItem(list->elements(), idx);
-  push(stack, std::move(element));
+  push(stack, element);
   return 0;
 }
 
index 6c56d6a..13ad911 100644 (file)
@@ -401,6 +401,7 @@ struct Environment {
           {"min", std::make_shared<BuiltinFunction>(prim::min, at::nullopt)},
           {"max", std::make_shared<BuiltinFunction>(prim::max, at::nullopt)},
           {"list", std::make_shared<BuiltinFunction>(aten::list, at::nullopt)},
+          {"range", std::make_shared<BuiltinFunction>(prim::range, at::nullopt)},
       };
       auto it = globals.find(ident);
       if (it != globals.end())
index b309ccb..d6ee9fb 100644 (file)
@@ -331,6 +331,49 @@ const std::vector<std::string> functions = {
                 return grad_self, grad_mat2
             return torch.bmm(self, mat2), backward
 
+        def AD_mat_transpose(mat):
+            dim = mat.dim()
+            if dim == 1:
+                out = mat
+            elif dim == 2:
+                out = mat.t()
+            else:
+                dims = range(dim)
+                dims[-1] = dim - 2
+                dims[-2] = dim - 1
+                out = mat.permute(dims)
+            return out
+
+        def AD_matmul_size(mat1, mat2,
+                           out_size: List[int]):
+            dim1 = mat1.dim()
+            dim2 = mat2.dim()
+            dim_out = len(out_size)
+            if dim1 == 0 or dim2 == 0:
+                out = mat1 * mat2
+            elif dim1 + dim2 == dim_out:
+                if dim2 == 1:
+                    target_dim2 = 0
+                else:
+                    target_dim2 = -2
+                out = torch.matmul(mat1.unsqueeze(dim1), mat2.unsqueeze(target_dim2))
+            elif dim_out == dim1 - dim2:
+                out = torch.matmul(mat1, mat2.unsqueeze(dim2)).squeeze(-1)
+            elif dim_out == dim2 - dim1:
+                out = torch.matmul(mat1.unsqueeze(-2), mat2).squeeze(-2)
+            else:
+                out = torch.matmul(mat1, mat2)
+            return out
+
+        def matmul(self, other):
+            def backward(grad_output):
+                self_size = self.size()
+                other_size = other.size()
+                grad_self = AD_matmul_size(grad_output, AD_mat_transpose(other), self_size)._grad_sum_to_size(self_size)
+                grad_other = AD_matmul_size(AD_mat_transpose(self), grad_output, other_size)._grad_sum_to_size(other_size)
+                return grad_self, grad_other
+
+            return torch.matmul(self, other), backward
     )",
     R"(
         def _dim_arange(like,
@@ -604,6 +647,18 @@ const std::vector<std::string> functions = {
 
             return output, backward
 
+        def dropout(input,
+                    p: float,
+                    train: bool):
+            mask = torch.empty_like(input)
+            mask.bernoulli_(1 - p)
+            res = mask * input / (1.0 - p)
+
+            def backward(grad_output):
+                grad_input = grad_output * mask / (1.0 - p)
+                return grad_input, None, None
+            return res, backward
+
         def embedding(weight,
                       indices,
                       padding_idx: int,
@@ -730,23 +785,21 @@ std::unordered_map<std::string, GradientPair> schema_to_graphs;
 // This map is a workaround to cache compiled gradient_pairs. Ideally this graph
 // should be compiled only once and saved in Operator structure.
 // This should be done along with merging into native_functions.yaml.
-std::unordered_map<const FunctionSchema*, GradientPair> cached_gradient_pairs;
+std::unordered_map<const FunctionSchema *, GradientPair> cached_gradient_pairs;
 } // anonymous namespace
 
-std::pair<std::shared_ptr<Graph>, Value*> extractClosure(Value* closure) {
-  AT_CHECK(
-      closure->node()->kind() == prim::TupleConstruct,
-      "closure must be a literal tuple construct");
-  Value* fn = closure->node()->inputs().at(0);
-  Value* context = closure->node()->inputs().at(1);
+std::pair<std::shared_ptr<Graph>, Value *> extractClosure(Value *closure) {
+  AT_CHECK(closure->node()->kind() == prim::TupleConstruct,
+           "closure must be a literal tuple construct");
+  Value *fn = closure->node()->inputs().at(0);
+  Value *context = closure->node()->inputs().at(1);
 
-  AT_CHECK(
-      fn->node()->kind() == prim::Function,
-      "closure tuple must contain a prim::Function");
+  AT_CHECK(fn->node()->kind() == prim::Function,
+           "closure tuple must contain a prim::Function");
   return std::make_pair(fn->node()->g(attr::Subgraph), context);
 }
 
-Argument originalReturnType(const TupleTypePtrtup) {
+Argument originalReturnType(const TupleTypePtr &tup) {
   AT_CHECK(tup->elements().size() > 1);
   if (tup->elements().size() == 2)
     return Argument("", tup->elements().at(0));
@@ -759,44 +812,42 @@ Argument originalReturnType(const TupleTypePtr& tup) {
 // overloaded functions of `func`.
 // Remove the suffix before adding the schema string to map
 // schema_to_graphs.
-std::string overloadedSchemaString(const FunctionSchemaschema) {
-  const autoschema_name = schema.name();
+std::string overloadedSchemaString(const FunctionSchema &schema) {
+  const auto &schema_name = schema.name();
   auto pos = schema_name.find_last_of('_');
   auto schema_name_suffix = schema_name.substr(pos + 1);
   std::string schema_string = canonicalSchemaString(schema);
   if (!schema_name_suffix.empty() &&
       schema_name_suffix.find_first_not_of("0123456789") == std::string::npos) {
-    schema_string.replace(
-        schema_string.find(schema_name),
-        schema_name.length(),
-        schema_name.substr(0, pos));
+    schema_string.replace(schema_string.find(schema_name), schema_name.length(),
+                          schema_name.substr(0, pos));
   }
   return schema_string;
 }
 
-bool isHelperFunction(const std::stringmethod_name) {
+bool isHelperFunction(const std::string &method_name) {
   std::string helper_prefix = "AD_";
   return method_name.compare(0, helper_prefix.length(), helper_prefix) == 0;
 }
 
-void loadModule(const std::shared_ptr<script::Module>module) {
-  for (const automethod_ : module->get_methods()) {
+void loadModule(const std::shared_ptr<script::Module> &module) {
+  for (const auto &method_ : module->get_methods()) {
     if (isHelperFunction(method_.key()))
       continue;
 
-    const automethod = method_.value();
+    const auto &method = method_.value();
     GradientPair pair;
     pair.forward = method->graph();
 
     // lookup the backward function
-    Nodeforward_tuple = pair.forward->outputs().at(0)->node();
+    Node *forward_tuple = pair.forward->outputs().at(0)->node();
 
     if (forward_tuple->kind() != prim::TupleConstruct) {
       throw script::ErrorReport(forward_tuple->getSourceLocation())
           << "gradient must return literal a tuple";
     }
 
-    Valuecontext;
+    Value *context;
     std::tie(pair.backward, context) =
         extractClosure(forward_tuple->inputs().back());
 
@@ -806,9 +857,9 @@ void loadModule(const std::shared_ptr<script::Module>& module) {
     //  return original, backward
     //  -----
     //  return original, context_tuple
-    std::vector<Value*> new_inputs = forward_tuple->inputs().vec();
+    std::vector<Value *> new_inputs = forward_tuple->inputs().vec();
     new_inputs.back() = context;
-    Valuenew_tuple =
+    Value *new_tuple =
         pair.forward->appendNode(pair.forward->createTuple(new_inputs))
             ->output();
     pair.forward->eraseOutput(0);
@@ -816,10 +867,9 @@ void loadModule(const std::shared_ptr<script::Module>& module) {
     forward_tuple->destroy();
 
     // derive schema from original function's schema:
-    const FunctionSchemaloaded_schema = method->getSchema();
+    const FunctionSchema &loaded_schema = method->getSchema();
     FunctionSchema actual_schema(
-        Symbol::aten(loaded_schema.name()),
-        loaded_schema.arguments(),
+        Symbol::aten(loaded_schema.name()), loaded_schema.arguments(),
         {originalReturnType(new_tuple->type()->expect<TupleType>())});
 
     // modify canonical string for function overloading
@@ -831,15 +881,15 @@ void loadModule(const std::shared_ptr<script::Module>& module) {
 }
 
 void loadFunctions() {
-  for (const std::stringstr : functions) {
+  for (const std::string &str : functions) {
     auto cu = std::make_shared<script::Module>();
     script::defineMethodsInModule(cu, str, script::nativeResolver, nullptr);
     loadModule(cu);
   }
 }
 
-c10::optional<GradientPair> gradientInfoForSchema(
-    const FunctionSchema& schema) {
+c10::optional<GradientPair>
+gradientInfoForSchema(const FunctionSchema &schema) {
   std::lock_guard<std::mutex> guard(lock);
   if (schema_to_graphs.size() == 0) {
     loadFunctions();
@@ -852,15 +902,15 @@ c10::optional<GradientPair> gradientInfoForSchema(
     auto sym_script_it = schema_to_graphs.find(schema_str);
 
     if (sym_script_it != schema_to_graphs.end()) {
-      cached_gradient_pairs.emplace_hint(
-          cache_it, &schema, sym_script_it->second);
+      cached_gradient_pairs.emplace_hint(cache_it, &schema,
+                                         sym_script_it->second);
       return sym_script_it->second;
     }
   }
   return c10::nullopt;
 }
 
-bool hasGradientInfoForSchema(const FunctionSchemaschema) {
+bool hasGradientInfoForSchema(const FunctionSchema &schema) {
   return gradientInfoForSchema(schema).has_value();
 }