Made FuncTorchBatched decompose CompositeImplicitAutograd (#63616)
authorHorace He <horacehe2007@yahoo.com>
Sun, 22 Aug 2021 00:13:27 +0000 (17:13 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Sun, 22 Aug 2021 00:14:39 +0000 (17:14 -0700)
Summary:
See https://github.com/facebookresearch/functorch/issues/56

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63616

Reviewed By: zou3519

Differential Revision: D30438316

Pulled By: Chillee

fbshipit-source-id: e84446d9f68b87daa0cfff75b3b8a972f36ec85a

c10/core/DispatchKeySet.cpp

index 404acc7..21433d4 100644 (file)
@@ -32,8 +32,8 @@ bool isBackendDispatchKey(DispatchKey t) {
 // math_dispatch_keyset contains all keys in backend_dispatch_keyset and
 // autograd_dispatch_keyset Alias key DispatchKey::CompositeImplicitAutograd
 // maps to math_dispatch_keyset.
-constexpr DispatchKeySet math_dispatch_keyset =
-    backend_dispatch_keyset | autograd_dispatch_keyset;
+constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset |
+    autograd_dispatch_keyset | DispatchKeySet({DispatchKey::FuncTorchBatched});
 
 DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
   TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);