[TENSORFLOW]StatefulPartitionedCall/PartitionedCall Ops support added (#5617)
authorDeepak <59532278+deepakbabel23@users.noreply.github.com>
Thu, 4 Jun 2020 04:33:42 +0000 (10:03 +0530)
committerGitHub <noreply@github.com>
Thu, 4 Jun 2020 04:33:42 +0000 (10:03 +0530)
* Implemented functionInvocation Unit Test for StatefulPartitionedCall operator(working) and initial changes for placeholder(not working as of now)

* Placeholder exercises with tvm

* placeholder interim

* SPOP Test cases structure

* New test cases for spop

* miscellaneous test cases for spop

* Placeholder samples..working with shapes explicitly passed

* Variables test case. Works with the same fix of shape_dict

* SPOP Positive test cases first iteration

* support output tensors as function args, multiple functions

* Corrected Indentation

* filewritter is only for debug purpose

* support variables in function args

* First working iteration of positive spop test cases

* Removed commented code, simplified code

* Code Reorganization- First working iteration of positive spop test cases

* corrected variable name after refactor

* Code Reorganization- First working iteration of positive spop test cases

* move code inside mapped operator function

* Removed extra line

* support variables in function args

* Removed commented code, simplified code

* move code inside mapped operator function

* Code Reorganization- First working iteration of positive spop test cases

# Conflicts:
# tests/python/frontend/tensorflow/test_forward.py

* Code Reorganization- First working iteration of positive spop test cases

* Function invocation more test cases

* Simplified & Merged different Function Invocation Test cases

* support invocation of nested callables

no need to explicitly handle paratitioned and
statefulPartitioned condition in convert_operator function

* Simplified and Uniform testcases

* support invocation of nested callables

no need to explicitly handle paratitioned and
statefulPartitioned condition in convert_operator function

* Simplified and Uniform testcases

* removed duplicate and renamed testcase

* Negative scenario added for testing operator statefulness. Only Exception to stateful operators are Partitioned & StatefulPartitionedOp which have capability to execute even stateless operators within them

* Miscellaneous reorganization changes for spop scenarios

* Miscellaneous reorganization changes for spop scenarios

* Corrected import of tensorflow modules safely using try except and other code reorganization

* Negative scenario for resource variables handled

* Documentation update for code

* SPOP change in function handling

* handle nested subgraph

* refactor

* get op def compatible with tf 1x & 2x

* Fixed liniting issues

* added doctsring and few nits

* Merged changes for positive test cases and negative test cases

* Moved StatefulPartitionedCall test case to the end of the TC list

* Fixed some typos and semantics

* dmlc-core

* dmlc-core

* fixes

* Addressing Review comments in the PR for SPOP support

* Fixed pylint errors

* Corrected tensorflow import syntax

* Placed the op_def_registry module import outside of for loop

* Removed new stateful operators list and combined these operators with missing operators to display as single list. Also removed throwing seperate exception for stateful ops

Co-authored-by: Prashant Sail <psail4444@gmail.com>
Co-authored-by: maheshambule <mahesh_ambule@persistent.com>
python/tvm/relay/frontend/tensorflow.py
tests/python/frontend/tensorflow/test_forward.py

index d4b73f9..201c6ba 100644 (file)
@@ -16,7 +16,7 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except
-# pylint: disable=import-outside-toplevel
+# pylint: disable=import-outside-toplevel, redefined-builtin
 """TF: Tensorflow frontend."""
 import warnings
 from collections import defaultdict
@@ -1927,7 +1927,6 @@ def _add_n():
         return  _res
     return _impl
 
-
 # compatible operators that do NOT require any conversion.
 _identity_list = []
 
@@ -2717,8 +2716,9 @@ class GraphProto(object):
         self._loop_var_order = {}
         self._hash2tfnode = {}
         self._while_loop_name_set = set()
+        self._main_graph_proto = self
 
-    def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
+    def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None):
         """Construct relay nodes from tensorflow graph definition - GraphDef.
 
         Follow the tensorflow graph definition to parse and convert it to Relay.
@@ -2885,6 +2885,13 @@ class GraphProto(object):
 
         out = out[0] if len(out) == 1 else _expr.Tuple(out)
         func = _function.Function(analysis.free_vars(out), out)
+        return func
+
+    def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
+        """ Wrapper to _get_relay_func which converts Tensorflow graph to Relay function
+        which is used as main function for the Relay module
+        """
+        func = self._get_relay_func(graph, layout=layout, shape=shape, outputs=outputs)
         self._mod["main"] = func
         return self._mod, self._params
 
@@ -2895,16 +2902,24 @@ class GraphProto(object):
                 which are not supported
         """
         missing_operators = set()
