From: Jerry Zhang Date: Mon, 10 Dec 2018 22:17:43 +0000 (-0800) Subject: caffe2/caffe2/contrib/script (#15007) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~2357 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=a51fe386c8e5f3198c948fe0a08458892a7901c4;p=platform%2Fupstream%2Fpytorch.git caffe2/caffe2/contrib/script (#15007) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15007 Pull Request resolved: https://github.com/pytorch/pytorch/pull/14979 att Reviewed By: dzhulgakov Differential Revision: D13286191 fbshipit-source-id: b8a6bc7aea44487aea4dcf7f44c858fd30c6293c --- diff --git a/caffe2/contrib/CMakeLists.txt b/caffe2/contrib/CMakeLists.txt index be8c0bd..6034e4d 100644 --- a/caffe2/contrib/CMakeLists.txt +++ b/caffe2/contrib/CMakeLists.txt @@ -4,7 +4,6 @@ add_subdirectory(nccl) add_subdirectory(opencl) add_subdirectory(prof) add_subdirectory(shm_mutex) -add_subdirectory(script) if (USE_TENSORRT) add_subdirectory(tensorrt) endif() diff --git a/caffe2/contrib/script/CMakeLists.txt b/caffe2/contrib/script/CMakeLists.txt deleted file mode 100644 index fb38787..0000000 --- a/caffe2/contrib/script/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -# ---[ CPU files. -file(GLOB tmp *.cc) -set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${tmp}) -# exclude test files and gpu files -file(GLOB tmp *_test.cc) -exclude(Caffe2_CPU_SRCS "${Caffe2_CPU_SRCS}" ${tmp}) -exclude(Caffe2_CPU_SRCS "${Caffe2_CPU_SRCS}" ${Caffe2_GPU_SRCS}) - -# ---[ CPU test files -file(GLOB tmp *_test.cc) -set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} ${tmp}) -exclude(Caffe2_CPU_TEST_SRCS "${Caffe2_CPU_TEST_SRCS}" ${Caffe2_GPU_TEST_SRCS}) - -# ---[ Send the lists to the parent scope. -set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE) -set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} PARENT_SCOPE) diff --git a/caffe2/contrib/script/caffe2_script_test.py b/caffe2/contrib/script/caffe2_script_test.py deleted file mode 100644 index d9f0b65..0000000 --- a/caffe2/contrib/script/caffe2_script_test.py +++ /dev/null @@ -1,520 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - -from hypothesis import given - -from caffe2.python import core, workspace -from caffe2.proto import caffe2_pb2 -import caffe2.python.hypothesis_test_util as hu -import hypothesis.strategies as st - -import numpy as np - - -def feed_inputs(inputs): - for name, value in inputs.items(): - workspace.FeedBlob(name, value) - - -def assert_proto_equals(proto, expected): - proto_lines = proto.strip().split('\n') - expected_lines = expected.strip().split('\n') - assert len(proto_lines) == len(expected_lines), \ - '{} != {}'.format(proto, expected) - for left, right in zip(proto_lines, expected_lines): - assert left.strip() == right.strip(), \ - '{} != {}'.format(proto, expected) - - -class TestCaffe2Script(hu.HypothesisTestCase): - test_program = """ - def foo(a,b,X,W) -> (c): - t = a + b*b - c = FC(X,W,t) - def testIf(c0,c1,t,f) -> (r): - if c0 < c1: - r = t - else: - r = f - r = Add(r,3f,broadcast=1) - def testWhile(r) -> (r): - m = 0 - while m < 4: - # Plus operator automatically broadcasts, and we cannot - # do in-place B and C arguments when we broadcast, so use - # an explicit Add op. - r = Add(r, r) - m = m + 1 - """ - - @given(firstdim=st.integers(min_value=1, max_value=4096), - seconddim=st.integers(min_value=1, max_value=4096), - seed=st.integers(min_value=0, max_value=65536), - **hu.gcs) - def test_foo(self, firstdim, seconddim, seed, gc, dc): - np.random.seed(int(seed)) - inputs = {} - a = inputs['a'] = np.random.rand(seconddim).astype(np.float32) - b = inputs['b'] = np.random.rand(seconddim).astype(np.float32) - X = inputs['X'] = np.random.rand(firstdim, firstdim).astype(np.float32) - W = inputs['W'] = np.random.rand(seconddim, firstdim).astype(np.float32) - - feed_inputs(inputs) - - CU = core.C.CompilationUnit() - CU.define(self.test_program) - CU.create_net('foo').run() - - ref_t = a + b * b - ref_c = np.matmul(X, W.transpose()) + ref_t - actual_c = workspace.FetchBlob('c') - - np.testing.assert_allclose(actual_c, ref_c, rtol=1e-05) - - def test_trinary(self): - CU = core.C.CompilationUnit() - CU.define(""" - def foo(c) -> (d): - d = 1 + (2 if c else 4) - """) - workspace.FeedBlob('c', np.ones((1), dtype=bool)) - net = CU.create_net('foo') - net.run() - assert(3 == workspace.FetchBlob('d')) - workspace.FeedBlob('c', np.zeros((1), dtype=bool)) - net.run() - assert(5 == workspace.FetchBlob('d')) - - def test_bool_literal(self): - CU = core.C.CompilationUnit() - CU.define(""" - def foo() -> (a,b): - a = True - b = False - """) - net = CU.create_net('foo') - net.run() - assert(workspace.FetchBlob('a')) - assert(not workspace.FetchBlob('b')) - - def test_bool_operators(self): - CU = core.C.CompilationUnit() - CU.define(""" - def foo() -> (a, b, c, d, e): - a = True and False - b = True or False - c = not b - d = not False or True - e = not (1 if a else 0) == (1 if b else 0) - """) - net = CU.create_net('foo') - net.run() - assert(not workspace.FetchBlob('a')) - assert(workspace.FetchBlob('b')) - assert(not workspace.FetchBlob('c')) - assert(workspace.FetchBlob('d')) - assert(workspace.FetchBlob('e')) - - def expect_fail(self, fn, msg): - try: - fn() - except RuntimeError as r: - if msg not in str(r): - raise RuntimeError( - "Failed wrong: expected string '{}' ".format(msg) + - "in error message but found\n{}".format(str(r))) - - def test_fails(self): - def fail_inputs(): - CU = core.C.CompilationUnit() - CU.define(""" - def foo() -> (): - Print(1,4) - """) - self.expect_fail(fail_inputs, "expects 1 inputs but found 2") - - def fail_undef(): - CU = core.C.CompilationUnit() - CU.define(""" - def foo(a) -> (b): - a = what() - """) - self.expect_fail(fail_undef, "attempting to call unknown operation") - - def fail_schema(): - CU = core.C.CompilationUnit() - CU.define(""" - def foo(a) -> (b): - a = FC(a,a,a) - """) - self.expect_fail(fail_schema, "failed schema checking") - - def test_print(self): - CU = core.C.CompilationUnit() - CU.define(""" - def foo() -> (): - a = 1 - Print(a) - Print(a+1) - _ = 4 - Print(_) # verify in print this isn't _ but some temorary - Print(1) - Print(1.f) - Print(3.0) - """) - net = CU.create_net('foo') - net.run() - - def test_method(self): - CU = core.C.CompilationUnit() - CU.define(""" - def foo() -> (a): - a = (3+1).Add(4).Add(1) - """) - net = CU.create_net('foo') - net.run() - assert(9 == workspace.FetchBlob('a')) - - def test_plus_eq(self): - CU = core.C.CompilationUnit() - CU.define(""" - def foo() -> (a): - a = 4 - a += 1 - """) - net = CU.create_net('foo') - net.run() - assert(5 == workspace.FetchBlob('a')) - - def test_cast(self): - CU = core.C.CompilationUnit() - CU.define(""" - def foo() -> (a): - a = int(4.5f) - """) - net = CU.create_net('foo') - net.run() - assert(4 == workspace.FetchBlob('a')) - - def test_global(self): - CU = core.C.CompilationUnit() - CU.define(""" - def foo() -> (a): - global m - m.a = 4 - m.b = 5 - a = m.a + m.b - """) - net = CU.create_net('foo') - net.run() - assert(9 == workspace.FetchBlob('a')) - - def test_module_as_arg_ret(self): - CU = core.C.CompilationUnit() - CU.define(""" - def bar(a,c) -> (b): - b = Module() - temp = a.second - b.first = temp - b.second = a.first + c - def foo() -> (a,b): - x = Module() - x.first = 1 - x.second = 2 - x.y = bar(x,4) - a = x.y.first - b = x.y.second - """) - net = CU.create_net('foo') - net.run() - assert(2 == workspace.FetchBlob('a')) - assert(5 == workspace.FetchBlob('b')) - - def test_call_extern(self): - CU = core.C.CompilationUnit() - net = caffe2_pb2.NetDef() - net.op.extend([ - core.CreateOperator( - 'Mul', - ['i', 'i'], - ['o'], - ) - ]) - net.external_input.append('i') - net.external_output.append('o') - - CU.extern("myActualExtern", net) - CU.define(""" - def myExtern(x) -> (y): - t = x - if t > 1: - y = t * t - else: - y = 5 - def foo() -> (b): - a = 4 - a += 1 - b = 2 + myExtern(a) + myExtern(a, rename=False) + myActualExtern(a) - """) - net = CU.create_net('foo') - net.run() - assert(77 == workspace.FetchBlob('b')) - - @given(seed=st.integers(min_value=0, max_value=65536), **hu.gcs) - def test_if(self, seed, gc, dc): - np.random.seed(int(seed)) - inputs = {} - c0 = inputs['c0'] = np.random.rand(1).astype(np.float32) - c1 = inputs['c1'] = np.random.rand(1).astype(np.float32) - t = inputs['t'] = np.random.rand(3, 3).astype(np.float32) - f = inputs['f'] = np.random.rand(3, 3).astype(np.float32) - - feed_inputs(inputs) - - CU = core.C.CompilationUnit() - CU.define(self.test_program) - CU.create_net('testIf').run() - - if c0 < c1: - ref_r = t + 3 - else: - ref_r = f + 3 - actual_r = workspace.FetchBlob('r') - - np.testing.assert_allclose(actual_r, ref_r) - - @given(seed=st.integers(min_value=0, max_value=65536), **hu.gcs) - def test_while(self, seed, gc, dc): - np.random.seed(int(seed)) - inputs = {} - r = inputs['r'] = np.ones([3, 3]).astype(np.float32) - - feed_inputs(inputs) - - CU = core.C.CompilationUnit() - CU.define(self.test_program) - CU.create_net('testWhile').run() - - m = 0 - while m < 4: - r = r + r - m = m + 1 - - actual_r = workspace.FetchBlob('r') - - np.testing.assert_allclose(actual_r, r) - - @given(seed=st.integers(min_value=0, max_value=65536), **hu.gcs) - def test_gather(self, seed, gc, dc): - CU = core.C.CompilationUnit() - CU.define(""" - def easy(tensor, indices) -> (output): - output = tensor[indices] - def hard(tensor, i, j, k) -> (output): - output = tensor[i][j][k] - """) - - # First check that the generated proto is as expected. This tests that - # we desugar the gather syntax correctly and emit the right code. - proto = CU.get_proto('easy') - assert_proto_equals(proto, """ - name: "easy" - op { - input: "tensor" - input: "indices" - output: "output" - type: "Gather" - }""") - - proto = CU.get_proto('hard') - assert_proto_equals(proto, """ - name: "hard" - op { - input: "tensor" - input: "i" - output: "$t1" - type: "Gather" - } - op { - input: "$t1" - input: "j" - output: "$t0" - type: "Gather" - } - op { - input: "$t0" - input: "k" - output: "output" - type: "Gather" - }""") - - # Now just test that the effect of the generated code is as expected. - np.random.seed(int(seed)) - tensor = np.random.rand(5, 4, 3).astype(np.float32) - indices = np.random.randint(len(tensor), size=(5, 5)) - - feed_inputs(dict(tensor=tensor, indices=indices)) - - net = CU.create_net('easy') - net.run() - - output = workspace.FetchBlob('output') - expected_output = [tensor[sample] for sample in indices] - np.testing.assert_allclose(output, expected_output) - - @given(seed=st.integers(min_value=0, max_value=65536), **hu.gcs) - def test_slice(self, seed, gc, dc): - CU = core.C.CompilationUnit() - CU.define(""" - def slice_from_tensor(tensor, start, end) -> (output): - output = tensor[start:end] - def slice_from_vector(vector, start, end) -> (a, b, c, d): - a = vector[start:end] - b = vector[start:] - c = vector[:end] - d = vector[:] - """) - - # slice_from_tensor - proto = CU.get_proto('slice_from_tensor') - assert_proto_equals(proto, """ - name: "slice_from_tensor" - op { - input: "tensor" - input: "start" - input: "end" - output: "output" - type: "Slice" - }""") - - np.random.seed(int(seed)) - tensor = np.random.rand(5, 4, 3).astype(np.float32) - start = np.array([0, 1, 0], dtype=np.int32) - end = np.array([-1, 2, -1], dtype=np.int32) - - feed_inputs(dict(tensor=tensor, start=start, end=end)) - - net = CU.create_net('slice_from_tensor') - net.run() - - output = workspace.FetchBlob('output') - np.testing.assert_allclose(output, tensor[:, 1:2]) - - # slice_from_vector - proto = CU.get_proto('slice_from_vector') - assert_proto_equals(proto, """ - name: "slice_from_vector" - op { - input: "vector" - input: "start" - input: "end" - output: "a" - type: "Slice" - } - op { - output: "$t0" - type: "ConstantFill" - arg { - name: "dtype" - i: 2 - } - arg { - name: "value" - i: -1 - } - arg { - name: "shape" - ints: 1 - } - } - op { - input: "vector" - input: "start" - input: "$t0" - output: "b" - type: "Slice" - } - op { - output: "$t1" - type: "ConstantFill" - arg { - name: "dtype" - i: 2 - } - arg { - name: "value" - i: 0 - } - arg { - name: "shape" - ints: 1 - } - } - op { - input: "vector" - input: "$t1" - input: "end" - output: "c" - type: "Slice" - } - op { - output: "$t2" - type: "ConstantFill" - arg { - name: "dtype" - i: 2 - } - arg { - name: "value" - i: 0 - } - arg { - name: "shape" - ints: 1 - } - } - op { - output: "$t3" - type: "ConstantFill" - arg { - name: "dtype" - i: 2 - } - arg { - name: "value" - i: -1 - } - arg { - name: "shape" - ints: 1 - } - } - op { - input: "vector" - input: "$t2" - input: "$t3" - output: "d" - type: "Slice" - }""") - - vector = np.random.rand(10).astype(np.float32) - start = np.array([2], dtype=np.int32) - end = np.array([6], dtype=np.int32) - feed_inputs(dict(vector=vector, start=start, end=end)) - - net = CU.create_net('slice_from_vector') - net.run() - - output = workspace.FetchBlob('a') - np.testing.assert_allclose(output, vector[2:6]) - - output = workspace.FetchBlob('b') - np.testing.assert_allclose(output, vector[2:]) - - output = workspace.FetchBlob('c') - np.testing.assert_allclose(output, vector[:6]) - - output = workspace.FetchBlob('d') - np.testing.assert_allclose(output, vector) diff --git a/caffe2/contrib/script/compiler.cc b/caffe2/contrib/script/compiler.cc deleted file mode 100644 index 16a7657..0000000 --- a/caffe2/contrib/script/compiler.cc +++ /dev/null @@ -1,793 +0,0 @@ -#include "caffe2/core/net.h" -#include "caffe2/utils/proto_utils.h" - -#include "compiler.h" -#include "parser.h" - -namespace caffe2 { -namespace script { - -namespace { - -static std::unordered_set ops_containing_nets = { - "If", - "While", - "RecurrentNetwork", -}; -// record of defined function -// NetDef + metadata -struct FunctionDefinition { - explicit FunctionDefinition(Def tree) - : tree(new Def(tree)), net_def(new NetDef()) {} - - explicit FunctionDefinition(std::unique_ptr def) - : tree(nullptr), net_def(std::move(def)) { - // we coop extern_inputs/extern_outputs to be the inputs/outputs to - // this net as a function - // but we _dont_ set these when creating the net in the workspace - // because they require the net to have valid inputs/outputs - inputs.insert( - inputs.begin(), - net_def->external_input().begin(), - net_def->external_input().end()); - outputs.insert( - outputs.begin(), - net_def->external_output().begin(), - net_def->external_output().end()); - net_def->clear_external_output(); - net_def->clear_external_input(); - } - - bool isExtern() const { - return tree == nullptr; - } - std::unique_ptr tree; - std::unique_ptr net_def; - std::vector inputs; - std::vector outputs; -}; - -} // namespace - -using SymbolTable = std::unordered_map; - -struct DefCompiler { - DefCompiler(FunctionDefinition& def, SymbolTable& symbol_table) - : def(def), - net_def_stack({def.net_def.get()}), - symbol_table(symbol_table) {} - void run() { - auto& tree = *def.tree; - cur().set_name(tree.name().name()); - for (auto input : tree.params()) { - auto& name = input.ident().name(); - map(name, name); - def.inputs.push_back(name); - } - for (auto output : tree.returns()) { - auto& name = output.ident().name(); - map(name, name); - def.outputs.push_back(name); - } - emitStatements(tree.statements()); - } - void emitExpressionStatement(TreeRef stmt) { - // expression with no used outputs - emit(stmt, {}); - } - void emitStatements(const ListView& statements) { - for (auto stmt : statements) { - switch (stmt->kind()) { - case TK_IF: - emitIf(If(stmt)); - break; - case TK_WHILE: - emitWhile(While(stmt)); - break; - case TK_ASSIGN: - emitAssignment(Assign(stmt)); - break; - case TK_GLOBAL: - for (auto ident : stmt->trees()) { - auto name = Ident(ident).name(); - map(name, name); - } - break; - default: - emitExpressionStatement(stmt); - break; - } - } - } - void map(const std::string& name, const std::string& value) { - env[name] = value; - } - const std::string& lookup(const Ident& ident) { - if (env.count(ident.name()) == 0) - throw ErrorReport(ident) << "undefined value " << ident.name(); - return env[ident.name()]; - } - void emitAssignment(const Assign& stmt) { - std::vector outputs; - for (auto lhs : stmt.lhs()) { - std::string name = getLHS(lhs); - // use of "_" gets renamed in Caffe2 graphs so that two uses - // don't unintentionally interfere with each other - if (name == "_") { - name = fresh(); - } - outputs.push_back(name); - } - if (stmt.reduction() != '=') { - if (stmt.lhs().size() != 1) { - throw ErrorReport(stmt) - << "reductions are only allow when there is a single variable " - << "on the left-hand side."; - } - auto lhs = stmt.lhs()[0]; - auto expr = - Compound::create(stmt.reduction(), stmt.range(), {lhs, stmt.rhs()}); - emit(expr, outputs); - } else { - emit(stmt.rhs(), outputs); - } - int i = 0; - for (auto ident : stmt.lhs()) { - if (ident->kind() == TK_IDENT) - map(Ident(ident).name(), outputs.at(i)); - i++; - } - } - void emitIf(const If& stmt) { - auto cond = getValue(stmt.cond()); - auto op = cur().add_op(); - op->set_type("If"); - op->add_input(cond); - auto true_branch = op->add_arg(); - true_branch->set_name("then_net"); - auto nd = true_branch->mutable_n(); - net_def_stack.push_back(nd); - emitStatements(stmt.trueBranch()); - net_def_stack.pop_back(); - if (stmt.falseBranch().size() > 0) { - auto false_branch = op->add_arg(); - false_branch->set_name("else_net"); - auto nd = false_branch->mutable_n(); - net_def_stack.push_back(nd); - emitStatements(stmt.falseBranch()); - net_def_stack.pop_back(); - } - } - void emitWhile(const While& stmt) { - std::string loop_var = fresh(); - emitConst(0, loop_var, "i"); // it needs a definition before loop - auto op = cur().add_op(); - op->set_type("While"); - auto cond = op->add_arg(); - cond->set_name("cond_net"); - auto cond_net = cond->mutable_n(); - - net_def_stack.push_back(cond_net); - emit(stmt.cond(), {loop_var}); - net_def_stack.pop_back(); - - op->add_input(loop_var); - auto body = op->add_arg(); - body->set_name("loop_net"); - auto body_net = body->mutable_n(); - - net_def_stack.push_back(body_net); - emitStatements(stmt.body()); - net_def_stack.pop_back(); - } - std::string getLHS(const TreeRef& tree) { - switch (tree->kind()) { - case TK_IDENT: { - return Ident(tree).name(); - } break; - case '.': { - auto sel = Select(tree); - std::string lhs = getValue(sel.value()); - // TODO: check whether this subname exists in object lhs - return lhs + "/" + sel.selector().name(); - } break; - default: { - throw ErrorReport(tree) - << "This expression cannot appear on the left-hand size of an assignment"; - } break; - } - } - std::string getValue(const TreeRef& tree) { - switch (tree->kind()) { - case TK_IDENT: { - return lookup(Ident(tree)); - } break; - case '.': { - auto sel = Select(tree); - std::string lhs = getValue(sel.value()); - // TODO: check whether this subname exists in object lhs - return lhs + "/" + sel.selector().name(); - } break; - default: { - std::string name = fresh(); - emit(tree, {name}); - return name; - } break; - } - } - std::string fresh(std::string prefix = "$t") { - return std::string(prefix) + c10::to_string(next_fresh++); - } - const char* operatorName(int kind, int ninputs) { - switch (kind) { - case '+': - return "Add"; - case '-': - if (ninputs == 1) - return "Negative"; - else - return "Sub"; - case '*': - return "Mul"; - case '/': - return "Div"; - case TK_NE: - return "NE"; - case TK_EQ: - return "EQ"; - case '<': - return "LT"; - case '>': - return "GT"; - case TK_LE: - return "LE"; - case TK_GE: - return "GE"; - case TK_IF_EXPR: - return "Conditional"; - case TK_AND: - return "And"; - case TK_OR: - return "Or"; - case TK_NOT: - return "Not"; - default: - throw std::runtime_error("unknown kind " + c10::to_string(kind)); - } - } - void fillArg(Argument* arg, const Attribute& attr) { - std::string name = attr.name().name(); - arg->set_name(name); - auto value = attr.value(); - // TODO: handle non-float attributes - switch (value->kind()) { - case TK_CONST: { - auto v = value->tree(0)->doubleValue(); - auto f = value->tree(1)->stringValue(); - if (f == "f") - arg->set_f(v); - else - arg->set_i(v); - } break; - case TK_LIST: - for (auto t : value->trees()) { - auto v = t->tree(0)->doubleValue(); - auto f = t->tree(1)->stringValue(); - if (f == "f") - arg->add_floats(v); - else - arg->add_ints(v); - } - break; - } - } - template - std::vector getValues(const Trees& trees) { - std::vector result; - for (const auto& tree : trees) { - result.push_back(getValue(tree)); - } - return result; - } - - bool renameLookup( - std::unordered_map& rename_map, - const std::string& name, - std::string& rename) { - // first look for name in the map directly - auto it = rename_map.find(name); - if (it != rename_map.end()) { - rename = it->second; - return true; - } - // otherwise if we have a rename entry like a => b and a name "a/foo/bar" - // then replace it with "b/foo/bar" - auto p = name.find("/"); - if (p == std::string::npos) - return false; - it = rename_map.find(name.substr(0, p)); - if (it != rename_map.end()) { - rename = it->second + name.substr(p); - return true; - } - return false; - } - void renameOp( - std::unordered_map& rename_map, - const Apply& apply, - const std::string& prefix, - bool isExtern, - OperatorDef* new_op) { - for (size_t i = 0; i < new_op->input().size(); i++) { - auto& name = new_op->input(i); - std::string renamed; - bool defined = renameLookup(rename_map, name, renamed); - if (!isExtern && !defined) { - throw ErrorReport(apply) - << " unexpected undefined name '" << name - << "' while attempting to inline '" << apply.name().name() << "'"; - } else if (!defined) { - // extern function using a global name, assign it an identity mapping - rename_map[name] = name; - } - new_op->set_input(i, renamed); - } - for (size_t i = 0; i < new_op->output().size(); i++) { - auto& name = new_op->output(i); - std::string renamed; - if (!renameLookup(rename_map, name, renamed)) { - renamed = prefix + name; - rename_map[name] = renamed; - } - new_op->set_output(i, renamed); - } - // handle control flow inside the op as well - if (ops_containing_nets.count(new_op->type()) > 0) { - for (size_t i = 0; i < new_op->arg_size(); i++) { - auto* arg = new_op->mutable_arg(i); - if (arg->has_n()) { - auto* n = arg->mutable_n(); - for (size_t j = 0; j < n->op_size(); j++) { - renameOp(rename_map, apply, prefix, isExtern, n->mutable_op(j)); - } - } - } - } - } - - bool hasBypassRename(const Apply& apply) { - for (auto attr : apply.attributes()) { - if (attr.name().name() == "rename") { - if (attr.value()->kind() != TK_CONST) { - throw ErrorReport(attr.value()) << "expected a single constant"; - } - return attr.value()->tree(0)->doubleValue() == 0; - } - } - return false; - } - - // emit a function call by inlining the function's NetDef into our - // net def, renaming temporaries func_name/orig_name - // renaming only happens for values defined by the function - // that are not marked outputs - - // inputs/outputs are passed by reference - void emitFunctionCall(Apply& apply, const std::vector& outputs) { - std::string fname = apply.name().name(); - std::string prefix = fresh(fname) + "/"; - auto& fn = symbol_table.at(apply.name().name()); - bool isExtern = fn.isExtern(); - auto inputs = getValues(apply.inputs()); - std::unordered_map rename_map; - if (inputs.size() != fn.inputs.size()) { - throw ErrorReport(apply) << fname << " expected " << fn.inputs.size() - << " values but received " << inputs.size(); - } - for (size_t i = 0; i < inputs.size(); i++) { - rename_map[fn.inputs[i]] = inputs[i]; - } - if (outputs.size() != fn.outputs.size()) { - throw ErrorReport(apply) << fname << " expected " << fn.outputs.size() - << " values but received " << outputs.size(); - } - for (size_t i = 0; i < outputs.size(); i++) { - rename_map[fn.outputs[i]] = outputs[i]; - } - for (auto& op : fn.net_def->op()) { - auto new_op = cur().add_op(); - new_op->CopyFrom(op); - if (hasBypassRename(apply)) { - prefix = ""; - } - renameOp(rename_map, apply, prefix, isExtern, new_op); - } - } - void expectOutputs( - const TreeRef& tree, - const std::vector& outputs, - size_t size) { - if (outputs.size() != size) { - throw ErrorReport(tree) - << "expected operator to produce " << outputs.size() - << " outputs but it produced " << size; - } - } - void appendOutputs( - const TreeRef& tree, - OperatorDef* op, - const std::vector& outputs, - size_t size) { - expectOutputs(tree, outputs, size); - for (size_t i = 0; i < size; i++) { - op->add_output(outputs[i]); - } - } - void emitOperator( - const Apply& apply, - const OpSchema* schema, - const std::vector& outputs) { - // must be before add_op - auto values = getValues(apply.inputs()); - if (values.size() < schema->min_input() || - values.size() > schema->max_input()) { - if (schema->min_input() == schema->max_input()) { - throw ErrorReport(apply) << "operator expects " << schema->min_input() - << " inputs but found " << values.size(); - } else { - throw ErrorReport(apply) - << "operator takes between " << schema->min_input() << " and " - << schema->max_input() << " inputs but found " << values.size() - << "."; - } - } - auto numActualOutputs = schema->CalculateOutput(values.size()); - if (numActualOutputs != kCannotComputeNumOutputs && - outputs.size() != numActualOutputs) { - throw ErrorReport(apply) - << "operator produces " << numActualOutputs - << " outputs but matched to " << outputs.size() << " outputs"; - } - auto op = cur().add_op(); - op->set_type(apply.name().name()); - for (auto& v : values) { - op->add_input(v); - } - // assume 1 output unless matched to more - appendOutputs(apply, op, outputs, outputs.size()); - for (auto attribute : apply.attributes()) { - fillArg(op->add_arg(), attribute); - } - // Ok, we checked the stuff where we can easily give a friendly error - // message, now verify against the schema and report the error at the line - if (!schema->Verify(*op)) { - throw ErrorReport(apply) << "failed schema checking"; - } - } - - // Emit an operation, writing results into 'outputs'. - // This will _always_ compute something, unlike 'getValue' which simply - // returns an already computed reference if possible. - // So if 'tree' is an identifier or nested identifier (foo.bar) - // this will cause it to be _copied_ into outputs. - void emit(const TreeRef& tree, const std::vector& outputs) { - switch (tree->kind()) { - case TK_IDENT: - case '.': { - auto op = cur().add_op(); - op->set_type("Copy"); - op->add_input(getValue(tree)); - appendOutputs(tree, op, outputs, 1); - } break; - case TK_NE: - case TK_EQ: - case '<': - case '>': - case TK_LE: - case TK_GE: - case '-': - case '*': - case '/': - case '+': - case TK_AND: - case TK_OR: - case TK_NOT: - case TK_IF_EXPR: { - // must be before add_op - auto values = getValues(tree->trees()); - auto op = cur().add_op(); - op->set_type(operatorName(tree->kind(), tree->trees().size())); - for (auto& v : values) { - op->add_input(v); - } - appendOutputs(tree, op, outputs, 1); - auto broadcast = op->add_arg(); - broadcast->set_name("broadcast"); - broadcast->set_i(1); - } break; - case TK_APPLY: { - auto apply = Apply(tree); - // Handle built-ins like zeros, ones, etc - if (builtins.count(apply.name().name()) > 0) { - builtins[apply.name().name()](this, apply, outputs); - break; - } - if (symbol_table.count(apply.name().name()) > 0) { - emitFunctionCall(apply, outputs); - break; - } - auto schema = OpSchemaRegistry::Schema(apply.name().name()); - if (schema) { - emitOperator(apply, schema, outputs); - break; - } - throw ErrorReport(apply) - << "attempting to call unknown operation or function '" - << apply.name().name() << "'"; - } break; - case TK_CAST: { - auto cast = Cast(tree); - auto c2type = getType(cast.type()); - auto input = getValue(cast.input()); - auto op = cur().add_op(); - op->set_type("Cast"); - op->add_input(input); - appendOutputs(tree, op, outputs, 1); - auto arg = op->add_arg(); - arg->set_name("to"); - arg->set_i(c2type); - } break; - case TK_CONST: { - expectOutputs(tree, outputs, 1); - emitConst( - tree->tree(0)->doubleValue(), - outputs[0], - tree->tree(1)->stringValue()); - } break; - case TK_GATHER: { - const auto gather = Gather(tree); - desugarAndEmitOperator( - "Gather", - gather.range(), - {gather.value(), gather.indices()}, - outputs); - break; - } - case TK_SLICE: { - const auto slice = Slice(tree); - desugarAndEmitOperator( - "Slice", - slice.range(), - {slice.value(), slice.startOr(0), slice.endOr(-1)}, - outputs); - break; - } - default: - throw ErrorReport(tree) << "NYI: " << tree; - break; - } - } - - // Desugars constructs that are syntactic sugar and emits the corresponding - // operator invocation, e.g. tensor[indices] -> tensor.Gather(indices). - void desugarAndEmitOperator( - const std::string& operatorName, - const SourceRange& range, - TreeList&& inputs, - const std::vector& outputs) { - const auto applyName = Ident::create(range, operatorName); - const auto applyInputs = - Compound::create(TK_LIST, range, std::move(inputs)); - const auto applyAttributes = Compound::create(TK_LIST, range, {}); - const auto apply = - Apply::create(range, applyName, applyInputs, applyAttributes); - const auto schema = OpSchemaRegistry::Schema(operatorName); - assert(schema != nullptr); - emitOperator(Apply(apply), schema, outputs); - } - - TensorProto_DataType getType(int type) { - switch (type) { - case TK_INT: - return TensorProto_DataType_INT32; - case TK_FLOAT: - return TensorProto_DataType_FLOAT; - case TK_LONG: - return TensorProto_DataType_INT64; - case TK_BOOL: - return TensorProto_DataType_BOOL; - default: - throw std::runtime_error( - "expected type token: " + c10::to_string(type)); - } - } - - OperatorDef* emitConst( - double v, - const std::string& output, - const std::string& type_ident) { - auto op = cur().add_op(); - op->set_type("ConstantFill"); - auto dtype = op->add_arg(); - dtype->set_name("dtype"); - auto value = op->add_arg(); - value->set_name("value"); - if (type_ident == "f") { - dtype->set_i(TensorProto_DataType_FLOAT); - value->set_f(v); - } else if (type_ident == "LL") { - dtype->set_i(TensorProto_DataType_INT64); - value->set_i(v); - } else if (type_ident == "b") { - dtype->set_i(TensorProto_DataType_BOOL); - value->set_i(v != 0); - } else if (type_ident == "i") { - dtype->set_i(TensorProto_DataType_INT32); - value->set_i(v); - } else { - throw std::runtime_error("unknown type_ident " + type_ident); - } - auto shape = op->add_arg(); - shape->set_name("shape"); - shape->add_ints(1); - op->add_output(output); - return op; - } - NetDef& cur() { - return *net_def_stack.back(); - } - FunctionDefinition& def; // the def being constructed - std::unordered_map - env; // map from name in Def to name in NetDef - std::vector net_def_stack; - SymbolTable& symbol_table; - int next_fresh = 0; - - private: - void emitFillOp(const Apply& apply, const std::vector& outputs) { - auto builtin_type = apply.name().name(); - auto values = getValues(apply.inputs()); - if (values.size() > 1) { - throw ErrorReport(apply) - << "Built-in " << builtin_type << " accepts 0 or 1 inputs."; - } - bool has_shape = false; - for (const auto& attribute : apply.attributes()) { - if (attribute.name().name() == "shape") { - has_shape = true; - } else { - throw ErrorReport(apply) - << "Unrecognized attribute " << attribute.name().name() - << " for built-in " << builtin_type; - } - } - if (builtin_type == "zeros" || builtin_type == "ones") { - if ((values.size() != 1) && !has_shape) { - throw ErrorReport(apply) - << "Built-in " << builtin_type - << " requires either 1 input or 1 shape attribute"; - } - } else { - // zeros_like or ones_like - if (values.size() != 1) { - throw ErrorReport(apply) - << "Built-in " << builtin_type << " requires 1 input"; - } - } - - auto op = cur().add_op(); - op->set_type("ConstantFill"); - if (values.size()) { - op->add_input(values[0]); - auto* input_as_shape = op->add_arg(); - input_as_shape->set_name("input_as_shape"); - if (builtin_type.find("_like") != std::string::npos) { - // zeros_like, ones_like take the shape of the input as constant - // tensor shape - input_as_shape->set_i(0); - } else { - // zeros, ones take the values in the tensor as constant tensor - // shape - input_as_shape->set_i(1); - } - } else { - fillArg(op->add_arg(), apply.attributes()[0]); - } - - auto value = op->add_arg(); - value->set_name("value"); - if (builtin_type.find("ones") != std::string::npos) { - value->set_f(1.0f); - } else { - value->set_f(0.0f); - } - appendOutputs(apply, op, outputs, 1); - } - // emitModule doesn't actually do anything except for allow - // statements like a = Module() to register 'a' as a valid identifier - // so that a.b = ... will work - void emitModule(const Apply& apply, const std::vector& outputs) { - expectOutputs(apply, outputs, 1); - } - std::unordered_map< - std::string, - std::function& outputs)>> - builtins{{"zeros", &DefCompiler::emitFillOp}, - {"zeros_like", &DefCompiler::emitFillOp}, - {"ones", &DefCompiler::emitFillOp}, - {"ones_like", &DefCompiler::emitFillOp}, - {"Module", &DefCompiler::emitModule}}; -}; - -struct CompilationUnitImpl { - void defineFunction(const Def& def) { - if (functions.count(def.name().name()) > 0) { - throw ErrorReport(def) << def.name().name() << " already defined."; - } - DefCompiler c( - functions.emplace(def.name().name(), FunctionDefinition(def)) - .first->second, - functions); - c.run(); - } - - void define(const std::string& str) { - Parser p(str); - while (p.lexer().cur().kind != TK_EOF) { - defineFunction(Def(p.parseFunction())); - } - } - - std::unique_ptr createNet(Workspace* ws, const std::string& str) { - if (functions.count(str) == 0) - throw ErrorReport() << "undefined function: " << str << "\n"; - auto& def = functions.at(str); - return caffe2::CreateNet(*def.net_def, ws); - } - - void defineExtern(const std::string& name, std::unique_ptr net_def) { - // TODO: unify extern and function namespaces - if (functions.count(name) > 0) { - throw ErrorReport() << "function '" << name << "' already defined."; - } - functions.emplace(name, FunctionDefinition(std::move(net_def))); - } - - std::string getProto(const std::string& functionName) { - return functions.at(functionName).net_def->DebugString(); - } - - private: - friend struct DefCompiler; - SymbolTable functions; -}; - -CompilationUnit::CompilationUnit() : pImpl(new CompilationUnitImpl()) {} - -void CompilationUnit::define(const std::string& str) { - return pImpl->define(str); -} - -void CompilationUnit::defineExtern( - const std::string& name, - std::unique_ptr nd) { - pImpl->defineExtern(name, std::move(nd)); -} - -std::unique_ptr CompilationUnit::createNet( - Workspace* ws, - const std::string& str) { - return pImpl->createNet(ws, str); -} - -std::string CompilationUnit::getProto(const std::string& functionName) const { - return pImpl->getProto(functionName); -} - -CompilationUnit::~CompilationUnit() {} - -} // namespace script -} // namespace caffe2 diff --git a/caffe2/contrib/script/compiler.h b/caffe2/contrib/script/compiler.h deleted file mode 100644 index 0a15c33..0000000 --- a/caffe2/contrib/script/compiler.h +++ /dev/null @@ -1,24 +0,0 @@ -#pragma once -#include -#include -#include "caffe2/core/net.h" - -namespace caffe2 { -namespace script { - -struct CompilationUnitImpl; - -struct CAFFE2_API CompilationUnit { - CompilationUnit(); - void define(const std::string& str); - void defineExtern(const std::string& str, std::unique_ptr netdef); - std::unique_ptr createNet(Workspace* ws, const std::string& name); - std::string getProto(const std::string& functionName) const; - ~CompilationUnit(); - - private: - std::unique_ptr pImpl; -}; - -} // namespace script -}; // namespace caffe2 diff --git a/caffe2/contrib/script/error_report.h b/caffe2/contrib/script/error_report.h deleted file mode 100644 index cecc0f3..0000000 --- a/caffe2/contrib/script/error_report.h +++ /dev/null @@ -1,51 +0,0 @@ -#pragma once - -#include "caffe2/contrib/script/tree.h" - -namespace caffe2 { -namespace script { - -struct ErrorReport : public std::exception { - ErrorReport(const ErrorReport& e) - : ss(e.ss.str()), context(e.context), the_message(e.the_message) {} - - ErrorReport() : context(nullptr) {} - explicit ErrorReport(const SourceRange& r) - : context(std::make_shared(r)) {} - explicit ErrorReport(const TreeRef& tree) : ErrorReport(tree->range()) {} - explicit ErrorReport(const Token& tok) : ErrorReport(tok.range) {} - virtual const char* what() const noexcept override { - std::stringstream msg; - msg << "\n" << ss.str(); - if (context != nullptr) { - msg << ":\n"; - context->highlight(msg); - } else { - msg << ".\n"; - } - the_message = msg.str(); - return the_message.c_str(); - } - - private: - template - friend const ErrorReport& operator<<(const ErrorReport& e, const T& t); - - mutable std::stringstream ss; - std::shared_ptr context; - mutable std::string the_message; -}; - -template -const ErrorReport& operator<<(const ErrorReport& e, const T& t) { - e.ss << t; - return e; -} - -#define C2S_ASSERT(ctx, cond) \ - if (!(cond)) { \ - throw ::caffe2::script::ErrorReport(ctx) \ - << __FILE__ << ":" << __LINE__ << ": assertion failed: " << #cond; \ - } -} // namespace script -} // namespace caffe2 diff --git a/caffe2/contrib/script/examples/example_beam_search.c2s b/caffe2/contrib/script/examples/example_beam_search.c2s deleted file mode 100644 index 2e081ee..0000000 --- a/caffe2/contrib/script/examples/example_beam_search.c2s +++ /dev/null @@ -1,76 +0,0 @@ -[["log_probs", [6, 1, 44463], "float32"], ["attentions", [6, 1, 21], "float32"], ["inputs", [21], "float32"]] -beam_search -["scores_t"] - -def beam_search(inputs, log_probs, attentions) -> (): - beam_size = 6LL - length = 20LL - beam_output_shape, _ = Concat(length + 1LL, beam_size, axis=0) - output_token_beam_list = int(zeros(beam_output_shape)) - output_prev_index_beam_list = int(zeros(beam_output_shape)) - output_score_beam_list = zeros(beam_output_shape) - - input_length = inputs.Size().ExpandDims(dims=[0]) - - attention_beam_output_shape, _ = Concat( - input_length, beam_output_shape, axis=0) - output_attention_weights_beam_list = zeros(attention_beam_output_shape) - - attention_step_output_shape, _ = Concat(beam_size, input_length, axis=0) - attention_t = zeros(attention_step_output_shape) - - scores_t = zeros(shape=[1, 6]) - hypo_t = int(zeros(shape=[6])) - tokens_t = int(ones(shape=[6])) * 99 - - output_token_beam_list = output_token_beam_list.ScatterAssign(0, tokens_t) - output_token_beam_list = output_token_beam_list.ExpandDims(dims=[2]) - output_prev_index_beam_list = output_prev_index_beam_list.ScatterAssign( - 0, hypo_t) - output_prev_index_beam_list = output_prev_index_beam_list.ExpandDims(dims=[2]) - output_score_beam_list = output_score_beam_list.ScatterAssign(0, scores_t) - output_score_beam_list = output_score_beam_list.ExpandDims(dims=[2]) - output_attention_weights_beam_list = output_attention_weights_beam_list\ - .ScatterAssign(0, attention_t) - - length_32 = int(length) - - timestep = 0 - not_finished = True - while not_finished: - # TODO: once we have a metaprogramming facility we need to insert the - # body of the post_eos_penalty here programmatically - - best_scores_per_hypo, best_tokens_per_hypo = log_probs.TopK(k=6) - - # Add the best score in each hypothesis to the cumulative score so far - output_scores = best_scores_per_hypo + scores_t.Squeeze(dims=[0]) - - # Flatten scores so we can find the best overall out of all hypotheses - output_scores_flattened_slice, _ = output_scores.FlattenToVec()\ - .Slice(0, 6 if timestep == 0 else -1).Reshape(shape=[1, -1]) - - # Find top K out of all - scores_t, best_indices = output_scores_flattened_slice.TopK(k=6) - - # Integer floor divide on indices finds the association back to original - # hypotheses. Use this to reorder states - hypo_t_int64 = best_indices / 6LL - - # Reorder attentions - attention_t, _ = attentions.Gather(hypo_t_int64)\ - .Reshape(shape=[1, 6, -1]) - tokens_t_int64 = best_tokens_per_hypo.FlattenToVec()\ - .Gather(best_indices).Cast(to=2) - - timestep += 1 - not_finished = timestep < length_32 - - output_token_beam_list = output_token_beam_list\ - .ScatterAssign(timestep, tokens_t) - output_prev_index_beam_list = output_prev_index_beam_list\ - .ScatterAssign(timestep, hypo_t) - output_score_beam_list = output_score_beam_list\ - .ScatterAssign(timestep, scores_t) - output_attention_weights_beam_list = output_attention_weights_beam_list\ - .ScatterAssign(timestep, attention_t) diff --git a/caffe2/contrib/script/examples/example_post_eos_penalty.c2s b/caffe2/contrib/script/examples/example_post_eos_penalty.c2s deleted file mode 100644 index 9988913..0000000 --- a/caffe2/contrib/script/examples/example_post_eos_penalty.c2s +++ /dev/null @@ -1,13 +0,0 @@ -[["tokens_t", [1, 6], "int32"], ["hypo_t", [1, 6], "int32"], ["log_probs", [6, 1, 44463], "float32"], ["on_initial_step", [1], "bool_"]] -post_eos_penalty -["log_probs"] - -def post_eos_penalty(tokens_t, hypo_t, log_probs, on_initial_step) \ - -> (log_probs): - eos_token = 1 - finished_penalty = 0f if on_initial_step else 0.5f - predecessor_tokens = tokens_t.FlattenToVec().Gather(hypo_t.FlattenToVec()) - predecessor_is_eos = float(predecessor_tokens == eos_token) - log_probs = log_probs.Add( - predecessor_is_eos * finished_penalty, broadcast=1, axis=0 - ) diff --git a/caffe2/contrib/script/examples/run_examples.py b/caffe2/contrib/script/examples/run_examples.py deleted file mode 100644 index 26f2db0..0000000 --- a/caffe2/contrib/script/examples/run_examples.py +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals -from caffe2.python import core, workspace -import glob -import json -import numpy as np - -example_files = glob.glob('example_*.c2s') - -for ex in example_files: - print('Running example file', ex) - with open(ex, 'r') as f: - inits = json.loads(f.readline()) - net_name = f.readline().strip() - outputs = json.loads(f.readline()) - - CU = core.C.CompilationUnit() - CU.define(f.read()) - - # Initialize workspace with required inputs - for name, shape, dt in inits: - workspace.FeedBlob(name, np.random.rand(*shape).astype(np.dtype(dt))) - - net = CU.create_net(net_name) - net.run() - - print('Success! Interesting outputs:') - for output in outputs: - print(output, workspace.FetchBlob(output)) diff --git a/caffe2/contrib/script/lexer.cc b/caffe2/contrib/script/lexer.cc deleted file mode 100644 index 9dafea9..0000000 --- a/caffe2/contrib/script/lexer.cc +++ /dev/null @@ -1,26 +0,0 @@ -#include "caffe2/contrib/script/lexer.h" -#include "caffe2/core/common.h" - -namespace caffe2 { -namespace script { - -std::string kindToString(int kind) { - if (kind < 256) - return std::string(1, kind); - switch (kind) { -#define DEFINE_CASE(tok, str, _) \ - case tok: \ - return str; - TC_FORALL_TOKEN_KINDS(DEFINE_CASE) -#undef DEFINE_CASE - default: - throw std::runtime_error("unknown kind: " + c10::to_string(kind)); - } -} - -SharedParserData& sharedParserData() { - static SharedParserData data; // safely handles multi-threaded init - return data; -} -} // namespace script -} // namespace caffe2 diff --git a/caffe2/contrib/script/lexer.h b/caffe2/contrib/script/lexer.h deleted file mode 100644 index b298809..0000000 --- a/caffe2/contrib/script/lexer.h +++ /dev/null @@ -1,527 +0,0 @@ -#pragma once -#include -#include -#include -#include -#include -#include -#include -#include - -#include "caffe2/core/common.h" - -namespace caffe2 { -namespace script { - -// single character tokens are just the character itself '+' -// multi-character tokens need an entry here -// if the third entry is not the empty string, it is used -// in the lexer to match this token. - -// These kinds are also used in Tree.h as the kind of the AST node. -// Some kinds TK_APPLY, TK_LIST are only used in the AST and are not seen in the -// lexer. - -#define TC_FORALL_TOKEN_KINDS(_) \ - _(TK_EOF, "eof", "") \ - _(TK_WHITESPACE, "whitespace", "") \ - _(TK_NUMBER, "number", "") \ - _(TK_NEWLINE, "newline", "") \ - _(TK_INDENT, "indent", "") \ - _(TK_DEDENT, "dedent", "") \ - _(TK_WHERE, "where", "where") \ - _(TK_FLOAT, "float", "float") \ - _(TK_DOUBLE, "double", "double") \ - _(TK_LONG, "long", "long") \ - _(TK_INT, "int", "int") \ - _(TK_DEF, "def", "def") \ - _(TK_ARROW, "arrow", "->") \ - _(TK_EQUIVALENT, "equivalent", "<=>") \ - _(TK_IDENT, "ident", "") \ - _(TK_STRING, "string", "") \ - _(TK_CONST, "const", "") \ - _(TK_LIST, "list", "") \ - _(TK_OPTION, "option", "") \ - _(TK_APPLY, "apply", "") \ - _(TK_COMPREHENSION, "comprehension", "") \ - _(TK_TENSOR_TYPE, "tensor_type", "") \ - _(TK_RANGE_CONSTRAINT, "range_constraint", "") \ - _(TK_PARAM, "param", "") \ - _(TK_INFERRED, "inferred", "") \ - _(TK_BOOL, "bool", "") \ - _(TK_ACCESS, "access", "") \ - _(TK_ASSIGN, "assign", "") \ - _(TK_ATTRIBUTE, "attribute", "") \ - _(TK_IF, "if", "if") \ - _(TK_ELSE, "else", "else") \ - _(TK_ELIF, "elif", "elif") \ - _(TK_WHILE, "while", "while") \ - _(TK_NE, "ne", "!=") \ - _(TK_EQ, "eq", "==") \ - _(TK_LE, "le", "<=") \ - _(TK_GE, "ge", ">=") \ - _(TK_IF_EXPR, "if", "") \ - _(TK_TRUE, "True", "True") \ - _(TK_FALSE, "False", "False") \ - _(TK_AND, "and", "and") \ - _(TK_OR, "or", "or") \ - _(TK_NOT, "not", "not") \ - _(TK_CAST, "cast", "") \ - _(TK_PLUS_EQ, "+=", "+=") \ - _(TK_MINUS_EQ, "-=", "-=") \ - _(TK_TIMES_EQ, "*=", "*=") \ - _(TK_DIV_EQ, "/=", "/=") \ - _(TK_GLOBAL, "global", "global") \ - _(TK_BUILT_IN, "built-in", "") \ - _(TK_SLICE, "slice", "") \ - _(TK_GATHER, "gather", "") -static const char* valid_single_char_tokens = "+-*/()[]:,={}><."; - -enum TokenKind { - // we use characters to represent themselves so skip all valid characters - // before - // assigning enum values to multi-char tokens. - TK_DUMMY_START = 256, -#define DEFINE_TOKEN(tok, _, _2) tok, - TC_FORALL_TOKEN_KINDS(DEFINE_TOKEN) -#undef DEFINE_TOKEN -}; - -std::string kindToString(int kind); - -// nested hash tables that indicate char-by-char what is a valid token. -struct TokenTrie; -using TokenTrieRef = std::unique_ptr; -struct TokenTrie { - TokenTrie() : kind(0) {} - void insert(const char* str, int tok) { - if (*str == '\0') { - assert(kind == 0); - kind = tok; - return; - } - auto& entry = children[*str]; - if (entry == nullptr) { - entry.reset(new TokenTrie()); - } - entry->insert(str + 1, tok); - } - int kind; // 0 == invalid token - std::unordered_map children; -}; - -// stuff that is shared against all TC lexers/parsers and is initialized only -// once. -struct SharedParserData { - SharedParserData() : head(new TokenTrie()) { - // listed in increasing order of precedence - std::vector> binary_ops = { - {TK_IF}, - {TK_AND, TK_OR}, - {}, // reserve a level for unary not - {'<', '>', TK_EQ, TK_LE, TK_GE, TK_NE}, - {'+', '-'}, - {'*', '/'}, - }; - std::vector> unary_ops = { - {'-'}, - }; - - std::stringstream ss; - for (const char* c = valid_single_char_tokens; *c; c++) { - const char str[] = {*c, '\0'}; - head->insert(str, *c); - } - -#define ADD_CASE(tok, _, tokstring) \ - if (*tokstring != '\0') { \ - head->insert(tokstring, tok); \ - } - TC_FORALL_TOKEN_KINDS(ADD_CASE) -#undef ADD_CASE - - // precedence starts at 1 so that there is always a 0 precedence - // less than any other precedence - int prec = 1; - for (auto& group : binary_ops) { - for (auto& element : group) { - binary_prec[element] = prec; - } - prec++; - } - // unary ops - for (auto& group : unary_ops) { - for (auto& element : group) { - unary_prec[element] = prec; - } - prec++; - } - // add unary not separately because it slots into the precedence of - // binary operators - unary_prec[TK_NOT] = binary_prec[TK_AND] + 1; - } - // 1. skip whitespace - // 2. handle comment or newline - // - bool isNumber(const std::string& str, size_t start, size_t* len) { - char first = str[start]; - // strtod allows numbers to start with + or - or nan or inf - // http://en.cppreference.com/w/cpp/string/byte/strtof - // but we want only the number part, otherwise 1+3 will turn into two - // adjacent numbers in the lexer - if (first == '-' || first == '+' || isalpha(first)) - return false; - const char* startptr = str.c_str() + start; - char* endptr; - std::strtod(startptr, &endptr); - *len = endptr - startptr; - return *len > 0; - } - bool isblank(int n) { - return isspace(n) && n != '\n'; - } - // find the longest match of str.substring(pos) against a token, return true - // if successful - // filling in kind, start,and len - bool match( - const std::string& str, - size_t pos, - bool continuation, // are we inside a scope where newlines don't count - // (e.g. inside parens) - bool whitespace_token, // should we treat whitespace as a token - int* kind, - size_t* start, - size_t* len) { - *start = pos; - // skip whitespace - while (pos < str.size() && isblank(str[pos])) - pos++; - - // special handling - if (pos < str.size()) { - if (str[pos] == '#') { - // skip comments - while (pos < str.size() && str[pos] != '\n') - pos++; - // tail call, handle whitespace and more comments - return match( - str, pos, continuation, whitespace_token, kind, start, len); - } - if (str[pos] == '\\' && pos + 1 < str.size() && str[pos + 1] == '\n' && - !whitespace_token) { - return match(str, pos + 2, continuation, false, kind, start, len); - } - if (str[pos] == '\n') { - return match( - str, pos + 1, continuation, !continuation, kind, start, len); - } - } - if (pos == str.size()) { - *kind = TK_EOF; - *start = pos; - *len = 0; - return true; - } - // invariant: the next token is not whitespace or newline - if (whitespace_token) { - *kind = TK_WHITESPACE; - *len = pos - *start; - return true; - } - *start = pos; - // check for a valid number - if (isNumber(str, pos, len)) { - *kind = TK_NUMBER; - return true; - } - // check for either an ident or a token - // ident tracks whether what we have scanned so far could be an identifier - // matched indicates if we have found any match. - bool matched = false; - bool ident = true; - TokenTrie* cur = head.get(); - for (size_t i = 0; pos + i < str.size() && (ident || cur != nullptr); i++) { - ident = ident && validIdent(i, str[pos + i]); - if (ident) { - matched = true; - *len = i + 1; - *kind = TK_IDENT; - } - // check for token second, so that e.g. 'max' matches the token TK_MAX - // rather the - // identifier 'max' - if (cur) { - auto it = cur->children.find(str[pos + i]); - cur = (it == cur->children.end()) ? nullptr : it->second.get(); - if (cur && cur->kind != 0) { - matched = true; - *len = i + 1; - *kind = cur->kind; - } - } - } - return matched; - } - bool isUnary(int kind, int* prec) { - auto it = unary_prec.find(kind); - if (it != unary_prec.end()) { - *prec = it->second; - return true; - } - return false; - } - bool isBinary(int kind, int* prec) { - auto it = binary_prec.find(kind); - if (it != binary_prec.end()) { - *prec = it->second; - return true; - } - return false; - } - bool isRightAssociative(int kind) { - switch (kind) { - case '?': - return true; - default: - return false; - } - } - - private: - bool validIdent(size_t i, char n) { - return isalpha(n) || n == '_' || (i > 0 && isdigit(n)); - } - TokenTrieRef head; - std::unordered_map - unary_prec; // map from token to its unary precedence - std::unordered_map - binary_prec; // map from token to its binary precedence -}; - -SharedParserData& sharedParserData(); - -// a range of a shared string 'file_' with functions to help debug by highlight -// that -// range. -struct SourceRange { - SourceRange( - const std::shared_ptr& file_, - size_t start_, - size_t end_) - : file_(file_), start_(start_), end_(end_) {} - const std::string text() const { - return file().substr(start(), end() - start()); - } - size_t size() const { - return end() - start(); - } - void highlight(std::ostream& out) const { - const std::string& str = file(); - size_t begin = start(); - size_t end = start(); - while (begin > 0 && str[begin - 1] != '\n') - --begin; - while (end < str.size() && str[end] != '\n') - ++end; - out << str.substr(0, end) << "\n"; - out << std::string(start() - begin, ' '); - size_t len = std::min(size(), end - start()); - out << std::string(len, '~') - << (len < size() ? "... <--- HERE" : " <--- HERE"); - out << str.substr(end); - if (str.size() > 0 && str.back() != '\n') - out << "\n"; - } - const std::string& file() const { - return *file_; - } - const std::shared_ptr& file_ptr() const { - return file_; - } - size_t start() const { - return start_; - } - size_t end() const { - return end_; - } - - private: - std::shared_ptr file_; - size_t start_; - size_t end_; -}; - -struct Token { - int kind; - SourceRange range; - Token(int kind, const SourceRange& range) : kind(kind), range(range) {} - double doubleValue() { - assert(TK_NUMBER == kind); - size_t idx; - double r = ::c10::stod(text(), &idx); - assert(idx == range.size()); - return r; - } - std::string text() { - return range.text(); - } - std::string kindString() const { - return kindToString(kind); - } -}; - -struct Lookahead { - Lookahead(const Token& t) : t(t) {} - Token t; - bool valid = false; - size_t repeat = 0; -}; - -struct Lexer { - std::shared_ptr file; - explicit Lexer(const std::string& str) - : file(std::make_shared(str)), - pos(0), - cur_(TK_EOF, SourceRange(file, 0, 0)), - lookahead_(cur_), - repeat(0), - nesting(0), - shared(sharedParserData()) { - auto first_indent = lexRaw(true); - indent_stack.push_back(first_indent.range.size()); - next(); - } - Token next() { - Token r = cur_; - if (repeat > 0) { - repeat--; - } else if (lookahead_.valid) { - lookahead_.valid = false; - repeat = lookahead_.repeat; - cur_ = lookahead_.t; - } else { - std::tie(cur_, repeat) = lex(); - } - return r; - } - bool nextIf(int kind) { - if (cur_.kind != kind) - return false; - next(); - return true; - } - - [[noreturn]] void reportError(const std::string& what) { - reportError(what, cur_); - } - [[noreturn]] void reportError(const std::string& what, const Token& t) { - std::stringstream ss; - ss << what << ":\n"; - t.range.highlight(ss); - throw std::runtime_error(ss.str()); - } - [[noreturn]] void expected(const std::string& what, const Token& t) { - std::stringstream ss; - ss << "expected " << what << " but found '" << t.kindString() - << "' here:\n"; - t.range.highlight(ss); - throw std::runtime_error(ss.str()); - } - [[noreturn]] void expected(const std::string& what) { - expected(what, cur_); - } - Token expect(int kind) { - if (cur_.kind != kind) { - expected(kindToString(kind)); - } - return next(); - } - Token& lookahead() { - if (!lookahead_.valid) { - lookahead_.valid = true; - std::tie(lookahead_.t, lookahead_.repeat) = lex(); - } - return lookahead_.t; - } - Token& cur() { - return cur_; - } - - private: - // token, number of times to repeat it - std::pair lex() { - auto r = lexRaw(); - int repeat = 0; - switch (r.kind) { - case '(': - case '[': - case '{': - nesting++; - break; - case ')': - case ']': - case '}': - nesting--; - break; - case TK_WHITESPACE: { - size_t depth = r.range.size(); - if (depth > indent_stack.back()) { - indent_stack.push_back(depth); - r.kind = TK_INDENT; - } else if (depth == indent_stack.back()) { - r.kind = TK_NEWLINE; - } else { - while (indent_stack.back() != depth) { - indent_stack.pop_back(); - repeat++; - if (indent_stack.size() == 0) { - reportError("invalid ident level", r); - } - } - repeat--; // first repeat is this return - r.kind = TK_DEDENT; - } - } break; - case TK_EOF: - if (indent_stack.size() > 1) { - r.kind = TK_DEDENT; - indent_stack.pop_back(); - } - break; - default: - break; - } - return std::make_pair(r, repeat); - } - Token lexRaw(bool whitespace_token = false) { - int kind; - size_t start; - size_t length; - assert(file); - if (!shared.match( - *file, - pos, - nesting > 0, - whitespace_token, - &kind, - &start, - &length)) { - expected( - "a valid token", - Token((*file)[start], SourceRange(file, start, start + 1))); - } - auto t = Token(kind, SourceRange(file, start, start + length)); - pos = start + length; - return t; - } - size_t pos; - Token cur_; - Lookahead lookahead_; - size_t repeat; // how many times to repeat the current token until we continue - - size_t nesting; // depth of ( [ { nesting... - std::vector indent_stack; // stack of identation level of blocks - SharedParserData& shared; -}; -} // namespace script -} // namespace caffe2 diff --git a/caffe2/contrib/script/parser.h b/caffe2/contrib/script/parser.h deleted file mode 100644 index 4b68b8d..0000000 --- a/caffe2/contrib/script/parser.h +++ /dev/null @@ -1,418 +0,0 @@ -#pragma once -#include "lexer.h" -#include "tree.h" -#include "tree_views.h" - -namespace caffe2 { -namespace script { - -struct Parser { - explicit Parser(const std::string& str) - : L(str), shared(sharedParserData()) {} - - TreeRef parseIdent() { - auto t = L.expect(TK_IDENT); - // whenever we parse something that has a TreeView type we always - // use its create method so that the accessors and the constructor - // of the Compound tree are in the same place. - return Ident::create(t.range, t.text()); - } - TreeRef createApply(TreeRef ident, TreeList& inputs) { - TreeList attributes; - auto range = L.cur().range; - parseOperatorArguments(inputs, attributes); - return Apply::create( - range, - ident, - List(range, std::move(inputs)), - List(range, std::move(attributes))); - } - // things like a 1.0 or a(4) that are not unary/binary expressions - // and have higher precedence than all of them - TreeRef parseBaseExp() { - TreeRef prefix; - switch (L.cur().kind) { - case TK_NUMBER: - case TK_TRUE: - case TK_FALSE: { - prefix = parseConst(); - } break; - case '(': { - L.next(); - prefix = parseExp(); - L.expect(')'); - } break; - case TK_FLOAT: - case TK_INT: - case TK_LONG: { - auto r = L.cur().range; - auto type = c(L.next().kind, r, {}); - L.expect('('); - auto exp = parseExp(); - L.expect(')'); - prefix = Cast::create(r, type, exp); - } break; - default: { - prefix = parseIdent(); - if (L.cur().kind == '(') { - TreeList inputs; - prefix = createApply(prefix, inputs); - } - } break; - } - while (true) { - if (L.nextIf('.')) { - const auto name = parseIdent(); - if (L.cur().kind == '(') { - TreeList inputs = {prefix}; - prefix = createApply(name, inputs); - } else { - prefix = Select::create(name->range(), prefix, name); - } - } else if (L.cur().kind == '[') { - prefix = parseSliceOrGather(prefix); - } else { - break; - } - } - return prefix; - } - TreeRef parseOptionalReduction() { - auto r = L.cur().range; - switch (L.cur().kind) { - case TK_PLUS_EQ: - case TK_MINUS_EQ: - case TK_TIMES_EQ: - case TK_DIV_EQ: { - int modifier = L.next().text()[0]; - return c(modifier, r, {}); - } break; - default: { - L.expect('='); - return c('=', r, {}); // no reduction - } break; - } - } - TreeRef - parseTrinary(TreeRef true_branch, const SourceRange& range, int binary_prec) { - auto cond = parseExp(); - L.expect(TK_ELSE); - auto false_branch = parseExp(binary_prec); - return c(TK_IF_EXPR, range, {cond, true_branch, false_branch}); - } - // parse the longest expression whose binary operators have - // precedence strictly greater than 'precedence' - // precedence == 0 will parse _all_ expressions - // this is the core loop of 'top-down precedence parsing' - TreeRef parseExp(int precedence = 0) { - TreeRef prefix = nullptr; - int unary_prec; - if (shared.isUnary(L.cur().kind, &unary_prec)) { - auto kind = L.cur().kind; - auto pos = L.cur().range; - L.next(); - prefix = c(kind, pos, {parseExp(unary_prec)}); - } else { - prefix = parseBaseExp(); - } - int binary_prec; - while (shared.isBinary(L.cur().kind, &binary_prec)) { - if (binary_prec <= precedence) // not allowed to parse something which is - // not greater than 'precedenc' - break; - - int kind = L.cur().kind; - auto pos = L.cur().range; - L.next(); - if (shared.isRightAssociative(kind)) - binary_prec--; - - // special case for trinary operator - if (kind == TK_IF) { - prefix = parseTrinary(prefix, pos, binary_prec); - continue; - } - - prefix = c(kind, pos, {prefix, parseExp(binary_prec)}); - } - return prefix; - } - TreeRef - parseList(int begin, int sep, int end, std::function parse) { - auto r = L.cur().range; - L.expect(begin); - TreeList elements; - if (L.cur().kind != end) { - int i = 0; - do { - elements.push_back(parse(i++)); - } while (L.nextIf(sep)); - } - L.expect(end); - return c(TK_LIST, r, std::move(elements)); - } - TreeRef parseNonEmptyList(int sep, std::function parse) { - TreeList elements; - int i = 0; - do { - elements.push_back(parse(i++)); - } while (L.nextIf(sep)); - return c(TK_LIST, elements[0]->range(), std::move(elements)); - } - TreeRef parseExpList() { - return parseList('(', ',', ')', [&](int i) { return parseExp(); }); - } - TreeRef parseConst() { - // 'b' - boolean - // 'LL' 64-bit integer - // 'f' single-precision float - // 'i' 32-bit integer - // 'f' is default if '.' appears in the number - auto range = L.cur().range; - if (L.nextIf(TK_TRUE)) { - return c(TK_CONST, range, {d(1), s("b")}); - } else if (L.nextIf(TK_FALSE)) { - return c(TK_CONST, range, {d(0), s("b")}); - } - float mult = 1.0f; - while (L.nextIf('-')) { - mult *= -1.0f; - } - auto t = L.expect(TK_NUMBER); - std::string type_ident = - (t.text().find('.') == std::string::npos) ? "i" : "f"; - if (L.cur().kind == TK_IDENT) { - Token type_ident_tok = L.expect(TK_IDENT); - type_ident = type_ident_tok.text(); - if (type_ident != "LL" && type_ident != "f") { - throw ErrorReport(type_ident_tok) - << "expected 'f' or 'LL' " - << "as numeric type identifier but found '" << type_ident << "'"; - } - } - return c(TK_CONST, t.range, {d(mult * t.doubleValue()), s(type_ident)}); - } - TreeRef parseAttributeValue() { - int kind = L.cur().kind; - switch (kind) { - case '[': - return parseList('[', ',', ']', [&](int i) { return parseConst(); }); - default: - return parseConst(); - } - } - void parseOperatorArguments(TreeList& inputs, TreeList& attributes) { - L.expect('('); - if (L.cur().kind != ')') { - do { - if (L.cur().kind == TK_IDENT && L.lookahead().kind == '=') { - auto ident = parseIdent(); - L.expect('='); - auto v = parseAttributeValue(); - attributes.push_back(Attribute::create(ident->range(), ident, v)); - } else { - inputs.push_back(parseExp()); - } - } while (L.nextIf(',')); - } - L.expect(')'); - } - - // OK: [a] (gather), [a:], [:a], [a:b], [:] (slice) - // Not OK: [] - TreeRef parseSliceOrGather(TreeRef value) { - const auto range = L.cur().range; - L.expect('['); - - // `first` will either be the gather indices, or the start of the slice. - TreeRef first, second; - - // Here we can either have a colon (which starts a slice), or an expression. - // If an expression, we don't know yet if it will be a slice or a gather. - if (L.cur().kind != ':') { - first = parseExp(); - if (L.nextIf(']')) { - return Gather::create(range, value, first); - } else { - first = c(TK_OPTION, range, {first}); - } - } else { - first = c(TK_OPTION, range, {}); - } - L.expect(':'); - // Now we *may* have an expression. - if (L.cur().kind != ']') { - second = c(TK_OPTION, range, {parseExp()}); - } else { - second = c(TK_OPTION, range, {}); - } - L.expect(']'); - - return Slice::create(range, value, first, second); - } - TreeRef parseIdentList() { - return parseList('(', ',', ')', [&](int i) { return parseIdent(); }); - } - TreeRef parseParam() { - auto typ = parseType(); - if (L.cur().kind != TK_IDENT && typ->trees()[0]->kind() == TK_IDENT) { - // oops, it wasn't a type but just a param without any type specified - return Param::create( - typ->range(), typ->trees()[0], c(TK_INFERRED, typ->range(), {})); - } - auto ident = parseIdent(); - return Param::create(typ->range(), ident, typ); - } - // TODO: these functions should be unnecessary, but we currently do not - // emit a TK_NEWLINE before a series of TK_DEDENT tokens - // so if we see a TK_DEDENT then we know a newline must have happened and - // ignore it. The real fix is to patch the lexer so TK_NEWLINE does get - // emited before a TK_INDENT - void expectEndOfLine() { - if (L.cur().kind != TK_DEDENT) - L.expect(TK_NEWLINE); - } - bool isEndOfLine() { - return L.cur().kind == TK_NEWLINE || L.cur().kind == TK_DEDENT; - } - - // 'first' has already been parsed since expressions can exist - // alone on a line: - // first[,other,lhs] = rhs - TreeRef parseAssign(TreeRef first) { - TreeRef list = parseOneOrMoreExp(first); - auto red = parseOptionalReduction(); - auto rhs = parseExp(); - expectEndOfLine(); - return Assign::create(list->range(), list, red, rhs); - } - TreeRef parseStmt() { - switch (L.cur().kind) { - case TK_IF: - return parseIf(); - case TK_WHILE: - return parseWhile(); - case TK_GLOBAL: { - auto range = L.next().range; - std::vector idents; - do { - idents.push_back(parseIdent()); - } while (L.nextIf(',')); - expectEndOfLine(); - return c(TK_GLOBAL, range, std::move(idents)); - } - default: { - auto r = parseExp(); - if (!isEndOfLine()) { - return parseAssign(r); - } else { - expectEndOfLine(); - return r; - } - } - } - } - TreeRef parseScalarType() { - switch (L.cur().kind) { - case TK_INT: - case TK_FLOAT: - case TK_LONG: - case TK_DOUBLE: { - auto t = L.next(); - return c(t.kind, t.range, {}); - } - default: - return parseIdent(); - } - } - TreeRef parseOptionalIdentList() { - TreeRef list = nullptr; - if (L.cur().kind == '(') { - list = parseIdentList(); - } else { - list = c(TK_LIST, L.cur().range, {}); - } - return list; - } - TreeRef parseType() { - auto st = parseScalarType(); - auto list = parseOptionalIdentList(); - return TensorType::create(st->range(), st, list); - } - // 'first' has already been parsed, add the rest - // if they exist - // first[, the, rest] - TreeRef parseOneOrMoreExp(TreeRef first) { - TreeList list{first}; - while (L.nextIf(',')) { - list.push_back(parseExp()); - } - return List(list.back()->range(), std::move(list)); - } - TreeRef parseIf() { - auto r = L.cur().range; - L.expect(TK_IF); - auto cond = parseExp(); - L.expect(':'); - auto true_branch = parseStatements(); - auto false_branch = List(L.cur().range, {}); - if (L.nextIf(TK_ELSE)) { - L.expect(':'); - false_branch = parseStatements(); - } - return If::create(r, cond, true_branch, false_branch); - } - TreeRef parseWhile() { - auto r = L.cur().range; - L.expect(TK_WHILE); - auto cond = parseExp(); - L.expect(':'); - auto body = parseStatements(); - return While::create(r, cond, body); - } - TreeRef parseStatements() { - auto r = L.cur().range; - L.expect(TK_INDENT); - TreeList stmts; - while (true) { - stmts.push_back(parseStmt()); - if (L.nextIf(TK_DEDENT)) - break; - } - return c(TK_LIST, r, std::move(stmts)); - } - TreeRef parseFunction() { - L.expect(TK_DEF); - auto name = parseIdent(); - auto paramlist = - parseList('(', ',', ')', [&](int i) { return parseParam(); }); - L.expect(TK_ARROW); - auto retlist = - parseList('(', ',', ')', [&](int i) { return parseParam(); }); - L.expect(':'); - auto stmts_list = parseStatements(); - return Def::create(name->range(), name, paramlist, retlist, stmts_list); - } - Lexer& lexer() { - return L; - } - - private: - // short helpers to create nodes - TreeRef d(double v) { - return Number::create(v); - } - TreeRef s(const std::string& s) { - return String::create(s); - } - TreeRef c(int kind, const SourceRange& range, TreeList&& trees) { - return Compound::create(kind, range, std::move(trees)); - } - TreeRef List(const SourceRange& range, TreeList&& trees) { - return c(TK_LIST, range, std::move(trees)); - } - Lexer L; - SharedParserData& shared; -}; -} // namespace script -} // namespace caffe2 diff --git a/caffe2/contrib/script/tree.h b/caffe2/contrib/script/tree.h deleted file mode 100644 index c508308..0000000 --- a/caffe2/contrib/script/tree.h +++ /dev/null @@ -1,233 +0,0 @@ -#pragma once - -#include -#include - -#include "caffe2/contrib/script/lexer.h" - -namespace caffe2 { -namespace script { - -// Tree's are used to represent all forms of TC IR, pre- and post- typechecking. -// Rather than have a full class hierarchy for all TC statements, -// Trees are a slight variation of Lisp S-expressions. -// for instance the expression a*b+1 is represented as: -// (+ (* (ident a) (ident b)) (const 1)) -// Atoms like 'a', 'b', and '1' are represented by subclasses of Tree which -// define stringValue() and doubleValue(). -// Everything else is a Compound object, which has a 'kind' that is a token from -// Lexer.h's TokenKind enum, and contains a list of subtrees. -// Like TokenKind single-character operators like '+' are representing using the -// character itself, so add.kind() == '+'. -// Compound objects are also always associated with a SourceRange for -// reporting error message. - -// Memory management of trees is done using shared_ptr. - -struct Tree; -using TreeRef = std::shared_ptr; -using TreeList = std::vector; - -static const TreeList empty_trees = {}; - -struct Tree : std::enable_shared_from_this { - Tree(int kind_) : kind_(kind_) {} - int kind() const { - return kind_; - } - virtual bool isAtom() const { - return true; - } - virtual const SourceRange& range() const { - throw std::runtime_error("is an Atom"); - } - virtual double doubleValue() const { - throw std::runtime_error("not a TK_NUMBER"); - } - virtual const std::string& stringValue() const { - throw std::runtime_error("not a TK_STRING"); - } - virtual bool boolValue() const { - throw std::runtime_error("not a TK_BOOL"); - } - virtual const TreeList& trees() const { - return empty_trees; - } - const TreeRef& tree(size_t i) const { - return trees().at(i); - } - virtual TreeRef map(std::function /*fn*/) { - return shared_from_this(); - } - template - void match(int k, Args&... args) { - matchD(k, "unknown", 0, args...); - } - template - void matchD(int k, const char* filename, int lineno, Args&... args) { - if (kind() != k) { - std::stringstream ss; - ss << filename << ":" << lineno << ": expecting kind '" << kindToString(k) - << "' but found '" << kind() << "'\n"; - range().highlight(ss); - throw std::runtime_error(ss.str()); - } - std::initializer_list vars = {&args...}; - if (vars.size() > trees().size()) { - std::stringstream ss; - ss << filename << ":" << lineno << ": trying to match " << vars.size() - << " variables against " << trees().size() << " values in list.\n"; - range().highlight(ss); - throw std::runtime_error(ss.str()); - } - size_t i = 0; - for (TreeRef* v : vars) { - *v = trees()[i++]; - } - } - virtual ~Tree() {} - - private: - int kind_; -}; - -struct String : public Tree { - String(const std::string& value_) : Tree(TK_STRING), value_(value_) {} - virtual const std::string& stringValue() const override { - return value_; - } - template - static TreeRef create(Args&&... args) { - return std::make_shared(std::forward(args)...); - } - - private: - std::string value_; -}; -struct Number : public Tree { - Number(double value_) : Tree(TK_NUMBER), value_(value_) {} - virtual double doubleValue() const override { - return value_; - } - template - static TreeRef create(Args&&... args) { - return std::make_shared(std::forward(args)...); - } - - private: - double value_; -}; -struct Bool : public Tree { - Bool(bool value_) : Tree(TK_BOOL), value_(value_) {} - virtual double doubleValue() const override { - return value_; - } - template - static TreeRef create(Args&&... args) { - return std::make_shared(std::forward(args)...); - } - - private: - bool value_; -}; - -static SourceRange mergeRanges(SourceRange c, const TreeList& others) { - for (auto t : others) { - if (t->isAtom()) - continue; - size_t s = std::min(c.start(), t->range().start()); - size_t e = std::max(c.end(), t->range().end()); - c = SourceRange(c.file_ptr(), s, e); - } - return c; -} - -struct Compound : public Tree { - Compound(int kind, const SourceRange& range_) : Tree(kind), range_(range_) {} - Compound(int kind, const SourceRange& range_, TreeList&& trees_) - : Tree(kind), - range_(mergeRanges(range_, trees_)), - trees_(std::move(trees_)) {} - virtual const TreeList& trees() const override { - return trees_; - } - static TreeRef - create(int kind, const SourceRange& range_, TreeList&& trees_) { - return std::make_shared(kind, range_, std::move(trees_)); - } - virtual bool isAtom() const override { - return false; - } - virtual TreeRef map(std::function fn) override { - TreeList trees_; - for (auto& t : trees()) { - trees_.push_back(fn(t)); - } - return Compound::create(kind(), range(), std::move(trees_)); - } - const SourceRange& range() const override { - return range_; - } - - private: - SourceRange range_; - TreeList trees_; -}; - -// tree pretty printer -struct pretty_tree { - pretty_tree(const TreeRef& tree, size_t col = 40) : tree(tree), col(col) {} - const TreeRef& tree; - size_t col; - std::unordered_map flat_strings; - const std::string& get_flat(const TreeRef& t) { - auto it = flat_strings.find(t); - if (it != flat_strings.end()) - return it->second; - - std::stringstream out; - switch (t->kind()) { - case TK_NUMBER: - out << t->doubleValue(); - break; - case TK_STRING: - out << t->stringValue(); - break; - default: - out << "(" << kindToString(t->kind()); - for (auto e : t->trees()) { - out << " " << get_flat(e); - } - out << ")"; - break; - } - auto it_ = flat_strings.emplace(t, out.str()); - return it_.first->second; - } - void print(std::ostream& out, const TreeRef& t, int indent) { - const std::string& s = get_flat(t); - if (indent + s.size() < col || t->isAtom()) { - out << s; - return; - } - std::string k = kindToString(t->kind()); - out << "(" << k; - for (auto e : t->trees()) { - out << "\n" << std::string(indent + 2, ' '); - print(out, e, indent + 2); - } - out << ")"; - } -}; - -static inline std::ostream& operator<<(std::ostream& out, pretty_tree t_) { - t_.print(out, t_.tree, 0); - return out << std::endl; -} - -static inline std::ostream& operator<<(std::ostream& out, TreeRef t) { - return out << pretty_tree(t); -} - -} // namespace script -} // namespace caffe2 diff --git a/caffe2/contrib/script/tree_views.h b/caffe2/contrib/script/tree_views.h deleted file mode 100644 index 2089333..0000000 --- a/caffe2/contrib/script/tree_views.h +++ /dev/null @@ -1,442 +0,0 @@ -#pragma once -#include "error_report.h" -#include "tree.h" - -namespace caffe2 { -namespace script { - -// TreeView provides a statically-typed way to access the members of a TreeRef -// instead of using TK_MATCH - -struct TreeView { - explicit TreeView(const TreeRef& tree_) : tree_(tree_) {} - TreeRef tree() const { - return tree_; - } - const SourceRange& range() const { - return tree_->range(); - } - operator TreeRef() const { - return tree_; - } - - protected: - TreeRef tree_; -}; - -template -struct ListViewIterator { - ListViewIterator(TreeList::const_iterator it) : it(it) {} - bool operator!=(const ListViewIterator& rhs) const { - return it != rhs.it; - } - T operator*() const { - return T(*it); - } - void operator++() { - ++it; - } - void operator--() { - --it; - } - - private: - TreeList::const_iterator it; -}; - -template -struct ListView : public TreeView { - ListView(const TreeRef& tree) : TreeView(tree) { - tree->match(TK_LIST); - } - typedef ListViewIterator iterator; - typedef ListViewIterator const_iterator; - iterator begin() const { - return iterator(tree_->trees().begin()); - } - iterator end() const { - return iterator(tree_->trees().end()); - } - T operator[](size_t i) const { - return T(tree_->trees().at(i)); - } - TreeRef map(std::function fn) { - return tree_->map([&](TreeRef v) { return fn(T(v)); }); - } - size_t size() const { - return tree_->trees().size(); - } -}; - -template -struct OptionView : public TreeView { - explicit OptionView(const TreeRef& tree) : TreeView(tree) { - C2S_ASSERT(tree, tree->kind() == TK_OPTION); - } - bool present() const { - return tree_->trees().size() > 0; - } - T get() const { - C2S_ASSERT(tree_, present()); - return T(tree_->trees()[0]); - } - TreeRef map(std::function fn) { - return tree_->map([&](TreeRef v) { return fn(T(v)); }); - } -}; - -struct Ident : public TreeView { - // each subclass of TreeView provides: - // 1. a constructor that takes a TreeRef, and matches it to the right type. - explicit Ident(const TreeRef& tree) : TreeView(tree) { - tree_->match(TK_IDENT, name_); - } - // 2. accessors that get underlying information out of the object - // in this case, we return the name of the identifier, and handle the - // converstion to a string in the method - const std::string& name() const { - return name_->stringValue(); - } - - // 3. a static method 'create' that creates the underlying TreeRef object - // for every TreeRef kind that has a TreeView, the parser always uses - // (e.g.) Ident::create rather than Compound::Create, this means that - // changes to the structure of Ident are always made right here rather - // than both in the parser and in this code - static TreeRef create(const SourceRange& range, const std::string& name) { - return Compound::create(TK_IDENT, range, {String::create(name)}); - } - - private: - TreeRef name_; -}; - -struct Attribute : public TreeView { - explicit Attribute(const TreeRef& tree) : TreeView(tree) { - tree_->match(TK_ATTRIBUTE, name_, value_); - } - Ident name() const { - return Ident(name_); - } - TreeRef value() const { - return value_; - } - static TreeRef create(const SourceRange& range, TreeRef name, TreeRef value) { - return Compound::create(TK_ATTRIBUTE, range, {name, value}); - } - - private: - TreeRef name_; - TreeRef value_; -}; - -struct Apply : public TreeView { - explicit Apply(const TreeRef& tree) : TreeView(tree) { - tree_->match(TK_APPLY, name_, inputs_, attributes_); - } - - Ident name() const { - return Ident(name_); - } - ListView inputs() const { - return ListView(inputs_); - } - ListView attributes() const { - return ListView(attributes_); - } - - static TreeRef create( - const SourceRange& range, - TreeRef name, - TreeRef inputs, - TreeRef attributes) { - return Compound::create(TK_APPLY, range, {name, inputs, attributes}); - } - - private: - TreeRef name_; - TreeRef inputs_; - TreeRef attributes_; -}; - -struct Slice : public TreeView { - explicit Slice(const TreeRef& tree) : TreeView(tree) { - tree_->match(TK_SLICE, value_, start_, end_); - } - - TreeRef value() const { - return value_; - } - - OptionView start() const { - return OptionView(start_); - } - - OptionView end() const { - return OptionView(end_); - } - - TreeRef startOr(int alternative) const { - const auto startOption = start(); - return startOption.present() ? startOption.get() : createInt(alternative); - } - - TreeRef endOr(int alternative) const { - const auto endOption = end(); - return endOption.present() ? endOption.get() : createInt(alternative); - } - - static TreeRef - create(const SourceRange& range, TreeRef value, TreeRef start, TreeRef end) { - return Compound::create(TK_SLICE, range, {value, start, end}); - } - - private: - TreeRef createInt(int value) const { - return Compound::create( - TK_CONST, range(), {Number::create(value), String::create("i")}); - } - - TreeRef value_; - TreeRef start_; - TreeRef end_; -}; - -struct Gather : public TreeView { - explicit Gather(const TreeRef& tree) : TreeView(tree) { - tree_->match(TK_GATHER, value_, indices_); - } - - TreeRef value() const { - return value_; - } - - TreeRef indices() const { - return indices_; - } - - static TreeRef - create(const SourceRange& range, TreeRef value, TreeRef indices) { - return Compound::create(TK_GATHER, range, {value, indices}); - } - - private: - TreeRef value_; - TreeRef indices_; -}; - -struct Cast : public TreeView { - explicit Cast(const TreeRef& tree) : TreeView(tree) { - tree_->match(TK_CAST, type_, input_); - } - - int type() const { - return type_->kind(); - } - TreeRef input() const { - return input_; - } - - static TreeRef create(const SourceRange& range, TreeRef type, TreeRef input) { - return Compound::create(TK_CAST, range, {type, input}); - } - - private: - TreeRef type_; - TreeRef input_; -}; - -struct TensorType : public TreeView { - explicit TensorType(const TreeRef& tree) : TreeView(tree) { - tree_->match(TK_TENSOR_TYPE, scalar_type_, dims_); - } - static TreeRef - create(const SourceRange& range, TreeRef scalar_type_, TreeRef dims_) { - return Compound::create(TK_TENSOR_TYPE, range, {scalar_type_, dims_}); - } - int scalarType() const { - if (scalar_type_->kind() == TK_IDENT) - throw ErrorReport(tree_) - << " TensorType has a symbolic ident " << Ident(scalar_type_).name() - << " rather than a concrete type"; - return scalar_type_->kind(); - } - ListView dims() const { - return ListView(dims_); - } - - private: - TreeRef scalar_type_; - TreeRef dims_; -}; - -struct Param : public TreeView { - explicit Param(const TreeRef& tree) : TreeView(tree) { - tree_->match(TK_PARAM, ident_, type_); - } - static TreeRef create(const SourceRange& range, TreeRef ident, TreeRef type) { - return Compound::create(TK_PARAM, range, {ident, type}); - } - // when the type of a field is statically know the accessors return - // the wrapped type. for instance here we know ident_ is an identifier - // so the accessor returns an Ident - // this means that clients can do p.ident().name() to get the name of the - // parameter. - Ident ident() const { - return Ident(ident_); - } - // may be TensorType or TK_INFERRED - TreeRef type() const { - return type_; - } - bool typeIsInferred() const { - return type_->kind() == TK_INFERRED; - } - // helper for when you know the type is not inferred. - TensorType tensorType() const { - return TensorType(type_); - } - - private: - TreeRef ident_; - TreeRef type_; -}; - -struct Assign : public TreeView { - explicit Assign(const TreeRef& tree) : TreeView(tree) { - tree_->match(TK_ASSIGN, lhs_, reduction_, rhs_); - } - static TreeRef create( - const SourceRange& range, - TreeRef lhs, - TreeRef reduction, - TreeRef rhs) { - return Compound::create(TK_ASSIGN, range, {lhs, reduction, rhs}); - } - // when the type of a field is statically know the accessors return - // the wrapped type. for instance here we know ident_ is an identifier - // so the accessor returns an Ident - // this means that clients can do p.ident().name() to get the name of the - // parameter. - ListView lhs() const { - return ListView(lhs_); - } - int reduction() const { - return reduction_->kind(); - } - TreeRef rhs() const { - return rhs_; - } - - private: - TreeRef lhs_; - TreeRef reduction_; - TreeRef rhs_; -}; - -struct Def : public TreeView { - explicit Def(const TreeRef& tree) : TreeView(tree) { - tree->match(TK_DEF, name_, paramlist, retlist, stmts_list); - } - Ident name() const { - return Ident(name_); - } - // ListView helps turn TK_LISTs into vectors of TreeViews - // so that we can, e.g., return lists of parameters - ListView params() const { - return ListView(paramlist); - } - ListView returns() const { - return ListView(retlist); - } - ListView statements() const { - return ListView(stmts_list); - } - static TreeRef create( - const SourceRange& range, - TreeRef name, - TreeRef paramlist, - TreeRef retlist, - TreeRef stmts_list) { - return Compound::create( - TK_DEF, range, {name, paramlist, retlist, stmts_list}); - } - - private: - TreeRef name_; - TreeRef paramlist; - TreeRef retlist; - TreeRef stmts_list; -}; - -struct Select : public TreeView { - explicit Select(const TreeRef& tree) : TreeView(tree) { - tree_->match('.', value_, selector_); - } - TreeRef value() const { - return value_; - } - Ident selector() const { - return Ident(selector_); - } - static TreeRef - create(const SourceRange& range, TreeRef value, TreeRef selector) { - return Compound::create('.', range, {value, selector}); - } - - private: - TreeRef value_; - TreeRef selector_; -}; - -struct If : public TreeView { - explicit If(const TreeRef& tree) : TreeView(tree) { - tree_->match(TK_IF, cond_, true_branch_, false_branch_); - } - const TreeRef& cond() const { - return cond_; - } - ListView trueBranch() const { - return ListView(true_branch_); - } - ListView falseBranch() const { - return ListView(false_branch_); - } - - static TreeRef create( - const SourceRange& range, - TreeRef cond_, - TreeRef true_branch_, - TreeRef false_branch_) { - return Compound::create(TK_IF, range, {cond_, true_branch_, false_branch_}); - } - - private: - TreeRef cond_; - TreeRef true_branch_; - TreeRef false_branch_; -}; - -struct While : public TreeView { - explicit While(const TreeRef& tree) : TreeView(tree) { - tree_->match(TK_WHILE, cond_, body_); - } - const TreeRef& cond() const { - return cond_; - } - ListView body() const { - return ListView(body_); - } - - static TreeRef - create(const SourceRange& range, TreeRef cond_, TreeRef body_) { - return Compound::create(TK_WHILE, range, {cond_, body_}); - } - - private: - TreeRef cond_; - TreeRef body_; -}; - -} // namespace script -} // namespace caffe2 diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc index 709879a..a4a1509 100644 --- a/caffe2/python/pybind_state.cc +++ b/caffe2/python/pybind_state.cc @@ -6,7 +6,6 @@ #include #include -#include "caffe2/contrib/script/compiler.h" #include "caffe2/core/asan.h" #include "caffe2/core/blob_stats.h" #include "caffe2/core/db.h" @@ -938,29 +937,6 @@ void addObjectMethods(py::module& m) { } return pyout; }); - - py::class_(m, "CompilationUnit") - .def(py::init<>()) - .def("define", &script::CompilationUnit::define) - .def("get_proto", &script::CompilationUnit::getProto) - .def( - "create_net", - [](script::CompilationUnit* self, const std::string& name) { - auto net = self->createNet(gWorkspace, name); - CAFFE_ENFORCE(net); - return net; - }) - .def( - "extern", - [](script::CompilationUnit* self, - const std::string& name, - py::object py_proto) { - py::bytes bytes = py_proto.attr("SerializeToString")(); - std::unique_ptr proto(new NetDef()); - CAFFE_ENFORCE(ParseProtoFromLargeString( - bytes.cast(), proto.get())); - self->defineExtern(name, std::move(proto)); - }); } void addGlobalMethods(py::module& m) {