return op_context;
}
-// Returns an OpInfo for MatMul with unknown input shapes.
-OpContext DescribeMatMulUnknownShape() {
- OpContext op_context;
- SetCpuDevice(&op_context.op_info);
- op_context.op_info.set_op("MatMul");
-
- auto input = op_context.op_info.add_inputs();
- auto shape = input->mutable_shape();
- shape->set_unknown_rank(true);
-
- input = op_context.op_info.add_inputs();
- shape = input->mutable_shape();
- shape->set_unknown_rank(true);
-
- return op_context;
-}
-
// Wrangles the minimum number of proto fields to set up an input of
// arbitrary rank and type.
void DescribeArbitraryRankInput(const std::vector<int>& dims, DataType dtype,
// optimizations, such as removing nodes that are effectively noops.
class DependencyOptimizer : public GraphOptimizer {
public:
- DependencyOptimizer() : opt_level_(RewriterConfig::ON) {}
- explicit DependencyOptimizer(RewriterConfig::Toggle opt_level)
- : opt_level_(opt_level) {}
+ DependencyOptimizer() {}
+ explicit DependencyOptimizer(RewriterConfig::Toggle opt_level) {}
~DependencyOptimizer() override {}
string name() const override { return "dependency_optimizer"; };
// Main driver of dependency optimizations.
Status OptimizeDependencies();
- RewriterConfig::Toggle opt_level_;
bool fetch_nodes_known_;
std::unordered_set<string> nodes_to_preserve_;
std::unique_ptr<NodeMap> node_map_;
// operations to make the overall graph more efficient.
class FunctionOptimizer : public GraphOptimizer {
public:
- FunctionOptimizer(RewriterConfig::Toggle opt_level) : opt_level_(opt_level) {}
+ FunctionOptimizer(RewriterConfig::Toggle opt_level) {}
~FunctionOptimizer() override {}
string name() const override { return "function_optimizer"; };
void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimized_graph, double result) override;
-
- private:
- RewriterConfig::Toggle opt_level_;
};
} // end namespace grappler
return is_compare;
}
-bool IsLogicalOp(const NodeDef& node) {
- return IsLogicalAnd(node) || IsLogicalNot(node) || IsLogicalOr(node);
-}
-
bool IsReduceOp(const NodeDef& node) {
return IsSum(node) || IsMean(node) || IsProd(node) || IsMax(node) ||
IsMin(node) || IsAll(node) || IsAny(node);