+        from tensorflow.python.framework import op_def_registry
         for node in graph.node:
+            getOpDef = op_def_registry._registered_ops.get if hasattr(op_def_registry,\
+                        "_registered_ops") else op_def_registry.get
+            op_def = getOpDef(node.op)
             if node.op == "Placeholder" or node.op == 'PlaceholderWithDefault':
                 pass
             elif node.op == "Const":
                 pass
+            elif node.op in ["PartitionedCall", "StatefulPartitionedCall"]:
+                pass
             else:
                 if any([node.op in t for t in [_identity_list, _convert_map,
                                                _convert_map_rnn,
                                                _control_flow_nodes]]):
                     pass
+                elif op_def is not None and op_def.is_stateful:
+                    missing_operators.add(node.op)
                 else:
                     missing_operators.add(node.op)
 
@@ -3149,6 +3164,91 @@ class GraphProto(object):
 
         return op
 
+    def _partition_call_operator(self, inputs, attr):
+        """
+        Convert the Relay Partition call ops into Relay Function calls and
+        function definitions from Tensorflow graph library attribute to Relay global
+        functions
+
+        Parameters
+        ----------
+        node: TensorFlow graph node object.
+            A TensorFlow graph node object.
+
+        inputs : List[tvm.relay.Expr]
+            List of input symbols.
+
+        attrs : Dict[tvm.Attrs]
+            Dict of operator attributes.
+
+        Returns
+        -------
+        op : tvm.relay.Expr
+            Converted relay expression.
+        """
+
+        try:
+            from tensorflow.python.framework import function_def_to_graph
+        except ImportError as e:
+            raise ImportError(
+                "Unable to import tensorflow which is required {}".format(e))
+
+        main_graph_proto = self._main_graph_proto
+        outer_graph_def = main_graph_proto._graph
+
+        node_func_name = attr.get('f').name
+        func = next((f for f in outer_graph_def.library.function
+                     if f.signature.name == node_func_name), None)
+        if func:
+            devices = set(node.device for node in func.node_def)
+            if len(devices) > 1:
+                raise Exception("Found inconsistent Device assignment in the "\
+                                "Stateful Partitioned SubGraph. Rejecting "\
+                                "the subgraph ")
+            # Convert function definition to graph
+            func_input_shapes = func.attr["_input_shapes"].list.shape
+            subgraph, _ = function_def_to_graph.\
+                function_def_to_graph_def(func, func_input_shapes)
+
+            # Computing subgraph's input shape dictionary
+            subgraph_shape_dict, input_expr_dict = {}, {}
+            for f_arg, input in zip(func.signature.input_arg, inputs):
+                input_expr_dict[f_arg.name] = input
+                subgraph_shape_dict[f_arg.name] = _infer_shape(input, main_graph_proto._mod)
+
+            func_name = 'func_{}'.format(func.signature.name)
+            try:
+                global_func = main_graph_proto._mod[func_name]
+                sub_func = global_func
+                sub_params = main_graph_proto._params
+            except ValueError:
+                # Construct relay nodes from the subgraph
+                g1 = SubGraphProto(main_graph_proto)
+                sub_func, sub_params = g1.from_tensorflow(subgraph, shape=subgraph_shape_dict)
+                main_graph_proto._params.update(sub_params)
+                func_expr = _function.Function(sub_func.params, sub_func.body)
+                global_func = tvm.relay.GlobalVar(func_name)
+                main_graph_proto._mod[global_func] = func_expr
+
+            param_exprs = []
+            for param_expr in sub_func.params:
+                # sub_params is subset of sub_func.params
+                param_name = param_expr.vid.name_hint
+                if param_name in input_expr_dict.keys():
+                    param_exprs.append(input_expr_dict[param_name])
+                elif param_name in sub_params.keys():
+                    param_exprs.append(param_expr)
+                else:
+                    raise Exception("Input parameter {} not found".format(param_name))
+
+            sb = tvm.relay.scope_builder.ScopeBuilder()
+            loop_ret = global_func(*param_exprs)
+            sb.ret(loop_ret)
+            ret = sb.get()
+        else:
+            raise Exception("Function not found - {}".format(node_func_name))
+        return ret
+
     def _convert_operator(self, op_name, inputs, attrs,
                           graph, identity_list=None, convert_map=None):
         """Convert from Tensorflow operator to relay operator.
@@ -3190,6 +3290,9 @@ class GraphProto(object):
             sym = self._convert_rnn_operator(op_name, inputs, attrs,
                                              self._params, graph,
                                              convert_map_rnn)
+
+        elif op_name in ["PartitionedCall", "StatefulPartitionedCall"]:
+            sym = self._partition_call_operator(inputs, attrs)
         else:
             raise NotImplementedError("Operator {} not implemented.".format(op_name))
         return sym
@@ -3253,6 +3356,22 @@ class GraphProto(object):
 
         return out[0]
 
+
+class SubGraphProto(GraphProto):
+    """ A helper class for handling relay subgraph copying from Tensorflow GraphDef.
+    """
+    def __init__(self, main_graph_proto):
+        super().__init__()
+        self._main_graph_proto = main_graph_proto  # holds main graph proto object
+
+    def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
+        """ Wrapper to _get_relay_func which converts Tensorflow graph to Relay function.
+        Return Relay function and params
+        """
+        func = self._get_relay_func(graph, layout=layout, shape=shape, outputs=outputs)
+        return func, self._params
+
+
 def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
     """Load tensorflow graph which is a python tensorflow graph object into relay.
     The companion parameters will be handled automatically.
