"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
limitations under the License.
"""
-import networkx as nx
import numpy as np
from extensions.middle.FusePermutesSequence import FusePermutesSequence
+from mo.graph.graph import Graph
from mo.middle.replacement import MiddleReplacementPattern
enabled = True
def run_after(self):
- return []
+ from extensions.middle.pass_separator import MiddleStart
+ return [MiddleStart]
def run_before(self):
return [
edges=[]
)
- def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
+ def replace_pattern(self, graph: Graph, match: dict):
weights_node = match['lstm'].in_node(3)
biases_node = match['lstm'].in_node(4)
node = match['lstm']
hidden_size = node.in_node(1).shape[1]
weights = weights_node.value
biases = biases_node.value
- assert weights.shape[0] == input_size + hidden_size, "weights.shape={} input_size={} hidden_size={}".format(
- weights.shape, input_size, hidden_size)
- assert weights.shape[1] == biases.shape[0] == 4 * hidden_size,\
+ assert weights.shape[0] == input_size + hidden_size, \
+ "weights.shape={} input_size={} hidden_size={}".format(weights.shape, input_size, hidden_size)
+ assert weights.shape[1] == biases.shape[0] == 4 * hidden_size, \
"weights.shape={} biases.shape={} hidden_size={}".format(weights.shape, biases.shape, hidden_size)
weights = weights.reshape([