caffe2/caffe2/contrib/script (#15007)
authorJerry Zhang <jerryzh@fb.com>
Mon, 10 Dec 2018 22:17:43 +0000 (14:17 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 10 Dec 2018 22:23:31 +0000 (14:23 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15007

Pull Request resolved: https://github.com/pytorch/pytorch/pull/14979

att

Reviewed By: dzhulgakov

Differential Revision: D13286191

fbshipit-source-id: b8a6bc7aea44487aea4dcf7f44c858fd30c6293c

15 files changed:
caffe2/contrib/CMakeLists.txt
caffe2/contrib/script/CMakeLists.txt [deleted file]
caffe2/contrib/script/caffe2_script_test.py [deleted file]
caffe2/contrib/script/compiler.cc [deleted file]
caffe2/contrib/script/compiler.h [deleted file]
caffe2/contrib/script/error_report.h [deleted file]
caffe2/contrib/script/examples/example_beam_search.c2s [deleted file]
caffe2/contrib/script/examples/example_post_eos_penalty.c2s [deleted file]
caffe2/contrib/script/examples/run_examples.py [deleted file]
caffe2/contrib/script/lexer.cc [deleted file]
caffe2/contrib/script/lexer.h [deleted file]
caffe2/contrib/script/parser.h [deleted file]
caffe2/contrib/script/tree.h [deleted file]
caffe2/contrib/script/tree_views.h [deleted file]
caffe2/python/pybind_state.cc

index be8c0bd..6034e4d 100644 (file)
@@ -4,7 +4,6 @@ add_subdirectory(nccl)
 add_subdirectory(opencl)
 add_subdirectory(prof)
 add_subdirectory(shm_mutex)
-add_subdirectory(script)
 if (USE_TENSORRT)
 add_subdirectory(tensorrt)
 endif()
diff --git a/caffe2/contrib/script/CMakeLists.txt b/caffe2/contrib/script/CMakeLists.txt
deleted file mode 100644 (file)
index fb38787..0000000
+++ /dev/null
@@ -1,16 +0,0 @@
-# ---[ CPU files.
-file(GLOB tmp *.cc)
-set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${tmp})
-# exclude test files and gpu files
-file(GLOB tmp *_test.cc)
-exclude(Caffe2_CPU_SRCS "${Caffe2_CPU_SRCS}" ${tmp})
-exclude(Caffe2_CPU_SRCS "${Caffe2_CPU_SRCS}" ${Caffe2_GPU_SRCS})
-
-# ---[ CPU test files
-file(GLOB tmp *_test.cc)
-set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} ${tmp})
-exclude(Caffe2_CPU_TEST_SRCS "${Caffe2_CPU_TEST_SRCS}" ${Caffe2_GPU_TEST_SRCS})
-
-# ---[ Send the lists to the parent scope.
-set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE)
-set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} PARENT_SCOPE)
diff --git a/caffe2/contrib/script/caffe2_script_test.py b/caffe2/contrib/script/caffe2_script_test.py
deleted file mode 100644 (file)
index d9f0b65..0000000
+++ /dev/null
@@ -1,520 +0,0 @@
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-from __future__ import unicode_literals
-
-from hypothesis import given
-
-from caffe2.python import core, workspace
-from caffe2.proto import caffe2_pb2
-import caffe2.python.hypothesis_test_util as hu
-import hypothesis.strategies as st
-
-import numpy as np
-
-
-def feed_inputs(inputs):
-    for name, value in inputs.items():
-        workspace.FeedBlob(name, value)
-
-
-def assert_proto_equals(proto, expected):
-    proto_lines = proto.strip().split('\n')
-    expected_lines = expected.strip().split('\n')
-    assert len(proto_lines) == len(expected_lines), \
-        '{} != {}'.format(proto, expected)
-    for left, right in zip(proto_lines, expected_lines):
-        assert left.strip() == right.strip(), \
-            '{} != {}'.format(proto, expected)
-
-
-class TestCaffe2Script(hu.HypothesisTestCase):
-    test_program = """
-          def foo(a,b,X,W) -> (c):
-              t = a + b*b
-              c = FC(X,W,t)
-          def testIf(c0,c1,t,f) -> (r):
-              if c0 < c1:
-                  r = t
-              else:
-                  r = f
-              r = Add(r,3f,broadcast=1)
-          def testWhile(r) -> (r):
-              m = 0
-              while m < 4:
-                  # Plus operator automatically broadcasts, and we cannot
-                  # do in-place B and C arguments when we broadcast, so use
-                  # an explicit Add op.
-                  r = Add(r, r)
-                  m = m + 1
-      """
-
-    @given(firstdim=st.integers(min_value=1, max_value=4096),
-           seconddim=st.integers(min_value=1, max_value=4096),
-           seed=st.integers(min_value=0, max_value=65536),
-           **hu.gcs)
-    def test_foo(self, firstdim, seconddim, seed, gc, dc):
-        np.random.seed(int(seed))
-        inputs = {}
-        a = inputs['a'] = np.random.rand(seconddim).astype(np.float32)
-        b = inputs['b'] = np.random.rand(seconddim).astype(np.float32)
-        X = inputs['X'] = np.random.rand(firstdim, firstdim).astype(np.float32)
-        W = inputs['W'] = np.random.rand(seconddim, firstdim).astype(np.float32)
-
-        feed_inputs(inputs)
-
-        CU = core.C.CompilationUnit()
-        CU.define(self.test_program)
-        CU.create_net('foo').run()
-
-        ref_t = a + b * b
-        ref_c = np.matmul(X, W.transpose()) + ref_t
-        actual_c = workspace.FetchBlob('c')
-
-        np.testing.assert_allclose(actual_c, ref_c, rtol=1e-05)
-
-    def test_trinary(self):
-        CU = core.C.CompilationUnit()
-        CU.define("""
-            def foo(c) -> (d):
-                d = 1 + (2 if c else 4)
-        """)
-        workspace.FeedBlob('c', np.ones((1), dtype=bool))
-        net = CU.create_net('foo')
-        net.run()
-        assert(3 == workspace.FetchBlob('d'))
-        workspace.FeedBlob('c', np.zeros((1), dtype=bool))
-        net.run()
-        assert(5 == workspace.FetchBlob('d'))
-
-    def test_bool_literal(self):
-        CU = core.C.CompilationUnit()
-        CU.define("""
-            def foo() -> (a,b):
-                a = True
-                b = False
-        """)
-        net = CU.create_net('foo')
-        net.run()
-        assert(workspace.FetchBlob('a'))
-        assert(not workspace.FetchBlob('b'))
-
-    def test_bool_operators(self):
-        CU = core.C.CompilationUnit()
-        CU.define("""
-            def foo() -> (a, b, c, d, e):
-                a = True and False
-                b = True or False
-                c = not b
-                d = not False or True
-                e = not (1 if a else 0) == (1 if b else 0)
-        """)
-        net = CU.create_net('foo')
-        net.run()
-        assert(not workspace.FetchBlob('a'))
-        assert(workspace.FetchBlob('b'))
-        assert(not workspace.FetchBlob('c'))
-        assert(workspace.FetchBlob('d'))
-        assert(workspace.FetchBlob('e'))
-
-    def expect_fail(self, fn, msg):
-        try:
-            fn()
-        except RuntimeError as r:
-            if msg not in str(r):
-                raise RuntimeError(
-                    "Failed wrong: expected string '{}' ".format(msg) +
-                    "in error message but found\n{}".format(str(r)))
-
-    def test_fails(self):
-        def fail_inputs():
-            CU = core.C.CompilationUnit()
-            CU.define("""
-                def foo() -> ():
-                    Print(1,4)
-            """)
-        self.expect_fail(fail_inputs, "expects 1 inputs but found 2")
-
-        def fail_undef():
-            CU = core.C.CompilationUnit()
-            CU.define("""
-                def foo(a) -> (b):
-                    a = what()
-            """)
-        self.expect_fail(fail_undef, "attempting to call unknown operation")
-
-        def fail_schema():
-            CU = core.C.CompilationUnit()
-            CU.define("""
-                def foo(a) -> (b):
-                    a = FC(a,a,a)
-            """)
-        self.expect_fail(fail_schema, "failed schema checking")
-
-    def test_print(self):
-        CU = core.C.CompilationUnit()
-        CU.define("""
-            def foo() -> ():
-                a = 1
-                Print(a)
-                Print(a+1)
-                _ = 4
-                Print(_) # verify in print this isn't _ but some temorary
-                Print(1)
-                Print(1.f)
-                Print(3.0)
-        """)
-        net = CU.create_net('foo')
-        net.run()
-
-    def test_method(self):
-        CU = core.C.CompilationUnit()
-        CU.define("""
-            def foo() -> (a):
-                a = (3+1).Add(4).Add(1)
-        """)
-        net = CU.create_net('foo')
-        net.run()
-        assert(9 == workspace.FetchBlob('a'))
-
-    def test_plus_eq(self):
-        CU = core.C.CompilationUnit()
-        CU.define("""
-            def foo() -> (a):
-                a = 4
-                a += 1
-        """)
-        net = CU.create_net('foo')
-        net.run()
-        assert(5 == workspace.FetchBlob('a'))
-
-    def test_cast(self):
-        CU = core.C.CompilationUnit()
-        CU.define("""
-            def foo() -> (a):
-                a = int(4.5f)
-        """)
-        net = CU.create_net('foo')
-        net.run()
-        assert(4 == workspace.FetchBlob('a'))
-
-    def test_global(self):
-        CU = core.C.CompilationUnit()
-        CU.define("""
-            def foo() -> (a):
-                global m
-                m.a = 4
-                m.b = 5
-                a = m.a + m.b
-        """)
-        net = CU.create_net('foo')
-        net.run()
-        assert(9 == workspace.FetchBlob('a'))
-
-    def test_module_as_arg_ret(self):
-        CU = core.C.CompilationUnit()
-        CU.define("""
-            def bar(a,c) -> (b):
-                b = Module()
-                temp = a.second
-                b.first = temp
-                b.second = a.first + c
-            def foo() -> (a,b):
-                x = Module()
-                x.first = 1
-                x.second = 2
-                x.y = bar(x,4)
-                a = x.y.first
-                b = x.y.second
-        """)
-        net = CU.create_net('foo')
-        net.run()
-        assert(2 == workspace.FetchBlob('a'))
-        assert(5 == workspace.FetchBlob('b'))
-
-    def test_call_extern(self):
-        CU = core.C.CompilationUnit()
-        net = caffe2_pb2.NetDef()
-        net.op.extend([
-            core.CreateOperator(
-                'Mul',
-                ['i', 'i'],
-                ['o'],
-            )
-        ])
-        net.external_input.append('i')
-        net.external_output.append('o')
-
-        CU.extern("myActualExtern", net)
-        CU.define("""
-            def myExtern(x) -> (y):
-                t = x
-                if t > 1:
-                    y = t * t
-                else:
-                    y = 5
-            def foo() -> (b):
-                a = 4
-                a += 1
-                b = 2 + myExtern(a) + myExtern(a, rename=False) + myActualExtern(a)
-        """)
-        net = CU.create_net('foo')
-        net.run()
-        assert(77 == workspace.FetchBlob('b'))
-
-    @given(seed=st.integers(min_value=0, max_value=65536), **hu.gcs)
-    def test_if(self, seed, gc, dc):
-        np.random.seed(int(seed))
-        inputs = {}
-        c0 = inputs['c0'] = np.random.rand(1).astype(np.float32)
-        c1 = inputs['c1'] = np.random.rand(1).astype(np.float32)
-        t = inputs['t'] = np.random.rand(3, 3).astype(np.float32)
-        f = inputs['f'] = np.random.rand(3, 3).astype(np.float32)
-
-        feed_inputs(inputs)
-
-        CU = core.C.CompilationUnit()
-        CU.define(self.test_program)
-        CU.create_net('testIf').run()
-
-        if c0 < c1:
-            ref_r = t + 3
-        else:
-            ref_r = f + 3
-        actual_r = workspace.FetchBlob('r')
-
-        np.testing.assert_allclose(actual_r, ref_r)
-
-    @given(seed=st.integers(min_value=0, max_value=65536), **hu.gcs)
-    def test_while(self, seed, gc, dc):
-        np.random.seed(int(seed))
-        inputs = {}
-        r = inputs['r'] = np.ones([3, 3]).astype(np.float32)
-
-        feed_inputs(inputs)
-
-        CU = core.C.CompilationUnit()
-        CU.define(self.test_program)
-        CU.create_net('testWhile').run()
-
-        m = 0
-        while m < 4:
-            r = r + r
-            m = m + 1
-
-        actual_r = workspace.FetchBlob('r')
-
-        np.testing.assert_allclose(actual_r, r)
-
-    @given(seed=st.integers(min_value=0, max_value=65536), **hu.gcs)
-    def test_gather(self, seed, gc, dc):
-        CU = core.C.CompilationUnit()
-        CU.define("""
-        def easy(tensor, indices) -> (output):
-            output = tensor[indices]
-        def hard(tensor, i, j, k) -> (output):
-            output = tensor[i][j][k]
-        """)
-
-        # First check that the generated proto is as expected. This tests that
-        # we desugar the gather syntax correctly and emit the right code.
-        proto = CU.get_proto('easy')
-        assert_proto_equals(proto, """
-            name: "easy"
-            op {
-              input: "tensor"
-              input: "indices"
-              output: "output"
-              type: "Gather"
-            }""")
-
-        proto = CU.get_proto('hard')
-        assert_proto_equals(proto, """
-            name: "hard"
-            op {
-              input: "tensor"
-              input: "i"
-              output: "$t1"
-              type: "Gather"
-            }
-            op {
-              input: "$t1"
-              input: "j"
-              output: "$t0"
-              type: "Gather"
-            }
-            op {
-              input: "$t0"
-              input: "k"
-              output: "output"
-              type: "Gather"
-            }""")
-
-        # Now just test that the effect of the generated code is as expected.
-        np.random.seed(int(seed))
-        tensor = np.random.rand(5, 4, 3).astype(np.float32)
-        indices = np.random.randint(len(tensor), size=(5, 5))
-
-        feed_inputs(dict(tensor=tensor, indices=indices))
-
-        net = CU.create_net('easy')
-        net.run()
-
-        output = workspace.FetchBlob('output')
-        expected_output = [tensor[sample] for sample in indices]
-        np.testing.assert_allclose(output, expected_output)
-
-    @given(seed=st.integers(min_value=0, max_value=65536), **hu.gcs)
-    def test_slice(self, seed, gc, dc):
-        CU = core.C.CompilationUnit()
-        CU.define("""
-        def slice_from_tensor(tensor, start, end) -> (output):
-            output = tensor[start:end]
-        def slice_from_vector(vector, start, end) -> (a, b, c, d):
-            a = vector[start:end]
-            b = vector[start:]
-            c = vector[:end]
-            d = vector[:]
-        """)
-
-        # slice_from_tensor
-        proto = CU.get_proto('slice_from_tensor')
-        assert_proto_equals(proto, """
-            name: "slice_from_tensor"
-            op {
-              input: "tensor"
-              input: "start"
-              input: "end"
-              output: "output"
-              type: "Slice"
-            }""")
-
-        np.random.seed(int(seed))
-        tensor = np.random.rand(5, 4, 3).astype(np.float32)
-        start = np.array([0, 1, 0], dtype=np.int32)
-        end = np.array([-1, 2, -1], dtype=np.int32)
-
-        feed_inputs(dict(tensor=tensor, start=start, end=end))
-
-        net = CU.create_net('slice_from_tensor')
-        net.run()
-
-        output = workspace.FetchBlob('output')
-        np.testing.assert_allclose(output, tensor[:, 1:2])
-
-        # slice_from_vector
-        proto = CU.get_proto('slice_from_vector')
-        assert_proto_equals(proto, """
-            name: "slice_from_vector"
-            op {
-              input: "vector"
-              input: "start"
-              input: "end"
-              output: "a"
-              type: "Slice"
-            }
-            op {
-              output: "$t0"
-              type: "ConstantFill"
-              arg {
-                name: "dtype"
-                i: 2
-              }
-              arg {
-                name: "value"
-                i: -1
-              }
-              arg {
-                name: "shape"
-                ints: 1
-              }
-            }
-            op {
-              input: "vector"
-              input: "start"
-              input: "$t0"
-              output: "b"
-              type: "Slice"
-            }
-            op {
-              output: "$t1"
-              type: "ConstantFill"
-              arg {
-                name: "dtype"
-                i: 2
-              }
-             arg {
-                name: "value"
-                i: 0
-              }
-              arg {
-                name: "shape"
-                ints: 1
-              }
-            }
-            op {
-              input: "vector"
-              input: "$t1"
-              input: "end"
-              output: "c"
-              type: "Slice"
-            }
-            op {
-              output: "$t2"
-              type: "ConstantFill"
-              arg {
-                name: "dtype"
-                i: 2
-              }
-             arg {
-                name: "value"
-                i: 0
-              }
-              arg {
-                name: "shape"
-                ints: 1
-              }
-            }
-            op {
-              output: "$t3"
-              type: "ConstantFill"
-              arg {
-                name: "dtype"
-                i: 2
-              }
-             arg {
-                name: "value"
-                i: -1
-              }
-              arg {
-                name: "shape"
-                ints: 1
-              }
-            }
-            op {
-              input: "vector"
-              input: "$t2"
-              input: "$t3"
-              output: "d"
-              type: "Slice"
-            }""")
-
-        vector = np.random.rand(10).astype(np.float32)
-        start = np.array([2], dtype=np.int32)
-        end = np.array([6], dtype=np.int32)
-        feed_inputs(dict(vector=vector, start=start, end=end))
-
-        net = CU.create_net('slice_from_vector')
-        net.run()
-
-        output = workspace.FetchBlob('a')
-        np.testing.assert_allclose(output, vector[2:6])
-
-        output = workspace.FetchBlob('b')
-        np.testing.assert_allclose(output, vector[2:])
-
-        output = workspace.FetchBlob('c')
-        np.testing.assert_allclose(output, vector[:6])
-
-        output = workspace.FetchBlob('d')
-        np.testing.assert_allclose(output, vector)
diff --git a/caffe2/contrib/script/compiler.cc b/caffe2/contrib/script/compiler.cc
deleted file mode 100644 (file)
index 16a7657..0000000
+++ /dev/null
@@ -1,793 +0,0 @@
-#include "caffe2/core/net.h"
-#include "caffe2/utils/proto_utils.h"
-
-#include "compiler.h"
-#include "parser.h"
-
-namespace caffe2 {
-namespace script {
-
-namespace {
-
-static std::unordered_set<std::string> ops_containing_nets = {
-    "If",
-    "While",
-    "RecurrentNetwork",
-};
-// record of defined function
-// NetDef + metadata
-struct FunctionDefinition {
-  explicit FunctionDefinition(Def tree)
-      : tree(new Def(tree)), net_def(new NetDef()) {}
-
-  explicit FunctionDefinition(std::unique_ptr<NetDef> def)
-      : tree(nullptr), net_def(std::move(def)) {
-    // we coop extern_inputs/extern_outputs to be the inputs/outputs to
-    // this net as a function
-    // but we _dont_ set these when creating the net in the workspace
-    // because they require the net to have valid inputs/outputs
-    inputs.insert(
-        inputs.begin(),
-        net_def->external_input().begin(),
-        net_def->external_input().end());
-    outputs.insert(
-        outputs.begin(),
-        net_def->external_output().begin(),
-        net_def->external_output().end());
-    net_def->clear_external_output();
-    net_def->clear_external_input();
-  }
-
-  bool isExtern() const {
-    return tree == nullptr;
-  }
-  std::unique_ptr<Def> tree;
-  std::unique_ptr<NetDef> net_def;
-  std::vector<std::string> inputs;
-  std::vector<std::string> outputs;
-};
-
-} // namespace
-
-using SymbolTable = std::unordered_map<std::string, FunctionDefinition>;
-
-struct DefCompiler {
-  DefCompiler(FunctionDefinition& def, SymbolTable& symbol_table)
-      : def(def),
-        net_def_stack({def.net_def.get()}),
-        symbol_table(symbol_table) {}
-  void run() {
-    auto& tree = *def.tree;
-    cur().set_name(tree.name().name());
-    for (auto input : tree.params()) {
-      auto& name = input.ident().name();
-      map(name, name);
-      def.inputs.push_back(name);
-    }
-    for (auto output : tree.returns()) {
-      auto& name = output.ident().name();
-      map(name, name);
-      def.outputs.push_back(name);
-    }
-    emitStatements(tree.statements());
-  }
-  void emitExpressionStatement(TreeRef stmt) {
-    // expression with no used outputs
-    emit(stmt, {});
-  }
-  void emitStatements(const ListView<TreeRef>& statements) {
-    for (auto stmt : statements) {
-      switch (stmt->kind()) {
-        case TK_IF:
-          emitIf(If(stmt));
-          break;
-        case TK_WHILE:
-          emitWhile(While(stmt));
-          break;
-        case TK_ASSIGN:
-          emitAssignment(Assign(stmt));
-          break;
-        case TK_GLOBAL:
-          for (auto ident : stmt->trees()) {
-            auto name = Ident(ident).name();
-            map(name, name);
-          }
-          break;
-        default:
-          emitExpressionStatement(stmt);
-          break;
-      }
-    }
-  }
-  void map(const std::string& name, const std::string& value) {
-    env[name] = value;
-  }
-  const std::string& lookup(const Ident& ident) {
-    if (env.count(ident.name()) == 0)
-      throw ErrorReport(ident) << "undefined value " << ident.name();
-    return env[ident.name()];
-  }
-  void emitAssignment(const Assign& stmt) {
-    std::vector<std::string> outputs;
-    for (auto lhs : stmt.lhs()) {
-      std::string name = getLHS(lhs);
-      // use of "_" gets renamed in Caffe2 graphs so that two uses
-      // don't unintentionally interfere with each other
-      if (name == "_") {
-        name = fresh();
-      }
-      outputs.push_back(name);
-    }
-    if (stmt.reduction() != '=') {
-      if (stmt.lhs().size() != 1) {
-        throw ErrorReport(stmt)
-            << "reductions are only allow when there is a single variable "
-            << "on the left-hand side.";
-      }
-      auto lhs = stmt.lhs()[0];
-      auto expr =
-          Compound::create(stmt.reduction(), stmt.range(), {lhs, stmt.rhs()});
-      emit(expr, outputs);
-    } else {
-      emit(stmt.rhs(), outputs);
-    }
-    int i = 0;
-    for (auto ident : stmt.lhs()) {
-      if (ident->kind() == TK_IDENT)
-        map(Ident(ident).name(), outputs.at(i));
-      i++;
-    }
-  }
-  void emitIf(const If& stmt) {
-    auto cond = getValue(stmt.cond());
-    auto op = cur().add_op();
-    op->set_type("If");
-    op->add_input(cond);
-    auto true_branch = op->add_arg();
-    true_branch->set_name("then_net");
-    auto nd = true_branch->mutable_n();
-    net_def_stack.push_back(nd);
-    emitStatements(stmt.trueBranch());
-    net_def_stack.pop_back();
-    if (stmt.falseBranch().size() > 0) {
-      auto false_branch = op->add_arg();
-      false_branch->set_name("else_net");
-      auto nd = false_branch->mutable_n();
-      net_def_stack.push_back(nd);
-      emitStatements(stmt.falseBranch());
-      net_def_stack.pop_back();
-    }
-  }
-  void emitWhile(const While& stmt) {
-    std::string loop_var = fresh();
-    emitConst(0, loop_var, "i"); // it needs a definition before loop
-    auto op = cur().add_op();
-    op->set_type("While");
-    auto cond = op->add_arg();
-    cond->set_name("cond_net");
-    auto cond_net = cond->mutable_n();
-
-    net_def_stack.push_back(cond_net);
-    emit(stmt.cond(), {loop_var});
-    net_def_stack.pop_back();
-
-    op->add_input(loop_var);
-    auto body = op->add_arg();
-    body->set_name("loop_net");
-    auto body_net = body->mutable_n();
-
-    net_def_stack.push_back(body_net);
-    emitStatements(stmt.body());
-    net_def_stack.pop_back();
-  }
-  std::string getLHS(const TreeRef& tree) {
-    switch (tree->kind()) {
-      case TK_IDENT: {
-        return Ident(tree).name();
-      } break;
-      case '.': {
-        auto sel = Select(tree);
-        std::string lhs = getValue(sel.value());
-        // TODO: check whether this subname exists in object lhs
-        return lhs + "/" + sel.selector().name();
-      } break;
-      default: {
-        throw ErrorReport(tree)
-            << "This expression cannot appear on the left-hand size of an assignment";
-      } break;
-    }
-  }
-  std::string getValue(const TreeRef& tree) {
-    switch (tree->kind()) {
-      case TK_IDENT: {
-        return lookup(Ident(tree));
-      } break;
-      case '.': {
-        auto sel = Select(tree);
-        std::string lhs = getValue(sel.value());
-        // TODO: check whether this subname exists in object lhs
-        return lhs + "/" + sel.selector().name();
-      } break;
-      default: {
-        std::string name = fresh();
-        emit(tree, {name});
-        return name;
-      } break;
-    }
-  }
-  std::string fresh(std::string prefix = "$t") {
-    return std::string(prefix) + c10::to_string(next_fresh++);
-  }
-  const char* operatorName(int kind, int ninputs) {
-    switch (kind) {
-      case '+':
-        return "Add";
-      case '-':
-        if (ninputs == 1)
-          return "Negative";
-        else
-          return "Sub";
-      case '*':
-        return "Mul";
-      case '/':
-        return "Div";
-      case TK_NE:
-        return "NE";
-      case TK_EQ:
-        return "EQ";
-      case '<':
-        return "LT";
-      case '>':
-        return "GT";
-      case TK_LE:
-        return "LE";
-      case TK_GE:
-        return "GE";
-      case TK_IF_EXPR:
-        return "Conditional";
-      case TK_AND:
-        return "And";
-      case TK_OR:
-        return "Or";
-      case TK_NOT:
-        return "Not";
-      default:
-        throw std::runtime_error("unknown kind " + c10::to_string(kind));
-    }
-  }
-  void fillArg(Argument* arg, const Attribute& attr) {
-    std::string name = attr.name().name();
-    arg->set_name(name);
-    auto value = attr.value();
-    // TODO: handle non-float attributes
-    switch (value->kind()) {
-      case TK_CONST: {
-        auto v = value->tree(0)->doubleValue();
-        auto f = value->tree(1)->stringValue();
-        if (f == "f")
-          arg->set_f(v);
-        else
-          arg->set_i(v);
-      } break;
-      case TK_LIST:
-        for (auto t : value->trees()) {
-          auto v = t->tree(0)->doubleValue();
-          auto f = t->tree(1)->stringValue();
-          if (f == "f")
-            arg->add_floats(v);
-          else
-            arg->add_ints(v);
-        }
-        break;
-    }
-  }
-  template <typename Trees>
-  std::vector<std::string> getValues(const Trees& trees) {
-    std::vector<std::string> result;
-    for (const auto& tree : trees) {
-      result.push_back(getValue(tree));
-    }
-    return result;
-  }
-
-  bool renameLookup(
-      std::unordered_map<std::string, std::string>& rename_map,
-      const std::string& name,
-      std::string& rename) {
-    // first look for name in the map directly
-    auto it = rename_map.find(name);
-    if (it != rename_map.end()) {
-      rename = it->second;
-      return true;
-    }
-    // otherwise if we have a rename entry like a => b and a name "a/foo/bar"
-    // then replace it with "b/foo/bar"
-    auto p = name.find("/");
-    if (p == std::string::npos)
-      return false;
-    it = rename_map.find(name.substr(0, p));
-    if (it != rename_map.end()) {
-      rename = it->second + name.substr(p);
-      return true;
-    }
-    return false;
-  }
-  void renameOp(
-      std::unordered_map<std::string, std::string>& rename_map,
-      const Apply& apply,
-      const std::string& prefix,
-      bool isExtern,
-      OperatorDef* new_op) {
-    for (size_t i = 0; i < new_op->input().size(); i++) {
-      auto& name = new_op->input(i);
-      std::string renamed;
-      bool defined = renameLookup(rename_map, name, renamed);
-      if (!isExtern && !defined) {
-        throw ErrorReport(apply)
-            << " unexpected undefined name '" << name
-            << "' while attempting to inline '" << apply.name().name() << "'";
-      } else if (!defined) {
-        // extern function using a global name, assign it an identity mapping
-        rename_map[name] = name;
-      }
-      new_op->set_input(i, renamed);
-    }
-    for (size_t i = 0; i < new_op->output().size(); i++) {
-      auto& name = new_op->output(i);
-      std::string renamed;
-      if (!renameLookup(rename_map, name, renamed)) {
-        renamed = prefix + name;
-        rename_map[name] = renamed;
-      }
-      new_op->set_output(i, renamed);
-    }
-    // handle control flow inside the op as well
-    if (ops_containing_nets.count(new_op->type()) > 0) {
-      for (size_t i = 0; i < new_op->arg_size(); i++) {
-        auto* arg = new_op->mutable_arg(i);
-        if (arg->has_n()) {
-          auto* n = arg->mutable_n();
-          for (size_t j = 0; j < n->op_size(); j++) {
-            renameOp(rename_map, apply, prefix, isExtern, n->mutable_op(j));
-          }
-        }
-      }
-    }
-  }
-
-  bool hasBypassRename(const Apply& apply) {
-    for (auto attr : apply.attributes()) {
-      if (attr.name().name() == "rename") {
-        if (attr.value()->kind() != TK_CONST) {
-          throw ErrorReport(attr.value()) << "expected a single constant";
-        }
-        return attr.value()->tree(0)->doubleValue() == 0;
-      }
-    }
-    return false;
-  }
-
-  // emit a function call by inlining the function's NetDef into our
-  // net def, renaming temporaries func_name<unique_id>/orig_name
-  // renaming only happens for values defined by the function
-  // that are not marked outputs
-
-  // inputs/outputs are passed by reference
-  void emitFunctionCall(Apply& apply, const std::vector<std::string>& outputs) {
-    std::string fname = apply.name().name();
-    std::string prefix = fresh(fname) + "/";
-    auto& fn = symbol_table.at(apply.name().name());
-    bool isExtern = fn.isExtern();
-    auto inputs = getValues(apply.inputs());
-    std::unordered_map<std::string, std::string> rename_map;
-    if (inputs.size() != fn.inputs.size()) {
-      throw ErrorReport(apply) << fname << " expected " << fn.inputs.size()
-                               << " values but received " << inputs.size();
-    }
-    for (size_t i = 0; i < inputs.size(); i++) {
-      rename_map[fn.inputs[i]] = inputs[i];
-    }
-    if (outputs.size() != fn.outputs.size()) {
-      throw ErrorReport(apply) << fname << " expected " << fn.outputs.size()
-                               << " values but received " << outputs.size();
-    }
-    for (size_t i = 0; i < outputs.size(); i++) {
-      rename_map[fn.outputs[i]] = outputs[i];
-    }
-    for (auto& op : fn.net_def->op()) {
-      auto new_op = cur().add_op();
-      new_op->CopyFrom(op);
-      if (hasBypassRename(apply)) {
-        prefix = "";
-      }
-      renameOp(rename_map, apply, prefix, isExtern, new_op);
-    }
-  }
-  void expectOutputs(
-      const TreeRef& tree,
-      const std::vector<std::string>& outputs,
-      size_t size) {
-    if (outputs.size() != size) {
-      throw ErrorReport(tree)
-          << "expected operator to produce " << outputs.size()
-          << " outputs but it produced " << size;
-    }
-  }
-  void appendOutputs(
-      const TreeRef& tree,
-      OperatorDef* op,
-      const std::vector<std::string>& outputs,
-      size_t size) {
-    expectOutputs(tree, outputs, size);
-    for (size_t i = 0; i < size; i++) {
-      op->add_output(outputs[i]);
-    }
-  }
-  void emitOperator(
-      const Apply& apply,
-      const OpSchema* schema,
-      const std::vector<std::string>& outputs) {
-    // must be before add_op
-    auto values = getValues(apply.inputs());
-    if (values.size() < schema->min_input() ||
-        values.size() > schema->max_input()) {
-      if (schema->min_input() == schema->max_input()) {
-        throw ErrorReport(apply) << "operator expects " << schema->min_input()
-                                 << " inputs but found " << values.size();
-      } else {
-        throw ErrorReport(apply)
-            << "operator takes between " << schema->min_input() << " and "
-            << schema->max_input() << " inputs but found " << values.size()
-            << ".";
-      }
-    }
-    auto numActualOutputs = schema->CalculateOutput(values.size());
-    if (numActualOutputs != kCannotComputeNumOutputs &&
-        outputs.size() != numActualOutputs) {
-      throw ErrorReport(apply)
-          << "operator produces " << numActualOutputs
-          << " outputs but matched to " << outputs.size() << " outputs";
-    }
-    auto op = cur().add_op();
-    op->set_type(apply.name().name());
-    for (auto& v : values) {
-      op->add_input(v);
-    }
-    // assume 1 output unless matched to more
-    appendOutputs(apply, op, outputs, outputs.size());
-    for (auto attribute : apply.attributes()) {
-      fillArg(op->add_arg(), attribute);
-    }
-    // Ok, we checked the stuff where we can easily give a friendly error
-    // message, now verify against the schema and report the error at the line
-    if (!schema->Verify(*op)) {
-      throw ErrorReport(apply) << "failed schema checking";
-    }
-  }
-
-  // Emit an operation, writing results into 'outputs'.
-  // This will _always_ compute something, unlike 'getValue' which simply
-  // returns an already computed reference if possible.
-  // So if 'tree' is an identifier or nested identifier (foo.bar)
-  // this will cause it to be _copied_ into outputs.
-  void emit(const TreeRef& tree, const std::vector<std::string>& outputs) {
-    switch (tree->kind()) {
-      case TK_IDENT:
-      case '.': {
-        auto op = cur().add_op();
-        op->set_type("Copy");
-        op->add_input(getValue(tree));
-        appendOutputs(tree, op, outputs, 1);
-      } break;
-      case TK_NE:
-      case TK_EQ:
-      case '<':
-      case '>':
-      case TK_LE:
-      case TK_GE:
-      case '-':
-      case '*':
-      case '/':
-      case '+':
-      case TK_AND:
-      case TK_OR:
-      case TK_NOT:
-      case TK_IF_EXPR: {
-        // must be before add_op
-        auto values = getValues(tree->trees());
-        auto op = cur().add_op();
-        op->set_type(operatorName(tree->kind(), tree->trees().size()));
-        for (auto& v : values) {
-          op->add_input(v);
-        }
-        appendOutputs(tree, op, outputs, 1);
-        auto broadcast = op->add_arg();
-        broadcast->set_name("broadcast");
-        broadcast->set_i(1);
-      } break;
-      case TK_APPLY: {
-        auto apply = Apply(tree);
-        // Handle built-ins like zeros, ones, etc
-        if (builtins.count(apply.name().name()) > 0) {
-          builtins[apply.name().name()](this, apply, outputs);
-          break;
-        }
-        if (symbol_table.count(apply.name().name()) > 0) {
-          emitFunctionCall(apply, outputs);
-          break;
-        }
-        auto schema = OpSchemaRegistry::Schema(apply.name().name());
-        if (schema) {
-          emitOperator(apply, schema, outputs);
-          break;
-        }
-        throw ErrorReport(apply)
-            << "attempting to call unknown operation or function '"
-            << apply.name().name() << "'";
-      } break;
-      case TK_CAST: {
-        auto cast = Cast(tree);
-        auto c2type = getType(cast.type());
-        auto input = getValue(cast.input());
-        auto op = cur().add_op();
-        op->set_type("Cast");
-        op->add_input(input);
-        appendOutputs(tree, op, outputs, 1);
-        auto arg = op->add_arg();
-        arg->set_name("to");
-        arg->set_i(c2type);
-      } break;
-      case TK_CONST: {
-        expectOutputs(tree, outputs, 1);
-        emitConst(
-            tree->tree(0)->doubleValue(),
-            outputs[0],
-            tree->tree(1)->stringValue());
-      } break;
-      case TK_GATHER: {
-        const auto gather = Gather(tree);
-        desugarAndEmitOperator(
-            "Gather",
-            gather.range(),
-            {gather.value(), gather.indices()},
-            outputs);
-        break;
-      }
-      case TK_SLICE: {
-        const auto slice = Slice(tree);
-        desugarAndEmitOperator(
-            "Slice",
-            slice.range(),
-            {slice.value(), slice.startOr(0), slice.endOr(-1)},
-            outputs);
-        break;
-      }
-      default:
-        throw ErrorReport(tree) << "NYI: " << tree;
-        break;
-    }
-  }
-
-  // Desugars constructs that are syntactic sugar and emits the corresponding
-  // operator invocation, e.g. tensor[indices] -> tensor.Gather(indices).
-  void desugarAndEmitOperator(
-      const std::string& operatorName,
-      const SourceRange& range,
-      TreeList&& inputs,
-      const std::vector<std::string>& outputs) {
-    const auto applyName = Ident::create(range, operatorName);
-    const auto applyInputs =
-        Compound::create(TK_LIST, range, std::move(inputs));
-    const auto applyAttributes = Compound::create(TK_LIST, range, {});
-    const auto apply =
-        Apply::create(range, applyName, applyInputs, applyAttributes);
-    const auto schema = OpSchemaRegistry::Schema(operatorName);
-    assert(schema != nullptr);
-    emitOperator(Apply(apply), schema, outputs);
-  }
-
-  TensorProto_DataType getType(int type) {
-    switch (type) {
-      case TK_INT:
-        return TensorProto_DataType_INT32;
-      case TK_FLOAT:
-        return TensorProto_DataType_FLOAT;
-      case TK_LONG:
-        return TensorProto_DataType_INT64;
-      case TK_BOOL:
-        return TensorProto_DataType_BOOL;
-      default:
-        throw std::runtime_error(
-            "expected type token: " + c10::to_string(type));
-    }
-  }
-
-  OperatorDef* emitConst(
-      double v,
-      const std::string& output,
-      const std::string& type_ident) {
-    auto op = cur().add_op();
-    op->set_type("ConstantFill");
-    auto dtype = op->add_arg();
-    dtype->set_name("dtype");
-    auto value = op->add_arg();
-    value->set_name("value");
-    if (type_ident == "f") {
-      dtype->set_i(TensorProto_DataType_FLOAT);
-      value->set_f(v);
-    } else if (type_ident == "LL") {
-      dtype->set_i(TensorProto_DataType_INT64);
-      value->set_i(v);
-    } else if (type_ident == "b") {
-      dtype->set_i(TensorProto_DataType_BOOL);
-      value->set_i(v != 0);
-    } else if (type_ident == "i") {
-      dtype->set_i(TensorProto_DataType_INT32);
-      value->set_i(v);
-    } else {
-      throw std::runtime_error("unknown type_ident " + type_ident);
-    }
-    auto shape = op->add_arg();
-    shape->set_name("shape");
-    shape->add_ints(1);
-    op->add_output(output);
-    return op;
-  }
-  NetDef& cur() {
-    return *net_def_stack.back();
-  }
-  FunctionDefinition& def; // the def being constructed
-  std::unordered_map<std::string, std::string>
-      env; // map from name in Def to name in NetDef
-  std::vector<NetDef*> net_def_stack;
-  SymbolTable& symbol_table;
-  int next_fresh = 0;
-
- private:
-  void emitFillOp(const Apply& apply, const std::vector<std::string>& outputs) {
-    auto builtin_type = apply.name().name();
-    auto values = getValues(apply.inputs());
-    if (values.size() > 1) {
-      throw ErrorReport(apply)
-          << "Built-in " << builtin_type << " accepts 0 or 1 inputs.";
-    }
-    bool has_shape = false;
-    for (const auto& attribute : apply.attributes()) {
-      if (attribute.name().name() == "shape") {
-        has_shape = true;
-      } else {
-        throw ErrorReport(apply)
-            << "Unrecognized attribute " << attribute.name().name()
-            << " for built-in " << builtin_type;
-      }
-    }
-    if (builtin_type == "zeros" || builtin_type == "ones") {
-      if ((values.size() != 1) && !has_shape) {
-        throw ErrorReport(apply)
-            << "Built-in " << builtin_type
-            << " requires either 1 input or 1 shape attribute";
-      }
-    } else {
-      // zeros_like or ones_like
-      if (values.size() != 1) {
-        throw ErrorReport(apply)
-            << "Built-in " << builtin_type << " requires 1 input";
-      }
-    }
-
-    auto op = cur().add_op();
-    op->set_type("ConstantFill");
-    if (values.size()) {
-      op->add_input(values[0]);
-      auto* input_as_shape = op->add_arg();
-      input_as_shape->set_name("input_as_shape");
-      if (builtin_type.find("_like") != std::string::npos) {
-        // zeros_like, ones_like take the shape of the input as constant
-        // tensor shape
-        input_as_shape->set_i(0);
-      } else {
-        // zeros, ones take the values in the tensor as constant tensor
-        // shape
-        input_as_shape->set_i(1);
-      }
-    } else {
-      fillArg(op->add_arg(), apply.attributes()[0]);
-    }
-
-    auto value = op->add_arg();
-    value->set_name("value");
-    if (builtin_type.find("ones") != std::string::npos) {
-      value->set_f(1.0f);
-    } else {
-      value->set_f(0.0f);
-    }
-    appendOutputs(apply, op, outputs, 1);
-  }
-  // emitModule doesn't actually do anything except for allow
-  // statements like a = Module() to register 'a' as a valid identifier
-  // so that a.b = ... will work
-  void emitModule(const Apply& apply, const std::vector<std::string>& outputs) {
-    expectOutputs(apply, outputs, 1);
-  }
-  std::unordered_map<
-      std::string,
-      std::function<void(
-          DefCompiler*,
-          const Apply&,
-          const std::vector<std::string>& outputs)>>
-      builtins{{"zeros", &DefCompiler::emitFillOp},
-               {"zeros_like", &DefCompiler::emitFillOp},
-               {"ones", &DefCompiler::emitFillOp},
-               {"ones_like", &DefCompiler::emitFillOp},
-               {"Module", &DefCompiler::emitModule}};
-};
-
-struct CompilationUnitImpl {
-  void defineFunction(const Def& def) {
-    if (functions.count(def.name().name()) > 0) {
-      throw ErrorReport(def) << def.name().name() << " already defined.";
-    }
-    DefCompiler c(
-        functions.emplace(def.name().name(), FunctionDefinition(def))
-            .first->second,
-        functions);
-    c.run();
-  }
-
-  void define(const std::string& str) {
-    Parser p(str);
-    while (p.lexer().cur().kind != TK_EOF) {
-      defineFunction(Def(p.parseFunction()));
-    }
-  }
-
-  std::unique_ptr<NetBase> createNet(Workspace* ws, const std::string& str) {
-    if (functions.count(str) == 0)
-      throw ErrorReport() << "undefined function: " << str << "\n";
-    auto& def = functions.at(str);
-    return caffe2::CreateNet(*def.net_def, ws);
-  }
-
-  void defineExtern(const std::string& name, std::unique_ptr<NetDef> net_def) {
-    // TODO: unify extern and function namespaces
-    if (functions.count(name) > 0) {
-      throw ErrorReport() << "function '" << name << "' already defined.";
-    }
-    functions.emplace(name, FunctionDefinition(std::move(net_def)));
-  }
-
-  std::string getProto(const std::string& functionName) {
-    return functions.at(functionName).net_def->DebugString();
-  }
-
- private:
-  friend struct DefCompiler;
-  SymbolTable functions;
-};
-
-CompilationUnit::CompilationUnit() : pImpl(new CompilationUnitImpl()) {}
-
-void CompilationUnit::define(const std::string& str) {
-  return pImpl->define(str);
-}
-
-void CompilationUnit::defineExtern(
-    const std::string& name,
-    std::unique_ptr<NetDef> nd) {
-  pImpl->defineExtern(name, std::move(nd));
-}
-
-std::unique_ptr<NetBase> CompilationUnit::createNet(
-    Workspace* ws,
-    const std::string& str) {
-  return pImpl->createNet(ws, str);
-}
-
-std::string CompilationUnit::getProto(const std::string& functionName) const {
-  return pImpl->getProto(functionName);
-}
-
-CompilationUnit::~CompilationUnit() {}
-
-} // namespace script
-} // namespace caffe2
diff --git a/caffe2/contrib/script/compiler.h b/caffe2/contrib/script/compiler.h
deleted file mode 100644 (file)
index 0a15c33..0000000
+++ /dev/null
@@ -1,24 +0,0 @@
-#pragma once
-#include <memory>
-#include <string>
-#include "caffe2/core/net.h"
-
-namespace caffe2 {
-namespace script {
-
-struct CompilationUnitImpl;
-
-struct CAFFE2_API CompilationUnit {
-  CompilationUnit();
-  void define(const std::string& str);
-  void defineExtern(const std::string& str, std::unique_ptr<NetDef> netdef);
-  std::unique_ptr<NetBase> createNet(Workspace* ws, const std::string& name);
-  std::string getProto(const std::string& functionName) const;
-  ~CompilationUnit();
-
- private:
-  std::unique_ptr<CompilationUnitImpl> pImpl;
-};
-
-} // namespace script
-}; // namespace caffe2
diff --git a/caffe2/contrib/script/error_report.h b/caffe2/contrib/script/error_report.h
deleted file mode 100644 (file)
index cecc0f3..0000000
+++ /dev/null
@@ -1,51 +0,0 @@
-#pragma once
-
-#include "caffe2/contrib/script/tree.h"
-
-namespace caffe2 {
-namespace script {
-
-struct ErrorReport : public std::exception {
-  ErrorReport(const ErrorReport& e)
-      : ss(e.ss.str()), context(e.context), the_message(e.the_message) {}
-
-  ErrorReport() : context(nullptr) {}
-  explicit ErrorReport(const SourceRange& r)
-      : context(std::make_shared<SourceRange>(r)) {}
-  explicit ErrorReport(const TreeRef& tree) : ErrorReport(tree->range()) {}
-  explicit ErrorReport(const Token& tok) : ErrorReport(tok.range) {}
-  virtual const char* what() const noexcept override {
-    std::stringstream msg;
-    msg << "\n" << ss.str();
-    if (context != nullptr) {
-      msg << ":\n";
-      context->highlight(msg);
-    } else {
-      msg << ".\n";
-    }
-    the_message = msg.str();
-    return the_message.c_str();
-  }
-
- private:
-  template <typename T>
-  friend const ErrorReport& operator<<(const ErrorReport& e, const T& t);
-
-  mutable std::stringstream ss;
-  std::shared_ptr<SourceRange> context;
-  mutable std::string the_message;
-};
-
-template <typename T>
-const ErrorReport& operator<<(const ErrorReport& e, const T& t) {
-  e.ss << t;
-  return e;
-}
-
-#define C2S_ASSERT(ctx, cond)                                              \
-  if (!(cond)) {                                                           \
-    throw ::caffe2::script::ErrorReport(ctx)                               \
-        << __FILE__ << ":" << __LINE__ << ": assertion failed: " << #cond; \
-  }
-} // namespace script
-} // namespace caffe2
diff --git a/caffe2/contrib/script/examples/example_beam_search.c2s b/caffe2/contrib/script/examples/example_beam_search.c2s
deleted file mode 100644 (file)
index 2e081ee..0000000
+++ /dev/null
@@ -1,76 +0,0 @@
-[["log_probs", [6, 1, 44463], "float32"], ["attentions", [6, 1, 21], "float32"], ["inputs", [21], "float32"]]
-beam_search
-["scores_t"]
-
-def beam_search(inputs, log_probs, attentions) -> ():
-    beam_size = 6LL
-    length = 20LL
-    beam_output_shape, _ = Concat(length + 1LL, beam_size, axis=0)
-    output_token_beam_list = int(zeros(beam_output_shape))
-    output_prev_index_beam_list = int(zeros(beam_output_shape))
-    output_score_beam_list = zeros(beam_output_shape)
-
-    input_length = inputs.Size().ExpandDims(dims=[0])
-
-    attention_beam_output_shape, _ = Concat(
-        input_length, beam_output_shape, axis=0)
-    output_attention_weights_beam_list = zeros(attention_beam_output_shape)
-
-    attention_step_output_shape, _ = Concat(beam_size, input_length, axis=0)
-    attention_t = zeros(attention_step_output_shape)
-
-    scores_t = zeros(shape=[1, 6])
-    hypo_t = int(zeros(shape=[6]))
-    tokens_t = int(ones(shape=[6])) * 99
-
-    output_token_beam_list = output_token_beam_list.ScatterAssign(0, tokens_t)
-    output_token_beam_list = output_token_beam_list.ExpandDims(dims=[2])
-    output_prev_index_beam_list = output_prev_index_beam_list.ScatterAssign(
-        0, hypo_t)
-    output_prev_index_beam_list = output_prev_index_beam_list.ExpandDims(dims=[2])
-    output_score_beam_list = output_score_beam_list.ScatterAssign(0, scores_t)
-    output_score_beam_list = output_score_beam_list.ExpandDims(dims=[2])
-    output_attention_weights_beam_list = output_attention_weights_beam_list\
-        .ScatterAssign(0, attention_t)
-
-    length_32 = int(length)
-
-    timestep = 0
-    not_finished = True
-    while not_finished:
-        # TODO: once we have a metaprogramming facility we need to insert the
-        # body of the post_eos_penalty here programmatically
-
-        best_scores_per_hypo, best_tokens_per_hypo = log_probs.TopK(k=6)
-
-        # Add the best score in each hypothesis to the cumulative score so far
-        output_scores = best_scores_per_hypo + scores_t.Squeeze(dims=[0])
-
-        # Flatten scores so we can find the best overall out of all hypotheses
-        output_scores_flattened_slice, _ = output_scores.FlattenToVec()\
-            .Slice(0, 6 if timestep == 0 else -1).Reshape(shape=[1, -1])
-
-        # Find top K out of all
-        scores_t, best_indices = output_scores_flattened_slice.TopK(k=6)
-
-        # Integer floor divide on indices finds the association back to original
-        #  hypotheses. Use this to reorder states
-        hypo_t_int64 = best_indices / 6LL
-
-        # Reorder attentions
-        attention_t, _ = attentions.Gather(hypo_t_int64)\
-            .Reshape(shape=[1, 6, -1])
-        tokens_t_int64 = best_tokens_per_hypo.FlattenToVec()\
-            .Gather(best_indices).Cast(to=2)
-
-        timestep += 1
-        not_finished = timestep < length_32
-
-        output_token_beam_list = output_token_beam_list\
-            .ScatterAssign(timestep, tokens_t)
-        output_prev_index_beam_list = output_prev_index_beam_list\
-            .ScatterAssign(timestep, hypo_t)
-        output_score_beam_list = output_score_beam_list\
-            .ScatterAssign(timestep, scores_t)
-        output_attention_weights_beam_list = output_attention_weights_beam_list\
-            .ScatterAssign(timestep, attention_t)
diff --git a/caffe2/contrib/script/examples/example_post_eos_penalty.c2s b/caffe2/contrib/script/examples/example_post_eos_penalty.c2s
deleted file mode 100644 (file)
index 9988913..0000000
+++ /dev/null
@@ -1,13 +0,0 @@
-[["tokens_t", [1, 6], "int32"], ["hypo_t", [1, 6], "int32"], ["log_probs", [6, 1, 44463], "float32"], ["on_initial_step", [1], "bool_"]]
-post_eos_penalty
-["log_probs"]
-
-def post_eos_penalty(tokens_t, hypo_t, log_probs, on_initial_step) \
-  -> (log_probs):
-  eos_token = 1
-  finished_penalty = 0f if on_initial_step else 0.5f
-  predecessor_tokens = tokens_t.FlattenToVec().Gather(hypo_t.FlattenToVec())
-  predecessor_is_eos = float(predecessor_tokens == eos_token)
-  log_probs = log_probs.Add(
-      predecessor_is_eos * finished_penalty, broadcast=1, axis=0
-  )
diff --git a/caffe2/contrib/script/examples/run_examples.py b/caffe2/contrib/script/examples/run_examples.py
deleted file mode 100644 (file)
index 26f2db0..0000000
+++ /dev/null
@@ -1,31 +0,0 @@
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-from __future__ import unicode_literals
-from caffe2.python import core, workspace
-import glob
-import json
-import numpy as np
-
-example_files = glob.glob('example_*.c2s')
-
-for ex in example_files:
-    print('Running example file', ex)
-    with open(ex, 'r') as f:
-        inits = json.loads(f.readline())
-        net_name = f.readline().strip()
-        outputs = json.loads(f.readline())
-
-        CU = core.C.CompilationUnit()
-        CU.define(f.read())
-
-    # Initialize workspace with required inputs
-    for name, shape, dt in inits:
-        workspace.FeedBlob(name, np.random.rand(*shape).astype(np.dtype(dt)))
-
-    net = CU.create_net(net_name)
-    net.run()
-
-    print('Success! Interesting outputs:')
-    for output in outputs:
-        print(output, workspace.FetchBlob(output))
diff --git a/caffe2/contrib/script/lexer.cc b/caffe2/contrib/script/lexer.cc
deleted file mode 100644 (file)
index 9dafea9..0000000
+++ /dev/null
@@ -1,26 +0,0 @@
-#include "caffe2/contrib/script/lexer.h"
-#include "caffe2/core/common.h"
-
-namespace caffe2 {
-namespace script {
-
-std::string kindToString(int kind) {
-  if (kind < 256)
-    return std::string(1, kind);
-  switch (kind) {
-#define DEFINE_CASE(tok, str, _) \
-  case tok:                      \
-    return str;
-    TC_FORALL_TOKEN_KINDS(DEFINE_CASE)
-#undef DEFINE_CASE
-    default:
-      throw std::runtime_error("unknown kind: " + c10::to_string(kind));
-  }
-}
-
-SharedParserData& sharedParserData() {
-  static SharedParserData data; // safely handles multi-threaded init
-  return data;
-}
-} // namespace script
-} // namespace caffe2
diff --git a/caffe2/contrib/script/lexer.h b/caffe2/contrib/script/lexer.h
deleted file mode 100644 (file)
index b298809..0000000
+++ /dev/null
@@ -1,527 +0,0 @@
-#pragma once
-#include <assert.h>
-#include <algorithm>
-#include <iostream>
-#include <memory>
-#include <sstream>
-#include <string>
-#include <unordered_map>
-#include <vector>
-
-#include "caffe2/core/common.h"
-
-namespace caffe2 {
-namespace script {
-
-// single character tokens are just the character itself '+'
-// multi-character tokens need an entry here
-// if the third entry is not the empty string, it is used
-// in the lexer to match this token.
-
-// These kinds are also used in Tree.h as the kind of the AST node.
-// Some kinds TK_APPLY, TK_LIST are only used in the AST and are not seen in the
-// lexer.
-
-#define TC_FORALL_TOKEN_KINDS(_)                 \
-  _(TK_EOF, "eof", "")                           \
-  _(TK_WHITESPACE, "whitespace", "")             \
-  _(TK_NUMBER, "number", "")                     \
-  _(TK_NEWLINE, "newline", "")                   \
-  _(TK_INDENT, "indent", "")                     \
-  _(TK_DEDENT, "dedent", "")                     \
-  _(TK_WHERE, "where", "where")                  \
-  _(TK_FLOAT, "float", "float")                  \
-  _(TK_DOUBLE, "double", "double")               \
-  _(TK_LONG, "long", "long")                     \
-  _(TK_INT, "int", "int")                        \
-  _(TK_DEF, "def", "def")                        \
-  _(TK_ARROW, "arrow", "->")                     \
-  _(TK_EQUIVALENT, "equivalent", "<=>")          \
-  _(TK_IDENT, "ident", "")                       \
-  _(TK_STRING, "string", "")                     \
-  _(TK_CONST, "const", "")                       \
-  _(TK_LIST, "list", "")                         \
-  _(TK_OPTION, "option", "")                     \
-  _(TK_APPLY, "apply", "")                       \
-  _(TK_COMPREHENSION, "comprehension", "")       \
-  _(TK_TENSOR_TYPE, "tensor_type", "")           \
-  _(TK_RANGE_CONSTRAINT, "range_constraint", "") \
-  _(TK_PARAM, "param", "")                       \
-  _(TK_INFERRED, "inferred", "")                 \
-  _(TK_BOOL, "bool", "")                         \
-  _(TK_ACCESS, "access", "")                     \
-  _(TK_ASSIGN, "assign", "")                     \
-  _(TK_ATTRIBUTE, "attribute", "")               \
-  _(TK_IF, "if", "if")                           \
-  _(TK_ELSE, "else", "else")                     \
-  _(TK_ELIF, "elif", "elif")                     \
-  _(TK_WHILE, "while", "while")                  \
-  _(TK_NE, "ne", "!=")                           \
-  _(TK_EQ, "eq", "==")                           \
-  _(TK_LE, "le", "<=")                           \
-  _(TK_GE, "ge", ">=")                           \
-  _(TK_IF_EXPR, "if", "")                        \
-  _(TK_TRUE, "True", "True")                     \
-  _(TK_FALSE, "False", "False")                  \
-  _(TK_AND, "and", "and")                        \
-  _(TK_OR, "or", "or")                           \
-  _(TK_NOT, "not", "not")                        \
-  _(TK_CAST, "cast", "")                         \
-  _(TK_PLUS_EQ, "+=", "+=")                      \
-  _(TK_MINUS_EQ, "-=", "-=")                     \
-  _(TK_TIMES_EQ, "*=", "*=")                     \
-  _(TK_DIV_EQ, "/=", "/=")                       \
-  _(TK_GLOBAL, "global", "global")               \
-  _(TK_BUILT_IN, "built-in", "")                 \
-  _(TK_SLICE, "slice", "")                       \
-  _(TK_GATHER, "gather", "")
-static const char* valid_single_char_tokens = "+-*/()[]:,={}><.";
-
-enum TokenKind {
-  // we use characters to represent themselves so skip all valid characters
-  // before
-  // assigning enum values to multi-char tokens.
-  TK_DUMMY_START = 256,
-#define DEFINE_TOKEN(tok, _, _2) tok,
-  TC_FORALL_TOKEN_KINDS(DEFINE_TOKEN)
-#undef DEFINE_TOKEN
-};
-
-std::string kindToString(int kind);
-
-// nested hash tables that indicate char-by-char what is a valid token.
-struct TokenTrie;
-using TokenTrieRef = std::unique_ptr<TokenTrie>;
-struct TokenTrie {
-  TokenTrie() : kind(0) {}
-  void insert(const char* str, int tok) {
-    if (*str == '\0') {
-      assert(kind == 0);
-      kind = tok;
-      return;
-    }
-    auto& entry = children[*str];
-    if (entry == nullptr) {
-      entry.reset(new TokenTrie());
-    }
-    entry->insert(str + 1, tok);
-  }
-  int kind; // 0 == invalid token
-  std::unordered_map<char, TokenTrieRef> children;
-};
-
-// stuff that is shared against all TC lexers/parsers and is initialized only
-// once.
-struct SharedParserData {
-  SharedParserData() : head(new TokenTrie()) {
-    // listed in increasing order of precedence
-    std::vector<std::vector<int>> binary_ops = {
-        {TK_IF},
-        {TK_AND, TK_OR},
-        {}, // reserve a level for unary not
-        {'<', '>', TK_EQ, TK_LE, TK_GE, TK_NE},
-        {'+', '-'},
-        {'*', '/'},
-    };
-    std::vector<std::vector<int>> unary_ops = {
-        {'-'},
-    };
-
-    std::stringstream ss;
-    for (const char* c = valid_single_char_tokens; *c; c++) {
-      const char str[] = {*c, '\0'};
-      head->insert(str, *c);
-    }
-
-#define ADD_CASE(tok, _, tokstring) \
-  if (*tokstring != '\0') {         \
-    head->insert(tokstring, tok);   \
-  }
-    TC_FORALL_TOKEN_KINDS(ADD_CASE)
-#undef ADD_CASE
-
-    // precedence starts at 1 so that there is always a 0 precedence
-    // less than any other precedence
-    int prec = 1;
-    for (auto& group : binary_ops) {
-      for (auto& element : group) {
-        binary_prec[element] = prec;
-      }
-      prec++;
-    }
-    // unary ops
-    for (auto& group : unary_ops) {
-      for (auto& element : group) {
-        unary_prec[element] = prec;
-      }
-      prec++;
-    }
-    // add unary not separately because it slots into the precedence of
-    // binary operators
-    unary_prec[TK_NOT] = binary_prec[TK_AND] + 1;
-  }
-  // 1. skip whitespace
-  // 2. handle comment or newline
-  //
-  bool isNumber(const std::string& str, size_t start, size_t* len) {
-    char first = str[start];
-    // strtod allows numbers to start with + or - or nan or inf
-    // http://en.cppreference.com/w/cpp/string/byte/strtof
-    // but we want only the number part, otherwise 1+3 will turn into two
-    // adjacent numbers in the lexer
-    if (first == '-' || first == '+' || isalpha(first))
-      return false;
-    const char* startptr = str.c_str() + start;
-    char* endptr;
-    std::strtod(startptr, &endptr);
-    *len = endptr - startptr;
-    return *len > 0;
-  }
-  bool isblank(int n) {
-    return isspace(n) && n != '\n';
-  }
-  // find the longest match of str.substring(pos) against a token, return true
-  // if successful
-  // filling in kind, start,and len
-  bool match(
-      const std::string& str,
-      size_t pos,
-      bool continuation, // are we inside a scope where newlines don't count
-                         // (e.g. inside parens)
-      bool whitespace_token, // should we treat whitespace as a token
-      int* kind,
-      size_t* start,
-      size_t* len) {
-    *start = pos;
-    // skip whitespace
-    while (pos < str.size() && isblank(str[pos]))
-      pos++;
-
-    // special handling
-    if (pos < str.size()) {
-      if (str[pos] == '#') {
-        // skip comments
-        while (pos < str.size() && str[pos] != '\n')
-          pos++;
-        // tail call, handle whitespace and more comments
-        return match(
-            str, pos, continuation, whitespace_token, kind, start, len);
-      }
-      if (str[pos] == '\\' && pos + 1 < str.size() && str[pos + 1] == '\n' &&
-          !whitespace_token) {
-        return match(str, pos + 2, continuation, false, kind, start, len);
-      }
-      if (str[pos] == '\n') {
-        return match(
-            str, pos + 1, continuation, !continuation, kind, start, len);
-      }
-    }
-    if (pos == str.size()) {
-      *kind = TK_EOF;
-      *start = pos;
-      *len = 0;
-      return true;
-    }
-    // invariant: the next token is not whitespace or newline
-    if (whitespace_token) {
-      *kind = TK_WHITESPACE;
-      *len = pos - *start;
-      return true;
-    }
-    *start = pos;
-    // check for a valid number
-    if (isNumber(str, pos, len)) {
-      *kind = TK_NUMBER;
-      return true;
-    }
-    // check for either an ident or a token
-    // ident tracks whether what we have scanned so far could be an identifier
-    // matched indicates if we have found any match.
-    bool matched = false;
-    bool ident = true;
-    TokenTrie* cur = head.get();
-    for (size_t i = 0; pos + i < str.size() && (ident || cur != nullptr); i++) {
-      ident = ident && validIdent(i, str[pos + i]);
-      if (ident) {
-        matched = true;
-        *len = i + 1;
-        *kind = TK_IDENT;
-      }
-      // check for token second, so that e.g. 'max' matches the token TK_MAX
-      // rather the
-      // identifier 'max'
-      if (cur) {
-        auto it = cur->children.find(str[pos + i]);
-        cur = (it == cur->children.end()) ? nullptr : it->second.get();
-        if (cur && cur->kind != 0) {
-          matched = true;
-          *len = i + 1;
-          *kind = cur->kind;
-        }
-      }
-    }
-    return matched;
-  }
-  bool isUnary(int kind, int* prec) {
-    auto it = unary_prec.find(kind);
-    if (it != unary_prec.end()) {
-      *prec = it->second;
-      return true;
-    }
-    return false;
-  }
-  bool isBinary(int kind, int* prec) {
-    auto it = binary_prec.find(kind);
-    if (it != binary_prec.end()) {
-      *prec = it->second;
-      return true;
-    }
-    return false;
-  }
-  bool isRightAssociative(int kind) {
-    switch (kind) {
-      case '?':
-        return true;
-      default:
-        return false;
-    }
-  }
-
- private:
-  bool validIdent(size_t i, char n) {
-    return isalpha(n) || n == '_' || (i > 0 && isdigit(n));
-  }
-  TokenTrieRef head;
-  std::unordered_map<int, int>
-      unary_prec; // map from token to its unary precedence
-  std::unordered_map<int, int>
-      binary_prec; // map from token to its binary precedence
-};
-
-SharedParserData& sharedParserData();
-
-// a range of a shared string 'file_' with functions to help debug by highlight
-// that
-// range.
-struct SourceRange {
-  SourceRange(
-      const std::shared_ptr<std::string>& file_,
-      size_t start_,
-      size_t end_)
-      : file_(file_), start_(start_), end_(end_) {}
-  const std::string text() const {
-    return file().substr(start(), end() - start());
-  }
-  size_t size() const {
-    return end() - start();
-  }
-  void highlight(std::ostream& out) const {
-    const std::string& str = file();
-    size_t begin = start();
-    size_t end = start();
-    while (begin > 0 && str[begin - 1] != '\n')
-      --begin;
-    while (end < str.size() && str[end] != '\n')
-      ++end;
-    out << str.substr(0, end) << "\n";
-    out << std::string(start() - begin, ' ');
-    size_t len = std::min(size(), end - start());
-    out << std::string(len, '~')
-        << (len < size() ? "...  <--- HERE" : " <--- HERE");
-    out << str.substr(end);
-    if (str.size() > 0 && str.back() != '\n')
-      out << "\n";
-  }
-  const std::string& file() const {
-    return *file_;
-  }
-  const std::shared_ptr<std::string>& file_ptr() const {
-    return file_;
-  }
-  size_t start() const {
-    return start_;
-  }
-  size_t end() const {
-    return end_;
-  }
-
- private:
-  std::shared_ptr<std::string> file_;
-  size_t start_;
-  size_t end_;
-};
-
-struct Token {
-  int kind;
-  SourceRange range;
-  Token(int kind, const SourceRange& range) : kind(kind), range(range) {}
-  double doubleValue() {
-    assert(TK_NUMBER == kind);
-    size_t idx;
-    double r = ::c10::stod(text(), &idx);
-    assert(idx == range.size());
-    return r;
-  }
-  std::string text() {
-    return range.text();
-  }
-  std::string kindString() const {
-    return kindToString(kind);
-  }
-};
-
-struct Lookahead {
-  Lookahead(const Token& t) : t(t) {}
-  Token t;
-  bool valid = false;
-  size_t repeat = 0;
-};
-
-struct Lexer {
-  std::shared_ptr<std::string> file;
-  explicit Lexer(const std::string& str)
-      : file(std::make_shared<std::string>(str)),
-        pos(0),
-        cur_(TK_EOF, SourceRange(file, 0, 0)),
-        lookahead_(cur_),
-        repeat(0),
-        nesting(0),
-        shared(sharedParserData()) {
-    auto first_indent = lexRaw(true);
-    indent_stack.push_back(first_indent.range.size());
-    next();
-  }
-  Token next() {
-    Token r = cur_;
-    if (repeat > 0) {
-      repeat--;
-    } else if (lookahead_.valid) {
-      lookahead_.valid = false;
-      repeat = lookahead_.repeat;
-      cur_ = lookahead_.t;
-    } else {
-      std::tie(cur_, repeat) = lex();
-    }
-    return r;
-  }
-  bool nextIf(int kind) {
-    if (cur_.kind != kind)
-      return false;
-    next();
-    return true;
-  }
-
-  [[noreturn]] void reportError(const std::string& what) {
-    reportError(what, cur_);
-  }
-  [[noreturn]] void reportError(const std::string& what, const Token& t) {
-    std::stringstream ss;
-    ss << what << ":\n";
-    t.range.highlight(ss);
-    throw std::runtime_error(ss.str());
-  }
-  [[noreturn]] void expected(const std::string& what, const Token& t) {
-    std::stringstream ss;
-    ss << "expected " << what << " but found '" << t.kindString()
-       << "' here:\n";
-    t.range.highlight(ss);
-    throw std::runtime_error(ss.str());
-  }
-  [[noreturn]] void expected(const std::string& what) {
-    expected(what, cur_);
-  }
-  Token expect(int kind) {
-    if (cur_.kind != kind) {
-      expected(kindToString(kind));
-    }
-    return next();
-  }
-  Token& lookahead() {
-    if (!lookahead_.valid) {
-      lookahead_.valid = true;
-      std::tie(lookahead_.t, lookahead_.repeat) = lex();
-    }
-    return lookahead_.t;
-  }
-  Token& cur() {
-    return cur_;
-  }
-
- private:
-  // token, number of times to repeat it
-  std::pair<Token, int> lex() {
-    auto r = lexRaw();
-    int repeat = 0;
-    switch (r.kind) {
-      case '(':
-      case '[':
-      case '{':
-        nesting++;
-        break;
-      case ')':
-      case ']':
-      case '}':
-        nesting--;
-        break;
-      case TK_WHITESPACE: {
-        size_t depth = r.range.size();
-        if (depth > indent_stack.back()) {
-          indent_stack.push_back(depth);
-          r.kind = TK_INDENT;
-        } else if (depth == indent_stack.back()) {
-          r.kind = TK_NEWLINE;
-        } else {
-          while (indent_stack.back() != depth) {
-            indent_stack.pop_back();
-            repeat++;
-            if (indent_stack.size() == 0) {
-              reportError("invalid ident level", r);
-            }
-          }
-          repeat--; // first repeat is this return
-          r.kind = TK_DEDENT;
-        }
-      } break;
-      case TK_EOF:
-        if (indent_stack.size() > 1) {
-          r.kind = TK_DEDENT;
-          indent_stack.pop_back();
-        }
-        break;
-      default:
-        break;
-    }
-    return std::make_pair(r, repeat);
-  }
-  Token lexRaw(bool whitespace_token = false) {
-    int kind;
-    size_t start;
-    size_t length;
-    assert(file);
-    if (!shared.match(
-            *file,
-            pos,
-            nesting > 0,
-            whitespace_token,
-            &kind,
-            &start,
-            &length)) {
-      expected(
-          "a valid token",
-          Token((*file)[start], SourceRange(file, start, start + 1)));
-    }
-    auto t = Token(kind, SourceRange(file, start, start + length));
-    pos = start + length;
-    return t;
-  }
-  size_t pos;
-  Token cur_;
-  Lookahead lookahead_;
-  size_t repeat; // how many times to repeat the current token until we continue
-
-  size_t nesting; // depth of ( [ { nesting...
-  std::vector<int> indent_stack; // stack of identation level of blocks
-  SharedParserData& shared;
-};
-} // namespace script
-} // namespace caffe2
diff --git a/caffe2/contrib/script/parser.h b/caffe2/contrib/script/parser.h
deleted file mode 100644 (file)
index 4b68b8d..0000000
+++ /dev/null
@@ -1,418 +0,0 @@
-#pragma once
-#include "lexer.h"
-#include "tree.h"
-#include "tree_views.h"
-
-namespace caffe2 {
-namespace script {
-
-struct Parser {
-  explicit Parser(const std::string& str)
-      : L(str), shared(sharedParserData()) {}
-
-  TreeRef parseIdent() {
-    auto t = L.expect(TK_IDENT);
-    // whenever we parse something that has a TreeView type we always
-    // use its create method so that the accessors and the constructor
-    // of the Compound tree are in the same place.
-    return Ident::create(t.range, t.text());
-  }
-  TreeRef createApply(TreeRef ident, TreeList& inputs) {
-    TreeList attributes;
-    auto range = L.cur().range;
-    parseOperatorArguments(inputs, attributes);
-    return Apply::create(
-        range,
-        ident,
-        List(range, std::move(inputs)),
-        List(range, std::move(attributes)));
-  }
-  // things like a 1.0 or a(4) that are not unary/binary expressions
-  // and have higher precedence than all of them
-  TreeRef parseBaseExp() {
-    TreeRef prefix;
-    switch (L.cur().kind) {
-      case TK_NUMBER:
-      case TK_TRUE:
-      case TK_FALSE: {
-        prefix = parseConst();
-      } break;
-      case '(': {
-        L.next();
-        prefix = parseExp();
-        L.expect(')');
-      } break;
-      case TK_FLOAT:
-      case TK_INT:
-      case TK_LONG: {
-        auto r = L.cur().range;
-        auto type = c(L.next().kind, r, {});
-        L.expect('(');
-        auto exp = parseExp();
-        L.expect(')');
-        prefix = Cast::create(r, type, exp);
-      } break;
-      default: {
-        prefix = parseIdent();
-        if (L.cur().kind == '(') {
-          TreeList inputs;
-          prefix = createApply(prefix, inputs);
-        }
-      } break;
-    }
-    while (true) {
-      if (L.nextIf('.')) {
-        const auto name = parseIdent();
-        if (L.cur().kind == '(') {
-          TreeList inputs = {prefix};
-          prefix = createApply(name, inputs);
-        } else {
-          prefix = Select::create(name->range(), prefix, name);
-        }
-      } else if (L.cur().kind == '[') {
-        prefix = parseSliceOrGather(prefix);
-      } else {
-        break;
-      }
-    }
-    return prefix;
-  }
-  TreeRef parseOptionalReduction() {
-    auto r = L.cur().range;
-    switch (L.cur().kind) {
-      case TK_PLUS_EQ:
-      case TK_MINUS_EQ:
-      case TK_TIMES_EQ:
-      case TK_DIV_EQ: {
-        int modifier = L.next().text()[0];
-        return c(modifier, r, {});
-      } break;
-      default: {
-        L.expect('=');
-        return c('=', r, {}); // no reduction
-      } break;
-    }
-  }
-  TreeRef
-  parseTrinary(TreeRef true_branch, const SourceRange& range, int binary_prec) {
-    auto cond = parseExp();
-    L.expect(TK_ELSE);
-    auto false_branch = parseExp(binary_prec);
-    return c(TK_IF_EXPR, range, {cond, true_branch, false_branch});
-  }
-  // parse the longest expression whose binary operators have
-  // precedence strictly greater than 'precedence'
-  // precedence == 0 will parse _all_ expressions
-  // this is the core loop of 'top-down precedence parsing'
-  TreeRef parseExp(int precedence = 0) {
-    TreeRef prefix = nullptr;
-    int unary_prec;
-    if (shared.isUnary(L.cur().kind, &unary_prec)) {
-      auto kind = L.cur().kind;
-      auto pos = L.cur().range;
-      L.next();
-      prefix = c(kind, pos, {parseExp(unary_prec)});
-    } else {
-      prefix = parseBaseExp();
-    }
-    int binary_prec;
-    while (shared.isBinary(L.cur().kind, &binary_prec)) {
-      if (binary_prec <= precedence) // not allowed to parse something which is
-        // not greater than 'precedenc'
-        break;
-
-      int kind = L.cur().kind;
-      auto pos = L.cur().range;
-      L.next();
-      if (shared.isRightAssociative(kind))
-        binary_prec--;
-
-      // special case for trinary operator
-      if (kind == TK_IF) {
-        prefix = parseTrinary(prefix, pos, binary_prec);
-        continue;
-      }
-
-      prefix = c(kind, pos, {prefix, parseExp(binary_prec)});
-    }
-    return prefix;
-  }
-  TreeRef
-  parseList(int begin, int sep, int end, std::function<TreeRef(int)> parse) {
-    auto r = L.cur().range;
-    L.expect(begin);
-    TreeList elements;
-    if (L.cur().kind != end) {
-      int i = 0;
-      do {
-        elements.push_back(parse(i++));
-      } while (L.nextIf(sep));
-    }
-    L.expect(end);
-    return c(TK_LIST, r, std::move(elements));
-  }
-  TreeRef parseNonEmptyList(int sep, std::function<TreeRef(int)> parse) {
-    TreeList elements;
-    int i = 0;
-    do {
-      elements.push_back(parse(i++));
-    } while (L.nextIf(sep));
-    return c(TK_LIST, elements[0]->range(), std::move(elements));
-  }
-  TreeRef parseExpList() {
-    return parseList('(', ',', ')', [&](int i) { return parseExp(); });
-  }
-  TreeRef parseConst() {
-    // 'b' - boolean
-    // 'LL' 64-bit integer
-    // 'f' single-precision float
-    // 'i' 32-bit integer
-    // 'f' is default if '.' appears in the number
-    auto range = L.cur().range;
-    if (L.nextIf(TK_TRUE)) {
-      return c(TK_CONST, range, {d(1), s("b")});
-    } else if (L.nextIf(TK_FALSE)) {
-      return c(TK_CONST, range, {d(0), s("b")});
-    }
-    float mult = 1.0f;
-    while (L.nextIf('-')) {
-      mult *= -1.0f;
-    }
-    auto t = L.expect(TK_NUMBER);
-    std::string type_ident =
-        (t.text().find('.') == std::string::npos) ? "i" : "f";
-    if (L.cur().kind == TK_IDENT) {
-      Token type_ident_tok = L.expect(TK_IDENT);
-      type_ident = type_ident_tok.text();
-      if (type_ident != "LL" && type_ident != "f") {
-        throw ErrorReport(type_ident_tok)
-            << "expected 'f' or 'LL' "
-            << "as numeric type identifier but found '" << type_ident << "'";
-      }
-    }
-    return c(TK_CONST, t.range, {d(mult * t.doubleValue()), s(type_ident)});
-  }
-  TreeRef parseAttributeValue() {
-    int kind = L.cur().kind;
-    switch (kind) {
-      case '[':
-        return parseList('[', ',', ']', [&](int i) { return parseConst(); });
-      default:
-        return parseConst();
-    }
-  }
-  void parseOperatorArguments(TreeList& inputs, TreeList& attributes) {
-    L.expect('(');
-    if (L.cur().kind != ')') {
-      do {
-        if (L.cur().kind == TK_IDENT && L.lookahead().kind == '=') {
-          auto ident = parseIdent();
-          L.expect('=');
-          auto v = parseAttributeValue();
-          attributes.push_back(Attribute::create(ident->range(), ident, v));
-        } else {
-          inputs.push_back(parseExp());
-        }
-      } while (L.nextIf(','));
-    }
-    L.expect(')');
-  }
-
-  // OK: [a] (gather), [a:], [:a], [a:b], [:] (slice)
-  // Not OK: []
-  TreeRef parseSliceOrGather(TreeRef value) {
-    const auto range = L.cur().range;
-    L.expect('[');
-
-    // `first` will either be the gather indices, or the start of the slice.
-    TreeRef first, second;
-
-    // Here we can either have a colon (which starts a slice), or an expression.
-    // If an expression, we don't know yet if it will be a slice or a gather.
-    if (L.cur().kind != ':') {
-      first = parseExp();
-      if (L.nextIf(']')) {
-        return Gather::create(range, value, first);
-      } else {
-        first = c(TK_OPTION, range, {first});
-      }
-    } else {
-      first = c(TK_OPTION, range, {});
-    }
-    L.expect(':');
-    // Now we *may* have an expression.
-    if (L.cur().kind != ']') {
-      second = c(TK_OPTION, range, {parseExp()});
-    } else {
-      second = c(TK_OPTION, range, {});
-    }
-    L.expect(']');
-
-    return Slice::create(range, value, first, second);
-  }
-  TreeRef parseIdentList() {
-    return parseList('(', ',', ')', [&](int i) { return parseIdent(); });
-  }
-  TreeRef parseParam() {
-    auto typ = parseType();
-    if (L.cur().kind != TK_IDENT && typ->trees()[0]->kind() == TK_IDENT) {
-      // oops, it wasn't a type but just a param without any type specified
-      return Param::create(
-          typ->range(), typ->trees()[0], c(TK_INFERRED, typ->range(), {}));
-    }
-    auto ident = parseIdent();
-    return Param::create(typ->range(), ident, typ);
-  }
-  // TODO: these functions should be unnecessary, but we currently do not
-  // emit a TK_NEWLINE before a series of TK_DEDENT tokens
-  // so if we see a TK_DEDENT then we know a newline must have happened and
-  // ignore it. The real fix is to patch the lexer so TK_NEWLINE does get
-  // emited before a TK_INDENT
-  void expectEndOfLine() {
-    if (L.cur().kind != TK_DEDENT)
-      L.expect(TK_NEWLINE);
-  }
-  bool isEndOfLine() {
-    return L.cur().kind == TK_NEWLINE || L.cur().kind == TK_DEDENT;
-  }
-
-  // 'first' has already been parsed since expressions can exist
-  // alone on a line:
-  // first[,other,lhs] = rhs
-  TreeRef parseAssign(TreeRef first) {
-    TreeRef list = parseOneOrMoreExp(first);
-    auto red = parseOptionalReduction();
-    auto rhs = parseExp();
-    expectEndOfLine();
-    return Assign::create(list->range(), list, red, rhs);
-  }
-  TreeRef parseStmt() {
-    switch (L.cur().kind) {
-      case TK_IF:
-        return parseIf();
-      case TK_WHILE:
-        return parseWhile();
-      case TK_GLOBAL: {
-        auto range = L.next().range;
-        std::vector<TreeRef> idents;
-        do {
-          idents.push_back(parseIdent());
-        } while (L.nextIf(','));
-        expectEndOfLine();
-        return c(TK_GLOBAL, range, std::move(idents));
-      }
-      default: {
-        auto r = parseExp();
-        if (!isEndOfLine()) {
-          return parseAssign(r);
-        } else {
-          expectEndOfLine();
-          return r;
-        }
-      }
-    }
-  }
-  TreeRef parseScalarType() {
-    switch (L.cur().kind) {
-      case TK_INT:
-      case TK_FLOAT:
-      case TK_LONG:
-      case TK_DOUBLE: {
-        auto t = L.next();
-        return c(t.kind, t.range, {});
-      }
-      default:
-        return parseIdent();
-    }
-  }
-  TreeRef parseOptionalIdentList() {
-    TreeRef list = nullptr;
-    if (L.cur().kind == '(') {
-      list = parseIdentList();
-    } else {
-      list = c(TK_LIST, L.cur().range, {});
-    }
-    return list;
-  }
-  TreeRef parseType() {
-    auto st = parseScalarType();
-    auto list = parseOptionalIdentList();
-    return TensorType::create(st->range(), st, list);
-  }
-  // 'first' has already been parsed, add the rest
-  // if they exist
-  // first[, the, rest]
-  TreeRef parseOneOrMoreExp(TreeRef first) {
-    TreeList list{first};
-    while (L.nextIf(',')) {
-      list.push_back(parseExp());
-    }
-    return List(list.back()->range(), std::move(list));
-  }
-  TreeRef parseIf() {
-    auto r = L.cur().range;
-    L.expect(TK_IF);
-    auto cond = parseExp();
-    L.expect(':');
-    auto true_branch = parseStatements();
-    auto false_branch = List(L.cur().range, {});
-    if (L.nextIf(TK_ELSE)) {
-      L.expect(':');
-      false_branch = parseStatements();
-    }
-    return If::create(r, cond, true_branch, false_branch);
-  }
-  TreeRef parseWhile() {
-    auto r = L.cur().range;
-    L.expect(TK_WHILE);
-    auto cond = parseExp();
-    L.expect(':');
-    auto body = parseStatements();
-    return While::create(r, cond, body);
-  }
-  TreeRef parseStatements() {
-    auto r = L.cur().range;
-    L.expect(TK_INDENT);
-    TreeList stmts;
-    while (true) {
-      stmts.push_back(parseStmt());
-      if (L.nextIf(TK_DEDENT))
-        break;
-    }
-    return c(TK_LIST, r, std::move(stmts));
-  }
-  TreeRef parseFunction() {
-    L.expect(TK_DEF);
-    auto name = parseIdent();
-    auto paramlist =
-        parseList('(', ',', ')', [&](int i) { return parseParam(); });
-    L.expect(TK_ARROW);
-    auto retlist =
-        parseList('(', ',', ')', [&](int i) { return parseParam(); });
-    L.expect(':');
-    auto stmts_list = parseStatements();
-    return Def::create(name->range(), name, paramlist, retlist, stmts_list);
-  }
-  Lexer& lexer() {
-    return L;
-  }
-
- private:
-  // short helpers to create nodes
-  TreeRef d(double v) {
-    return Number::create(v);
-  }
-  TreeRef s(const std::string& s) {
-    return String::create(s);
-  }
-  TreeRef c(int kind, const SourceRange& range, TreeList&& trees) {
-    return Compound::create(kind, range, std::move(trees));
-  }
-  TreeRef List(const SourceRange& range, TreeList&& trees) {
-    return c(TK_LIST, range, std::move(trees));
-  }
-  Lexer L;
-  SharedParserData& shared;
-};
-} // namespace script
-} // namespace caffe2
diff --git a/caffe2/contrib/script/tree.h b/caffe2/contrib/script/tree.h
deleted file mode 100644 (file)
index c508308..0000000
+++ /dev/null
@@ -1,233 +0,0 @@
-#pragma once
-
-#include <memory>
-#include <vector>
-
-#include "caffe2/contrib/script/lexer.h"
-
-namespace caffe2 {
-namespace script {
-
-// Tree's are used to represent all forms of TC IR, pre- and post- typechecking.
-// Rather than have a full class hierarchy for all TC statements,
-// Trees are a slight variation of Lisp S-expressions.
-// for instance the expression a*b+1 is represented as:
-// (+ (* (ident a) (ident b)) (const 1))
-// Atoms like 'a', 'b', and '1' are represented by subclasses of Tree which
-// define stringValue() and doubleValue().
-// Everything else is a Compound object, which has a 'kind' that is a token from
-// Lexer.h's TokenKind enum, and contains a list of subtrees.
-// Like TokenKind single-character operators like '+' are representing using the
-// character itself, so add.kind() == '+'.
-// Compound objects are also always associated with a SourceRange for
-// reporting error message.
-
-// Memory management of trees is done using shared_ptr.
-
-struct Tree;
-using TreeRef = std::shared_ptr<Tree>;
-using TreeList = std::vector<TreeRef>;
-
-static const TreeList empty_trees = {};
-
-struct Tree : std::enable_shared_from_this<Tree> {
-  Tree(int kind_) : kind_(kind_) {}
-  int kind() const {
-    return kind_;
-  }
-  virtual bool isAtom() const {
-    return true;
-  }
-  virtual const SourceRange& range() const {
-    throw std::runtime_error("is an Atom");
-  }
-  virtual double doubleValue() const {
-    throw std::runtime_error("not a TK_NUMBER");
-  }
-  virtual const std::string& stringValue() const {
-    throw std::runtime_error("not a TK_STRING");
-  }
-  virtual bool boolValue() const {
-    throw std::runtime_error("not a TK_BOOL");
-  }
-  virtual const TreeList& trees() const {
-    return empty_trees;
-  }
-  const TreeRef& tree(size_t i) const {
-    return trees().at(i);
-  }
-  virtual TreeRef map(std::function<TreeRef(TreeRef)> /*fn*/) {
-    return shared_from_this();
-  }
-  template <typename... Args>
-  void match(int k, Args&... args) {
-    matchD(k, "unknown", 0, args...);
-  }
-  template <typename... Args>
-  void matchD(int k, const char* filename, int lineno, Args&... args) {
-    if (kind() != k) {
-      std::stringstream ss;
-      ss << filename << ":" << lineno << ": expecting kind '" << kindToString(k)
-         << "' but found '" << kind() << "'\n";
-      range().highlight(ss);
-      throw std::runtime_error(ss.str());
-    }
-    std::initializer_list<TreeRef*> vars = {&args...};
-    if (vars.size() > trees().size()) {
-      std::stringstream ss;
-      ss << filename << ":" << lineno << ": trying to match " << vars.size()
-         << " variables against " << trees().size() << " values in list.\n";
-      range().highlight(ss);
-      throw std::runtime_error(ss.str());
-    }
-    size_t i = 0;
-    for (TreeRef* v : vars) {
-      *v = trees()[i++];
-    }
-  }
-  virtual ~Tree() {}
-
- private:
-  int kind_;
-};
-
-struct String : public Tree {
-  String(const std::string& value_) : Tree(TK_STRING), value_(value_) {}
-  virtual const std::string& stringValue() const override {
-    return value_;
-  }
-  template <typename... Args>
-  static TreeRef create(Args&&... args) {
-    return std::make_shared<String>(std::forward<Args>(args)...);
-  }
-
- private:
-  std::string value_;
-};
-struct Number : public Tree {
-  Number(double value_) : Tree(TK_NUMBER), value_(value_) {}
-  virtual double doubleValue() const override {
-    return value_;
-  }
-  template <typename... Args>
-  static TreeRef create(Args&&... args) {
-    return std::make_shared<Number>(std::forward<Args>(args)...);
-  }
-
- private:
-  double value_;
-};
-struct Bool : public Tree {
-  Bool(bool value_) : Tree(TK_BOOL), value_(value_) {}
-  virtual double doubleValue() const override {
-    return value_;
-  }
-  template <typename... Args>
-  static TreeRef create(Args&&... args) {
-    return std::make_shared<Bool>(std::forward<Args>(args)...);
-  }
-
- private:
-  bool value_;
-};
-
-static SourceRange mergeRanges(SourceRange c, const TreeList& others) {
-  for (auto t : others) {
-    if (t->isAtom())
-      continue;
-    size_t s = std::min(c.start(), t->range().start());
-    size_t e = std::max(c.end(), t->range().end());
-    c = SourceRange(c.file_ptr(), s, e);
-  }
-  return c;
-}
-
-struct Compound : public Tree {
-  Compound(int kind, const SourceRange& range_) : Tree(kind), range_(range_) {}
-  Compound(int kind, const SourceRange& range_, TreeList&& trees_)
-      : Tree(kind),
-        range_(mergeRanges(range_, trees_)),
-        trees_(std::move(trees_)) {}
-  virtual const TreeList& trees() const override {
-    return trees_;
-  }
-  static TreeRef
-  create(int kind, const SourceRange& range_, TreeList&& trees_) {
-    return std::make_shared<Compound>(kind, range_, std::move(trees_));
-  }
-  virtual bool isAtom() const override {
-    return false;
-  }
-  virtual TreeRef map(std::function<TreeRef(TreeRef)> fn) override {
-    TreeList trees_;
-    for (auto& t : trees()) {
-      trees_.push_back(fn(t));
-    }
-    return Compound::create(kind(), range(), std::move(trees_));
-  }
-  const SourceRange& range() const override {
-    return range_;
-  }
-
- private:
-  SourceRange range_;
-  TreeList trees_;
-};
-
-// tree pretty printer
-struct pretty_tree {
-  pretty_tree(const TreeRef& tree, size_t col = 40) : tree(tree), col(col) {}
-  const TreeRef& tree;
-  size_t col;
-  std::unordered_map<TreeRef, std::string> flat_strings;
-  const std::string& get_flat(const TreeRef& t) {
-    auto it = flat_strings.find(t);
-    if (it != flat_strings.end())
-      return it->second;
-
-    std::stringstream out;
-    switch (t->kind()) {
-      case TK_NUMBER:
-        out << t->doubleValue();
-        break;
-      case TK_STRING:
-        out << t->stringValue();
-        break;
-      default:
-        out << "(" << kindToString(t->kind());
-        for (auto e : t->trees()) {
-          out << " " << get_flat(e);
-        }
-        out << ")";
-        break;
-    }
-    auto it_ = flat_strings.emplace(t, out.str());
-    return it_.first->second;
-  }
-  void print(std::ostream& out, const TreeRef& t, int indent) {
-    const std::string& s = get_flat(t);
-    if (indent + s.size() < col || t->isAtom()) {
-      out << s;
-      return;
-    }
-    std::string k = kindToString(t->kind());
-    out << "(" << k;
-    for (auto e : t->trees()) {
-      out << "\n" << std::string(indent + 2, ' ');
-      print(out, e, indent + 2);
-    }
-    out << ")";
-  }
-};
-
-static inline std::ostream& operator<<(std::ostream& out, pretty_tree t_) {
-  t_.print(out, t_.tree, 0);
-  return out << std::endl;
-}
-
-static inline std::ostream& operator<<(std::ostream& out, TreeRef t) {
-  return out << pretty_tree(t);
-}
-
-} // namespace script
-} // namespace caffe2
diff --git a/caffe2/contrib/script/tree_views.h b/caffe2/contrib/script/tree_views.h
deleted file mode 100644 (file)
index 2089333..0000000
+++ /dev/null
@@ -1,442 +0,0 @@
-#pragma once
-#include "error_report.h"
-#include "tree.h"
-
-namespace caffe2 {
-namespace script {
-
-// TreeView provides a statically-typed way to access the members of a TreeRef
-// instead of using TK_MATCH
-
-struct TreeView {
-  explicit TreeView(const TreeRef& tree_) : tree_(tree_) {}
-  TreeRef tree() const {
-    return tree_;
-  }
-  const SourceRange& range() const {
-    return tree_->range();
-  }
-  operator TreeRef() const {
-    return tree_;
-  }
-
- protected:
-  TreeRef tree_;
-};
-
-template <typename T>
-struct ListViewIterator {
-  ListViewIterator(TreeList::const_iterator it) : it(it) {}
-  bool operator!=(const ListViewIterator& rhs) const {
-    return it != rhs.it;
-  }
-  T operator*() const {
-    return T(*it);
-  }
-  void operator++() {
-    ++it;
-  }
-  void operator--() {
-    --it;
-  }
-
- private:
-  TreeList::const_iterator it;
-};
-
-template <typename T>
-struct ListView : public TreeView {
-  ListView(const TreeRef& tree) : TreeView(tree) {
-    tree->match(TK_LIST);
-  }
-  typedef ListViewIterator<T> iterator;
-  typedef ListViewIterator<T> const_iterator;
-  iterator begin() const {
-    return iterator(tree_->trees().begin());
-  }
-  iterator end() const {
-    return iterator(tree_->trees().end());
-  }
-  T operator[](size_t i) const {
-    return T(tree_->trees().at(i));
-  }
-  TreeRef map(std::function<TreeRef(const T&)> fn) {
-    return tree_->map([&](TreeRef v) { return fn(T(v)); });
-  }
-  size_t size() const {
-    return tree_->trees().size();
-  }
-};
-
-template <typename T>
-struct OptionView : public TreeView {
-  explicit OptionView(const TreeRef& tree) : TreeView(tree) {
-    C2S_ASSERT(tree, tree->kind() == TK_OPTION);
-  }
-  bool present() const {
-    return tree_->trees().size() > 0;
-  }
-  T get() const {
-    C2S_ASSERT(tree_, present());
-    return T(tree_->trees()[0]);
-  }
-  TreeRef map(std::function<TreeRef(const T&)> fn) {
-    return tree_->map([&](TreeRef v) { return fn(T(v)); });
-  }
-};
-
-struct Ident : public TreeView {
-  // each subclass of TreeView provides:
-  // 1. a constructor that takes a TreeRef, and matches it to the right type.
-  explicit Ident(const TreeRef& tree) : TreeView(tree) {
-    tree_->match(TK_IDENT, name_);
-  }
-  // 2. accessors that get underlying information out of the object
-  // in this case, we return the name of the identifier, and handle the
-  // converstion to a string in the method
-  const std::string& name() const {
-    return name_->stringValue();
-  }
-
-  // 3. a static method 'create' that creates the underlying TreeRef object
-  // for every TreeRef kind that has a TreeView, the parser always uses
-  // (e.g.) Ident::create rather than Compound::Create, this means that
-  // changes to the structure of Ident are always made right here rather
-  // than both in the parser and in this code
-  static TreeRef create(const SourceRange& range, const std::string& name) {
-    return Compound::create(TK_IDENT, range, {String::create(name)});
-  }
-
- private:
-  TreeRef name_;
-};
-
-struct Attribute : public TreeView {
-  explicit Attribute(const TreeRef& tree) : TreeView(tree) {
-    tree_->match(TK_ATTRIBUTE, name_, value_);
-  }
-  Ident name() const {
-    return Ident(name_);
-  }
-  TreeRef value() const {
-    return value_;
-  }
-  static TreeRef create(const SourceRange& range, TreeRef name, TreeRef value) {
-    return Compound::create(TK_ATTRIBUTE, range, {name, value});
-  }
-
- private:
-  TreeRef name_;
-  TreeRef value_;
-};
-
-struct Apply : public TreeView {
-  explicit Apply(const TreeRef& tree) : TreeView(tree) {
-    tree_->match(TK_APPLY, name_, inputs_, attributes_);
-  }
-
-  Ident name() const {
-    return Ident(name_);
-  }
-  ListView<TreeRef> inputs() const {
-    return ListView<TreeRef>(inputs_);
-  }
-  ListView<Attribute> attributes() const {
-    return ListView<Attribute>(attributes_);
-  }
-
-  static TreeRef create(
-      const SourceRange& range,
-      TreeRef name,
-      TreeRef inputs,
-      TreeRef attributes) {
-    return Compound::create(TK_APPLY, range, {name, inputs, attributes});
-  }
-
- private:
-  TreeRef name_;
-  TreeRef inputs_;
-  TreeRef attributes_;
-};
-
-struct Slice : public TreeView {
-  explicit Slice(const TreeRef& tree) : TreeView(tree) {
-    tree_->match(TK_SLICE, value_, start_, end_);
-  }
-
-  TreeRef value() const {
-    return value_;
-  }
-
-  OptionView<TreeRef> start() const {
-    return OptionView<TreeRef>(start_);
-  }
-
-  OptionView<TreeRef> end() const {
-    return OptionView<TreeRef>(end_);
-  }
-
-  TreeRef startOr(int alternative) const {
-    const auto startOption = start();
-    return startOption.present() ? startOption.get() : createInt(alternative);
-  }
-
-  TreeRef endOr(int alternative) const {
-    const auto endOption = end();
-    return endOption.present() ? endOption.get() : createInt(alternative);
-  }
-
-  static TreeRef
-  create(const SourceRange& range, TreeRef value, TreeRef start, TreeRef end) {
-    return Compound::create(TK_SLICE, range, {value, start, end});
-  }
-
- private:
-  TreeRef createInt(int value) const {
-    return Compound::create(
-        TK_CONST, range(), {Number::create(value), String::create("i")});
-  }
-
-  TreeRef value_;
-  TreeRef start_;
-  TreeRef end_;
-};
-
-struct Gather : public TreeView {
-  explicit Gather(const TreeRef& tree) : TreeView(tree) {
-    tree_->match(TK_GATHER, value_, indices_);
-  }
-
-  TreeRef value() const {
-    return value_;
-  }
-
-  TreeRef indices() const {
-    return indices_;
-  }
-
-  static TreeRef
-  create(const SourceRange& range, TreeRef value, TreeRef indices) {
-    return Compound::create(TK_GATHER, range, {value, indices});
-  }
-
- private:
-  TreeRef value_;
-  TreeRef indices_;
-};
-
-struct Cast : public TreeView {
-  explicit Cast(const TreeRef& tree) : TreeView(tree) {
-    tree_->match(TK_CAST, type_, input_);
-  }
-
-  int type() const {
-    return type_->kind();
-  }
-  TreeRef input() const {
-    return input_;
-  }
-
-  static TreeRef create(const SourceRange& range, TreeRef type, TreeRef input) {
-    return Compound::create(TK_CAST, range, {type, input});
-  }
-
- private:
-  TreeRef type_;
-  TreeRef input_;
-};
-
-struct TensorType : public TreeView {
-  explicit TensorType(const TreeRef& tree) : TreeView(tree) {
-    tree_->match(TK_TENSOR_TYPE, scalar_type_, dims_);
-  }
-  static TreeRef
-  create(const SourceRange& range, TreeRef scalar_type_, TreeRef dims_) {
-    return Compound::create(TK_TENSOR_TYPE, range, {scalar_type_, dims_});
-  }
-  int scalarType() const {
-    if (scalar_type_->kind() == TK_IDENT)
-      throw ErrorReport(tree_)
-          << " TensorType has a symbolic ident " << Ident(scalar_type_).name()
-          << " rather than a concrete type";
-    return scalar_type_->kind();
-  }
-  ListView<Ident> dims() const {
-    return ListView<Ident>(dims_);
-  }
-
- private:
-  TreeRef scalar_type_;
-  TreeRef dims_;
-};
-
-struct Param : public TreeView {
-  explicit Param(const TreeRef& tree) : TreeView(tree) {
-    tree_->match(TK_PARAM, ident_, type_);
-  }
-  static TreeRef create(const SourceRange& range, TreeRef ident, TreeRef type) {
-    return Compound::create(TK_PARAM, range, {ident, type});
-  }
-  // when the type of a field is statically know the accessors return
-  // the wrapped type. for instance here we know ident_ is an identifier
-  // so the accessor returns an Ident
-  // this means that clients can do p.ident().name() to get the name of the
-  // parameter.
-  Ident ident() const {
-    return Ident(ident_);
-  }
-  // may be TensorType or TK_INFERRED
-  TreeRef type() const {
-    return type_;
-  }
-  bool typeIsInferred() const {
-    return type_->kind() == TK_INFERRED;
-  }
-  // helper for when you know the type is not inferred.
-  TensorType tensorType() const {
-    return TensorType(type_);
-  }
-
- private:
-  TreeRef ident_;
-  TreeRef type_;
-};
-
-struct Assign : public TreeView {
-  explicit Assign(const TreeRef& tree) : TreeView(tree) {
-    tree_->match(TK_ASSIGN, lhs_, reduction_, rhs_);
-  }
-  static TreeRef create(
-      const SourceRange& range,
-      TreeRef lhs,
-      TreeRef reduction,
-      TreeRef rhs) {
-    return Compound::create(TK_ASSIGN, range, {lhs, reduction, rhs});
-  }
-  // when the type of a field is statically know the accessors return
-  // the wrapped type. for instance here we know ident_ is an identifier
-  // so the accessor returns an Ident
-  // this means that clients can do p.ident().name() to get the name of the
-  // parameter.
-  ListView<TreeRef> lhs() const {
-    return ListView<TreeRef>(lhs_);
-  }
-  int reduction() const {
-    return reduction_->kind();
-  }
-  TreeRef rhs() const {
-    return rhs_;
-  }
-
- private:
-  TreeRef lhs_;
-  TreeRef reduction_;
-  TreeRef rhs_;
-};
-
-struct Def : public TreeView {
-  explicit Def(const TreeRef& tree) : TreeView(tree) {
-    tree->match(TK_DEF, name_, paramlist, retlist, stmts_list);
-  }
-  Ident name() const {
-    return Ident(name_);
-  }
-  // ListView helps turn TK_LISTs into vectors of TreeViews
-  // so that we can, e.g., return lists of parameters
-  ListView<Param> params() const {
-    return ListView<Param>(paramlist);
-  }
-  ListView<Param> returns() const {
-    return ListView<Param>(retlist);
-  }
-  ListView<TreeRef> statements() const {
-    return ListView<TreeRef>(stmts_list);
-  }
-  static TreeRef create(
-      const SourceRange& range,
-      TreeRef name,
-      TreeRef paramlist,
-      TreeRef retlist,
-      TreeRef stmts_list) {
-    return Compound::create(
-        TK_DEF, range, {name, paramlist, retlist, stmts_list});
-  }
-
- private:
-  TreeRef name_;
-  TreeRef paramlist;
-  TreeRef retlist;
-  TreeRef stmts_list;
-};
-
-struct Select : public TreeView {
-  explicit Select(const TreeRef& tree) : TreeView(tree) {
-    tree_->match('.', value_, selector_);
-  }
-  TreeRef value() const {
-    return value_;
-  }
-  Ident selector() const {
-    return Ident(selector_);
-  }
-  static TreeRef
-  create(const SourceRange& range, TreeRef value, TreeRef selector) {
-    return Compound::create('.', range, {value, selector});
-  }
-
- private:
-  TreeRef value_;
-  TreeRef selector_;
-};
-
-struct If : public TreeView {
-  explicit If(const TreeRef& tree) : TreeView(tree) {
-    tree_->match(TK_IF, cond_, true_branch_, false_branch_);
-  }
-  const TreeRef& cond() const {
-    return cond_;
-  }
-  ListView<TreeRef> trueBranch() const {
-    return ListView<TreeRef>(true_branch_);
-  }
-  ListView<TreeRef> falseBranch() const {
-    return ListView<TreeRef>(false_branch_);
-  }
-
-  static TreeRef create(
-      const SourceRange& range,
-      TreeRef cond_,
-      TreeRef true_branch_,
-      TreeRef false_branch_) {
-    return Compound::create(TK_IF, range, {cond_, true_branch_, false_branch_});
-  }
-
- private:
-  TreeRef cond_;
-  TreeRef true_branch_;
-  TreeRef false_branch_;
-};
-
-struct While : public TreeView {
-  explicit While(const TreeRef& tree) : TreeView(tree) {
-    tree_->match(TK_WHILE, cond_, body_);
-  }
-  const TreeRef& cond() const {
-    return cond_;
-  }
-  ListView<TreeRef> body() const {
-    return ListView<TreeRef>(body_);
-  }
-
-  static TreeRef
-  create(const SourceRange& range, TreeRef cond_, TreeRef body_) {
-    return Compound::create(TK_WHILE, range, {cond_, body_});
-  }
-
- private:
-  TreeRef cond_;
-  TreeRef body_;
-};
-
-} // namespace script
-} // namespace caffe2
index 709879a..a4a1509 100644 (file)
@@ -6,7 +6,6 @@
 #include <pybind11/pybind11.h>
 #include <pybind11/stl.h>
 
-#include "caffe2/contrib/script/compiler.h"
 #include "caffe2/core/asan.h"
 #include "caffe2/core/blob_stats.h"
 #include "caffe2/core/db.h"
@@ -938,29 +937,6 @@ void addObjectMethods(py::module& m) {
             }
             return pyout;
           });
-
-  py::class_<script::CompilationUnit>(m, "CompilationUnit")
-      .def(py::init<>())
-      .def("define", &script::CompilationUnit::define)
-      .def("get_proto", &script::CompilationUnit::getProto)
-      .def(
-          "create_net",
-          [](script::CompilationUnit* self, const std::string& name) {
-            auto net = self->createNet(gWorkspace, name);
-            CAFFE_ENFORCE(net);
-            return net;
-          })
-      .def(
-          "extern",
-          [](script::CompilationUnit* self,
-             const std::string& name,
-             py::object py_proto) {
-            py::bytes bytes = py_proto.attr("SerializeToString")();
-            std::unique_ptr<caffe2::NetDef> proto(new NetDef());
-            CAFFE_ENFORCE(ParseProtoFromLargeString(
-                bytes.cast<std::string>(), proto.get()));
-            self->defineExtern(name, std::move(proto));
-          });
 }
 
 void addGlobalMethods(py::module& m) {