--- /dev/null
+ir_version: 3
+producer_name: "pytorch"
+producer_version: "0.4"
+graph {
+ node {
+ input: "input"
+ output: "1"
+ op_type: "RandomUniformLike"
+ attribute {
+ name: "high"
+ f: 0.333333343
+ type: FLOAT
+ }
+ attribute {
+ name: "low"
+ f: 0.125
+ type: FLOAT
+ }
+ }
+ node {
+ input: "input"
+ input: "1"
+ output: "2"
+ op_type: "PRelu"
+ }
+ name: "torch-jit-export"
+ input {
+ name: "input"
+ type {
+ tensor_type {
+ elem_type: FLOAT
+ shape {
+ dim {
+ dim_value: 1
+ }
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_value: 3
+ }
+ dim {
+ dim_value: 4
+ }
+ }
+ }
+ }
+ }
+ output {
+ name: "2"
+ type {
+ tensor_type {
+ elem_type: FLOAT
+ shape {
+ dim {
+ dim_value: 1
+ }
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_value: 3
+ }
+ dim {
+ dim_value: 4
+ }
+ }
+ }
+ }
+ }
+}
+opset_import {
+ version: 9
+}
x = torch.randn(1, 2, 3, 4)
self.assertONNX(lambda x: torch.randn(1, 2, 3, 4) + x, x)
+ def test_rrelu(self):
+ x = torch.randn(1, 2, 3, 4)
+ self.assertONNX(torch.nn.RReLU(), x)
+
+
if __name__ == '__main__':
no_onnx_dep_flag = '--no-onnx'
_onnx_dep = no_onnx_dep_flag not in common.UNITTEST_ARGS
def _parse_arg(value, desc):
+ if desc == 'none':
+ return value
if desc == 'v' or not _is_value(value):
return value
-
if value.node().kind() != 'onnx::Constant':
raise RuntimeError("ONNX symbolic expected a constant value in the trace")
tval = value.node()['value']
shapes_list = list(shapes)
shape = _maybe_get_const(shapes_list[0], "is")
return g.op('RandomNormal', shape_i=shape)
+
+
+@parse_args('v', 'f', 'f', 'i', 'none')
+def rrelu(g, input, lower, upper, training, generator):
+ p = g.op('RandomUniformLike', input, high_f=upper, low_f=lower)
+ return g.op('PRelu', input, p)