def test_expand_propagate_qinfo(self):
pass
- def test_expand_insert_observers(self):
- pass
+ def test_insert_observers(self):
+ x1 = torch.tensor([0.4, 0.3])
+ y1 = torch.tensor([0.7, 0.5])
+ x2 = torch.tensor([0.1, 0.9])
+ y2 = torch.tensor([1.1, 1.9])
+
+ # Function that we will use as a graph
+ def fn(x, y):
+ p = x + y
+ z = x - y
+ return p * z
+
+ # Custom observer function
+ value_stats = {}
+
+ def observe(x, name):
+ if name not in value_stats:
+ value_stats[name] = []
+ value_stats[name].append(x)
+ return x
+
+ m = torch.jit.script(fn)
+ # Insert observers
+ torch._C._jit_pass_insert_observers(m.graph, observe)
+
+ # Collect statistics
+ m.forward(x1, y1)
+
+ # Check what we collected
+ self.assertTrue('p' in value_stats and 'z' in value_stats)
+ self.assertEqual(len(value_stats['p']), 1)
+ self.assertEqual(len(value_stats['z']), 1)
+ self.assertEqual(value_stats['p'][0], x1 + y1)
+ self.assertEqual(value_stats['z'][0], x1 - y1)
+
+ # Run one more time and check the updated statistics
+ m.forward(x2, y2)
+ self.assertEqual(len(value_stats['p']), 2)
+ self.assertEqual(len(value_stats['z']), 2)
+ self.assertEqual(value_stats['p'][1], x2 + y2)
+ self.assertEqual(value_stats['z'][1], x2 - y2)
def test_expand_insert_fakequant(self):
pass
[](std::shared_ptr<Graph>& g) { return PropagateQuantInfo(g); })
.def(
"_jit_pass_insert_observers",
- [](std::shared_ptr<Graph>& g) { return InsertObserverNodes(g); })
+ [](std::shared_ptr<Graph>& g, py::function pyObserverFunction) {
+ // Create a new node that would be used in the insert observer pass:
+ // all observer nodes will be cloned from this one.
+ Node* new_node = g->createPythonOp(
+ THPObjectPtr(pyObserverFunction.release().ptr()), "dd", {});
+ InsertObserverNodes(g, new_node);
+ // We don't need this node anymore, don't forget to remove it.
+ new_node->destroy();
+ })
.def(
"_jit_pass_insert_fakequant",
[](std::shared_ptr<Graph>& g) { return InsertFakeQuantNodes(g); })
#include <torch/csrc/jit/node_hashing.h>
#include <torch/csrc/jit/passes/alias_analysis.h>
-#include <unordered_map>
+#include <stack>
namespace torch {
namespace jit {
-namespace {} // namespace
void ExpandFakeQuantNodes(std::shared_ptr<Graph>& graph) {
throw std::runtime_error("Pass not implemented yet!");
throw std::runtime_error("Pass not implemented yet!");
}
-void InsertObserverNodes(std::shared_ptr<Graph>& graph) {
- throw std::runtime_error("Pass not implemented yet!");
+static void addObserverFor(Value* v, Node* original_observer_node) {
+ Node* def = v->node();
+ WithInsertPoint ins(def);
+
+ // We need to pass the value name to observer function - create a constant
+ // holding this name.
+ Value* vname = def->owningGraph()->insertConstant(v->uniqueName());
+
+ // Create a new observer node. We just need to clone the original one.
+ Node* observerNode =
+ def->owningGraph()
+ ->createClone(
+ &*original_observer_node, [&](Value* v) { return v; }, false)
+ ->insertAfter(def);
+
+ // Set the type and the name of the output of the new observer node. It will
+ // be used instead of the original value v.
+ Value* observedValue = observerNode->addOutput();
+ observedValue->setType(v->type());
+ observedValue->setUniqueName(v->uniqueName() + ".observed");
+
+ // Replace the uses of v with observedValue. We need to do it *before* we add
+ // the inputs - otherwise we would replace the newly added inputs as well.
+ v->replaceAllUsesWith(observedValue);
+
+ // Now we can add the inputs.
+ observerNode->addInput(v);
+ observerNode->addInput(vname);
+}
+
+static bool outputsNeedToBeObserved(Node* n) {
+ return n->kind().toQualString() != std::string("prim::Constant");
+}
+
+void InsertObserverNodes(std::shared_ptr<Graph>& graph, Node* observer_node) {
+ // For storing all values that need to be instrumented with an observer call.
+ std::vector<Value*> values_to_observe;
+
+ // For traversing all blocks in the graph including subblocks.
+ std::stack<Block*> blocks_to_visit;
+
+ blocks_to_visit.push(graph->block());
+ while (!blocks_to_visit.empty()) {
+ Block* b = blocks_to_visit.top();
+ blocks_to_visit.pop();
+ for (Node* n : b->nodes()) {
+ // Skip nodes that we don't need to observe, e.g. 'prim::Constant'.
+ if (!outputsNeedToBeObserved(n)) {
+ continue;
+ }
+
+ // Record all outputs in the values_to_observe - we'll later add observers
+ // for all values from it.
+ for (Value* v : n->outputs()) {
+ values_to_observe.push_back(v);
+ }
+
+ // Schedule subblocks (if any) for visiting.
+ for (Block* subblock : n->blocks()) {
+ blocks_to_visit.push(subblock);
+ }
+ }
+ }
+
+ // Actually add observer nodes.
+ for (Value* v : values_to_observe) {
+ addObserverFor(v, observer_node);
+ }
}
void InsertFakeQuantNodes(std::shared_ptr<Graph>& graph) {