@@ -3279,6 +3398,7 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
     params : dict of str to tvm.nd.NDArray
         Dict of converted parameters stored in tvm.nd.NDArray format
     """
+
     g = GraphProto()
     mod, params = g.from_tensorflow(graph, layout, shape, outputs)
     return mod, params
index 89a0335..f9fc5dd 100644 (file)
@@ -36,6 +36,10 @@ from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
 from tensorflow.python.ops import init_ops
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import gen_functional_ops
 from distutils.version import LooseVersion
 import tvm
 from tvm import te
@@ -176,6 +180,7 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
         if init_global_variables:
             sess.run(variables.global_variables_initializer())
         final_graph_def = tf_testing.AddShapesToGraphDef(sess, out_node)
+
         tf_output = run_tf_graph(sess, in_data, in_name, out_name)
 
         for device in ["llvm", "cuda"]:
@@ -1138,13 +1143,13 @@ def test_read_variable_op():
         tf_output = run_tf_graph(sess, in_data, in_name, out_name)
 
         shape_dict = {e: i.shape for e, i in zip(in_name, in_data)}
-        with pytest.raises(Exception) as exexcinfo:
+        with pytest.raises(Exception) as execinfo:
             mod, params = relay.frontend.from_tensorflow(final_graph_def,
                                                          layout=None,
                                                          shape=shape_dict,
                                                          outputs=None)
 
-        assert exexcinfo.value.args[0].startswith("Graph is not frozen. Provide a frozen graph.")
+        assert execinfo.value.args[0].startswith("Graph is not frozen. Provide a frozen graph")
 
         # Now convert the variables to constant and run inference on the converted graph
         final_graph_def = tf.graph_util.convert_variables_to_constants(
@@ -3195,10 +3200,342 @@ def test_forward_isfinite():
     _verify_infiniteness_ops(tf.is_finite, "isfinite")
 
 
+def _test_spop_placeholder_without_shape_info():
+    with tf.Graph().as_default():
+
+        @function.Defun(*[tf.int32]*2)
+        def Forward(x,y):
+            print(x.name)
+            print(y.name)
+            b = tf.add(x, y)
+            return b
+        pl1 = tf.placeholder(tf.int32,name="pl1")
+        pl2 = tf.placeholder(tf.int32,name="pl2")
+        pl3 = tf.placeholder(tf.int32, name="pl3")
+        data = np.array([[-1, 1], [2, -2]], dtype=np.int32)
+        data2 = np.array([[-2, 3], [4, -6]], dtype=np.int32)
+        data3 = np.array([[-2, 3], [4, -6]], dtype=np.int32)
+        z1 = gen_functional_ops.StatefulPartitionedCall(args=[pl1,pl2], Tout=[tf.int32],f=Forward)
+        z2 = z1 + pl3
+        compare_tf_with_tvm([data, data2, data3], ['pl1:0', 'pl2:0', 'pl3:0'],
+                            ['StatefulPartitionedCall:0',z2.name],  mode='vm', init_global_variables=True)
+
+
+def _test_spop_placeholder_with_shape_and_default_value():
+    with tf.Graph().as_default():
+        data = np.ones([1], dtype=int).astype(np.int32)
+        dataVar = tf.Variable(data, shape=data.shape)
+        pl1 = array_ops.placeholder_with_default(dataVar,shape=data.shape,name="pl1")
+        tpl = tf.convert_to_tensor(pl1, dtype=tf.int32)
+
+        @function.Defun(*[tf.int32])
+        def pl_with_default(pl):
+            return tf.expand_dims(tf.multiply(pl, pl), 0)
+
+        z = gen_functional_ops.StatefulPartitionedCall(args=[tpl], Tout=[tf.int32], f=pl_with_default)
+        compare_tf_with_tvm(data, ['pl1:0'], 'StatefulPartitionedCall:0', mode='vm', init_global_variables=True)
+
+
+def _test_spop_placeholder_numpy_arange_feed():
+    with tf.Graph().as_default():
+        t1 = tf.placeholder(tf.int32, (3, 3, 3), "t1")
+        t1_data = np.arange(27, dtype=np.int32).reshape((3, 3, 3))
+        t2 = tf.placeholder(tf.int32, (3, 3, 3), "t2")
+        t2_data = np.arange(27, dtype=np.int32).reshape((3, 3, 3))
+
+        @tf.function
+        def add(x, y):
+            return tf.add(x, y, "add_t1_t2")
+
+        t3 = add(t1, t2)
+        compare_tf_with_tvm([t1_data, t2_data], ['t1:0', 't2:0'], [t3.name], mode='vm', init_global_variables=True)
+
+
+def _test_spop_placeholder_numpy_array_feed():
+    with tf.Graph().as_default():
+        t1_data = np.array([[-1, 1, 3], [2, -2, 4], [2, -3, 14]], dtype=np.int32)
+        t2_data = np.array([[-2, 1, 2], [12, -2, 14], [12, -3, 4]], dtype=np.int32)
+        t1 = tf.placeholder(tf.int32, name="t1")
+        t2 = tf.placeholder(tf.int32, name="t2")
+
+        @tf.function
+        def add(x, y):
+            return tf.add(x, y, "add_t1_t2")
+
+        t3 = add(t1, t2)
+        compare_tf_with_tvm([t1_data, t2_data], ['t1:0', 't2:0'], [t3.name], mode='vm', init_global_variables=True)
+
+
+def _test_spop_function_invocation_basic():
+    with tf.Graph().as_default():
+
+        def fun1(a):
+            return tf.multiply(a,a)
+
+        def fun2(b):
+            return tf.multiply(b,10)
+
+        @tf.function
+        def fun3(x,y):
+            x = fun2(x)
+            y = fun1(y)
+            z = tf.add(x,y)
+            return z
+
+        t3 = fun3(tf.constant(10.5), tf.constant(20.4))
+
+        compare_tf_with_tvm([], [], [t3.name], mode='vm', init_global_variables=True)
+
+
+def _test_spop_function_invocation_nested():
+    with tf.Graph().as_default():
+        t1 = tf.placeholder(tf.int32, (3, 3, 3), name="t1")
+        t1_data = np.arange(27, dtype=np.int32).reshape((3, 3, 3))
+        t2 = tf.placeholder(tf.int32, name="t2")
+        t2_data = np.arange(27, dtype=np.int32).reshape((3, 3, 3))
+
+        @tf.function
+        def myfunc(x, y):
+            return tf.add(x, y, "myfunc")
+
+        @tf.function
+        def myfunc2(x, y):
+            z = myfunc(x, y)
+            l = myfunc(z, y)
+            m = myfunc(l,z)
+            return tf.add(l, m, "myfunc2")
+
+        res1 = myfunc(t1, t2)
+        res2 = myfunc2(res1, t1)
+
+        compare_tf_with_tvm([t1_data, t2_data], ['t1:0', 't2:0'], [res2.name], mode='vm', init_global_variables=True)
+
+
+def _test_spop_function_invocation_no_autograph():
+    with tf.Graph().as_default():
+
+        @tf.function(autograph=False)
+        def fun1(a):
+            return tf.multiply(a,a)
+
+        @tf.function(autograph=False)
+        def fun2(b):
+            return tf.multiply(b,10)
+
+        @tf.function
+        def fun3(x,y):
+            x = fun2(x)
+            y = fun1(y)
+            z = tf.add(x,y)
+            return z
+
+        t3 = fun3(tf.constant(10.5), tf.constant(20.4))
+
+        compare_tf_with_tvm([], [], [t3.name], mode='vm', init_global_variables=True)
+
+
+def _test_spop_function_invocation_defun():
+    with tf.Graph().as_default():
+
+        def fun1(a):
+            return tf.multiply(a,a)
+
+        def fun2(b):
+            return tf.multiply(b,b)
+
+        @function.Defun(dtypes.float32, dtypes.float32, func_name="Fun3")
+        def fun3(x,y):
+            x = fun2(x)
+            y = fun1(y)
+            z = tf.add(x,y)
+            return z
+
+        op = gen_functional_ops.StatefulPartitionedCall(args=[tf.constant(10.5),tf.constant(20.4)],
+                                                        Tout=[dtypes.float32], f=fun3, name="SpopFnInvocation")
+        compare_tf_with_tvm([],[], 'SpopFnInvocation:0', mode='vm', init_global_variables=True)
+
+
+def _test_spop_arithmetic():
+    with tf.Graph().as_default():
+        @function.Defun(*[dtypes.int32]*3)
+        def arithmetic(m,x,c):
+            z = tf.add(tf.multiply(m, x), c)
+            return z
+
+        m = tf.constant(10)
+        x = tf.constant(20)
+        c = tf.constant(2)
+        spopFn = gen_functional_ops.StatefulPartitionedCall(args=[m,x,c],Tout=[tf.int32], f=arithmetic)
+
+        compare_tf_with_tvm([],[],'StatefulPartitionedCall:0', mode='vm', init_global_variables=True)
+
+
+def _test_spop_control_flow():
+    with tf.Graph().as_default():
+
+        @function.Defun(*[dtypes.float32] * 2)
+        def Body1(x, y):
+            with ops.device("/job:localhost/replica:0/task:0/device:CPU:0"):
+                z = math_ops.multiply(x, y)
+                i = 0
+                while i<10 :
+                    i +=1
+                    if i == 5:
+                        continue
+                    z = math_ops.multiply(x, y*i)
+            return z
+
+        op = gen_functional_ops.StatefulPartitionedCall(
+            args=[constant_op.constant(32.), constant_op.constant(100.)],
+            Tout=[dtypes.float32], f=Body1)
+        compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0', mode='vm', init_global_variables=True)
+
+
+def _test_spop_variables():
+    with tf.Graph().as_default():
+        const1 = tf.constant(10)
+        const2 = tf.constant(20)
+        var1 = tf.Variable(const1, dtype=tf.int32)
+        var2 = tf.Variable(const2, dtype=tf.int32)
+
+        @function.Defun(tf.int32,tf.int32)
+        def Forward(x,y):
+            return tf.multiply(x,y)
+
+        z = gen_functional_ops.StatefulPartitionedCall(args=[var1,var2],Tout=[tf.int32], f=Forward)
+        compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0', init_global_variables=True, mode="vm")
+
+
+def _test_spop_constants():
+    with tf.Graph().as_default():
+        @function.Defun(*[dtypes.int32] * 2)
+        def constantsFn(x, y):
+            vv = tf.constant([2, 3, 4], name="vv")
+            z = tf.add(vv + x, y)
+            return z
+
+        a = tf.constant(20000, name = "a")
+        b = tf.constant(40000, name = "b")
+        spopFn = gen_functional_ops.StatefulPartitionedCall(args=[a, b], Tout=[tf.int32], f=constantsFn)
+
+        compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0', mode='vm', init_global_variables=True)
+
+
+def _test_spop_stateful():
+    # This test case is to test that TVM rejects any TF stateful operations
+    # (including Resource Variables) except StatefulPartitionedCall/PartitionedCall
+    # (as these two operators can still be used as container graphs to execute
+    # "stateless" operations internally.
+    tf.reset_default_graph()
+    with tf.Graph().as_default():
+
+        @tf.function
+        def FunctionWithStatefulOp_One(i):
+            b = tf.random.uniform(shape=[2, 4], maxval=10, dtype=tf.float32, seed=10)
+            y = tf.multiply(b, i)
+            return y
+
+        @tf.function
+        def FunctionWithStatefulOp(m, n):
+            a = tf.random.uniform(shape=[2, 4], maxval=10, dtype=tf.float32, seed = 10)
+            x = tf.multiply(a,m)
+            y = FunctionWithStatefulOp_One(n)
+            z = tf.multiply(x,y)
+            return z
+
+        op = FunctionWithStatefulOp(constant_op.constant(1.), constant_op.constant(2.))
+        with pytest.raises(Exception) as execinfo:
+            compare_tf_with_tvm([], [], [op.name], init_global_variables=True, mode="vm")
+        assert execinfo.value.args[0].startswith(
+            "The following operators are not implemented")
+
+
+def _test_spop_device_assignment():
+    # This test case is to test that TVM rejects inconsistent device assignment
+    # while using StatefulPartitionedCall/PartitionedCall operators which in case of TVM will
+    # be used as container graphs to internally execute "stateless" operations.
+
+    tf.reset_default_graph()
+    with tf.Graph().as_default():
+
+        def fun1(a):
+            with ops.device("/GPU:0"):
+                return tf.multiply(a,a)
+
+        def fun2(b):
+            with ops.device("/job:localhost/replica:0/task:0/device:CPU:1"):
+                return tf.multiply(b,b)
+
+        @function.Defun(dtypes.float32, dtypes.float32, func_name="Fun3")
+        def fun3(x,y):
+            with ops.device("/CPU:0"):
+                x = fun2(x)
+            with ops.device("/job:localhost/replica:0/task:0/device:CPU:2"):
+                y = fun1(y)
+            with ops.device("/job:localhost/replica:0/task:0/device:CPU:3"):
+                z = tf.add(x,y)
+                return z
+
+        op = gen_functional_ops.StatefulPartitionedCall(args=[tf.constant(10.5),tf.constant(20.4)],
+                                                        Tout=[dtypes.float32], f=fun3)
+        with pytest.raises(Exception) as execinfo:
+            compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0',
+                                mode='vm', init_global_variables=True)
+        assert execinfo.value.args[0].startswith("Found inconsistent Device assignment")
+
+
+def _test_spop_resource_variables():
+    # This test case is to test that TVM rejects any graph containing
+    # resource variables with StatefulPartitionedOp.
+
+    tf.reset_default_graph()
+    with tf.Graph().as_default():
+
+        const1 = tf.constant(10)
+        const2 = tf.constant(20)
+        var1 = tf.Variable(const1, dtype=tf.int32, use_resource=True)
+        var2 = tf.Variable(const2, dtype=tf.int32, use_resource=True)
+
+        @tf.function
+        def resourceVariablesTest(x, y):
+            return tf.multiply(x, y)
+
+        op = resourceVariablesTest(var1,var2)
+        with pytest.raises(Exception) as execinfo:
+            compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0',
+                                mode='vm', init_global_variables=True)
+        assert execinfo.value.args[0].startswith("Graph is not frozen."
+                                                 " Provide a frozen graph")
+
+def test_forward_spop():
+    _test_spop_stateful()
+    _test_spop_device_assignment()
+    _test_spop_resource_variables()
+
+    #Placeholder test cases
+    _test_spop_placeholder_without_shape_info()
+    _test_spop_placeholder_with_shape_and_default_value()
+    _test_spop_placeholder_numpy_arange_feed()
+    _test_spop_placeholder_numpy_array_feed()
+
+    #Function Invocation test cases
+    _test_spop_function_invocation_basic()
+    _test_spop_function_invocation_nested()
+    _test_spop_function_invocation_no_autograph()
+    _test_spop_function_invocation_defun()
+
+    #Test cases for various other TF constructs
+    _test_spop_arithmetic()
+    _test_spop_control_flow()
+    _test_spop_variables()
+    _test_spop_constants()
+
+
 #######################################################################
 # Main
 # ----
 if __name__ == '__main__':
+
     # Transforms
     test_forward_slice()
     test_forward_transpose()
@@ -3323,3 +3660,6 @@ if __name__ == '__main__':
 
     # Sharing params case using Mean ops
     test_sharing_node()
+
+    # StatefulPartitionedCall
+    test_forward_spop()