}
program_node& input() const { return get_dependency(0); }
program_node& indices() const { return get_dependency(1); }
- index_select_axis_name get_axis() const { return get_primitive()->axis; }
+ bool get_reverse() const { return get_primitive()->reverse; }
+ std::vector<index_select_axis_name> get_axes() const { return get_primitive()->axis; }
};
using index_select_node = typed_program_node<index_select>;
memory_impl& input() const { return dep_memory(0); }
memory_impl& indices() const { return dep_memory(1); }
- index_select_axis_name get_axis() const { return node.get_axis(); }
+ bool get_reverse() const { return node.get_reverse(); }
+ std::vector<index_select_axis_name> get_axes() const { return node.get_axes(); }
};
using index_select_inst = typed_primitive_inst<index_select>;