testConcat(2);
}
-struct Attr : public Attributes<Attr> {};
void testAttributes() {
+ Graph g;
auto one = attr::alpha;
auto two = attr::device;
auto three = attr::end;
auto four = attr::perm;
- Attr attr;
+ Node *n = g.create(Symbol::fromQualString("foo::bar"));
+ Node &attr = *n;
attr.f_(one, 3.4)->i_(two, 5)->s_(three, "what");
ASSERT_EQ(attr.f(one), 3.4);
ASSERT_EQ(attr.s(three), "what");
attr.ss_(two, {"hi", "now"});
ASSERT_EQ(attr.ss(two).at(1), "now");
- Attr attr2;
+ Node *n2 = g.create(Symbol::fromQualString("foo::baz"));
+ Node &attr2 = *n2;
attr2.copyAttributes(attr);
ASSERT_EQ(attr2.s(one), "no");
attr2.f_(one, 5);
private:
std::string msg;
};
-
-// CRTP so that Node which inherits Attributes can be return for
-// method chaining e.g:
-// Node * n = g->create(kSelect)->i_(kOffset,3)->f_(kValue,3.5);
-// we return Derived* pointers because Nodes are normally held as pointers.
-template <typename Derived>
-struct Attributes {
- Attributes() = default;
- void copyAttributes(const Attributes& rhs) {
- values_.clear();
- for (auto& i : rhs.values_) {
- values_.push_back(i->clone());
- }
- }
- bool hasAttribute(Symbol name) const {
- JIT_ASSERT(name.is_attr());
- return find(name, false) != values_.end();
- }
- // We want direct string accessors, as it is nicer to use than
- // hasAttribute(Symbol::attr("blah"))
- //
- // For some reason, &Attributes<Node>::hasAttribute in pybind11 is able to
- // give the pybind11 metaprogramming machinery "the right type", but
- // the equivalent looking lambda [](Attributes<Node>& a, const std::string&)
- // doesn't work! So instead we define the methods on the class so we can
- // continue using the old idiom.
- bool hasAttributeS(const std::string& name) const {
- return hasAttribute(Symbol::attr(name));
- }
- AttributeKind kindOf(Symbol name) const {
- JIT_ASSERT(name.is_attr());
- return (*find(name, true))->kind();
- }
- AttributeKind kindOfS(const std::string& name) const {
- return kindOf(Symbol::attr(name));
- }
- Derived* removeAttribute(Symbol name) {
- JIT_ASSERT(name.is_attr());
- values_.erase(find(name, true));
- return This();
- }
- Derived* removeAttributeS(const std::string& name) {
- return removeAttribute(Symbol::attr(name));
- }
- bool hasAttributes() const {
- return values_.size() > 0;
- }
- size_t numAttributes() const {
- return values_.size();
- }
- // The names are returned in order, since name actually is the index.
- std::vector<Symbol> attributeNames() const {
- std::vector<Symbol> names;
- for (auto& a : values_)
- names.push_back(a->name);
- return names;
- }
- std::vector<const char*> attributeNamesS() const {
- std::vector<const char*> names;
- for (auto& a : values_)
- names.push_back(a->name.toUnqualString());
- return names;
- }
-
-#define CREATE_ACCESSOR(Kind, method) \
- Derived* method##_(Symbol name, Kind##Attr::ConstructorType v) { \
- return set<Kind##Attr>( \
- name, std::forward<Kind##Attr::ConstructorType>(v)); \
- } \
- const Kind##Attr::ValueType& method(Symbol name) const { \
- return get<Kind##Attr>(name); \
- }
-
- CREATE_ACCESSOR(Float, f)
- CREATE_ACCESSOR(Floats, fs)
- CREATE_ACCESSOR(String, s)
- CREATE_ACCESSOR(Strings, ss)
- CREATE_ACCESSOR(Int, i)
- CREATE_ACCESSOR(Ints, is)
- CREATE_ACCESSOR(Graph, g)
- CREATE_ACCESSOR(Graphs, gs)
-
-#undef CREATE_ACCESSOR
-
- // Our Graphs are not very const-correct, so we need to allow returning
- // non-const references too
- GraphAttr::ValueType& g(Symbol name) {
- return get<GraphAttr>(name);
- }
-
- // does not use CREATE_ACCESSOR because we need additional asserts
- Derived* t_(Symbol name, TensorAttr::ConstructorType v) {
- JIT_ASSERT(!v.defined() || !v.is_variable());
- return set<TensorAttr>(name, std::forward<TensorAttr::ConstructorType>(v));
- }
- const TensorAttr::ValueType& t(Symbol name) const {
- return get<TensorAttr>(name);
- }
-
- Derived* ts_(Symbol name, TensorsAttr::ConstructorType v) {
- for (auto& t : v) {
- JIT_ASSERT(!t.defined() || !t.is_variable());
- }
- return set<TensorsAttr>(
- name, std::forward<TensorsAttr::ConstructorType>(v));
- }
- const TensorsAttr::ValueType& ts(Symbol name) const {
- return get<TensorsAttr>(name);
- }
-
- template <typename T>
- static void printPrimList(std::ostream& out, const std::vector<T>& items) {
- out << "[";
- int i = 0;
- for (auto& item : items) {
- if (i++ > 0)
- out << ", ";
- out << item;
- }
- out << "]";
- }
-
- static std::string escapeString(std::string s) {
- std::vector<char> search = {'\n', '\t', '\v'};
- std::vector<std::string> replace = {"\\n", "\\t", "\\v"};
- for (size_t i = 0; i < search.size(); i++) {
- size_t pos = s.find(search[i]);
- while (pos != std::string::npos) {
- s.replace(pos, 1, replace[i]);
- pos = s.find(search[i], pos + 1);
- }
- }
- return s;
- }
-
- void printValue(std::ostream& out, const Symbol& name) const {
- switch (kindOf(name)) {
- case AttributeKind::f:
- out << f(name);
- break;
- case AttributeKind::fs:
- printPrimList(out, fs(name));
- break;
- case AttributeKind::i:
- out << i(name);
- break;
- case AttributeKind::is:
- printPrimList(out, is(name));
- break;
- case AttributeKind::s:
- out << "\"" << escapeString(s(name)) << "\"";
- break;
- case AttributeKind::ss:
- printPrimList(out, ss(name));
- break;
- case AttributeKind::t: {
- at::Tensor tensor = t(name);
- // 1-elem tensors are usually boxed scalars, so print them like it
- if (tensor.numel() == 1) {
- auto scalar_tensor = tensor.view({}).item();
- out << "{";
- if (scalar_tensor.isFloatingPoint()) {
- out << scalar_tensor.toDouble();
- } else {
- out << scalar_tensor.toLong();
- }
- out << "}";
- } else if (tensor.numel() <= max_tensor_display_size) {
- // TODO: This is awful code. Also it doesn't work on Windows.
- std::ostringstream tensor_ss;
- tensor_ss << tensor;
- std::string tensor_s{tensor_ss.str()};
- // Remove newlines
- std::replace(tensor_s.begin(), tensor_s.end(), '\n', ' ');
- out << tensor_s;
- } else {
- out << "<Tensor>";
- }
- break;
- }
- case AttributeKind::ts:
- out << "[<Tensors>]";
- break;
- case AttributeKind::g:
- out << "<Graph>";
- break;
- case AttributeKind::gs:
- out << "[<Graphs>]";
- break;
- }
- }
-
- private:
- // UBSAN error: https://github.com/pytorch/pytorch/issues/9055
- Derived* This() __ubsan_ignore_vptr__ {
- return static_cast<Derived*>(this);
- }
- template <typename T>
- Derived* set(Symbol name, typename T::ConstructorType v) {
- JIT_ASSERT(name.is_attr());
- auto it = find(name, false);
- auto nv = AVPtr(new T(name, std::forward<typename T::ConstructorType>(v)));
- if (it == values_.end()) {
- values_.push_back(std::move(nv));
- } else {
- *it = std::move(nv);
- }
- return This();
- }
- template <typename T>
- typename T::ValueType& get(Symbol name) const {
- JIT_ASSERT(name.is_attr());
- auto it = find(name, true);
- auto* child = dynamic_cast<T*>(it->get());
- if (child == nullptr) {
- throw AttributeError(name, true);
- }
- return child->value();
- }
- using AVPtr = AttributeValue::Ptr;
- // NB: For determinism, we use a vector rather than a hash map. This does
- // mean that lookups are O(n), so you shouldn't use Attributes to store
- // a big pile of messages.
- std::vector<AVPtr> values_;
- using iterator = std::vector<AVPtr>::iterator;
- iterator find(Symbol name, bool required) {
- JIT_ASSERT(name.is_attr());
- auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) {
- return v->name == name;
- });
- if (required && it == values_.end()) {
- throw AttributeError(name, false);
- }
- JIT_ASSERT(!required || it != values_.end());
- return it;
- }
- using const_iterator = std::vector<AVPtr>::const_iterator;
- const_iterator find(Symbol name, bool required) const {
- JIT_ASSERT(name.is_attr());
- auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) {
- return v->name == name;
- });
- if (required && it == values_.end()) {
- throw AttributeError(name, false);
- }
- JIT_ASSERT(!required || it != values_.end());
- return it;
- }
-};
-
} // namespace jit
} // namespace torch
return out;
}
-void printAttributes(
- std::ostream& out,
- const Node* n,
- bool ignore_subgraph = false) {
+template <typename T>
+static void printPrimList(std::ostream& out, const std::vector<T>& items) {
out << "[";
- auto names = n->attributeNames();
+ int i = 0;
+ for (auto& item : items) {
+ if (i++ > 0)
+ out << ", ";
+ out << item;
+ }
+ out << "]";
+}
+
+static std::string escapeString(std::string s) {
+ std::vector<char> search = {'\n', '\t', '\v'};
+ std::vector<std::string> replace = {"\\n", "\\t", "\\v"};
+ for (size_t i = 0; i < search.size(); i++) {
+ size_t pos = s.find(search[i]);
+ while (pos != std::string::npos) {
+ s.replace(pos, 1, replace[i]);
+ pos = s.find(search[i], pos + 1);
+ }
+ }
+ return s;
+}
+
+void Node::printAttrValue(std::ostream& out, const Symbol& name) const {
+ switch (kindOf(name)) {
+ case AttributeKind::f:
+ out << f(name);
+ break;
+ case AttributeKind::fs:
+ printPrimList(out, fs(name));
+ break;
+ case AttributeKind::i:
+ out << i(name);
+ break;
+ case AttributeKind::is:
+ printPrimList(out, is(name));
+ break;
+ case AttributeKind::s:
+ out << "\"" << escapeString(s(name)) << "\"";
+ break;
+ case AttributeKind::ss:
+ printPrimList(out, ss(name));
+ break;
+ case AttributeKind::t: {
+ at::Tensor tensor = t(name);
+ // 1-elem tensors are usually boxed scalars, so print them like it
+ if (tensor.numel() == 1) {
+ auto scalar_tensor = tensor.view({}).item();
+ out << "{";
+ if (scalar_tensor.isFloatingPoint()) {
+ out << scalar_tensor.toDouble();
+ } else {
+ out << scalar_tensor.toLong();
+ }
+ out << "}";
+ } else if (tensor.numel() <= max_tensor_display_size) {
+ // TODO: This is awful code. Also it doesn't work on Windows.
+ std::ostringstream tensor_ss;
+ tensor_ss << tensor;
+ std::string tensor_s{tensor_ss.str()};
+ // Remove newlines
+ std::replace(tensor_s.begin(), tensor_s.end(), '\n', ' ');
+ out << tensor_s;
+ } else {
+ out << "<Tensor>";
+ }
+ break;
+ }
+ case AttributeKind::ts:
+ out << "[<Tensors>]";
+ break;
+ case AttributeKind::g:
+ out << "<Graph>";
+ break;
+ case AttributeKind::gs:
+ out << "[<Graphs>]";
+ break;
+ }
+}
+
+void Node::printAttributes(std::ostream& out, bool ignore_subgraph = false)
+ const {
+ out << "[";
+ auto names = attributeNames();
int i = 0;
for (auto name : names) {
if (ignore_subgraph && name == attr::Subgraph)
// bug by printing it out.
out << name.toUnqualString() << "=";
- n->printValue(out, name);
+ printAttrValue(out, name);
}
out << "]";
}
return out;
}
-std::ostream& printNode(
+std::ostream& Node::print(
std::ostream& out,
size_t level,
- const Node* n,
- std::vector<const Node*>* groups) {
- auto outputs = n->outputs();
- indent(out, level) << const_value_list_with_types(outputs);
+ std::vector<const Node*>* groups) const {
+ auto outs = outputs();
+ indent(out, level) << const_value_list_with_types(outs);
out << " = ";
- IR_IFM_CONST(n, PythonOp)
- out << "^" << value->name();
- value->writeScalars(out);
- IR_ELSE()
- if (n->hasAttribute(attr::Subgraph) && groups) {
- out << n->kind().toQualString() << "_" << groups->size();
- if (n->numAttributes() > 1 && n->kind() != prim::DifferentiableGraph) {
- printAttributes(out, n, /*ignore_subgraph=*/true);
- }
- groups->push_back(n);
+ if (kind() == prim::PythonOp) {
+ auto* pyOp = static_cast<const ::torch::jit::PythonOp*>(this);
+ out << "^" << pyOp->name();
+ pyOp->writeScalars(out);
} else {
- out << n->kind().toQualString();
- if (n->hasAttributes()) {
- printAttributes(out, n);
+ if (hasAttribute(attr::Subgraph) && groups) {
+ out << kind().toQualString() << "_" << groups->size();
+ if (numAttributes() > 1 && kind() != prim::DifferentiableGraph) {
+ printAttributes(out, /*ignore_subgraph=*/true);
+ }
+ groups->push_back(this);
+ } else {
+ out << kind().toQualString();
+ if (hasAttributes()) {
+ printAttributes(out);
+ }
}
}
- IR_END()
- out << "(" << n->inputs() << ")";
- std::string scopeName = n->scopeName();
- if (scopeName.empty()) {
+
+ out << "(" << inputs() << ")";
+ std::string scName = scopeName();
+ if (scName.empty()) {
out << "\n";
} else {
out << ", ";
- out << "scope: " << scopeName << "\n";
+ out << "scope: " << scName << "\n";
}
- for (size_t i = 0; i < n->blocks().size(); ++i) {
- auto b = n->blocks()[i];
+ for (size_t i = 0; i < blocks().size(); ++i) {
+ auto b = blocks()[i];
indent(out, level + 1) << "block" << i << "("
<< const_value_list_with_types(b->inputs(), false)
<< ") {\n";
- for (auto n : b->nodes()) {
- printNode(out, level + 2, n, groups);
+ for (auto nested : b->nodes()) {
+ nested->print(out, level + 2, groups);
}
indent(out, level + 2) << "-> (" << b->outputs() << ")\n";
indent(out, level + 1) << "}\n";
}
std::ostream& operator<<(std::ostream& out, const Node& n) {
- return printNode(out, 0, &n, nullptr);
+ return n.print(out, 0, nullptr);
}
std::ostream& operator<<(std::ostream& out, const Graph& g) {
out << "graph(" << const_value_list_with_types(g.inputs(), true) << ") {\n";
std::vector<const Node*> groups;
for (auto n : g.nodes()) {
- printNode(out, 1, n, &groups);
+ n->print(out, 1, &groups);
}
out << " return (" << g.outputs() << ");\n}\n";
size_t i = 0;
TORCH_API Value* copyMetadata(Value* from);
};
-struct Node : public Attributes<Node> {
+struct Node {
TH_DISALLOW_COPY_AND_ASSIGN(Node);
friend struct Graph;
friend struct Block;
void dump() const;
+ std::ostream& print(
+ std::ostream& out,
+ size_t level,
+ std::vector<const Node*>* groups) const;
+
virtual ~Node() = default;
+ // Methods for accessing attributes
+ void copyAttributes(const Node& rhs) {
+ values_.clear();
+ for (auto& i : rhs.values_) {
+ values_.push_back(i->clone());
+ }
+ }
+ bool hasAttribute(Symbol name) const {
+ JIT_ASSERT(name.is_attr());
+ return findAttr(name, false) != values_.end();
+ }
+ bool hasAttributeS(const std::string& name) const {
+ return hasAttribute(Symbol::attr(name));
+ }
+ AttributeKind kindOf(Symbol name) const {
+ JIT_ASSERT(name.is_attr());
+ return (*findAttr(name, true))->kind();
+ }
+ AttributeKind kindOfS(const std::string& name) const {
+ return kindOf(Symbol::attr(name));
+ }
+ Node* removeAttribute(Symbol name) {
+ JIT_ASSERT(name.is_attr());
+ values_.erase(findAttr(name, true));
+ return this;
+ }
+ Node* removeAttributeS(const std::string& name) {
+ return removeAttribute(Symbol::attr(name));
+ }
+ bool hasAttributes() const {
+ return values_.size() > 0;
+ }
+ size_t numAttributes() const {
+ return values_.size();
+ }
+ // The names are returned in order, since name actually is the index.
+ std::vector<Symbol> attributeNames() const {
+ std::vector<Symbol> names;
+ for (auto& a : values_)
+ names.push_back(a->name);
+ return names;
+ }
+ std::vector<const char*> attributeNamesS() const {
+ std::vector<const char*> names;
+ for (auto& a : values_)
+ names.push_back(a->name.toUnqualString());
+ return names;
+ }
+
+#define CREATE_ACCESSOR(Kind, method) \
+ Node* method##_(Symbol name, Kind##Attr::ConstructorType v) { \
+ return setAttr<Kind##Attr>( \
+ name, std::forward<Kind##Attr::ConstructorType>(v)); \
+ } \
+ const Kind##Attr::ValueType& method(Symbol name) const { \
+ return getAttr<Kind##Attr>(name); \
+ }
+
+ CREATE_ACCESSOR(Float, f)
+ CREATE_ACCESSOR(Floats, fs)
+ CREATE_ACCESSOR(String, s)
+ CREATE_ACCESSOR(Strings, ss)
+ CREATE_ACCESSOR(Int, i)
+ CREATE_ACCESSOR(Ints, is)
+ CREATE_ACCESSOR(Graph, g)
+ CREATE_ACCESSOR(Graphs, gs)
+
+#undef CREATE_ACCESSOR
+
+ // Our Graphs are not very const-correct, so we need to allow returning
+ // non-const references too
+ GraphAttr::ValueType& g(Symbol name) {
+ return getAttr<GraphAttr>(name);
+ }
+
+ // does not use CREATE_ACCESSOR because we need additional asserts
+ Node* t_(Symbol name, TensorAttr::ConstructorType v) {
+ JIT_ASSERT(!v.defined() || !v.is_variable());
+ return setAttr<TensorAttr>(name, std::forward<TensorAttr::ConstructorType>(v));
+ }
+ const TensorAttr::ValueType& t(Symbol name) const {
+ return getAttr<TensorAttr>(name);
+ }
+
+ Node* ts_(Symbol name, TensorsAttr::ConstructorType v) {
+ for (auto& t : v) {
+ JIT_ASSERT(!t.defined() || !t.is_variable());
+ }
+ return setAttr<TensorsAttr>(
+ name, std::forward<TensorsAttr::ConstructorType>(v));
+ }
+ const TensorsAttr::ValueType& ts(Symbol name) const {
+ return getAttr<TensorsAttr>(name);
+ }
+
private:
+ void printAttrValue(std::ostream& out, const Symbol& name) const;
+ void printAttributes(std::ostream& out, bool ignore_subgraph) const;
+
+ template <typename T>
+ Node* setAttr(Symbol name, typename T::ConstructorType v) {
+ JIT_ASSERT(name.is_attr());
+ auto it = findAttr(name, false);
+ auto nv = AVPtr(new T(name, std::forward<typename T::ConstructorType>(v)));
+ if (it == values_.end()) {
+ values_.push_back(std::move(nv));
+ } else {
+ *it = std::move(nv);
+ }
+ return this;
+ }
+ template <typename T>
+ typename T::ValueType& getAttr(Symbol name) const {
+ JIT_ASSERT(name.is_attr());
+ auto it = findAttr(name, true);
+ auto* child = dynamic_cast<T*>(it->get());
+ if (child == nullptr) {
+ throw AttributeError(name, true);
+ }
+ return child->value();
+ }
+ using AVPtr = AttributeValue::Ptr;
+ // NB: For determinism, we use a vector rather than a hash map. This does
+ // mean that lookups are O(n), so you shouldn't use Attributes to store
+ // a big pile of messages.
+ std::vector<AVPtr> values_;
+ std::vector<AVPtr>::iterator findAttr(Symbol name, bool required) {
+ JIT_ASSERT(name.is_attr());
+ auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) {
+ return v->name == name;
+ });
+ if (required && it == values_.end()) {
+ throw AttributeError(name, false);
+ }
+ JIT_ASSERT(!required || it != values_.end());
+ return it;
+ }
+ std::vector<AVPtr>::const_iterator findAttr(Symbol name, bool required) const {
+ JIT_ASSERT(name.is_attr());
+ auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) {
+ return v->name == name;
+ });
+ if (required && it == values_.end()) {
+ throw AttributeError(name, false);
+ }
+ JIT_ASSERT(!required || it != values_.end());
+ return it;
+ }
+
enum class MoveSide { BEFORE, AFTER };
bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const;
})
.NS(addBlock)
-#define AS(name) def(#name, &Attributes<Node>::name)
+#define AS(name) def(#name, &Node::name)
// methods from Attributes
.AS(copyAttributes)
.AS(hasAttributes)
#undef AS
-#define AS(name) def(#name, &Attributes<Node>::name##S)
+#define AS(name) def(#name, &Node::name##S)
// The default method names take Symbol, but the string conversion for
// Symbol you to qualify with attr::. This is not very user friendly
// for attributes, so expose the string variants instead.