# nix files
.envrc
*.nix
-
-# antlr files
-*.tokens
-*.interp
include(cmake/util/FindVulkan.cmake)
include(cmake/util/FindLLVM.cmake)
include(cmake/util/FindROCM.cmake)
-include(cmake/util/FindANTLR.cmake)
if(EXISTS ${CMAKE_CURRENT_BINARY_DIR}/config.cmake)
include(${CMAKE_CURRENT_BINARY_DIR}/config.cmake)
tvm_option(USE_NNPACK "Build with nnpack support" OFF)
tvm_option(USE_RANDOM "Build with random support" OFF)
tvm_option(USE_MICRO_STANDALONE_RUNTIME "Build with micro.standalone_runtime support" OFF)
-tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF)
tvm_option(USE_CPP_RPC "Build CPP RPC" OFF)
tvm_option(USE_TFLITE "Build with tflite support" OFF)
tvm_option(USE_TENSORFLOW_PATH "TensorFlow root path when use TFLite" none)
include(cmake/modules/ROCM.cmake)
include(cmake/modules/LLVM.cmake)
include(cmake/modules/Micro.cmake)
-include(cmake/modules/ANTLR.cmake)
include(cmake/modules/contrib/BLAS.cmake)
include(cmake/modules/contrib/CODEGENC.cmake)
include(cmake/modules/contrib/DNNL.cmake)
+++ /dev/null
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-if(USE_ANTLR)
- find_antlr(${USE_ANTLR})
- if(ANTLR4)
-
- set(RELAY_PARSER_DIR
- ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm/relay/grammar)
-
- set(RELAY_PARSER
- ${RELAY_PARSER_DIR}/py3/RelayVisitor.py
- ${RELAY_PARSER_DIR}/py3/RelayParser.py
- ${RELAY_PARSER_DIR}/py3/RelayLexer.py)
-
-
- # Generate ANTLR grammar for parsing.
- add_custom_command(OUTPUT ${RELAY_PARSER}
- COMMAND ${ANTLR4} -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3
- DEPENDS ${RELAY_PARSER_DIR}/Relay.g4
- WORKING_DIRECTORY ${RELAY_PARSER_DIR})
-
- add_custom_target(relay_parser ALL DEPENDS ${RELAY_PARSER})
- else()
- message(FATAL_ERROR "Can't find ANTLR4")
- endif()
-endif(USE_ANTLR)
+++ /dev/null
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-#######################################################
-# Enhanced version of find ANTLR.
-#
-# Usage:
-# find_antlr(${USE_ANTLR})
-#
-# - When USE_ANTLR=ON, use auto search by first trying to find antlr4 program,
-# then trying to find antlr-*-complete.jar
-# - When USE_ANTLR=/path/to/antlr-*-complete.jar, use provided jar
-#
-# Provide variables:
-# - ANTLR4
-#
-macro(find_antlr use_antlr)
- set(JAVA_HOME $ENV{JAVA_HOME})
- if (NOT DEFINED JAVA_HOME)
- # Hack to get system to search for Java itself.
- message(STATUS "JAVA_HOME is not defined. Set it to ensure proper use")
- set(JAVA_HOME "/usr")
- endif()
- if(MSVC)
- set(JAVA_PROGRAM ${JAVA_HOME}/java.exe)
- else()
- set(JAVA_PROGRAM ${JAVA_HOME}/bin/java)
- endif()
- message(STATUS "Using Java at " ${JAVA_PROGRAM})
-
- if (${use_antlr} STREQUAL "ON")
- find_program(ANTLR4 antlr4)
- if (NOT ANTLR4)
- file(GLOB_RECURSE ANTLR4JAR
- /usr/local/lib/antlr-*-complete.jar
- /usr/local/Cellar/*antlr-*-complete.jar)
-
- # Get the first element of the list of antlr jars.
- # Sort and reverse the list so the item selected is the highest
- # version in lib or else in Cellar if no lib installation exists.
- list(SORT ANTLR4JAR)
- list(REVERSE ANTLR4JAR)
- list(GET ANTLR4JAR 0 ANTLR4JAR)
-
- set(ANTLR4 ${JAVA_PROGRAM} -jar ${ANTLR4JAR})
- endif()
- elseif(NOT ${use_antlr} STREQUAL "OFF")
- set(ANTLR4 ${JAVA_PROGRAM} -jar ${use_antlr})
- endif()
- message(STATUS "ANTLR4=${ANTLR4}")
-endmacro(find_antlr)
COPY install/ubuntu_install_java.sh /install/ubuntu_install_java.sh
RUN bash /install/ubuntu_install_java.sh
-COPY install/ubuntu_install_antlr.sh /install/ubuntu_install_antlr.sh
-RUN bash /install/ubuntu_install_antlr.sh
-
# Chisel deps for TSIM
COPY install/ubuntu_install_chisel.sh /install/ubuntu_install_chisel.sh
RUN bash /install/ubuntu_install_chisel.sh
COPY install/ubuntu_install_redis.sh /install/ubuntu_install_redis.sh
RUN bash /install/ubuntu_install_redis.sh
-COPY install/ubuntu_install_antlr.sh /install/ubuntu_install_antlr.sh
-RUN bash /install/ubuntu_install_antlr.sh
-
# NNPACK deps
COPY install/ubuntu_install_nnpack.sh /install/ubuntu_install_nnpack.sh
RUN bash /install/ubuntu_install_nnpack.sh
COPY install/ubuntu_install_java.sh /install/ubuntu_install_java.sh
RUN bash /install/ubuntu_install_java.sh
-COPY install/ubuntu_install_antlr.sh /install/ubuntu_install_antlr.sh
-RUN bash /install/ubuntu_install_antlr.sh
-
COPY install/ubuntu_install_nodejs.sh /install/ubuntu_install_nodejs.sh
RUN bash /install/ubuntu_install_nodejs.sh
+++ /dev/null
-#!/bin/bash
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-set -e
-set -u
-set -o pipefail
-
-cd /usr/local/lib
-wget -q https://www.antlr.org/download/antlr-4.7.1-complete.jar
-cd -
set -o pipefail
# install libraries for python package on ubuntu
-pip3 install pylint==1.9.4 six numpy pytest cython decorator scipy tornado typed_ast pytest mypy orderedset antlr4-python3-runtime attrs requests Pillow packaging
+pip3 install pylint==1.9.4 six numpy pytest cython decorator scipy tornado typed_ast pytest mypy orderedset attrs requests Pillow packaging
This script skips the tutorial executions and is useful for quickly check the content.
```bash
-./tests/scrpts/task_sphinx_precheck.sh
+./tests/scripts/task_sphinx_precheck.sh
```
The following script runs the full build which includes tutorial executions.
You will need a gpu CI environment.
```bash
-./tests/scrpts/task_python_docs.sh
+./tests/scripts/task_python_docs.sh
```
pip3 install --user tornado psutil xgboost
- * If you want to build tvm to compile a model, you must use Python 3 and run the following
-
- .. code:: bash
-
- sudo apt install antlr4
- pip3 install --user mypy orderedset antlr4-python3-runtime
-
Install Contrib Libraries
-------------------------
~AttrInitEntry() DMLC_THROW_EXCEPTION {
if (value_missing_) {
std::ostringstream os;
- os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization";
+ os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization."
+ << "If the key is defined check that its type matches the declared type.";
throw AttrError(os.str());
}
}
class SpanNode : public Object {
public:
/*! \brief The source name. */
- SourceName source;
+ SourceName source_name;
/*! \brief The line number. */
int line;
/*! \brief The column offset. */
int column;
+ /*! \brief The end line number. */
+ int end_line;
+ /*! \brief The end column number. */
+ int end_column;
// override attr visitor
void VisitAttrs(AttrVisitor* v) {
- v->Visit("source", &source);
+ v->Visit("source_name", &source_name);
v->Visit("line", &line);
v->Visit("column", &column);
+ v->Visit("end_line", &end_line);
+ v->Visit("end_column", &end_column);
}
bool SEqualReduce(const SpanNode* other, SEqualReducer equal) const {
- return equal(source, other->source) && equal(line, other->line) && equal(column, other->column);
+ return equal(source_name, other->source_name) && equal(line, other->line) &&
+ equal(column, other->column) && equal(end_line, other->end_line) &&
+ equal(end_column, other->end_column);
}
static constexpr const char* _type_key = "Span";
class Span : public ObjectRef {
public:
- TVM_DLL Span(SourceName source, int lineno, int col_offset);
+ TVM_DLL Span(SourceName source_name, int line, int end_line, int column, int end_column);
+
+ /*! \brief Merge two spans into one which captures the combined regions. */
+ TVM_DLL Span Merge(const Span& other);
TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode);
};
namespace tvm {
namespace parser {
-IRModule Parse(std::string file_name, std::string file_content);
+IRModule ParseModule(std::string file_name, std::string file_content);
} // namespace parser
} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file source_map.h
+ * \brief A map from source names to source code.
+ */
+#ifndef TVM_PARSER_SOURCE_MAP_H_
+#define TVM_PARSER_SOURCE_MAP_H_
+
+#include <tvm/ir/span.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include <fstream>
+#include <string>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace parser {
+
+/*! \brief A program source in any language.
+ *
+ * Could represent the source from an ML framework or the internal
+ * source of a TVM program.
+ */
+struct Source {
+ /*! \brief The source name. */
+ SourceName source_name;
+
+ /*! \brief The raw source. */
+ std::string source;
+ /*! \brief A mapping of line breaks into the raw source. */
+ std::vector<std::pair<int, int>> line_map;
+
+ /*! \brief An empty source. */
+ Source() : source_name(), source(), line_map() {}
+
+ /*! \brief Construct a source from a string. */
+ TVM_DLL explicit Source(const SourceName& src_name, const std::string& source);
+
+ TVM_DLL Source(const Source& source)
+ : source_name(source.source_name), source(source.source), line_map(source.line_map) {}
+
+ /*! \brief Generate an error message at a specific line and column with the
+ * annotated message.
+ *
+ * The error is written directly to the `out` std::ostream.
+ *
+ * \param out The output ostream.
+ * \param span The span to report the error at.
+ * \param msg The message to attach.
+ *
+ */
+ // TODO(@jroesch): replace the ostream with an interface for rendering errors.
+ TVM_DLL void ReportAt(std::ostream& out, const Span& span, const std::string& msg) const;
+};
+
+/*!
+ * \brief A mapping from a unique source name to source fragment.
+ */
+class SourceMap;
+/*!
+ * \brief Stores locations in frontend source that generated a node.
+ */
+class SourceMapNode : public Object {
+ public:
+ /*! \brief The source mapping. */
+ Map<SourceName, tvm::String> source_map;
+
+ // override attr visitor
+ void VisitAttrs(AttrVisitor* v) { v->Visit("source_map", &source_map); }
+
+ bool SEqualReduce(const SourceMapNode* other, SEqualReducer equal) const {
+ return equal(source_map, other->source_map);
+ }
+
+ static constexpr const char* _type_key = "SourceMap";
+ TVM_DECLARE_FINAL_OBJECT_INFO(SourceMapNode, Object);
+};
+
+class SourceMap : public ObjectRef {
+ public:
+ TVM_DLL SourceMap(Map<SourceName, tvm::String> source_map);
+
+ TVM_DLL static SourceMap* Get();
+
+ TVM_DEFINE_OBJECT_REF_METHODS(SourceMap, ObjectRef, SourceMapNode);
+};
+
+} // namespace parser
+} // namespace tvm
+
+#endif // TVM_PARSER_SOURCE_MAP_H_
* \param data the input being deconstructed.
* \param clauses The clauses for matching.
* \param complete Indicate if this match is complete.
+ * \param span The span of the expression.
*/
- TVM_DLL Match(Expr data, tvm::Array<Clause> clauses, bool complete = true);
+ TVM_DLL Match(Expr data, tvm::Array<Clause> clauses, bool complete = true, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Match, RelayExpr, MatchNode);
};
/*!
* \brief The constructor
* \param data The data of the constant tensor.
+ * \param span The source span of the expression.
*/
- TVM_DLL explicit Constant(runtime::NDArray data);
+ TVM_DLL explicit Constant(runtime::NDArray data, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode);
};
/*!
* \brief The constructor
* \param fields The fields of a tuple.
+ * \param span The source span of the expression.
*/
- TVM_DLL explicit Tuple(tvm::Array<relay::Expr> fields);
+ TVM_DLL explicit Tuple(tvm::Array<relay::Expr> fields, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Tuple, RelayExpr, TupleNode);
};
hash_reduce.FreeVarHashImpl(this);
}
- TVM_DLL static Var make(String name_hint, Type type_annotation);
-
- TVM_DLL static Var make(Id vid, Type type_annotation);
-
static constexpr const char* _type_key = "relay.Var";
TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode);
};
* \brief The constructor
* \param name_hint The name hint of a variable.
* \param type_annotation The type annotation of a variable.
+ * \param span The source span of the expression.
*/
- TVM_DLL Var(String name_hint, Type type_annotation) : Var(Id(name_hint), type_annotation) {}
+ TVM_DLL Var(String name_hint, Type type_annotation, Span span = Span())
+ : Var(Id(name_hint), type_annotation, span) {}
/*!
* \brief The constructor
* \param vid The unique id of a variable.
* \param type_annotation The type annotation of a variable.
+ * \param span The source span of the expression.
*/
- TVM_DLL Var(Id vid, Type type_annotation);
+ TVM_DLL Var(Id vid, Type type_annotation, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode);
};
* \param args The arguments of the call.
* \param attrs The attributes of the call node.
* \param type_args The type arguments passed to a polymorphic function.
+ * \param span The source span of the expression.
*/
TVM_DLL Call(Expr op, Array<Expr> args, Attrs attrs = Attrs(),
- Array<Type> type_args = Array<Type>());
+ Array<Type> type_args = Array<Type>(), Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Call, RelayExpr, CallNode);
};
* \param var The variable that is bound to.
* \param value The value used to bind to the variable.
* \param body The body of the let binding.
+ * \param span The source span of the expression.
*/
- TVM_DLL Let(Var var, Expr value, Expr body);
+ TVM_DLL Let(Var var, Expr value, Expr body, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Let, RelayExpr, LetNode);
};
* \param cond The condition of a if node.
* \param true_branch The fall through branch
* \param false_branch The branch for execution when condition is false.
+ * \param span The source span of the expression.
*/
- TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch);
+ TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(If, RelayExpr, IfNode);
};
* \brief The constructor
* \param tuple The tuple to get an element from.
* \param index The index for extracting a value in the tuple.
+ * \param span The source span of the expression.
*/
- TVM_DLL TupleGetItem(Expr tuple, int index);
+ TVM_DLL TupleGetItem(Expr tuple, int index, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, RelayExpr, TupleGetItemNode);
};
/*!
* \brief The constructor
* \param value The initial value of the reference.
+ * \param span The source span of the expression.
*/
- TVM_DLL explicit RefCreate(Expr value);
+ TVM_DLL explicit RefCreate(Expr value, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(RefCreate, RelayExpr, RefCreateNode);
};
/*!
* \brief The constructor
* \param ref The reference where to read data.
+ * \param span The source span of the expression.
*/
- TVM_DLL explicit RefRead(Expr ref);
+ TVM_DLL explicit RefRead(Expr ref, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(RefRead, RelayExpr, RefReadNode);
};
hash_reduce(value);
}
- TVM_DLL static RefWrite make(Expr ref, Expr value);
-
static constexpr const char* _type_key = "relay.RefWrite";
TVM_DECLARE_FINAL_OBJECT_INFO(RefWriteNode, ExprNode);
};
* \brief The constructor
* \param ref The reference where data is write to.
* \param value The value to write.
+ * \param span The source span of the expression.
*/
- TVM_DLL RefWrite(Expr ref, Expr value);
+ TVM_DLL RefWrite(Expr ref, Expr value, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(RefWrite, RelayExpr, RefWriteNode);
};
virtual void VisitType(const Type& t);
virtual void VisitClause(const Clause& c);
virtual void VisitPattern(const Pattern& c);
+ virtual void VisitSpan(const Span& span);
protected:
// Internal visiting counter
* \param ret_type The return type of the function.
* \param ty_params The type parameters.
* \param attrs Additional function attributes.
+ * \param span The span of the function.
*/
TVM_DLL Function(tvm::Array<Var> params, Expr body, Type ret_type, tvm::Array<TypeVar> ty_params,
- tvm::DictAttrs attrs = NullValue<DictAttrs>());
+ tvm::DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode);
'psutil',
'xgboost>=1.1.0',
'mypy',
- 'orderedset',
- 'antlr4-python3-runtime']},
+ 'orderedset']},
packages=find_packages(),
package_dir={'tvm': 'tvm'},
"Attribute {} is not supported in operator {}".format(
attr_name, op_name))
"""
+
+@register_error
+class DiagnosticError(TVMError):
+ """Error diagnostics were reported during the execution of a pass.
+
+ See the configured diagnostic renderer for detailed error information.
+ """
col_offset : int
The column offset of the location.
"""
- def __init__(self, source, lineno, col_offset):
+ def __init__(self, source_name, line, end_line, column, end_column):
self.__init_handle_by_constructor__(
- _ffi_api.Span, source, lineno, col_offset)
+ _ffi_api.Span, source_name, line, end_line, column, end_column)
@tvm._ffi.register_object
return _ffi_api.ParseExpr("string", source)
def fromtext(source, source_name="from_string"):
- return parse(str(source), str(source_name))
+ return parse(source, source_name)
from . import prelude
from . import loops
from . import scope_builder
-from . import parser
from . import transform
from . import analysis
# Prelude
Prelude = prelude.Prelude
-# Scope builder
+# Scope Builder
ScopeBuilder = scope_builder.ScopeBuilder
-# Parser
-fromtext = parser.fromtext
-
# Param Serialization
save_param_dict = param_dict.save_param_dict
load_param_dict = param_dict.load_param_dict
+++ /dev/null
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# pylint: disable=invalid-name, unused-argument
-"""A parser for Relay's text format."""
-from __future__ import absolute_import
-
-import sys
-from ast import literal_eval
-from collections import deque
-
-try:
- # no typing.Deque in Python 3.5
- # https://bugs.python.org/issue29011
- from typing import Any, Dict, List, Optional, TypeVar, Tuple, Union, MutableSequence, T, Deque
-except ImportError:
- class Deque(deque, MutableSequence[T], extra=deque):
-
- def __new__(cls, *args, **kwds):
- if _geqv(cls, Deque):
- raise TypeError("Type Deque cannot be instantiated; "
- "use deque() instead")
- return deque.__new__(cls, *args, **kwds)
-
-import tvm
-import tvm.ir._ffi_api
-from tvm.ir import IRModule
-
-from .base import Span, SourceName
-from . import adt
-from . import expr
-from . import function
-from . import ty
-from . import op
-
-PYTHON_VERSION = sys.version_info.major
-try:
- from antlr4 import InputStream, CommonTokenStream
- from antlr4.error.ErrorListener import ErrorListener
-except ImportError:
- raise Exception("Couldn't find ANTLR runtime." +
- "Try running `pip{version} install antlr4-python{version}-runtime`."
- .format(version=PYTHON_VERSION))
-
-try:
- from .grammar.py3.RelayVisitor import RelayVisitor
- from .grammar.py3.RelayParser import RelayParser
- from .grammar.py3.RelayLexer import RelayLexer
-except ImportError:
- raise Exception("Couldn't find ANTLR parser. Try building with USE_ANTLR=ON.")
-
-
-sys.setrecursionlimit(10000)
-
-class ParseError(Exception):
- """Exception type for parse errors."""
-
- def __init__(self, message: str) -> None:
- super(ParseError, self).__init__()
- self.message = message
-
- def __repr__(self):
- return "ParseError({})".format(self.message)
-
- def __str__(self):
- return repr(self)
-
-class OpWrapper:
- """Overload the __call__ for op."""
-
-
-class ExprOp(OpWrapper):
- """Call an expr. The default, but does not handle attrs well."""
- def __init__(self, operator):
- self.operator = operator
-
- def __call__(self, args, attrs, type_args):
- try:
- return expr.Call(self.operator, args, attrs, type_args)
- except Exception:
- raise Exception("Operator {} is not registered. It's attributes are {}"
- .format(self.operator, attrs))
-
-class FuncOp(OpWrapper):
- """Convert the attrs, call the python function with the attrs passed in as keyword arguments.
- Tvm should provide this in the future, as this is pretty similar to what op.get is providing.
- """
- def __init__(self, operator):
- self.operator = operator
-
- def convert(self, v):
- if isinstance(v, tuple):
- return tuple([self.convert(x) for x in v])
- if isinstance(v, expr.Constant):
- return v.data.asnumpy().item()
- if isinstance(v, str):
- return v
- raise Exception(v)
-
- def __call__(self, args, attrs, type_args):
- if attrs is None:
- attrs = {}
- if self.operator in (op.strided_slice,):
- x = self.operator(*args)
- else:
- x = self.operator(*args, **{k: self.convert(v) for k, v in attrs.items()})
- if isinstance(x, expr.TupleWrapper):
- x = x.astuple()
- return x
-
-BINARY_OPS = {
- RelayParser.MUL: op.multiply,
- RelayParser.DIV: op.divide,
- RelayParser.ADD: op.add,
- RelayParser.SUB: op.subtract,
- RelayParser.LT: op.less,
- RelayParser.GT: op.greater,
- RelayParser.LE: op.less_equal,
- RelayParser.GE: op.greater_equal,
- RelayParser.EQ: op.equal,
- RelayParser.NE: op.not_equal,
-}
-
-FUNC_OPS = {
- "nn.conv2d": op.nn.conv2d,
- "nn.batch_norm": op.nn.batch_norm,
- "nn.dense": op.nn.dense,
- "nn.bias_add": op.nn.bias_add,
- "nn.max_pool2d": op.nn.max_pool2d,
- "nn.max_pool3d": op.nn.max_pool3d,
- "nn.global_max_pool2d": op.nn.global_max_pool2d,
- "nn.avg_pool2d": op.nn.avg_pool2d,
- "nn.avg_pool3d": op.nn.avg_pool3d,
- "nn.global_avg_pool2d": op.nn.global_avg_pool2d,
- "nn.softmax": op.nn.softmax,
- "reshape": op.reshape,
- "nn.conv2d_transpose": op.nn.conv2d_transpose,
- "nn.conv1d_transpose": op.nn.conv1d_transpose,
- "concatenate": op.concatenate,
- "nn.dropout": op.nn.dropout_raw,
- "zeros": op.zeros,
- "split": op.split,
- "cast": op.cast,
- "clip": op.clip,
- "right_shift": op.right_shift,
-}
-
-TYPE_PREFIXES = [
- "int",
- "uint",
- "float",
- "bool",
-]
-
-T = TypeVar("T")
-Scope = Deque[Tuple[str, T]]
-Scopes = Deque[Scope[T]]
-
-def lookup(scopes: Scopes[T], name: str) -> Optional[T]:
- """Look up `name` in `scopes`."""
-
- for scope in scopes:
- for key, val in scope:
- if key == name:
- return val
- return None
-
-def spanify(f):
- """A decorator which attaches span information
- to the value returned by calling `f`.
-
- Intended for use with the below AST visiting
- methods. The idea is that after we do the work
- of constructing the AST we attach Span information.
- """
-
- def _wrapper(*args, **kwargs):
- # Assumes 0th arg is self and gets source_name from object.
- sn = args[0].source_name
- # Assumes 1st arg is an ANTLR parser context.
- ctx = args[1]
- ast = f(*args, **kwargs)
- line, col = ctx.getSourceInterval()
- sp = Span(sn, line, col)
- if isinstance(ast, tvm.relay.expr.TupleWrapper):
- ast = ast.astuple()
- tvm.ir._ffi_api.NodeSetSpan(ast, sp)
- return ast
- return _wrapper
-
-# TODO(@jmp): Use https://stackoverflow.com/q/13889941
-# to figure out how to get ANTLR4 to be more unhappy about syntax errors
-class ParseTreeToRelayIR(RelayVisitor):
- """Parse Relay text format into Relay IR."""
-
- def __init__(self, source_name: str) -> None:
- self.source_name = source_name
- self.module = IRModule({}) # type: IRModule
-
- # Adding an empty scope allows naked lets without pain.
- self.var_scopes = deque([deque()]) # type: Scopes[expr.Var]
- self.global_vars = {} # type: Scope[expr.GlobalVar]
- self.type_var_scopes = deque([deque()]) # type: Scopes[ty.TypeVar]
- self.global_type_vars = {} # type: Scope[expr.GlobalVar]
- self.graph_expr = [] # type: List[expr.Expr]
-
- super(ParseTreeToRelayIR, self).__init__()
-
-
- def enter_var_scope(self) -> None:
- """Enter a new Var scope so it can be popped off later."""
- self.var_scopes.appendleft(deque())
-
- def exit_var_scope(self) -> Scope[expr.Var]:
- """Pop off the current Var scope and return it."""
- return self.var_scopes.popleft()
-
- def mk_var(self, name: str, typ: ty.Type = None):
- """Create a new Var and add it to the Var scope."""
- var = expr.Var(name, typ)
- self.var_scopes[0].appendleft((name, var))
- return var
-
- def mk_global_var(self, name: str) -> expr.GlobalVar:
- """Create a new GlobalVar and add it to the GlobalVar scope."""
- if name in self.global_vars:
- raise ParseError("duplicate global var \"{0}\"".format(name))
- var = expr.GlobalVar(name)
- self.global_vars[name] = var
- return var
-
- def enter_type_param_scope(self) -> None:
- """Enter a new TypeVar scope so it can be popped off later."""
- self.type_var_scopes.appendleft(deque())
-
- def exit_type_param_scope(self) -> Scope[ty.TypeVar]:
- """Pop off the current TypeVar scope and return it."""
- return self.type_var_scopes.popleft()
-
- def mk_typ(self, name: str, kind: ty.TypeKind) -> ty.TypeVar:
- """Create a new TypeVar and add it to the TypeVar scope."""
- typ = ty.TypeVar(name, kind)
- self.type_var_scopes[0].append((name, typ))
- return typ
-
- def mk_global_typ_var(self, name, kind):
- # (str, ty.Kind) -> ty.GlobalTypeVar
- """Create a new TypeVar and add it to the TypeVar scope."""
- typ = ty.GlobalTypeVar(name, kind)
- self._check_existing_typ_expr(name, typ)
- self.global_type_vars[name] = typ
- return typ
-
- # TODO(weberlo): rethink whether we should have type constructors mixed with type vars.
- def mk_global_typ_cons(self, name, cons):
- self._check_existing_typ_expr(name, cons)
- self.global_type_vars[name] = cons
-
- def _check_existing_typ_expr(self, name, new_expr):
- if name in self.global_type_vars:
- new_typ_name = self._type_expr_name(new_expr)
- existing_typ_name = self._type_expr_name(self.global_type_vars[name])
- raise ParseError(
- "{0} `{1}` conflicts with existing {2}".format(new_typ_name,\
- name, existing_typ_name))
-
- def _type_expr_name(self, e):
- if isinstance(e, adt.Constructor):
- return "`{0}` ADT constructor".format(e.belong_to.name_hint)
- if isinstance(e, ty.GlobalTypeVar):
- if e.kind == ty.TypeKind.AdtHandle:
- return "ADT definition"
- return "function definition"
-
- def visitProjection(self, ctx):
- return expr.TupleGetItem(self.visit(ctx.expr()), self.visit(ctx.NAT()))
-
- def visitTerminal(self, node) -> Union[expr.Expr, int, float]:
- """Visit lexer tokens that aren't ignored or visited by other functions."""
- node_type = node.getSymbol().type
- node_text = node.getText()
-
- if node_type == RelayLexer.NAT:
- return int(node_text)
- if node_type == RelayLexer.FLOAT:
- return float(node_text[:-1])
- if node_type == RelayLexer.BOOL_LIT:
- if node_text == "True":
- return True
- if node_text == "False":
- return False
- raise ParseError("unrecognized BOOL_LIT: `{}`".format(node_text))
- if node_type == RelayLexer.QUOTED_STRING:
- return literal_eval(node_text)
- raise ParseError("unhandled terminal \"{0}\" of type `{1}`".format(node_text, node_type))
-
- def visitGeneralIdent(self, ctx):
- name = ctx.getText()
- # Look through all type prefixes for a match.
- for type_prefix in TYPE_PREFIXES:
- if name.startswith(type_prefix):
- return ty.scalar_type(name)
- # Next, look it up in the local then global type params.
- type_expr = lookup(self.type_var_scopes, name)
- if type_expr is None:
- type_expr = self.global_type_vars.get(name, None)
- if type_expr is not None:
- # Zero-arity constructor calls fall into the general ident case, so in that case,
- # we construct a constructor call with no args.
- if isinstance(type_expr, adt.Constructor) and not type_expr.inputs:
- type_expr = expr.Call(type_expr, [])
- return type_expr
- # Check if it's an operator.
- op_name = ".".join([name.getText() for name in ctx.CNAME()])
- if op_name in FUNC_OPS:
- return FuncOp(FUNC_OPS[op_name])
- return ExprOp(op.get(op_name))
-
- def visitGlobalVar(self, ctx):
- var_name = ctx.CNAME().getText()
- global_var = self.global_vars.get(var_name, None)
- if global_var is None:
- raise ParseError("unbound global var `{0}`".format(var_name))
- return global_var
-
- def visitLocalVar(self, ctx):
- var_name = ctx.CNAME().getText()
- local_var = lookup(self.var_scopes, var_name)
- if local_var is None:
- raise ParseError("unbound local var `{0}`".format(var_name))
- return local_var
-
- def visitGraphVar(self, ctx):
- return self.graph_expr[int(ctx.NAT().getText())]
-
- def visit_list(self, ctx_list) -> List[Any]:
- """"Visit a list of contexts."""
- assert isinstance(ctx_list, list)
-
- return [self.visit(ctx) for ctx in ctx_list]
-
- def getTypeExpr(self, ctx: Optional[RelayParser.TypeExprContext]) -> Optional[ty.Type]:
- """Return a (possibly None) Relay type."""
- if ctx is None:
- return None
-
- return self.visit(ctx)
-
- def visitProg(self, ctx: RelayParser.ProgContext) -> Union[expr.Expr, IRModule]:
- self.meta = None
- if ctx.METADATA():
- header, data = str(ctx.METADATA()).split("\n", 1)
- assert header == "METADATA:"
- self.meta = tvm.ir.load_json(data)
- if ctx.defn():
- self.visit_list(ctx.defn())
- return self.module
-
- if ctx.expr():
- return self.visit(ctx.expr())
-
- return self.module
-
- # Exprs
- def visitOpIdent(self, ctx) -> tvm.ir.Op:
- op_name = ".".join([name.getText() for name in ctx.CNAME()])
- if op_name in FUNC_OPS:
- return FuncOp(FUNC_OPS[op_name])
- return ExprOp(op.get(op_name))
-
- # pass through
- def visitParen(self, ctx: RelayParser.ParenContext) -> expr.Expr:
- return self.visit(ctx.expr())
-
- # pass through
- def visitTypeParen(self, ctx: RelayParser.TypeParenContext) -> expr.Expr:
- return self.visit(ctx.typeExpr())
-
- # pass through
- def visitBody(self, ctx: RelayParser.BodyContext) -> expr.Expr:
- return self.visit(ctx.expr())
-
- def visitScalarFloat(self, ctx: RelayParser.ScalarFloatContext) -> expr.Constant:
- return expr.const(self.visit(ctx.FLOAT()))
-
- def visitScalarInt(self, ctx: RelayParser.ScalarIntContext) -> expr.Constant:
- return expr.const(self.visit(ctx.NAT()))
-
- def visitScalarBool(self, ctx: RelayParser.ScalarBoolContext) -> expr.Constant:
- return expr.const(self.visit(ctx.BOOL_LIT()))
-
- def visitNeg(self, ctx: RelayParser.NegContext) -> Union[expr.Constant, expr.Call]:
- val = self.visit(ctx.expr())
- if isinstance(val, expr.Constant) and val.data.asnumpy().ndim == 0:
- # fold Neg in for scalars
- return expr.const(-val.data.asnumpy().item())
-
- return op.negative(val)
-
- def visitTuple(self, ctx: RelayParser.TupleContext) -> expr.Tuple:
- tup = self.visit_list(ctx.expr())
- return expr.Tuple(tup)
-
- def visitLet(self, ctx: RelayParser.LetContext) -> expr.Let:
- """Desugar various sequence constructs to Relay Let nodes."""
-
- if ctx.var() is None:
- # anonymous identity
- ident = "_"
- typ = None
- var = self.mk_var(ident, typ)
- else:
- var = self.visitVar(ctx.var())
-
- self.enter_var_scope()
- value = self.visit(ctx.expr(0))
- self.exit_var_scope()
-
- body = self.visit(ctx.expr(1))
-
- return expr.Let(var, value, body)
-
- def visitBinOp(self, ctx: RelayParser.BinOpContext) -> expr.Call:
- """Desugar binary operators."""
- arg0, arg1 = self.visit_list(ctx.expr())
- relay_op = BINARY_OPS.get(ctx.op.type)
-
- if relay_op is None:
- raise ParseError("unimplemented binary op.")
-
- return relay_op(arg0, arg1)
-
- @spanify
- def visitVar(self, ctx: RelayParser.VarContext) -> expr.Var:
- """Visit a single variable."""
- ident = ctx.localVar()
-
- if ident is None:
- raise ParseError("only local ids may be used in vars.")
-
- typeExpr = self.getTypeExpr(ctx.typeExpr())
-
- return self.mk_var(ident.getText()[1:], typeExpr)
-
- def visitVarList(self, ctx: RelayParser.VarListContext) -> List[expr.Var]:
- return self.visit_list(ctx.var())
-
- # TODO: support a larger class of values than just Relay exprs
- def visitAttr(self, ctx: RelayParser.AttrContext) -> Tuple[str, expr.Expr]:
- return (ctx.CNAME().getText(), self.visit(ctx.expr()))
-
- def visitArgNoAttr(self, ctx: RelayParser.ArgNoAttrContext):
- return (self.visit_list(ctx.varList().var()), None)
-
- def visitAttrSeq(self, ctx: RelayParser.AttrSeqContext) -> Dict[str, expr.Expr]:
- return dict(self.visit_list(ctx.attr()))
-
- def visitArgWithAttr(self, ctx: RelayParser.AttrSeqContext) \
- -> Tuple[List[expr.Var], Dict[str, expr.Expr]]:
- return (self.visit_list(ctx.var()), self.visitAttrSeq(ctx.attrSeq()))
-
- def visitArgList(self, ctx: RelayParser.ArgListContext) \
- -> Tuple[Optional[List[expr.Var]], Optional[Dict[str, expr.Expr]]]:
- var_list = self.visit(ctx.varList()) if ctx.varList() else None
- attr_list = self.visit(ctx.attrList()) if ctx.attrList() else None
- return (var_list, attr_list)
-
- def visitMeta(self, ctx: RelayParser.MetaContext):
- type_key = str(ctx.CNAME())
- index = int(self.visit(ctx.NAT()))
- return self.meta[type_key][index]
-
- def mk_func(
- self,
- ctx: Union[RelayParser.FuncContext, RelayParser.DefnContext]) \
- -> function.Function:
- """Construct a function from either a Func or Defn."""
- # Enter var scope early to put params in scope.
- self.enter_var_scope()
- # Capture type params in params.
- self.enter_type_param_scope()
- type_params = ctx.typeParamList()
-
- if type_params is not None:
- type_params = type_params.typeExpr()
- assert type_params
- for ty_param in type_params:
- name = ty_param.getText()
- self.mk_typ(name, ty.TypeKind.Type)
-
- var_list, attr_list = self.visit(ctx.argList())
- if var_list is None:
- var_list = []
- ret_type = self.getTypeExpr(ctx.typeExpr())
-
- body = self.visit(ctx.body())
- # NB(@jroesch): you must stay in the type parameter scope until
- # after you exit the body, you can reference the type parameters
- # of your parent scopes.
- type_params = list(self.exit_type_param_scope())
- if type_params:
- _, type_params = zip(*type_params)
- self.exit_var_scope()
-
- attrs = tvm.ir.make_node("DictAttrs", **attr_list) if attr_list is not None else None
- return function.Function(var_list, body, ret_type, type_params, attrs)
-
- @spanify
- def visitFunc(self, ctx: RelayParser.FuncContext) -> function.Function:
- return self.mk_func(ctx)
-
- # TODO: how to set spans for definitions?
- # @spanify
- def visitFuncDefn(self, ctx: RelayParser.DefnContext) -> None:
- ident_name = ctx.globalVar().getText()[1:]
- ident = self.mk_global_var(ident_name)
- func = self.mk_func(ctx)
- self.module[ident] = func
-
- def handle_adt_header(
- self,
- ctx: Union[RelayParser.ExternAdtDefnContext, RelayParser.AdtDefnContext]):
- """Handles parsing of the name and type params of an ADT definition."""
- adt_name = ctx.generalIdent().getText()
- adt_var = self.mk_global_typ_var(adt_name, ty.TypeKind.AdtHandle)
- # parse type params
- type_params = ctx.typeParamList()
- if type_params is None:
- type_params = []
- else:
- type_params = [self.mk_typ(type_ident.getText(), ty.TypeKind.Type)
- for type_ident in type_params.typeExpr()]
- return adt_var, type_params
-
- def visitExternAdtDefn(self, ctx: RelayParser.ExternAdtDefnContext):
- # TODO(weberlo): update this handler once extern is implemented
- self.enter_type_param_scope()
- adt_var, type_params = self.handle_adt_header(ctx)
- # update module being built
- self.module[adt_var] = adt.TypeData(adt_var, type_params, [])
- self.exit_type_param_scope()
-
- def visitAdtDefn(self, ctx: RelayParser.AdtDefnContext):
- self.enter_type_param_scope()
- adt_var, type_params = self.handle_adt_header(ctx)
- # parse constructors
- adt_cons_defns = ctx.adtConsDefnList()
- if adt_cons_defns is None:
- adt_cons_defns = []
- else:
- adt_cons_defns = adt_cons_defns.adtConsDefn()
- parsed_constructors = []
- for cons_defn in adt_cons_defns:
- inputs = [self.visit(inp) for inp in cons_defn.typeExpr()]
- cons_defn_name = cons_defn.constructorName().getText()
- cons_defn = adt.Constructor(cons_defn_name, inputs, adt_var)
- self.mk_global_typ_cons(cons_defn_name, cons_defn)
- parsed_constructors.append(cons_defn)
- # update module being built
- self.module[adt_var] = adt.TypeData(adt_var, type_params, parsed_constructors)
- self.exit_type_param_scope()
-
- def visitMatch(self, ctx: RelayParser.MatchContext):
- match_type = ctx.matchType().getText()
- if match_type == "match":
- complete_match = True
- elif match_type == "match?":
- complete_match = False
- else:
- raise RuntimeError("unknown match type {0}".format(match_type))
-
- match_data = self.visit(ctx.expr())
- match_clauses = ctx.matchClauseList()
- if match_clauses is None:
- match_clauses = []
- else:
- match_clauses = match_clauses.matchClause()
- parsed_clauses = []
- for clause in match_clauses:
- self.enter_var_scope()
- pattern = self.visit(clause.pattern())
- clause_body = self.visit(clause.expr())
- self.exit_var_scope()
- parsed_clauses.append(adt.Clause(pattern, clause_body))
- return adt.Match(match_data, parsed_clauses, complete=complete_match)
-
- def visitWildcardPattern(self, ctx: RelayParser.WildcardPatternContext):
- return adt.PatternWildcard()
-
- def visitVarPattern(self, ctx: RelayParser.VarPatternContext):
- text = ctx.localVar().getText()
- typ = ctx.typeExpr()
- if typ is not None:
- typ = self.visit(typ)
- var = self.mk_var(text[1:], typ=typ)
- return adt.PatternVar(var)
-
- def visitConstructorPattern(self, ctx: RelayParser.ConstructorPatternContext):
- constructor_name = ctx.constructorName().getText()
- constructor = self.global_type_vars[constructor_name]
- pattern_list = ctx.patternList()
- if pattern_list is None:
- patterns = []
- else:
- patterns = [self.visit(pattern) for pattern in pattern_list.pattern()]
- return adt.PatternConstructor(constructor, patterns)
-
- def visitTuplePattern(self, ctx: RelayParser.TuplePatternContext):
- return adt.PatternTuple([self.visit(pattern) for pattern in ctx.patternList().pattern()])
-
- def visitCallNoAttr(self, ctx: RelayParser.CallNoAttrContext):
- return (self.visit_list(ctx.exprList().expr()), None)
-
- def visitCallWithAttr(self, ctx: RelayParser.CallWithAttrContext):
- return (self.visit_list(ctx.expr()), self.visit(ctx.attrSeq()))
-
- def call(self, func, args, attrs, type_args):
- if isinstance(func, OpWrapper):
- return func(args, attrs, type_args)
- if isinstance(func, adt.Constructor):
- return func(*args)
- return expr.Call(func, args, attrs, type_args)
-
- @spanify
- def visitCall(self, ctx: RelayParser.CallContext) -> expr.Call:
- func = self.visit(ctx.expr())
- args, attrs = self.visit(ctx.callList())
- res = self.call(func, args, attrs, [])
- return res
-
- @spanify
- def visitIfElse(self, ctx: RelayParser.IfElseContext) -> expr.If:
- """Construct a Relay If node. Creates a new scope for each branch."""
- cond = self.visit(ctx.expr())
-
- self.enter_var_scope()
- true_branch = self.visit(ctx.body(0))
- self.exit_var_scope()
-
- self.enter_var_scope()
- false_branch = self.visit(ctx.body(1))
- self.exit_var_scope()
-
- return expr.If(cond, true_branch, false_branch)
-
- @spanify
- def visitGraph(self, ctx: RelayParser.GraphContext) -> expr.Expr:
- """Visit a graph variable assignment."""
- graph_nid = int(ctx.graphVar().getText()[1:])
-
- self.enter_var_scope()
- value = self.visit(ctx.expr(0))
- self.exit_var_scope()
-
- if graph_nid != len(self.graph_expr):
- raise ParseError(
- "expected new graph variable to be `%{}`,".format(len(self.graph_expr)) + \
- "but got `%{}`".format(graph_nid))
- self.graph_expr.append(value)
-
- kont = self.visit(ctx.expr(1))
- return kont
-
- # Types
-
- # pylint: disable=unused-argument
- def visitIncompleteType(self, ctx: RelayParser.IncompleteTypeContext) -> None:
- return None
-
- def visitTypeCallType(self, ctx: RelayParser.TypeCallTypeContext):
- func = self.visit(ctx.generalIdent())
- args = [self.visit(arg) for arg in ctx.typeParamList().typeExpr()]
- return ty.TypeCall(func, args)
-
- def visitParensShape(self, ctx: RelayParser.ParensShapeContext) -> int:
- return self.visit(ctx.shape())
-
- def visitShapeList(self, ctx: RelayParser.ShapeListContext) -> List[int]:
- return self.visit_list(ctx.shape())
-
- def visitTensor(self, ctx: RelayParser.TensorContext):
- return tuple(self.visit_list(ctx.expr()))
-
- def visitTensorType(self, ctx: RelayParser.TensorTypeContext) -> ty.TensorType:
- """Create a simple tensor type. No generics."""
-
- shape = self.visit(ctx.shapeList())
- dtype = self.visit(ctx.typeExpr())
-
- if not isinstance(dtype, ty.TensorType):
- raise ParseError("expected dtype to be a Relay base type.")
-
- dtype = dtype.dtype
-
- return ty.TensorType(shape, dtype)
-
- def visitTupleType(self, ctx: RelayParser.TupleTypeContext) -> ty.TupleType:
- return ty.TupleType(self.visit_list(ctx.typeExpr()))
-
- def visitFuncType(self, ctx: RelayParser.FuncTypeContext) -> ty.FuncType:
- types = self.visit_list(ctx.typeExpr())
-
- arg_types = types[:-1]
- ret_type = types[-1]
-
- return ty.FuncType(arg_types, ret_type, [], None)
-
-def make_parser(data: str) -> RelayParser:
- """Construct a RelayParser a given data stream."""
- input_stream = InputStream(data)
- lexer = RelayLexer(input_stream)
- lexer.addErrorListener(StrictErrorListener(data))
- token_stream = CommonTokenStream(lexer)
- p = RelayParser(token_stream)
- p.addErrorListener(StrictErrorListener(data))
- return p
-
-__source_name_counter__ = 0
-
-class StrictErrorListener(ErrorListener):
- """This ErrorListener fail eagerly on all error, and report the program."""
- def __init__(self, text):
- self.text = text
-
- def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
- raise Exception("Syntax Error in:\n" + self.text)
-
- def reportAmbiguity(self, recognizer, dfa, startIndex, stopIndex, exact, ambigAlts, configs):
- raise Exception("Ambiguity Error in:\n" + self.text)
-
- def reportAttemptingFullContext(self,
- recognizer,
- dfa,
- startIndex,
- stopIndex,
- conflictingAlts,
- configs):
- raise Exception("Attempting Full Context in:\n" + self.text)
-
- def reportContextSensitivity(self, recognizer, dfa, startIndex, stopIndex, prediction, configs):
- raise Exception("Context Sensitivity in:\n" + self.text)
-
-def fromtext(data: str, source_name: str = None) -> Union[expr.Expr, IRModule]:
- """Parse a Relay program."""
- if data == "":
- raise ParseError("cannot parse the empty string.")
-
- global __source_name_counter__
-
- if source_name is None:
- source_name = "source_file{0}".format(__source_name_counter__)
-
- if isinstance(source_name, str):
- source_name = SourceName(source_name)
-
- tree = make_parser(data).prog()
- return ParseTreeToRelayIR(source_name).visit(tree)
expr : tvm.relay.Expr
The input expression.
- binds : Union[Map[tvm.relay.Var, tvm.relay.Expr], Map[str, tvm.relay.Expr]]
+ binds : Map[tvm.relay.Var, tvm.relay.Expr]
The specific bindings.
Returns
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*
- * NOTE: The `USE_ANTLR` option in `config.cmake` must be enabled in order for
- * changes in this file to be reflected by the parser.
- * NOTE: All upper-case rules are *lexer* rules and all camel-case rules are *parser* rules.
- */
-
-grammar Relay;
-
-SEMVER: 'v0.0.4' ;
-
-// Lexing
-// comments
-COMMENT : '/*' (COMMENT|.)*? '*/' -> skip;
-WS : [ \t\n\r]+ -> skip;
-LINE_COMMENT : '//' .*? '\n' -> skip;
-
-fragment ESCAPED_QUOTE : '\\"';
-QUOTED_STRING : '"' ( ESCAPED_QUOTE | ~('\n'|'\r') )*? '"';
-
-// operators
-MUL: '*' ;
-DIV: '/' ;
-ADD: '+' ;
-SUB: '-' ;
-LT: '<' ;
-GT: '>' ;
-LE: '<=' ;
-GE: '>=' ;
-EQ: '==' ;
-NE: '!=' ;
-
-BOOL_LIT
- : 'True'
- | 'False'
- ;
-
-CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ('.' CNAME)* ;
-
-// non-negative floats
-fragment PREFLOAT : NAT ('.' NAT)? EXP?; // 1.35, 1.35E-9, 0.3, 4.5, 1, 1e10 3e4
-
-FLOAT : PREFLOAT 'f';
-
-// non-negative ints
-NAT: DIGIT+ ;
-fragment EXP: [eE] [+\-]? NAT ; // \- since - means "range" inside [...]
-
-fragment LETTER: [a-zA-Z];
-fragment DIGIT: [0-9];
-
-METADATA: 'METADATA:' .*;
-// Parsing
-
-// A Relay program is a list of global definitions or an expression.
-prog: SEMVER (defn* | expr) METADATA? EOF ;
-
-// Covers both operator and type idents
-generalIdent: CNAME ('.' CNAME)*;
-globalVar: '@' CNAME ;
-localVar: '%' ('_' | CNAME) ;
-graphVar: '%' NAT ;
-
-exprList: (expr (',' expr)*)?;
-callList
- : exprList # callNoAttr
- | (expr ',')* attrSeq # callWithAttr
- ;
-
-expr
- // operators
- : '(' expr ')' # paren
- // function application
- | expr '(' callList ')' # call
- | '-' expr # neg
- | expr op=('*'|'/') expr # binOp
- | expr op=('+'|'-') expr # binOp
- | expr op=('<'|'>'|'<='|'>=') expr # binOp
- | expr op=('=='|'!=') expr # binOp
- // function definition
- | func # funcExpr
- // tuples and tensors
- | '(' ')' # tuple
- | '(' expr ',' ')' # tuple
- | '(' expr (',' expr)+ ')' # tuple
- | '[' (expr (',' expr)*)? ']' # tensor
- | 'if' '(' expr ')' body 'else' body # ifElse
- | matchType expr '{' matchClauseList? '}' # match
- | expr '.' NAT # projection
- // sequencing
- | 'let' var '=' expr ';' expr # let
- // sugar for let %_ = expr; expr
- | expr ';;' expr # let
- | graphVar '=' expr ';' expr # graph
- | ident # identExpr
- | scalar # scalarExpr
- | meta # metaExpr
- | QUOTED_STRING # stringExpr
- ;
-
-func: 'fn' typeParamList? '(' argList ')' ('->' typeExpr)? body ;
-defn
- : 'def' globalVar typeParamList? '(' argList ')' ('->' typeExpr)? body # funcDefn
- | 'extern' 'type' generalIdent typeParamList? # externAdtDefn
- | 'type' generalIdent typeParamList? '{' adtConsDefnList? '}' # adtDefn
- ;
-
-constructorName: CNAME ;
-
-adtConsDefnList: adtConsDefn (',' adtConsDefn)* ','? ;
-adtConsDefn: constructorName ('(' typeExpr (',' typeExpr)* ')')? ;
-matchClauseList: matchClause (',' matchClause)* ','? ;
-matchClause: pattern '=>' ('{' expr '}' | expr) ;
-// complete or incomplete match, respectively
-matchType : 'match' | 'match?' ;
-
-patternList: '(' pattern (',' pattern)* ')';
-pattern
- : '_' # wildcardPattern
- | localVar (':' typeExpr)? # varPattern
- | constructorName patternList? # constructorPattern
- | patternList # tuplePattern
- ;
-
-adtCons: constructorName adtConsParamList? ;
-adtConsParamList: '(' adtConsParam (',' adtConsParam)* ')' ;
-adtConsParam: localVar | constructorName ;
-
-argList
- : varList # argNoAttr
- | (var ',')* attrSeq # argWithAttr
- ;
-
-varList: (var (',' var)*)? ;
-var: localVar (':' typeExpr)? ;
-
-attrSeq: attr (',' attr)* ;
-attr: CNAME '=' expr ;
-
-typeExpr
- : '(' ')' # tupleType
- | '(' typeExpr ')' # typeParen
- | '(' typeExpr ',' ')' # tupleType
- | '(' typeExpr (',' typeExpr)+ ')' # tupleType
- | generalIdent typeParamList # typeCallType
- | generalIdent # typeIdentType
- | 'Tensor' '[' shapeList ',' typeExpr ']' # tensorType
- | 'fn' typeParamList? '(' (typeExpr (',' typeExpr)*)? ')' '->' typeExpr # funcType
- | '_' # incompleteType
- ;
-
-typeParamList: '[' typeExpr (',' typeExpr)* ']' ;
-
-shapeList
- : '(' ')'
- | '(' shape (',' shape)+ ')'
- | shape
- ;
-
-meta : 'meta' '[' CNAME ']' '[' NAT ']';
-
-shape
- : meta # metaShape
- | '(' shape ')' # parensShape
- | NAT # intShape
- ;
-
-body: '{' expr '}' ;
-
-scalar
- : FLOAT # scalarFloat
- | NAT # scalarInt
- | BOOL_LIT # scalarBool
- ;
-
-ident
- : generalIdent
- | globalVar
- | localVar
- | graphVar
- ;
+++ /dev/null
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
+++ /dev/null
-Relay* binary
-Relay* linguist-generated=true
-Relay* linguist-detectable=false
+++ /dev/null
-# Generated from /Users/doobs/Code/repo/sampl/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2
-from antlr4 import *
-from io import StringIO
-from typing.io import TextIO
-import sys
-
-
-
-def serializedATN():
- with StringIO() as buf:
- buf.write("\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\2\62")
- buf.write("\u0161\b\1\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4\7")
- buf.write("\t\7\4\b\t\b\4\t\t\t\4\n\t\n\4\13\t\13\4\f\t\f\4\r\t\r")
- buf.write("\4\16\t\16\4\17\t\17\4\20\t\20\4\21\t\21\4\22\t\22\4\23")
- buf.write("\t\23\4\24\t\24\4\25\t\25\4\26\t\26\4\27\t\27\4\30\t\30")
- buf.write("\4\31\t\31\4\32\t\32\4\33\t\33\4\34\t\34\4\35\t\35\4\36")
- buf.write("\t\36\4\37\t\37\4 \t \4!\t!\4\"\t\"\4#\t#\4$\t$\4%\t%")
- buf.write("\4&\t&\4\'\t\'\4(\t(\4)\t)\4*\t*\4+\t+\4,\t,\4-\t-\4.")
- buf.write("\t.\4/\t/\4\60\t\60\4\61\t\61\4\62\t\62\4\63\t\63\4\64")
- buf.write("\t\64\4\65\t\65\4\66\t\66\3\2\3\2\3\3\3\3\3\4\3\4\3\5")
- buf.write("\3\5\3\6\3\6\3\7\3\7\3\b\3\b\3\t\3\t\3\n\3\n\3\13\3\13")
- buf.write("\3\13\3\f\3\f\3\f\3\f\3\f\3\r\3\r\3\16\3\16\3\17\3\17")
- buf.write("\3\17\3\17\3\20\3\20\3\21\3\21\3\22\3\22\3\22\3\23\3\23")
- buf.write("\3\23\3\24\3\24\3\24\3\25\3\25\3\25\3\25\3\26\3\26\3\26")
- buf.write("\3\26\3\26\3\26\3\26\3\27\3\27\3\27\3\27\3\27\3\30\3\30")
- buf.write("\3\30\3\31\3\31\3\31\3\31\3\31\3\31\3\32\3\32\3\32\3\32")
- buf.write("\3\32\3\32\3\32\3\33\3\33\3\34\3\34\3\34\3\34\3\34\3\34")
- buf.write("\3\34\3\35\3\35\3\35\3\35\3\35\3\36\3\36\3\36\3\36\3\36")
- buf.write("\3\36\3\36\3\37\3\37\3\37\3\37\3\37\7\37\u00d7\n\37\f")
- buf.write("\37\16\37\u00da\13\37\3\37\3\37\3\37\3\37\3\37\3 \6 \u00e2")
- buf.write("\n \r \16 \u00e3\3 \3 \3!\3!\3!\3!\7!\u00ec\n!\f!\16!")
- buf.write("\u00ef\13!\3!\3!\3!\3!\3\"\3\"\3\"\3#\3#\3#\7#\u00fb\n")
- buf.write("#\f#\16#\u00fe\13#\3#\3#\3$\3$\3%\3%\3&\3&\3\'\3\'\3(")
- buf.write("\3(\3)\3)\3*\3*\3*\3+\3+\3+\3,\3,\3,\3-\3-\3-\3.\3.\3")
- buf.write(".\3.\3.\3.\3.\3.\3.\5.\u0123\n.\3/\3/\5/\u0127\n/\3/\3")
- buf.write("/\3/\7/\u012c\n/\f/\16/\u012f\13/\3/\3/\7/\u0133\n/\f")
- buf.write("/\16/\u0136\13/\3\60\3\60\3\60\5\60\u013b\n\60\3\60\5")
- buf.write("\60\u013e\n\60\3\61\3\61\3\61\3\62\6\62\u0144\n\62\r\62")
- buf.write("\16\62\u0145\3\63\3\63\5\63\u014a\n\63\3\63\3\63\3\64")
- buf.write("\3\64\3\65\3\65\3\66\3\66\3\66\3\66\3\66\3\66\3\66\3\66")
- buf.write("\3\66\3\66\3\66\7\66\u015d\n\66\f\66\16\66\u0160\13\66")
- buf.write("\5\u00d8\u00ed\u00fc\2\67\3\3\5\4\7\5\t\6\13\7\r\b\17")
- buf.write("\t\21\n\23\13\25\f\27\r\31\16\33\17\35\20\37\21!\22#\23")
- buf.write("%\24\'\25)\26+\27-\30/\31\61\32\63\33\65\34\67\359\36")
- buf.write(";\37= ?!A\"C\2E#G$I%K&M\'O(Q)S*U+W,Y-[.]/_\2a\60c\61e")
- buf.write("\2g\2i\2k\62\3\2\b\5\2\13\f\17\17\"\"\4\2\f\f\17\17\4")
- buf.write("\2GGgg\4\2--//\4\2C\\c|\3\2\62;\2\u016c\2\3\3\2\2\2\2")
- buf.write("\5\3\2\2\2\2\7\3\2\2\2\2\t\3\2\2\2\2\13\3\2\2\2\2\r\3")
- buf.write("\2\2\2\2\17\3\2\2\2\2\21\3\2\2\2\2\23\3\2\2\2\2\25\3\2")
- buf.write("\2\2\2\27\3\2\2\2\2\31\3\2\2\2\2\33\3\2\2\2\2\35\3\2\2")
- buf.write("\2\2\37\3\2\2\2\2!\3\2\2\2\2#\3\2\2\2\2%\3\2\2\2\2\'\3")
- buf.write("\2\2\2\2)\3\2\2\2\2+\3\2\2\2\2-\3\2\2\2\2/\3\2\2\2\2\61")
- buf.write("\3\2\2\2\2\63\3\2\2\2\2\65\3\2\2\2\2\67\3\2\2\2\29\3\2")
- buf.write("\2\2\2;\3\2\2\2\2=\3\2\2\2\2?\3\2\2\2\2A\3\2\2\2\2E\3")
- buf.write("\2\2\2\2G\3\2\2\2\2I\3\2\2\2\2K\3\2\2\2\2M\3\2\2\2\2O")
- buf.write("\3\2\2\2\2Q\3\2\2\2\2S\3\2\2\2\2U\3\2\2\2\2W\3\2\2\2\2")
- buf.write("Y\3\2\2\2\2[\3\2\2\2\2]\3\2\2\2\2a\3\2\2\2\2c\3\2\2\2")
- buf.write("\2k\3\2\2\2\3m\3\2\2\2\5o\3\2\2\2\7q\3\2\2\2\ts\3\2\2")
- buf.write("\2\13u\3\2\2\2\rw\3\2\2\2\17y\3\2\2\2\21{\3\2\2\2\23}")
- buf.write("\3\2\2\2\25\177\3\2\2\2\27\u0082\3\2\2\2\31\u0087\3\2")
- buf.write("\2\2\33\u0089\3\2\2\2\35\u008b\3\2\2\2\37\u008f\3\2\2")
- buf.write("\2!\u0091\3\2\2\2#\u0093\3\2\2\2%\u0096\3\2\2\2\'\u0099")
- buf.write("\3\2\2\2)\u009c\3\2\2\2+\u00a0\3\2\2\2-\u00a7\3\2\2\2")
- buf.write("/\u00ac\3\2\2\2\61\u00af\3\2\2\2\63\u00b5\3\2\2\2\65\u00bc")
- buf.write("\3\2\2\2\67\u00be\3\2\2\29\u00c5\3\2\2\2;\u00ca\3\2\2")
- buf.write("\2=\u00d1\3\2\2\2?\u00e1\3\2\2\2A\u00e7\3\2\2\2C\u00f4")
- buf.write("\3\2\2\2E\u00f7\3\2\2\2G\u0101\3\2\2\2I\u0103\3\2\2\2")
- buf.write("K\u0105\3\2\2\2M\u0107\3\2\2\2O\u0109\3\2\2\2Q\u010b\3")
- buf.write("\2\2\2S\u010d\3\2\2\2U\u0110\3\2\2\2W\u0113\3\2\2\2Y\u0116")
- buf.write("\3\2\2\2[\u0122\3\2\2\2]\u0126\3\2\2\2_\u0137\3\2\2\2")
- buf.write("a\u013f\3\2\2\2c\u0143\3\2\2\2e\u0147\3\2\2\2g\u014d\3")
- buf.write("\2\2\2i\u014f\3\2\2\2k\u0151\3\2\2\2mn\7\60\2\2n\4\3\2")
- buf.write("\2\2op\7B\2\2p\6\3\2\2\2qr\7\'\2\2r\b\3\2\2\2st\7a\2\2")
- buf.write("t\n\3\2\2\2uv\7.\2\2v\f\3\2\2\2wx\7*\2\2x\16\3\2\2\2y")
- buf.write("z\7+\2\2z\20\3\2\2\2{|\7]\2\2|\22\3\2\2\2}~\7_\2\2~\24")
- buf.write("\3\2\2\2\177\u0080\7k\2\2\u0080\u0081\7h\2\2\u0081\26")
- buf.write("\3\2\2\2\u0082\u0083\7g\2\2\u0083\u0084\7n\2\2\u0084\u0085")
- buf.write("\7u\2\2\u0085\u0086\7g\2\2\u0086\30\3\2\2\2\u0087\u0088")
- buf.write("\7}\2\2\u0088\32\3\2\2\2\u0089\u008a\7\177\2\2\u008a\34")
- buf.write("\3\2\2\2\u008b\u008c\7n\2\2\u008c\u008d\7g\2\2\u008d\u008e")
- buf.write("\7v\2\2\u008e\36\3\2\2\2\u008f\u0090\7?\2\2\u0090 \3\2")
- buf.write("\2\2\u0091\u0092\7=\2\2\u0092\"\3\2\2\2\u0093\u0094\7")
- buf.write("=\2\2\u0094\u0095\7=\2\2\u0095$\3\2\2\2\u0096\u0097\7")
- buf.write("h\2\2\u0097\u0098\7p\2\2\u0098&\3\2\2\2\u0099\u009a\7")
- buf.write("/\2\2\u009a\u009b\7@\2\2\u009b(\3\2\2\2\u009c\u009d\7")
- buf.write("f\2\2\u009d\u009e\7g\2\2\u009e\u009f\7h\2\2\u009f*\3\2")
- buf.write("\2\2\u00a0\u00a1\7g\2\2\u00a1\u00a2\7z\2\2\u00a2\u00a3")
- buf.write("\7v\2\2\u00a3\u00a4\7g\2\2\u00a4\u00a5\7t\2\2\u00a5\u00a6")
- buf.write("\7p\2\2\u00a6,\3\2\2\2\u00a7\u00a8\7v\2\2\u00a8\u00a9")
- buf.write("\7{\2\2\u00a9\u00aa\7r\2\2\u00aa\u00ab\7g\2\2\u00ab.\3")
- buf.write("\2\2\2\u00ac\u00ad\7?\2\2\u00ad\u00ae\7@\2\2\u00ae\60")
- buf.write("\3\2\2\2\u00af\u00b0\7o\2\2\u00b0\u00b1\7c\2\2\u00b1\u00b2")
- buf.write("\7v\2\2\u00b2\u00b3\7e\2\2\u00b3\u00b4\7j\2\2\u00b4\62")
- buf.write("\3\2\2\2\u00b5\u00b6\7o\2\2\u00b6\u00b7\7c\2\2\u00b7\u00b8")
- buf.write("\7v\2\2\u00b8\u00b9\7e\2\2\u00b9\u00ba\7j\2\2\u00ba\u00bb")
- buf.write("\7A\2\2\u00bb\64\3\2\2\2\u00bc\u00bd\7<\2\2\u00bd\66\3")
- buf.write("\2\2\2\u00be\u00bf\7V\2\2\u00bf\u00c0\7g\2\2\u00c0\u00c1")
- buf.write("\7p\2\2\u00c1\u00c2\7u\2\2\u00c2\u00c3\7q\2\2\u00c3\u00c4")
- buf.write("\7t\2\2\u00c48\3\2\2\2\u00c5\u00c6\7o\2\2\u00c6\u00c7")
- buf.write("\7g\2\2\u00c7\u00c8\7v\2\2\u00c8\u00c9\7c\2\2\u00c9:\3")
- buf.write("\2\2\2\u00ca\u00cb\7x\2\2\u00cb\u00cc\7\62\2\2\u00cc\u00cd")
- buf.write("\7\60\2\2\u00cd\u00ce\7\62\2\2\u00ce\u00cf\7\60\2\2\u00cf")
- buf.write("\u00d0\7\66\2\2\u00d0<\3\2\2\2\u00d1\u00d2\7\61\2\2\u00d2")
- buf.write("\u00d3\7,\2\2\u00d3\u00d8\3\2\2\2\u00d4\u00d7\5=\37\2")
- buf.write("\u00d5\u00d7\13\2\2\2\u00d6\u00d4\3\2\2\2\u00d6\u00d5")
- buf.write("\3\2\2\2\u00d7\u00da\3\2\2\2\u00d8\u00d9\3\2\2\2\u00d8")
- buf.write("\u00d6\3\2\2\2\u00d9\u00db\3\2\2\2\u00da\u00d8\3\2\2\2")
- buf.write("\u00db\u00dc\7,\2\2\u00dc\u00dd\7\61\2\2\u00dd\u00de\3")
- buf.write("\2\2\2\u00de\u00df\b\37\2\2\u00df>\3\2\2\2\u00e0\u00e2")
- buf.write("\t\2\2\2\u00e1\u00e0\3\2\2\2\u00e2\u00e3\3\2\2\2\u00e3")
- buf.write("\u00e1\3\2\2\2\u00e3\u00e4\3\2\2\2\u00e4\u00e5\3\2\2\2")
- buf.write("\u00e5\u00e6\b \2\2\u00e6@\3\2\2\2\u00e7\u00e8\7\61\2")
- buf.write("\2\u00e8\u00e9\7\61\2\2\u00e9\u00ed\3\2\2\2\u00ea\u00ec")
- buf.write("\13\2\2\2\u00eb\u00ea\3\2\2\2\u00ec\u00ef\3\2\2\2\u00ed")
- buf.write("\u00ee\3\2\2\2\u00ed\u00eb\3\2\2\2\u00ee\u00f0\3\2\2\2")
- buf.write("\u00ef\u00ed\3\2\2\2\u00f0\u00f1\7\f\2\2\u00f1\u00f2\3")
- buf.write("\2\2\2\u00f2\u00f3\b!\2\2\u00f3B\3\2\2\2\u00f4\u00f5\7")
- buf.write("^\2\2\u00f5\u00f6\7$\2\2\u00f6D\3\2\2\2\u00f7\u00fc\7")
- buf.write("$\2\2\u00f8\u00fb\5C\"\2\u00f9\u00fb\n\3\2\2\u00fa\u00f8")
- buf.write("\3\2\2\2\u00fa\u00f9\3\2\2\2\u00fb\u00fe\3\2\2\2\u00fc")
- buf.write("\u00fd\3\2\2\2\u00fc\u00fa\3\2\2\2\u00fd\u00ff\3\2\2\2")
- buf.write("\u00fe\u00fc\3\2\2\2\u00ff\u0100\7$\2\2\u0100F\3\2\2\2")
- buf.write("\u0101\u0102\7,\2\2\u0102H\3\2\2\2\u0103\u0104\7\61\2")
- buf.write("\2\u0104J\3\2\2\2\u0105\u0106\7-\2\2\u0106L\3\2\2\2\u0107")
- buf.write("\u0108\7/\2\2\u0108N\3\2\2\2\u0109\u010a\7>\2\2\u010a")
- buf.write("P\3\2\2\2\u010b\u010c\7@\2\2\u010cR\3\2\2\2\u010d\u010e")
- buf.write("\7>\2\2\u010e\u010f\7?\2\2\u010fT\3\2\2\2\u0110\u0111")
- buf.write("\7@\2\2\u0111\u0112\7?\2\2\u0112V\3\2\2\2\u0113\u0114")
- buf.write("\7?\2\2\u0114\u0115\7?\2\2\u0115X\3\2\2\2\u0116\u0117")
- buf.write("\7#\2\2\u0117\u0118\7?\2\2\u0118Z\3\2\2\2\u0119\u011a")
- buf.write("\7V\2\2\u011a\u011b\7t\2\2\u011b\u011c\7w\2\2\u011c\u0123")
- buf.write("\7g\2\2\u011d\u011e\7H\2\2\u011e\u011f\7c\2\2\u011f\u0120")
- buf.write("\7n\2\2\u0120\u0121\7u\2\2\u0121\u0123\7g\2\2\u0122\u0119")
- buf.write("\3\2\2\2\u0122\u011d\3\2\2\2\u0123\\\3\2\2\2\u0124\u0127")
- buf.write("\7a\2\2\u0125\u0127\5g\64\2\u0126\u0124\3\2\2\2\u0126")
- buf.write("\u0125\3\2\2\2\u0127\u012d\3\2\2\2\u0128\u012c\7a\2\2")
- buf.write("\u0129\u012c\5g\64\2\u012a\u012c\5i\65\2\u012b\u0128\3")
- buf.write("\2\2\2\u012b\u0129\3\2\2\2\u012b\u012a\3\2\2\2\u012c\u012f")
- buf.write("\3\2\2\2\u012d\u012b\3\2\2\2\u012d\u012e\3\2\2\2\u012e")
- buf.write("\u0134\3\2\2\2\u012f\u012d\3\2\2\2\u0130\u0131\7\60\2")
- buf.write("\2\u0131\u0133\5]/\2\u0132\u0130\3\2\2\2\u0133\u0136\3")
- buf.write("\2\2\2\u0134\u0132\3\2\2\2\u0134\u0135\3\2\2\2\u0135^")
- buf.write("\3\2\2\2\u0136\u0134\3\2\2\2\u0137\u013a\5c\62\2\u0138")
- buf.write("\u0139\7\60\2\2\u0139\u013b\5c\62\2\u013a\u0138\3\2\2")
- buf.write("\2\u013a\u013b\3\2\2\2\u013b\u013d\3\2\2\2\u013c\u013e")
- buf.write("\5e\63\2\u013d\u013c\3\2\2\2\u013d\u013e\3\2\2\2\u013e")
- buf.write("`\3\2\2\2\u013f\u0140\5_\60\2\u0140\u0141\7h\2\2\u0141")
- buf.write("b\3\2\2\2\u0142\u0144\5i\65\2\u0143\u0142\3\2\2\2\u0144")
- buf.write("\u0145\3\2\2\2\u0145\u0143\3\2\2\2\u0145\u0146\3\2\2\2")
- buf.write("\u0146d\3\2\2\2\u0147\u0149\t\4\2\2\u0148\u014a\t\5\2")
- buf.write("\2\u0149\u0148\3\2\2\2\u0149\u014a\3\2\2\2\u014a\u014b")
- buf.write("\3\2\2\2\u014b\u014c\5c\62\2\u014cf\3\2\2\2\u014d\u014e")
- buf.write("\t\6\2\2\u014eh\3\2\2\2\u014f\u0150\t\7\2\2\u0150j\3\2")
- buf.write("\2\2\u0151\u0152\7O\2\2\u0152\u0153\7G\2\2\u0153\u0154")
- buf.write("\7V\2\2\u0154\u0155\7C\2\2\u0155\u0156\7F\2\2\u0156\u0157")
- buf.write("\7C\2\2\u0157\u0158\7V\2\2\u0158\u0159\7C\2\2\u0159\u015a")
- buf.write("\7<\2\2\u015a\u015e\3\2\2\2\u015b\u015d\13\2\2\2\u015c")
- buf.write("\u015b\3\2\2\2\u015d\u0160\3\2\2\2\u015e\u015c\3\2\2\2")
- buf.write("\u015e\u015f\3\2\2\2\u015fl\3\2\2\2\u0160\u015e\3\2\2")
- buf.write("\2\23\2\u00d6\u00d8\u00e3\u00ed\u00fa\u00fc\u0122\u0126")
- buf.write("\u012b\u012d\u0134\u013a\u013d\u0145\u0149\u015e\3\b\2")
- buf.write("\2")
- return buf.getvalue()
-
-
-class RelayLexer(Lexer):
-
- atn = ATNDeserializer().deserialize(serializedATN())
-
- decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ]
-
- T__0 = 1
- T__1 = 2
- T__2 = 3
- T__3 = 4
- T__4 = 5
- T__5 = 6
- T__6 = 7
- T__7 = 8
- T__8 = 9
- T__9 = 10
- T__10 = 11
- T__11 = 12
- T__12 = 13
- T__13 = 14
- T__14 = 15
- T__15 = 16
- T__16 = 17
- T__17 = 18
- T__18 = 19
- T__19 = 20
- T__20 = 21
- T__21 = 22
- T__22 = 23
- T__23 = 24
- T__24 = 25
- T__25 = 26
- T__26 = 27
- T__27 = 28
- SEMVER = 29
- COMMENT = 30
- WS = 31
- LINE_COMMENT = 32
- QUOTED_STRING = 33
- MUL = 34
- DIV = 35
- ADD = 36
- SUB = 37
- LT = 38
- GT = 39
- LE = 40
- GE = 41
- EQ = 42
- NE = 43
- BOOL_LIT = 44
- CNAME = 45
- FLOAT = 46
- NAT = 47
- METADATA = 48
-
- channelNames = [ u"DEFAULT_TOKEN_CHANNEL", u"HIDDEN" ]
-
- modeNames = [ "DEFAULT_MODE" ]
-
- literalNames = [ "<INVALID>",
- "'.'", "'@'", "'%'", "'_'", "','", "'('", "')'", "'['", "']'",
- "'if'", "'else'", "'{'", "'}'", "'let'", "'='", "';'", "';;'",
- "'fn'", "'->'", "'def'", "'extern'", "'type'", "'=>'", "'match'",
- "'match?'", "':'", "'Tensor'", "'meta'", "'v0.0.4'", "'*'",
- "'/'", "'+'", "'-'", "'<'", "'>'", "'<='", "'>='", "'=='", "'!='" ]
-
- symbolicNames = [ "<INVALID>",
- "SEMVER", "COMMENT", "WS", "LINE_COMMENT", "QUOTED_STRING",
- "MUL", "DIV", "ADD", "SUB", "LT", "GT", "LE", "GE", "EQ", "NE",
- "BOOL_LIT", "CNAME", "FLOAT", "NAT", "METADATA" ]
-
- ruleNames = [ "T__0", "T__1", "T__2", "T__3", "T__4", "T__5", "T__6",
- "T__7", "T__8", "T__9", "T__10", "T__11", "T__12", "T__13",
- "T__14", "T__15", "T__16", "T__17", "T__18", "T__19",
- "T__20", "T__21", "T__22", "T__23", "T__24", "T__25",
- "T__26", "T__27", "SEMVER", "COMMENT", "WS", "LINE_COMMENT",
- "ESCAPED_QUOTE", "QUOTED_STRING", "MUL", "DIV", "ADD",
- "SUB", "LT", "GT", "LE", "GE", "EQ", "NE", "BOOL_LIT",
- "CNAME", "PREFLOAT", "FLOAT", "NAT", "EXP", "LETTER",
- "DIGIT", "METADATA" ]
-
- grammarFileName = "Relay.g4"
-
- def __init__(self, input=None, output:TextIO = sys.stdout):
- super().__init__(input, output)
- self.checkVersion("4.7.2")
- self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache())
- self._actions = None
- self._predicates = None
-
-
+++ /dev/null
-# Generated from /Users/doobs/Code/repo/sampl/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2
-# encoding: utf-8
-from antlr4 import *
-from io import StringIO
-from typing.io import TextIO
-import sys
-
-
-def serializedATN():
- with StringIO() as buf:
- buf.write("\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\3\62")
- buf.write("\u0200\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4\7\t\7")
- buf.write("\4\b\t\b\4\t\t\t\4\n\t\n\4\13\t\13\4\f\t\f\4\r\t\r\4\16")
- buf.write("\t\16\4\17\t\17\4\20\t\20\4\21\t\21\4\22\t\22\4\23\t\23")
- buf.write("\4\24\t\24\4\25\t\25\4\26\t\26\4\27\t\27\4\30\t\30\4\31")
- buf.write("\t\31\4\32\t\32\4\33\t\33\4\34\t\34\4\35\t\35\4\36\t\36")
- buf.write("\4\37\t\37\4 \t \4!\t!\4\"\t\"\4#\t#\3\2\3\2\7\2I\n\2")
- buf.write("\f\2\16\2L\13\2\3\2\5\2O\n\2\3\2\5\2R\n\2\3\2\3\2\3\3")
- buf.write("\3\3\3\3\7\3Y\n\3\f\3\16\3\\\13\3\3\4\3\4\3\4\3\5\3\5")
- buf.write("\3\5\3\6\3\6\3\6\3\7\3\7\3\7\7\7j\n\7\f\7\16\7m\13\7\5")
- buf.write("\7o\n\7\3\b\3\b\3\b\3\b\7\bu\n\b\f\b\16\bx\13\b\3\b\5")
- buf.write("\b{\n\b\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3")
- buf.write("\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\6\t\u0090\n\t\r\t\16\t")
- buf.write("\u0091\3\t\3\t\3\t\3\t\3\t\3\t\7\t\u009a\n\t\f\t\16\t")
- buf.write("\u009d\13\t\5\t\u009f\n\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t")
- buf.write("\3\t\3\t\3\t\3\t\3\t\3\t\5\t\u00ae\n\t\3\t\3\t\3\t\3\t")
- buf.write("\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3")
- buf.write("\t\3\t\5\t\u00c3\n\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3")
- buf.write("\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t")
- buf.write("\3\t\7\t\u00dc\n\t\f\t\16\t\u00df\13\t\3\n\3\n\5\n\u00e3")
- buf.write("\n\n\3\n\3\n\3\n\3\n\3\n\5\n\u00ea\n\n\3\n\3\n\3\13\3")
- buf.write("\13\3\13\5\13\u00f1\n\13\3\13\3\13\3\13\3\13\3\13\5\13")
- buf.write("\u00f8\n\13\3\13\3\13\3\13\3\13\3\13\3\13\5\13\u0100\n")
- buf.write("\13\3\13\3\13\3\13\5\13\u0105\n\13\3\13\3\13\5\13\u0109")
- buf.write("\n\13\3\13\3\13\5\13\u010d\n\13\3\f\3\f\3\r\3\r\3\r\7")
- buf.write("\r\u0114\n\r\f\r\16\r\u0117\13\r\3\r\5\r\u011a\n\r\3\16")
- buf.write("\3\16\3\16\3\16\3\16\7\16\u0121\n\16\f\16\16\16\u0124")
- buf.write("\13\16\3\16\3\16\5\16\u0128\n\16\3\17\3\17\3\17\7\17\u012d")
- buf.write("\n\17\f\17\16\17\u0130\13\17\3\17\5\17\u0133\n\17\3\20")
- buf.write("\3\20\3\20\3\20\3\20\3\20\3\20\5\20\u013c\n\20\3\21\3")
- buf.write("\21\3\22\3\22\3\22\3\22\7\22\u0144\n\22\f\22\16\22\u0147")
- buf.write("\13\22\3\22\3\22\3\23\3\23\3\23\3\23\5\23\u014f\n\23\3")
- buf.write("\23\3\23\5\23\u0153\n\23\3\23\5\23\u0156\n\23\3\24\3\24")
- buf.write("\5\24\u015a\n\24\3\25\3\25\3\25\3\25\7\25\u0160\n\25\f")
- buf.write("\25\16\25\u0163\13\25\3\25\3\25\3\26\3\26\5\26\u0169\n")
- buf.write("\26\3\27\3\27\3\27\3\27\7\27\u016f\n\27\f\27\16\27\u0172")
- buf.write("\13\27\3\27\5\27\u0175\n\27\3\30\3\30\3\30\7\30\u017a")
- buf.write("\n\30\f\30\16\30\u017d\13\30\5\30\u017f\n\30\3\31\3\31")
- buf.write("\3\31\5\31\u0184\n\31\3\32\3\32\3\32\7\32\u0189\n\32\f")
- buf.write("\32\16\32\u018c\13\32\3\33\3\33\3\33\3\33\3\34\3\34\3")
- buf.write("\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34")
- buf.write("\3\34\3\34\6\34\u01a1\n\34\r\34\16\34\u01a2\3\34\3\34")
- buf.write("\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34")
- buf.write("\3\34\3\34\5\34\u01b4\n\34\3\34\3\34\3\34\3\34\7\34\u01ba")
- buf.write("\n\34\f\34\16\34\u01bd\13\34\5\34\u01bf\n\34\3\34\3\34")
- buf.write("\3\34\3\34\5\34\u01c5\n\34\3\35\3\35\3\35\3\35\7\35\u01cb")
- buf.write("\n\35\f\35\16\35\u01ce\13\35\3\35\3\35\3\36\3\36\3\36")
- buf.write("\3\36\3\36\3\36\6\36\u01d8\n\36\r\36\16\36\u01d9\3\36")
- buf.write("\3\36\3\36\5\36\u01df\n\36\3\37\3\37\3\37\3\37\3\37\3")
- buf.write("\37\3\37\3\37\3 \3 \3 \3 \3 \3 \5 \u01ef\n \3!\3!\3!\3")
- buf.write("!\3\"\3\"\3\"\5\"\u01f8\n\"\3#\3#\3#\3#\5#\u01fe\n#\3")
- buf.write("#\2\3\20$\2\4\6\b\n\f\16\20\22\24\26\30\32\34\36 \"$&")
- buf.write("(*,.\60\62\64\668:<>@BD\2\b\4\2\6\6//\3\2$%\3\2&\'\3\2")
- buf.write("(+\3\2,-\3\2\32\33\2\u0234\2F\3\2\2\2\4U\3\2\2\2\6]\3")
- buf.write("\2\2\2\b`\3\2\2\2\nc\3\2\2\2\fn\3\2\2\2\16z\3\2\2\2\20")
- buf.write("\u00c2\3\2\2\2\22\u00e0\3\2\2\2\24\u010c\3\2\2\2\26\u010e")
- buf.write("\3\2\2\2\30\u0110\3\2\2\2\32\u011b\3\2\2\2\34\u0129\3")
- buf.write("\2\2\2\36\u0134\3\2\2\2 \u013d\3\2\2\2\"\u013f\3\2\2\2")
- buf.write("$\u0155\3\2\2\2&\u0157\3\2\2\2(\u015b\3\2\2\2*\u0168\3")
- buf.write("\2\2\2,\u0174\3\2\2\2.\u017e\3\2\2\2\60\u0180\3\2\2\2")
- buf.write("\62\u0185\3\2\2\2\64\u018d\3\2\2\2\66\u01c4\3\2\2\28\u01c6")
- buf.write("\3\2\2\2:\u01de\3\2\2\2<\u01e0\3\2\2\2>\u01ee\3\2\2\2")
- buf.write("@\u01f0\3\2\2\2B\u01f7\3\2\2\2D\u01fd\3\2\2\2FN\7\37\2")
- buf.write("\2GI\5\24\13\2HG\3\2\2\2IL\3\2\2\2JH\3\2\2\2JK\3\2\2\2")
- buf.write("KO\3\2\2\2LJ\3\2\2\2MO\5\20\t\2NJ\3\2\2\2NM\3\2\2\2OQ")
- buf.write("\3\2\2\2PR\7\62\2\2QP\3\2\2\2QR\3\2\2\2RS\3\2\2\2ST\7")
- buf.write("\2\2\3T\3\3\2\2\2UZ\7/\2\2VW\7\3\2\2WY\7/\2\2XV\3\2\2")
- buf.write("\2Y\\\3\2\2\2ZX\3\2\2\2Z[\3\2\2\2[\5\3\2\2\2\\Z\3\2\2")
- buf.write("\2]^\7\4\2\2^_\7/\2\2_\7\3\2\2\2`a\7\5\2\2ab\t\2\2\2b")
- buf.write("\t\3\2\2\2cd\7\5\2\2de\7\61\2\2e\13\3\2\2\2fk\5\20\t\2")
- buf.write("gh\7\7\2\2hj\5\20\t\2ig\3\2\2\2jm\3\2\2\2ki\3\2\2\2kl")
- buf.write("\3\2\2\2lo\3\2\2\2mk\3\2\2\2nf\3\2\2\2no\3\2\2\2o\r\3")
- buf.write("\2\2\2p{\5\f\7\2qr\5\20\t\2rs\7\7\2\2su\3\2\2\2tq\3\2")
- buf.write("\2\2ux\3\2\2\2vt\3\2\2\2vw\3\2\2\2wy\3\2\2\2xv\3\2\2\2")
- buf.write("y{\5\62\32\2zp\3\2\2\2zv\3\2\2\2{\17\3\2\2\2|}\b\t\1\2")
- buf.write("}~\7\b\2\2~\177\5\20\t\2\177\u0080\7\t\2\2\u0080\u00c3")
- buf.write("\3\2\2\2\u0081\u0082\7\'\2\2\u0082\u00c3\5\20\t\26\u0083")
- buf.write("\u00c3\5\22\n\2\u0084\u0085\7\b\2\2\u0085\u00c3\7\t\2")
- buf.write("\2\u0086\u0087\7\b\2\2\u0087\u0088\5\20\t\2\u0088\u0089")
- buf.write("\7\7\2\2\u0089\u008a\7\t\2\2\u008a\u00c3\3\2\2\2\u008b")
- buf.write("\u008c\7\b\2\2\u008c\u008f\5\20\t\2\u008d\u008e\7\7\2")
- buf.write("\2\u008e\u0090\5\20\t\2\u008f\u008d\3\2\2\2\u0090\u0091")
- buf.write("\3\2\2\2\u0091\u008f\3\2\2\2\u0091\u0092\3\2\2\2\u0092")
- buf.write("\u0093\3\2\2\2\u0093\u0094\7\t\2\2\u0094\u00c3\3\2\2\2")
- buf.write("\u0095\u009e\7\n\2\2\u0096\u009b\5\20\t\2\u0097\u0098")
- buf.write("\7\7\2\2\u0098\u009a\5\20\t\2\u0099\u0097\3\2\2\2\u009a")
- buf.write("\u009d\3\2\2\2\u009b\u0099\3\2\2\2\u009b\u009c\3\2\2\2")
- buf.write("\u009c\u009f\3\2\2\2\u009d\u009b\3\2\2\2\u009e\u0096\3")
- buf.write("\2\2\2\u009e\u009f\3\2\2\2\u009f\u00a0\3\2\2\2\u00a0\u00c3")
- buf.write("\7\13\2\2\u00a1\u00a2\7\f\2\2\u00a2\u00a3\7\b\2\2\u00a3")
- buf.write("\u00a4\5\20\t\2\u00a4\u00a5\7\t\2\2\u00a5\u00a6\5@!\2")
- buf.write("\u00a6\u00a7\7\r\2\2\u00a7\u00a8\5@!\2\u00a8\u00c3\3\2")
- buf.write("\2\2\u00a9\u00aa\5 \21\2\u00aa\u00ab\5\20\t\2\u00ab\u00ad")
- buf.write("\7\16\2\2\u00ac\u00ae\5\34\17\2\u00ad\u00ac\3\2\2\2\u00ad")
- buf.write("\u00ae\3\2\2\2\u00ae\u00af\3\2\2\2\u00af\u00b0\7\17\2")
- buf.write("\2\u00b0\u00c3\3\2\2\2\u00b1\u00b2\7\20\2\2\u00b2\u00b3")
- buf.write("\5\60\31\2\u00b3\u00b4\7\21\2\2\u00b4\u00b5\5\20\t\2\u00b5")
- buf.write("\u00b6\7\22\2\2\u00b6\u00b7\5\20\t\t\u00b7\u00c3\3\2\2")
- buf.write("\2\u00b8\u00b9\5\n\6\2\u00b9\u00ba\7\21\2\2\u00ba\u00bb")
- buf.write("\5\20\t\2\u00bb\u00bc\7\22\2\2\u00bc\u00bd\5\20\t\7\u00bd")
- buf.write("\u00c3\3\2\2\2\u00be\u00c3\5D#\2\u00bf\u00c3\5B\"\2\u00c0")
- buf.write("\u00c3\5<\37\2\u00c1\u00c3\7#\2\2\u00c2|\3\2\2\2\u00c2")
- buf.write("\u0081\3\2\2\2\u00c2\u0083\3\2\2\2\u00c2\u0084\3\2\2\2")
- buf.write("\u00c2\u0086\3\2\2\2\u00c2\u008b\3\2\2\2\u00c2\u0095\3")
- buf.write("\2\2\2\u00c2\u00a1\3\2\2\2\u00c2\u00a9\3\2\2\2\u00c2\u00b1")
- buf.write("\3\2\2\2\u00c2\u00b8\3\2\2\2\u00c2\u00be\3\2\2\2\u00c2")
- buf.write("\u00bf\3\2\2\2\u00c2\u00c0\3\2\2\2\u00c2\u00c1\3\2\2\2")
- buf.write("\u00c3\u00dd\3\2\2\2\u00c4\u00c5\f\25\2\2\u00c5\u00c6")
- buf.write("\t\3\2\2\u00c6\u00dc\5\20\t\26\u00c7\u00c8\f\24\2\2\u00c8")
- buf.write("\u00c9\t\4\2\2\u00c9\u00dc\5\20\t\25\u00ca\u00cb\f\23")
- buf.write("\2\2\u00cb\u00cc\t\5\2\2\u00cc\u00dc\5\20\t\24\u00cd\u00ce")
- buf.write("\f\22\2\2\u00ce\u00cf\t\6\2\2\u00cf\u00dc\5\20\t\23\u00d0")
- buf.write("\u00d1\f\b\2\2\u00d1\u00d2\7\23\2\2\u00d2\u00dc\5\20\t")
- buf.write("\t\u00d3\u00d4\f\27\2\2\u00d4\u00d5\7\b\2\2\u00d5\u00d6")
- buf.write("\5\16\b\2\u00d6\u00d7\7\t\2\2\u00d7\u00dc\3\2\2\2\u00d8")
- buf.write("\u00d9\f\n\2\2\u00d9\u00da\7\3\2\2\u00da\u00dc\7\61\2")
- buf.write("\2\u00db\u00c4\3\2\2\2\u00db\u00c7\3\2\2\2\u00db\u00ca")
- buf.write("\3\2\2\2\u00db\u00cd\3\2\2\2\u00db\u00d0\3\2\2\2\u00db")
- buf.write("\u00d3\3\2\2\2\u00db\u00d8\3\2\2\2\u00dc\u00df\3\2\2\2")
- buf.write("\u00dd\u00db\3\2\2\2\u00dd\u00de\3\2\2\2\u00de\21\3\2")
- buf.write("\2\2\u00df\u00dd\3\2\2\2\u00e0\u00e2\7\24\2\2\u00e1\u00e3")
- buf.write("\58\35\2\u00e2\u00e1\3\2\2\2\u00e2\u00e3\3\2\2\2\u00e3")
- buf.write("\u00e4\3\2\2\2\u00e4\u00e5\7\b\2\2\u00e5\u00e6\5,\27\2")
- buf.write("\u00e6\u00e9\7\t\2\2\u00e7\u00e8\7\25\2\2\u00e8\u00ea")
- buf.write("\5\66\34\2\u00e9\u00e7\3\2\2\2\u00e9\u00ea\3\2\2\2\u00ea")
- buf.write("\u00eb\3\2\2\2\u00eb\u00ec\5@!\2\u00ec\23\3\2\2\2\u00ed")
- buf.write("\u00ee\7\26\2\2\u00ee\u00f0\5\6\4\2\u00ef\u00f1\58\35")
- buf.write("\2\u00f0\u00ef\3\2\2\2\u00f0\u00f1\3\2\2\2\u00f1\u00f2")
- buf.write("\3\2\2\2\u00f2\u00f3\7\b\2\2\u00f3\u00f4\5,\27\2\u00f4")
- buf.write("\u00f7\7\t\2\2\u00f5\u00f6\7\25\2\2\u00f6\u00f8\5\66\34")
- buf.write("\2\u00f7\u00f5\3\2\2\2\u00f7\u00f8\3\2\2\2\u00f8\u00f9")
- buf.write("\3\2\2\2\u00f9\u00fa\5@!\2\u00fa\u010d\3\2\2\2\u00fb\u00fc")
- buf.write("\7\27\2\2\u00fc\u00fd\7\30\2\2\u00fd\u00ff\5\4\3\2\u00fe")
- buf.write("\u0100\58\35\2\u00ff\u00fe\3\2\2\2\u00ff\u0100\3\2\2\2")
- buf.write("\u0100\u010d\3\2\2\2\u0101\u0102\7\30\2\2\u0102\u0104")
- buf.write("\5\4\3\2\u0103\u0105\58\35\2\u0104\u0103\3\2\2\2\u0104")
- buf.write("\u0105\3\2\2\2\u0105\u0106\3\2\2\2\u0106\u0108\7\16\2")
- buf.write("\2\u0107\u0109\5\30\r\2\u0108\u0107\3\2\2\2\u0108\u0109")
- buf.write("\3\2\2\2\u0109\u010a\3\2\2\2\u010a\u010b\7\17\2\2\u010b")
- buf.write("\u010d\3\2\2\2\u010c\u00ed\3\2\2\2\u010c\u00fb\3\2\2\2")
- buf.write("\u010c\u0101\3\2\2\2\u010d\25\3\2\2\2\u010e\u010f\7/\2")
- buf.write("\2\u010f\27\3\2\2\2\u0110\u0115\5\32\16\2\u0111\u0112")
- buf.write("\7\7\2\2\u0112\u0114\5\32\16\2\u0113\u0111\3\2\2\2\u0114")
- buf.write("\u0117\3\2\2\2\u0115\u0113\3\2\2\2\u0115\u0116\3\2\2\2")
- buf.write("\u0116\u0119\3\2\2\2\u0117\u0115\3\2\2\2\u0118\u011a\7")
- buf.write("\7\2\2\u0119\u0118\3\2\2\2\u0119\u011a\3\2\2\2\u011a\31")
- buf.write("\3\2\2\2\u011b\u0127\5\26\f\2\u011c\u011d\7\b\2\2\u011d")
- buf.write("\u0122\5\66\34\2\u011e\u011f\7\7\2\2\u011f\u0121\5\66")
- buf.write("\34\2\u0120\u011e\3\2\2\2\u0121\u0124\3\2\2\2\u0122\u0120")
- buf.write("\3\2\2\2\u0122\u0123\3\2\2\2\u0123\u0125\3\2\2\2\u0124")
- buf.write("\u0122\3\2\2\2\u0125\u0126\7\t\2\2\u0126\u0128\3\2\2\2")
- buf.write("\u0127\u011c\3\2\2\2\u0127\u0128\3\2\2\2\u0128\33\3\2")
- buf.write("\2\2\u0129\u012e\5\36\20\2\u012a\u012b\7\7\2\2\u012b\u012d")
- buf.write("\5\36\20\2\u012c\u012a\3\2\2\2\u012d\u0130\3\2\2\2\u012e")
- buf.write("\u012c\3\2\2\2\u012e\u012f\3\2\2\2\u012f\u0132\3\2\2\2")
- buf.write("\u0130\u012e\3\2\2\2\u0131\u0133\7\7\2\2\u0132\u0131\3")
- buf.write("\2\2\2\u0132\u0133\3\2\2\2\u0133\35\3\2\2\2\u0134\u0135")
- buf.write("\5$\23\2\u0135\u013b\7\31\2\2\u0136\u0137\7\16\2\2\u0137")
- buf.write("\u0138\5\20\t\2\u0138\u0139\7\17\2\2\u0139\u013c\3\2\2")
- buf.write("\2\u013a\u013c\5\20\t\2\u013b\u0136\3\2\2\2\u013b\u013a")
- buf.write("\3\2\2\2\u013c\37\3\2\2\2\u013d\u013e\t\7\2\2\u013e!\3")
- buf.write("\2\2\2\u013f\u0140\7\b\2\2\u0140\u0145\5$\23\2\u0141\u0142")
- buf.write("\7\7\2\2\u0142\u0144\5$\23\2\u0143\u0141\3\2\2\2\u0144")
- buf.write("\u0147\3\2\2\2\u0145\u0143\3\2\2\2\u0145\u0146\3\2\2\2")
- buf.write("\u0146\u0148\3\2\2\2\u0147\u0145\3\2\2\2\u0148\u0149\7")
- buf.write("\t\2\2\u0149#\3\2\2\2\u014a\u0156\7\6\2\2\u014b\u014e")
- buf.write("\5\b\5\2\u014c\u014d\7\34\2\2\u014d\u014f\5\66\34\2\u014e")
- buf.write("\u014c\3\2\2\2\u014e\u014f\3\2\2\2\u014f\u0156\3\2\2\2")
- buf.write("\u0150\u0152\5\26\f\2\u0151\u0153\5\"\22\2\u0152\u0151")
- buf.write("\3\2\2\2\u0152\u0153\3\2\2\2\u0153\u0156\3\2\2\2\u0154")
- buf.write("\u0156\5\"\22\2\u0155\u014a\3\2\2\2\u0155\u014b\3\2\2")
- buf.write("\2\u0155\u0150\3\2\2\2\u0155\u0154\3\2\2\2\u0156%\3\2")
- buf.write("\2\2\u0157\u0159\5\26\f\2\u0158\u015a\5(\25\2\u0159\u0158")
- buf.write("\3\2\2\2\u0159\u015a\3\2\2\2\u015a\'\3\2\2\2\u015b\u015c")
- buf.write("\7\b\2\2\u015c\u0161\5*\26\2\u015d\u015e\7\7\2\2\u015e")
- buf.write("\u0160\5*\26\2\u015f\u015d\3\2\2\2\u0160\u0163\3\2\2\2")
- buf.write("\u0161\u015f\3\2\2\2\u0161\u0162\3\2\2\2\u0162\u0164\3")
- buf.write("\2\2\2\u0163\u0161\3\2\2\2\u0164\u0165\7\t\2\2\u0165)")
- buf.write("\3\2\2\2\u0166\u0169\5\b\5\2\u0167\u0169\5\26\f\2\u0168")
- buf.write("\u0166\3\2\2\2\u0168\u0167\3\2\2\2\u0169+\3\2\2\2\u016a")
- buf.write("\u0175\5.\30\2\u016b\u016c\5\60\31\2\u016c\u016d\7\7\2")
- buf.write("\2\u016d\u016f\3\2\2\2\u016e\u016b\3\2\2\2\u016f\u0172")
- buf.write("\3\2\2\2\u0170\u016e\3\2\2\2\u0170\u0171\3\2\2\2\u0171")
- buf.write("\u0173\3\2\2\2\u0172\u0170\3\2\2\2\u0173\u0175\5\62\32")
- buf.write("\2\u0174\u016a\3\2\2\2\u0174\u0170\3\2\2\2\u0175-\3\2")
- buf.write("\2\2\u0176\u017b\5\60\31\2\u0177\u0178\7\7\2\2\u0178\u017a")
- buf.write("\5\60\31\2\u0179\u0177\3\2\2\2\u017a\u017d\3\2\2\2\u017b")
- buf.write("\u0179\3\2\2\2\u017b\u017c\3\2\2\2\u017c\u017f\3\2\2\2")
- buf.write("\u017d\u017b\3\2\2\2\u017e\u0176\3\2\2\2\u017e\u017f\3")
- buf.write("\2\2\2\u017f/\3\2\2\2\u0180\u0183\5\b\5\2\u0181\u0182")
- buf.write("\7\34\2\2\u0182\u0184\5\66\34\2\u0183\u0181\3\2\2\2\u0183")
- buf.write("\u0184\3\2\2\2\u0184\61\3\2\2\2\u0185\u018a\5\64\33\2")
- buf.write("\u0186\u0187\7\7\2\2\u0187\u0189\5\64\33\2\u0188\u0186")
- buf.write("\3\2\2\2\u0189\u018c\3\2\2\2\u018a\u0188\3\2\2\2\u018a")
- buf.write("\u018b\3\2\2\2\u018b\63\3\2\2\2\u018c\u018a\3\2\2\2\u018d")
- buf.write("\u018e\7/\2\2\u018e\u018f\7\21\2\2\u018f\u0190\5\20\t")
- buf.write("\2\u0190\65\3\2\2\2\u0191\u0192\7\b\2\2\u0192\u01c5\7")
- buf.write("\t\2\2\u0193\u0194\7\b\2\2\u0194\u0195\5\66\34\2\u0195")
- buf.write("\u0196\7\t\2\2\u0196\u01c5\3\2\2\2\u0197\u0198\7\b\2\2")
- buf.write("\u0198\u0199\5\66\34\2\u0199\u019a\7\7\2\2\u019a\u019b")
- buf.write("\7\t\2\2\u019b\u01c5\3\2\2\2\u019c\u019d\7\b\2\2\u019d")
- buf.write("\u01a0\5\66\34\2\u019e\u019f\7\7\2\2\u019f\u01a1\5\66")
- buf.write("\34\2\u01a0\u019e\3\2\2\2\u01a1\u01a2\3\2\2\2\u01a2\u01a0")
- buf.write("\3\2\2\2\u01a2\u01a3\3\2\2\2\u01a3\u01a4\3\2\2\2\u01a4")
- buf.write("\u01a5\7\t\2\2\u01a5\u01c5\3\2\2\2\u01a6\u01a7\5\4\3\2")
- buf.write("\u01a7\u01a8\58\35\2\u01a8\u01c5\3\2\2\2\u01a9\u01c5\5")
- buf.write("\4\3\2\u01aa\u01ab\7\35\2\2\u01ab\u01ac\7\n\2\2\u01ac")
- buf.write("\u01ad\5:\36\2\u01ad\u01ae\7\7\2\2\u01ae\u01af\5\66\34")
- buf.write("\2\u01af\u01b0\7\13\2\2\u01b0\u01c5\3\2\2\2\u01b1\u01b3")
- buf.write("\7\24\2\2\u01b2\u01b4\58\35\2\u01b3\u01b2\3\2\2\2\u01b3")
- buf.write("\u01b4\3\2\2\2\u01b4\u01b5\3\2\2\2\u01b5\u01be\7\b\2\2")
- buf.write("\u01b6\u01bb\5\66\34\2\u01b7\u01b8\7\7\2\2\u01b8\u01ba")
- buf.write("\5\66\34\2\u01b9\u01b7\3\2\2\2\u01ba\u01bd\3\2\2\2\u01bb")
- buf.write("\u01b9\3\2\2\2\u01bb\u01bc\3\2\2\2\u01bc\u01bf\3\2\2\2")
- buf.write("\u01bd\u01bb\3\2\2\2\u01be\u01b6\3\2\2\2\u01be\u01bf\3")
- buf.write("\2\2\2\u01bf\u01c0\3\2\2\2\u01c0\u01c1\7\t\2\2\u01c1\u01c2")
- buf.write("\7\25\2\2\u01c2\u01c5\5\66\34\2\u01c3\u01c5\7\6\2\2\u01c4")
- buf.write("\u0191\3\2\2\2\u01c4\u0193\3\2\2\2\u01c4\u0197\3\2\2\2")
- buf.write("\u01c4\u019c\3\2\2\2\u01c4\u01a6\3\2\2\2\u01c4\u01a9\3")
- buf.write("\2\2\2\u01c4\u01aa\3\2\2\2\u01c4\u01b1\3\2\2\2\u01c4\u01c3")
- buf.write("\3\2\2\2\u01c5\67\3\2\2\2\u01c6\u01c7\7\n\2\2\u01c7\u01cc")
- buf.write("\5\66\34\2\u01c8\u01c9\7\7\2\2\u01c9\u01cb\5\66\34\2\u01ca")
- buf.write("\u01c8\3\2\2\2\u01cb\u01ce\3\2\2\2\u01cc\u01ca\3\2\2\2")
- buf.write("\u01cc\u01cd\3\2\2\2\u01cd\u01cf\3\2\2\2\u01ce\u01cc\3")
- buf.write("\2\2\2\u01cf\u01d0\7\13\2\2\u01d09\3\2\2\2\u01d1\u01d2")
- buf.write("\7\b\2\2\u01d2\u01df\7\t\2\2\u01d3\u01d4\7\b\2\2\u01d4")
- buf.write("\u01d7\5> \2\u01d5\u01d6\7\7\2\2\u01d6\u01d8\5> \2\u01d7")
- buf.write("\u01d5\3\2\2\2\u01d8\u01d9\3\2\2\2\u01d9\u01d7\3\2\2\2")
- buf.write("\u01d9\u01da\3\2\2\2\u01da\u01db\3\2\2\2\u01db\u01dc\7")
- buf.write("\t\2\2\u01dc\u01df\3\2\2\2\u01dd\u01df\5> \2\u01de\u01d1")
- buf.write("\3\2\2\2\u01de\u01d3\3\2\2\2\u01de\u01dd\3\2\2\2\u01df")
- buf.write(";\3\2\2\2\u01e0\u01e1\7\36\2\2\u01e1\u01e2\7\n\2\2\u01e2")
- buf.write("\u01e3\7/\2\2\u01e3\u01e4\7\13\2\2\u01e4\u01e5\7\n\2\2")
- buf.write("\u01e5\u01e6\7\61\2\2\u01e6\u01e7\7\13\2\2\u01e7=\3\2")
- buf.write("\2\2\u01e8\u01ef\5<\37\2\u01e9\u01ea\7\b\2\2\u01ea\u01eb")
- buf.write("\5> \2\u01eb\u01ec\7\t\2\2\u01ec\u01ef\3\2\2\2\u01ed\u01ef")
- buf.write("\7\61\2\2\u01ee\u01e8\3\2\2\2\u01ee\u01e9\3\2\2\2\u01ee")
- buf.write("\u01ed\3\2\2\2\u01ef?\3\2\2\2\u01f0\u01f1\7\16\2\2\u01f1")
- buf.write("\u01f2\5\20\t\2\u01f2\u01f3\7\17\2\2\u01f3A\3\2\2\2\u01f4")
- buf.write("\u01f8\7\60\2\2\u01f5\u01f8\7\61\2\2\u01f6\u01f8\7.\2")
- buf.write("\2\u01f7\u01f4\3\2\2\2\u01f7\u01f5\3\2\2\2\u01f7\u01f6")
- buf.write("\3\2\2\2\u01f8C\3\2\2\2\u01f9\u01fe\5\4\3\2\u01fa\u01fe")
- buf.write("\5\6\4\2\u01fb\u01fe\5\b\5\2\u01fc\u01fe\5\n\6\2\u01fd")
- buf.write("\u01f9\3\2\2\2\u01fd\u01fa\3\2\2\2\u01fd\u01fb\3\2\2\2")
- buf.write("\u01fd\u01fc\3\2\2\2\u01feE\3\2\2\28JNQZknvz\u0091\u009b")
- buf.write("\u009e\u00ad\u00c2\u00db\u00dd\u00e2\u00e9\u00f0\u00f7")
- buf.write("\u00ff\u0104\u0108\u010c\u0115\u0119\u0122\u0127\u012e")
- buf.write("\u0132\u013b\u0145\u014e\u0152\u0155\u0159\u0161\u0168")
- buf.write("\u0170\u0174\u017b\u017e\u0183\u018a\u01a2\u01b3\u01bb")
- buf.write("\u01be\u01c4\u01cc\u01d9\u01de\u01ee\u01f7\u01fd")
- return buf.getvalue()
-
-
-class RelayParser ( Parser ):
-
- grammarFileName = "Relay.g4"
-
- atn = ATNDeserializer().deserialize(serializedATN())
-
- decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ]
-
- sharedContextCache = PredictionContextCache()
-
- literalNames = [ "<INVALID>", "'.'", "'@'", "'%'", "'_'", "','", "'('",
- "')'", "'['", "']'", "'if'", "'else'", "'{'", "'}'",
- "'let'", "'='", "';'", "';;'", "'fn'", "'->'", "'def'",
- "'extern'", "'type'", "'=>'", "'match'", "'match?'",
- "':'", "'Tensor'", "'meta'", "'v0.0.4'", "<INVALID>",
- "<INVALID>", "<INVALID>", "<INVALID>", "'*'", "'/'",
- "'+'", "'-'", "'<'", "'>'", "'<='", "'>='", "'=='",
- "'!='" ]
-
- symbolicNames = [ "<INVALID>", "<INVALID>", "<INVALID>", "<INVALID>",
- "<INVALID>", "<INVALID>", "<INVALID>", "<INVALID>",
- "<INVALID>", "<INVALID>", "<INVALID>", "<INVALID>",
- "<INVALID>", "<INVALID>", "<INVALID>", "<INVALID>",
- "<INVALID>", "<INVALID>", "<INVALID>", "<INVALID>",
- "<INVALID>", "<INVALID>", "<INVALID>", "<INVALID>",
- "<INVALID>", "<INVALID>", "<INVALID>", "<INVALID>",
- "<INVALID>", "SEMVER", "COMMENT", "WS", "LINE_COMMENT",
- "QUOTED_STRING", "MUL", "DIV", "ADD", "SUB", "LT",
- "GT", "LE", "GE", "EQ", "NE", "BOOL_LIT", "CNAME",
- "FLOAT", "NAT", "METADATA" ]
-
- RULE_prog = 0
- RULE_generalIdent = 1
- RULE_globalVar = 2
- RULE_localVar = 3
- RULE_graphVar = 4
- RULE_exprList = 5
- RULE_callList = 6
- RULE_expr = 7
- RULE_func = 8
- RULE_defn = 9
- RULE_constructorName = 10
- RULE_adtConsDefnList = 11
- RULE_adtConsDefn = 12
- RULE_matchClauseList = 13
- RULE_matchClause = 14
- RULE_matchType = 15
- RULE_patternList = 16
- RULE_pattern = 17
- RULE_adtCons = 18
- RULE_adtConsParamList = 19
- RULE_adtConsParam = 20
- RULE_argList = 21
- RULE_varList = 22
- RULE_var = 23
- RULE_attrSeq = 24
- RULE_attr = 25
- RULE_typeExpr = 26
- RULE_typeParamList = 27
- RULE_shapeList = 28
- RULE_meta = 29
- RULE_shape = 30
- RULE_body = 31
- RULE_scalar = 32
- RULE_ident = 33
-
- ruleNames = [ "prog", "generalIdent", "globalVar", "localVar", "graphVar",
- "exprList", "callList", "expr", "func", "defn", "constructorName",
- "adtConsDefnList", "adtConsDefn", "matchClauseList",
- "matchClause", "matchType", "patternList", "pattern",
- "adtCons", "adtConsParamList", "adtConsParam", "argList",
- "varList", "var", "attrSeq", "attr", "typeExpr", "typeParamList",
- "shapeList", "meta", "shape", "body", "scalar", "ident" ]
-
- EOF = Token.EOF
- T__0=1
- T__1=2
- T__2=3
- T__3=4
- T__4=5
- T__5=6
- T__6=7
- T__7=8
- T__8=9
- T__9=10
- T__10=11
- T__11=12
- T__12=13
- T__13=14
- T__14=15
- T__15=16
- T__16=17
- T__17=18
- T__18=19
- T__19=20
- T__20=21
- T__21=22
- T__22=23
- T__23=24
- T__24=25
- T__25=26
- T__26=27
- T__27=28
- SEMVER=29
- COMMENT=30
- WS=31
- LINE_COMMENT=32
- QUOTED_STRING=33
- MUL=34
- DIV=35
- ADD=36
- SUB=37
- LT=38
- GT=39
- LE=40
- GE=41
- EQ=42
- NE=43
- BOOL_LIT=44
- CNAME=45
- FLOAT=46
- NAT=47
- METADATA=48
-
- def __init__(self, input:TokenStream, output:TextIO = sys.stdout):
- super().__init__(input, output)
- self.checkVersion("4.7.2")
- self._interp = ParserATNSimulator(self, self.atn, self.decisionsToDFA, self.sharedContextCache)
- self._predicates = None
-
-
-
-
- class ProgContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
- def SEMVER(self):
- return self.getToken(RelayParser.SEMVER, 0)
-
- def EOF(self):
- return self.getToken(RelayParser.EOF, 0)
-
- def expr(self):
- return self.getTypedRuleContext(RelayParser.ExprContext,0)
-
-
- def METADATA(self):
- return self.getToken(RelayParser.METADATA, 0)
-
- def defn(self, i:int=None):
- if i is None:
- return self.getTypedRuleContexts(RelayParser.DefnContext)
- else:
- return self.getTypedRuleContext(RelayParser.DefnContext,i)
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_prog
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitProg" ):
- return visitor.visitProg(self)
- else:
- return visitor.visitChildren(self)
-
-
-
-
- def prog(self):
-
- localctx = RelayParser.ProgContext(self, self._ctx, self.state)
- self.enterRule(localctx, 0, self.RULE_prog)
- self._la = 0 # Token type
- try:
- self.enterOuterAlt(localctx, 1)
- self.state = 68
- self.match(RelayParser.SEMVER)
- self.state = 76
- self._errHandler.sync(self)
- token = self._input.LA(1)
- if token in [RelayParser.EOF, RelayParser.T__19, RelayParser.T__20, RelayParser.T__21, RelayParser.METADATA]:
- self.state = 72
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- while (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__19) | (1 << RelayParser.T__20) | (1 << RelayParser.T__21))) != 0):
- self.state = 69
- self.defn()
- self.state = 74
- self._errHandler.sync(self)
- _la = self._input.LA(1)
-
- pass
- elif token in [RelayParser.T__1, RelayParser.T__2, RelayParser.T__5, RelayParser.T__7, RelayParser.T__9, RelayParser.T__13, RelayParser.T__17, RelayParser.T__23, RelayParser.T__24, RelayParser.T__27, RelayParser.QUOTED_STRING, RelayParser.SUB, RelayParser.BOOL_LIT, RelayParser.CNAME, RelayParser.FLOAT, RelayParser.NAT]:
- self.state = 75
- self.expr(0)
- pass
- else:
- raise NoViableAltException(self)
-
- self.state = 79
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- if _la==RelayParser.METADATA:
- self.state = 78
- self.match(RelayParser.METADATA)
-
-
- self.state = 81
- self.match(RelayParser.EOF)
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class GeneralIdentContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
- def CNAME(self, i:int=None):
- if i is None:
- return self.getTokens(RelayParser.CNAME)
- else:
- return self.getToken(RelayParser.CNAME, i)
-
- def getRuleIndex(self):
- return RelayParser.RULE_generalIdent
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitGeneralIdent" ):
- return visitor.visitGeneralIdent(self)
- else:
- return visitor.visitChildren(self)
-
-
-
-
- def generalIdent(self):
-
- localctx = RelayParser.GeneralIdentContext(self, self._ctx, self.state)
- self.enterRule(localctx, 2, self.RULE_generalIdent)
- try:
- self.enterOuterAlt(localctx, 1)
- self.state = 83
- self.match(RelayParser.CNAME)
- self.state = 88
- self._errHandler.sync(self)
- _alt = self._interp.adaptivePredict(self._input,3,self._ctx)
- while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER:
- if _alt==1:
- self.state = 84
- self.match(RelayParser.T__0)
- self.state = 85
- self.match(RelayParser.CNAME)
- self.state = 90
- self._errHandler.sync(self)
- _alt = self._interp.adaptivePredict(self._input,3,self._ctx)
-
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class GlobalVarContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
- def CNAME(self):
- return self.getToken(RelayParser.CNAME, 0)
-
- def getRuleIndex(self):
- return RelayParser.RULE_globalVar
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitGlobalVar" ):
- return visitor.visitGlobalVar(self)
- else:
- return visitor.visitChildren(self)
-
-
-
-
- def globalVar(self):
-
- localctx = RelayParser.GlobalVarContext(self, self._ctx, self.state)
- self.enterRule(localctx, 4, self.RULE_globalVar)
- try:
- self.enterOuterAlt(localctx, 1)
- self.state = 91
- self.match(RelayParser.T__1)
- self.state = 92
- self.match(RelayParser.CNAME)
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class LocalVarContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
- def CNAME(self):
- return self.getToken(RelayParser.CNAME, 0)
-
- def getRuleIndex(self):
- return RelayParser.RULE_localVar
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitLocalVar" ):
- return visitor.visitLocalVar(self)
- else:
- return visitor.visitChildren(self)
-
-
-
-
- def localVar(self):
-
- localctx = RelayParser.LocalVarContext(self, self._ctx, self.state)
- self.enterRule(localctx, 6, self.RULE_localVar)
- self._la = 0 # Token type
- try:
- self.enterOuterAlt(localctx, 1)
- self.state = 94
- self.match(RelayParser.T__2)
- self.state = 95
- _la = self._input.LA(1)
- if not(_la==RelayParser.T__3 or _la==RelayParser.CNAME):
- self._errHandler.recoverInline(self)
- else:
- self._errHandler.reportMatch(self)
- self.consume()
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class GraphVarContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
- def NAT(self):
- return self.getToken(RelayParser.NAT, 0)
-
- def getRuleIndex(self):
- return RelayParser.RULE_graphVar
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitGraphVar" ):
- return visitor.visitGraphVar(self)
- else:
- return visitor.visitChildren(self)
-
-
-
-
- def graphVar(self):
-
- localctx = RelayParser.GraphVarContext(self, self._ctx, self.state)
- self.enterRule(localctx, 8, self.RULE_graphVar)
- try:
- self.enterOuterAlt(localctx, 1)
- self.state = 97
- self.match(RelayParser.T__2)
- self.state = 98
- self.match(RelayParser.NAT)
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class ExprListContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
- def expr(self, i:int=None):
- if i is None:
- return self.getTypedRuleContexts(RelayParser.ExprContext)
- else:
- return self.getTypedRuleContext(RelayParser.ExprContext,i)
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_exprList
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitExprList" ):
- return visitor.visitExprList(self)
- else:
- return visitor.visitChildren(self)
-
-
-
-
- def exprList(self):
-
- localctx = RelayParser.ExprListContext(self, self._ctx, self.state)
- self.enterRule(localctx, 10, self.RULE_exprList)
- self._la = 0 # Token type
- try:
- self.enterOuterAlt(localctx, 1)
- self.state = 108
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__1) | (1 << RelayParser.T__2) | (1 << RelayParser.T__5) | (1 << RelayParser.T__7) | (1 << RelayParser.T__9) | (1 << RelayParser.T__13) | (1 << RelayParser.T__17) | (1 << RelayParser.T__23) | (1 << RelayParser.T__24) | (1 << RelayParser.T__27) | (1 << RelayParser.QUOTED_STRING) | (1 << RelayParser.SUB) | (1 << RelayParser.BOOL_LIT) | (1 << RelayParser.CNAME) | (1 << RelayParser.FLOAT) | (1 << RelayParser.NAT))) != 0):
- self.state = 100
- self.expr(0)
- self.state = 105
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- while _la==RelayParser.T__4:
- self.state = 101
- self.match(RelayParser.T__4)
- self.state = 102
- self.expr(0)
- self.state = 107
- self._errHandler.sync(self)
- _la = self._input.LA(1)
-
-
-
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class CallListContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_callList
-
-
- def copyFrom(self, ctx:ParserRuleContext):
- super().copyFrom(ctx)
-
-
-
- class CallWithAttrContext(CallListContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.CallListContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def attrSeq(self):
- return self.getTypedRuleContext(RelayParser.AttrSeqContext,0)
-
- def expr(self, i:int=None):
- if i is None:
- return self.getTypedRuleContexts(RelayParser.ExprContext)
- else:
- return self.getTypedRuleContext(RelayParser.ExprContext,i)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitCallWithAttr" ):
- return visitor.visitCallWithAttr(self)
- else:
- return visitor.visitChildren(self)
-
-
- class CallNoAttrContext(CallListContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.CallListContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def exprList(self):
- return self.getTypedRuleContext(RelayParser.ExprListContext,0)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitCallNoAttr" ):
- return visitor.visitCallNoAttr(self)
- else:
- return visitor.visitChildren(self)
-
-
-
- def callList(self):
-
- localctx = RelayParser.CallListContext(self, self._ctx, self.state)
- self.enterRule(localctx, 12, self.RULE_callList)
- try:
- self.state = 120
- self._errHandler.sync(self)
- la_ = self._interp.adaptivePredict(self._input,7,self._ctx)
- if la_ == 1:
- localctx = RelayParser.CallNoAttrContext(self, localctx)
- self.enterOuterAlt(localctx, 1)
- self.state = 110
- self.exprList()
- pass
-
- elif la_ == 2:
- localctx = RelayParser.CallWithAttrContext(self, localctx)
- self.enterOuterAlt(localctx, 2)
- self.state = 116
- self._errHandler.sync(self)
- _alt = self._interp.adaptivePredict(self._input,6,self._ctx)
- while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER:
- if _alt==1:
- self.state = 111
- self.expr(0)
- self.state = 112
- self.match(RelayParser.T__4)
- self.state = 118
- self._errHandler.sync(self)
- _alt = self._interp.adaptivePredict(self._input,6,self._ctx)
-
- self.state = 119
- self.attrSeq()
- pass
-
-
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class ExprContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_expr
-
-
- def copyFrom(self, ctx:ParserRuleContext):
- super().copyFrom(ctx)
-
-
- class FuncExprContext(ExprContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def func(self):
- return self.getTypedRuleContext(RelayParser.FuncContext,0)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitFuncExpr" ):
- return visitor.visitFuncExpr(self)
- else:
- return visitor.visitChildren(self)
-
-
- class MetaExprContext(ExprContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def meta(self):
- return self.getTypedRuleContext(RelayParser.MetaContext,0)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitMetaExpr" ):
- return visitor.visitMetaExpr(self)
- else:
- return visitor.visitChildren(self)
-
-
- class MatchContext(ExprContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def matchType(self):
- return self.getTypedRuleContext(RelayParser.MatchTypeContext,0)
-
- def expr(self):
- return self.getTypedRuleContext(RelayParser.ExprContext,0)
-
- def matchClauseList(self):
- return self.getTypedRuleContext(RelayParser.MatchClauseListContext,0)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitMatch" ):
- return visitor.visitMatch(self)
- else:
- return visitor.visitChildren(self)
-
-
- class TensorContext(ExprContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def expr(self, i:int=None):
- if i is None:
- return self.getTypedRuleContexts(RelayParser.ExprContext)
- else:
- return self.getTypedRuleContext(RelayParser.ExprContext,i)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitTensor" ):
- return visitor.visitTensor(self)
- else:
- return visitor.visitChildren(self)
-
-
- class GraphContext(ExprContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def graphVar(self):
- return self.getTypedRuleContext(RelayParser.GraphVarContext,0)
-
- def expr(self, i:int=None):
- if i is None:
- return self.getTypedRuleContexts(RelayParser.ExprContext)
- else:
- return self.getTypedRuleContext(RelayParser.ExprContext,i)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitGraph" ):
- return visitor.visitGraph(self)
- else:
- return visitor.visitChildren(self)
-
-
- class IdentExprContext(ExprContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def ident(self):
- return self.getTypedRuleContext(RelayParser.IdentContext,0)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitIdentExpr" ):
- return visitor.visitIdentExpr(self)
- else:
- return visitor.visitChildren(self)
-
-
- class StringExprContext(ExprContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def QUOTED_STRING(self):
- return self.getToken(RelayParser.QUOTED_STRING, 0)
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitStringExpr" ):
- return visitor.visitStringExpr(self)
- else:
- return visitor.visitChildren(self)
-
-
- class CallContext(ExprContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def expr(self):
- return self.getTypedRuleContext(RelayParser.ExprContext,0)
-
- def callList(self):
- return self.getTypedRuleContext(RelayParser.CallListContext,0)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitCall" ):
- return visitor.visitCall(self)
- else:
- return visitor.visitChildren(self)
-
-
- class NegContext(ExprContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def SUB(self):
- return self.getToken(RelayParser.SUB, 0)
- def expr(self):
- return self.getTypedRuleContext(RelayParser.ExprContext,0)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitNeg" ):
- return visitor.visitNeg(self)
- else:
- return visitor.visitChildren(self)
-
-
- class TupleContext(ExprContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def expr(self, i:int=None):
- if i is None:
- return self.getTypedRuleContexts(RelayParser.ExprContext)
- else:
- return self.getTypedRuleContext(RelayParser.ExprContext,i)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitTuple" ):
- return visitor.visitTuple(self)
- else:
- return visitor.visitChildren(self)
-
-
- class ParenContext(ExprContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def expr(self):
- return self.getTypedRuleContext(RelayParser.ExprContext,0)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitParen" ):
- return visitor.visitParen(self)
- else:
- return visitor.visitChildren(self)
-
-
- class ScalarExprContext(ExprContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def scalar(self):
- return self.getTypedRuleContext(RelayParser.ScalarContext,0)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitScalarExpr" ):
- return visitor.visitScalarExpr(self)
- else:
- return visitor.visitChildren(self)
-
-
- class LetContext(ExprContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def var(self):
- return self.getTypedRuleContext(RelayParser.VarContext,0)
-
- def expr(self, i:int=None):
- if i is None:
- return self.getTypedRuleContexts(RelayParser.ExprContext)
- else:
- return self.getTypedRuleContext(RelayParser.ExprContext,i)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitLet" ):
- return visitor.visitLet(self)
- else:
- return visitor.visitChildren(self)
-
-
- class ProjectionContext(ExprContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def expr(self):
- return self.getTypedRuleContext(RelayParser.ExprContext,0)
-
- def NAT(self):
- return self.getToken(RelayParser.NAT, 0)
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitProjection" ):
- return visitor.visitProjection(self)
- else:
- return visitor.visitChildren(self)
-
-
- class IfElseContext(ExprContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def expr(self):
- return self.getTypedRuleContext(RelayParser.ExprContext,0)
-
- def body(self, i:int=None):
- if i is None:
- return self.getTypedRuleContexts(RelayParser.BodyContext)
- else:
- return self.getTypedRuleContext(RelayParser.BodyContext,i)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitIfElse" ):
- return visitor.visitIfElse(self)
- else:
- return visitor.visitChildren(self)
-
-
- class BinOpContext(ExprContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext
- super().__init__(parser)
- self.op = None # Token
- self.copyFrom(ctx)
-
- def expr(self, i:int=None):
- if i is None:
- return self.getTypedRuleContexts(RelayParser.ExprContext)
- else:
- return self.getTypedRuleContext(RelayParser.ExprContext,i)
-
- def MUL(self):
- return self.getToken(RelayParser.MUL, 0)
- def DIV(self):
- return self.getToken(RelayParser.DIV, 0)
- def ADD(self):
- return self.getToken(RelayParser.ADD, 0)
- def SUB(self):
- return self.getToken(RelayParser.SUB, 0)
- def LT(self):
- return self.getToken(RelayParser.LT, 0)
- def GT(self):
- return self.getToken(RelayParser.GT, 0)
- def LE(self):
- return self.getToken(RelayParser.LE, 0)
- def GE(self):
- return self.getToken(RelayParser.GE, 0)
- def EQ(self):
- return self.getToken(RelayParser.EQ, 0)
- def NE(self):
- return self.getToken(RelayParser.NE, 0)
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitBinOp" ):
- return visitor.visitBinOp(self)
- else:
- return visitor.visitChildren(self)
-
-
-
- def expr(self, _p:int=0):
- _parentctx = self._ctx
- _parentState = self.state
- localctx = RelayParser.ExprContext(self, self._ctx, _parentState)
- _prevctx = localctx
- _startState = 14
- self.enterRecursionRule(localctx, 14, self.RULE_expr, _p)
- self._la = 0 # Token type
- try:
- self.enterOuterAlt(localctx, 1)
- self.state = 192
- self._errHandler.sync(self)
- la_ = self._interp.adaptivePredict(self._input,12,self._ctx)
- if la_ == 1:
- localctx = RelayParser.ParenContext(self, localctx)
- self._ctx = localctx
- _prevctx = localctx
-
- self.state = 123
- self.match(RelayParser.T__5)
- self.state = 124
- self.expr(0)
- self.state = 125
- self.match(RelayParser.T__6)
- pass
-
- elif la_ == 2:
- localctx = RelayParser.NegContext(self, localctx)
- self._ctx = localctx
- _prevctx = localctx
- self.state = 127
- self.match(RelayParser.SUB)
- self.state = 128
- self.expr(20)
- pass
-
- elif la_ == 3:
- localctx = RelayParser.FuncExprContext(self, localctx)
- self._ctx = localctx
- _prevctx = localctx
- self.state = 129
- self.func()
- pass
-
- elif la_ == 4:
- localctx = RelayParser.TupleContext(self, localctx)
- self._ctx = localctx
- _prevctx = localctx
- self.state = 130
- self.match(RelayParser.T__5)
- self.state = 131
- self.match(RelayParser.T__6)
- pass
-
- elif la_ == 5:
- localctx = RelayParser.TupleContext(self, localctx)
- self._ctx = localctx
- _prevctx = localctx
- self.state = 132
- self.match(RelayParser.T__5)
- self.state = 133
- self.expr(0)
- self.state = 134
- self.match(RelayParser.T__4)
- self.state = 135
- self.match(RelayParser.T__6)
- pass
-
- elif la_ == 6:
- localctx = RelayParser.TupleContext(self, localctx)
- self._ctx = localctx
- _prevctx = localctx
- self.state = 137
- self.match(RelayParser.T__5)
- self.state = 138
- self.expr(0)
- self.state = 141
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- while True:
- self.state = 139
- self.match(RelayParser.T__4)
- self.state = 140
- self.expr(0)
- self.state = 143
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- if not (_la==RelayParser.T__4):
- break
-
- self.state = 145
- self.match(RelayParser.T__6)
- pass
-
- elif la_ == 7:
- localctx = RelayParser.TensorContext(self, localctx)
- self._ctx = localctx
- _prevctx = localctx
- self.state = 147
- self.match(RelayParser.T__7)
- self.state = 156
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__1) | (1 << RelayParser.T__2) | (1 << RelayParser.T__5) | (1 << RelayParser.T__7) | (1 << RelayParser.T__9) | (1 << RelayParser.T__13) | (1 << RelayParser.T__17) | (1 << RelayParser.T__23) | (1 << RelayParser.T__24) | (1 << RelayParser.T__27) | (1 << RelayParser.QUOTED_STRING) | (1 << RelayParser.SUB) | (1 << RelayParser.BOOL_LIT) | (1 << RelayParser.CNAME) | (1 << RelayParser.FLOAT) | (1 << RelayParser.NAT))) != 0):
- self.state = 148
- self.expr(0)
- self.state = 153
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- while _la==RelayParser.T__4:
- self.state = 149
- self.match(RelayParser.T__4)
- self.state = 150
- self.expr(0)
- self.state = 155
- self._errHandler.sync(self)
- _la = self._input.LA(1)
-
-
-
- self.state = 158
- self.match(RelayParser.T__8)
- pass
-
- elif la_ == 8:
- localctx = RelayParser.IfElseContext(self, localctx)
- self._ctx = localctx
- _prevctx = localctx
- self.state = 159
- self.match(RelayParser.T__9)
- self.state = 160
- self.match(RelayParser.T__5)
- self.state = 161
- self.expr(0)
- self.state = 162
- self.match(RelayParser.T__6)
- self.state = 163
- self.body()
- self.state = 164
- self.match(RelayParser.T__10)
- self.state = 165
- self.body()
- pass
-
- elif la_ == 9:
- localctx = RelayParser.MatchContext(self, localctx)
- self._ctx = localctx
- _prevctx = localctx
- self.state = 167
- self.matchType()
- self.state = 168
- self.expr(0)
- self.state = 169
- self.match(RelayParser.T__11)
- self.state = 171
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__2) | (1 << RelayParser.T__3) | (1 << RelayParser.T__5) | (1 << RelayParser.CNAME))) != 0):
- self.state = 170
- self.matchClauseList()
-
-
- self.state = 173
- self.match(RelayParser.T__12)
- pass
-
- elif la_ == 10:
- localctx = RelayParser.LetContext(self, localctx)
- self._ctx = localctx
- _prevctx = localctx
- self.state = 175
- self.match(RelayParser.T__13)
- self.state = 176
- self.var()
- self.state = 177
- self.match(RelayParser.T__14)
- self.state = 178
- self.expr(0)
- self.state = 179
- self.match(RelayParser.T__15)
- self.state = 180
- self.expr(7)
- pass
-
- elif la_ == 11:
- localctx = RelayParser.GraphContext(self, localctx)
- self._ctx = localctx
- _prevctx = localctx
- self.state = 182
- self.graphVar()
- self.state = 183
- self.match(RelayParser.T__14)
- self.state = 184
- self.expr(0)
- self.state = 185
- self.match(RelayParser.T__15)
- self.state = 186
- self.expr(5)
- pass
-
- elif la_ == 12:
- localctx = RelayParser.IdentExprContext(self, localctx)
- self._ctx = localctx
- _prevctx = localctx
- self.state = 188
- self.ident()
- pass
-
- elif la_ == 13:
- localctx = RelayParser.ScalarExprContext(self, localctx)
- self._ctx = localctx
- _prevctx = localctx
- self.state = 189
- self.scalar()
- pass
-
- elif la_ == 14:
- localctx = RelayParser.MetaExprContext(self, localctx)
- self._ctx = localctx
- _prevctx = localctx
- self.state = 190
- self.meta()
- pass
-
- elif la_ == 15:
- localctx = RelayParser.StringExprContext(self, localctx)
- self._ctx = localctx
- _prevctx = localctx
- self.state = 191
- self.match(RelayParser.QUOTED_STRING)
- pass
-
-
- self._ctx.stop = self._input.LT(-1)
- self.state = 219
- self._errHandler.sync(self)
- _alt = self._interp.adaptivePredict(self._input,14,self._ctx)
- while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER:
- if _alt==1:
- if self._parseListeners is not None:
- self.triggerExitRuleEvent()
- _prevctx = localctx
- self.state = 217
- self._errHandler.sync(self)
- la_ = self._interp.adaptivePredict(self._input,13,self._ctx)
- if la_ == 1:
- localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState))
- self.pushNewRecursionContext(localctx, _startState, self.RULE_expr)
- self.state = 194
- if not self.precpred(self._ctx, 19):
- from antlr4.error.Errors import FailedPredicateException
- raise FailedPredicateException(self, "self.precpred(self._ctx, 19)")
- self.state = 195
- localctx.op = self._input.LT(1)
- _la = self._input.LA(1)
- if not(_la==RelayParser.MUL or _la==RelayParser.DIV):
- localctx.op = self._errHandler.recoverInline(self)
- else:
- self._errHandler.reportMatch(self)
- self.consume()
- self.state = 196
- self.expr(20)
- pass
-
- elif la_ == 2:
- localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState))
- self.pushNewRecursionContext(localctx, _startState, self.RULE_expr)
- self.state = 197
- if not self.precpred(self._ctx, 18):
- from antlr4.error.Errors import FailedPredicateException
- raise FailedPredicateException(self, "self.precpred(self._ctx, 18)")
- self.state = 198
- localctx.op = self._input.LT(1)
- _la = self._input.LA(1)
- if not(_la==RelayParser.ADD or _la==RelayParser.SUB):
- localctx.op = self._errHandler.recoverInline(self)
- else:
- self._errHandler.reportMatch(self)
- self.consume()
- self.state = 199
- self.expr(19)
- pass
-
- elif la_ == 3:
- localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState))
- self.pushNewRecursionContext(localctx, _startState, self.RULE_expr)
- self.state = 200
- if not self.precpred(self._ctx, 17):
- from antlr4.error.Errors import FailedPredicateException
- raise FailedPredicateException(self, "self.precpred(self._ctx, 17)")
- self.state = 201
- localctx.op = self._input.LT(1)
- _la = self._input.LA(1)
- if not((((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.LT) | (1 << RelayParser.GT) | (1 << RelayParser.LE) | (1 << RelayParser.GE))) != 0)):
- localctx.op = self._errHandler.recoverInline(self)
- else:
- self._errHandler.reportMatch(self)
- self.consume()
- self.state = 202
- self.expr(18)
- pass
-
- elif la_ == 4:
- localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState))
- self.pushNewRecursionContext(localctx, _startState, self.RULE_expr)
- self.state = 203
- if not self.precpred(self._ctx, 16):
- from antlr4.error.Errors import FailedPredicateException
- raise FailedPredicateException(self, "self.precpred(self._ctx, 16)")
- self.state = 204
- localctx.op = self._input.LT(1)
- _la = self._input.LA(1)
- if not(_la==RelayParser.EQ or _la==RelayParser.NE):
- localctx.op = self._errHandler.recoverInline(self)
- else:
- self._errHandler.reportMatch(self)
- self.consume()
- self.state = 205
- self.expr(17)
- pass
-
- elif la_ == 5:
- localctx = RelayParser.LetContext(self, RelayParser.ExprContext(self, _parentctx, _parentState))
- self.pushNewRecursionContext(localctx, _startState, self.RULE_expr)
- self.state = 206
- if not self.precpred(self._ctx, 6):
- from antlr4.error.Errors import FailedPredicateException
- raise FailedPredicateException(self, "self.precpred(self._ctx, 6)")
- self.state = 207
- self.match(RelayParser.T__16)
- self.state = 208
- self.expr(7)
- pass
-
- elif la_ == 6:
- localctx = RelayParser.CallContext(self, RelayParser.ExprContext(self, _parentctx, _parentState))
- self.pushNewRecursionContext(localctx, _startState, self.RULE_expr)
- self.state = 209
- if not self.precpred(self._ctx, 21):
- from antlr4.error.Errors import FailedPredicateException
- raise FailedPredicateException(self, "self.precpred(self._ctx, 21)")
- self.state = 210
- self.match(RelayParser.T__5)
- self.state = 211
- self.callList()
- self.state = 212
- self.match(RelayParser.T__6)
- pass
-
- elif la_ == 7:
- localctx = RelayParser.ProjectionContext(self, RelayParser.ExprContext(self, _parentctx, _parentState))
- self.pushNewRecursionContext(localctx, _startState, self.RULE_expr)
- self.state = 214
- if not self.precpred(self._ctx, 8):
- from antlr4.error.Errors import FailedPredicateException
- raise FailedPredicateException(self, "self.precpred(self._ctx, 8)")
- self.state = 215
- self.match(RelayParser.T__0)
- self.state = 216
- self.match(RelayParser.NAT)
- pass
-
-
- self.state = 221
- self._errHandler.sync(self)
- _alt = self._interp.adaptivePredict(self._input,14,self._ctx)
-
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.unrollRecursionContexts(_parentctx)
- return localctx
-
-
- class FuncContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
- def argList(self):
- return self.getTypedRuleContext(RelayParser.ArgListContext,0)
-
-
- def body(self):
- return self.getTypedRuleContext(RelayParser.BodyContext,0)
-
-
- def typeParamList(self):
- return self.getTypedRuleContext(RelayParser.TypeParamListContext,0)
-
-
- def typeExpr(self):
- return self.getTypedRuleContext(RelayParser.TypeExprContext,0)
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_func
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitFunc" ):
- return visitor.visitFunc(self)
- else:
- return visitor.visitChildren(self)
-
-
-
-
- def func(self):
-
- localctx = RelayParser.FuncContext(self, self._ctx, self.state)
- self.enterRule(localctx, 16, self.RULE_func)
- self._la = 0 # Token type
- try:
- self.enterOuterAlt(localctx, 1)
- self.state = 222
- self.match(RelayParser.T__17)
- self.state = 224
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- if _la==RelayParser.T__7:
- self.state = 223
- self.typeParamList()
-
-
- self.state = 226
- self.match(RelayParser.T__5)
- self.state = 227
- self.argList()
- self.state = 228
- self.match(RelayParser.T__6)
- self.state = 231
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- if _la==RelayParser.T__18:
- self.state = 229
- self.match(RelayParser.T__18)
- self.state = 230
- self.typeExpr()
-
-
- self.state = 233
- self.body()
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class DefnContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_defn
-
-
- def copyFrom(self, ctx:ParserRuleContext):
- super().copyFrom(ctx)
-
-
-
- class ExternAdtDefnContext(DefnContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.DefnContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def generalIdent(self):
- return self.getTypedRuleContext(RelayParser.GeneralIdentContext,0)
-
- def typeParamList(self):
- return self.getTypedRuleContext(RelayParser.TypeParamListContext,0)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitExternAdtDefn" ):
- return visitor.visitExternAdtDefn(self)
- else:
- return visitor.visitChildren(self)
-
-
- class FuncDefnContext(DefnContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.DefnContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def globalVar(self):
- return self.getTypedRuleContext(RelayParser.GlobalVarContext,0)
-
- def argList(self):
- return self.getTypedRuleContext(RelayParser.ArgListContext,0)
-
- def body(self):
- return self.getTypedRuleContext(RelayParser.BodyContext,0)
-
- def typeParamList(self):
- return self.getTypedRuleContext(RelayParser.TypeParamListContext,0)
-
- def typeExpr(self):
- return self.getTypedRuleContext(RelayParser.TypeExprContext,0)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitFuncDefn" ):
- return visitor.visitFuncDefn(self)
- else:
- return visitor.visitChildren(self)
-
-
- class AdtDefnContext(DefnContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.DefnContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def generalIdent(self):
- return self.getTypedRuleContext(RelayParser.GeneralIdentContext,0)
-
- def typeParamList(self):
- return self.getTypedRuleContext(RelayParser.TypeParamListContext,0)
-
- def adtConsDefnList(self):
- return self.getTypedRuleContext(RelayParser.AdtConsDefnListContext,0)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitAdtDefn" ):
- return visitor.visitAdtDefn(self)
- else:
- return visitor.visitChildren(self)
-
-
-
- def defn(self):
-
- localctx = RelayParser.DefnContext(self, self._ctx, self.state)
- self.enterRule(localctx, 18, self.RULE_defn)
- self._la = 0 # Token type
- try:
- self.state = 266
- self._errHandler.sync(self)
- token = self._input.LA(1)
- if token in [RelayParser.T__19]:
- localctx = RelayParser.FuncDefnContext(self, localctx)
- self.enterOuterAlt(localctx, 1)
- self.state = 235
- self.match(RelayParser.T__19)
- self.state = 236
- self.globalVar()
- self.state = 238
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- if _la==RelayParser.T__7:
- self.state = 237
- self.typeParamList()
-
-
- self.state = 240
- self.match(RelayParser.T__5)
- self.state = 241
- self.argList()
- self.state = 242
- self.match(RelayParser.T__6)
- self.state = 245
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- if _la==RelayParser.T__18:
- self.state = 243
- self.match(RelayParser.T__18)
- self.state = 244
- self.typeExpr()
-
-
- self.state = 247
- self.body()
- pass
- elif token in [RelayParser.T__20]:
- localctx = RelayParser.ExternAdtDefnContext(self, localctx)
- self.enterOuterAlt(localctx, 2)
- self.state = 249
- self.match(RelayParser.T__20)
- self.state = 250
- self.match(RelayParser.T__21)
- self.state = 251
- self.generalIdent()
- self.state = 253
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- if _la==RelayParser.T__7:
- self.state = 252
- self.typeParamList()
-
-
- pass
- elif token in [RelayParser.T__21]:
- localctx = RelayParser.AdtDefnContext(self, localctx)
- self.enterOuterAlt(localctx, 3)
- self.state = 255
- self.match(RelayParser.T__21)
- self.state = 256
- self.generalIdent()
- self.state = 258
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- if _la==RelayParser.T__7:
- self.state = 257
- self.typeParamList()
-
-
- self.state = 260
- self.match(RelayParser.T__11)
- self.state = 262
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- if _la==RelayParser.CNAME:
- self.state = 261
- self.adtConsDefnList()
-
-
- self.state = 264
- self.match(RelayParser.T__12)
- pass
- else:
- raise NoViableAltException(self)
-
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class ConstructorNameContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
- def CNAME(self):
- return self.getToken(RelayParser.CNAME, 0)
-
- def getRuleIndex(self):
- return RelayParser.RULE_constructorName
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitConstructorName" ):
- return visitor.visitConstructorName(self)
- else:
- return visitor.visitChildren(self)
-
-
-
-
- def constructorName(self):
-
- localctx = RelayParser.ConstructorNameContext(self, self._ctx, self.state)
- self.enterRule(localctx, 20, self.RULE_constructorName)
- try:
- self.enterOuterAlt(localctx, 1)
- self.state = 268
- self.match(RelayParser.CNAME)
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class AdtConsDefnListContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
- def adtConsDefn(self, i:int=None):
- if i is None:
- return self.getTypedRuleContexts(RelayParser.AdtConsDefnContext)
- else:
- return self.getTypedRuleContext(RelayParser.AdtConsDefnContext,i)
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_adtConsDefnList
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitAdtConsDefnList" ):
- return visitor.visitAdtConsDefnList(self)
- else:
- return visitor.visitChildren(self)
-
-
-
-
- def adtConsDefnList(self):
-
- localctx = RelayParser.AdtConsDefnListContext(self, self._ctx, self.state)
- self.enterRule(localctx, 22, self.RULE_adtConsDefnList)
- self._la = 0 # Token type
- try:
- self.enterOuterAlt(localctx, 1)
- self.state = 270
- self.adtConsDefn()
- self.state = 275
- self._errHandler.sync(self)
- _alt = self._interp.adaptivePredict(self._input,23,self._ctx)
- while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER:
- if _alt==1:
- self.state = 271
- self.match(RelayParser.T__4)
- self.state = 272
- self.adtConsDefn()
- self.state = 277
- self._errHandler.sync(self)
- _alt = self._interp.adaptivePredict(self._input,23,self._ctx)
-
- self.state = 279
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- if _la==RelayParser.T__4:
- self.state = 278
- self.match(RelayParser.T__4)
-
-
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class AdtConsDefnContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
- def constructorName(self):
- return self.getTypedRuleContext(RelayParser.ConstructorNameContext,0)
-
-
- def typeExpr(self, i:int=None):
- if i is None:
- return self.getTypedRuleContexts(RelayParser.TypeExprContext)
- else:
- return self.getTypedRuleContext(RelayParser.TypeExprContext,i)
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_adtConsDefn
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitAdtConsDefn" ):
- return visitor.visitAdtConsDefn(self)
- else:
- return visitor.visitChildren(self)
-
-
-
-
- def adtConsDefn(self):
-
- localctx = RelayParser.AdtConsDefnContext(self, self._ctx, self.state)
- self.enterRule(localctx, 24, self.RULE_adtConsDefn)
- self._la = 0 # Token type
- try:
- self.enterOuterAlt(localctx, 1)
- self.state = 281
- self.constructorName()
- self.state = 293
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- if _la==RelayParser.T__5:
- self.state = 282
- self.match(RelayParser.T__5)
- self.state = 283
- self.typeExpr()
- self.state = 288
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- while _la==RelayParser.T__4:
- self.state = 284
- self.match(RelayParser.T__4)
- self.state = 285
- self.typeExpr()
- self.state = 290
- self._errHandler.sync(self)
- _la = self._input.LA(1)
-
- self.state = 291
- self.match(RelayParser.T__6)
-
-
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class MatchClauseListContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
- def matchClause(self, i:int=None):
- if i is None:
- return self.getTypedRuleContexts(RelayParser.MatchClauseContext)
- else:
- return self.getTypedRuleContext(RelayParser.MatchClauseContext,i)
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_matchClauseList
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitMatchClauseList" ):
- return visitor.visitMatchClauseList(self)
- else:
- return visitor.visitChildren(self)
-
-
-
-
- def matchClauseList(self):
-
- localctx = RelayParser.MatchClauseListContext(self, self._ctx, self.state)
- self.enterRule(localctx, 26, self.RULE_matchClauseList)
- self._la = 0 # Token type
- try:
- self.enterOuterAlt(localctx, 1)
- self.state = 295
- self.matchClause()
- self.state = 300
- self._errHandler.sync(self)
- _alt = self._interp.adaptivePredict(self._input,27,self._ctx)
- while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER:
- if _alt==1:
- self.state = 296
- self.match(RelayParser.T__4)
- self.state = 297
- self.matchClause()
- self.state = 302
- self._errHandler.sync(self)
- _alt = self._interp.adaptivePredict(self._input,27,self._ctx)
-
- self.state = 304
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- if _la==RelayParser.T__4:
- self.state = 303
- self.match(RelayParser.T__4)
-
-
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class MatchClauseContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
- def pattern(self):
- return self.getTypedRuleContext(RelayParser.PatternContext,0)
-
-
- def expr(self):
- return self.getTypedRuleContext(RelayParser.ExprContext,0)
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_matchClause
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitMatchClause" ):
- return visitor.visitMatchClause(self)
- else:
- return visitor.visitChildren(self)
-
-
-
-
- def matchClause(self):
-
- localctx = RelayParser.MatchClauseContext(self, self._ctx, self.state)
- self.enterRule(localctx, 28, self.RULE_matchClause)
- try:
- self.enterOuterAlt(localctx, 1)
- self.state = 306
- self.pattern()
- self.state = 307
- self.match(RelayParser.T__22)
- self.state = 313
- self._errHandler.sync(self)
- token = self._input.LA(1)
- if token in [RelayParser.T__11]:
- self.state = 308
- self.match(RelayParser.T__11)
- self.state = 309
- self.expr(0)
- self.state = 310
- self.match(RelayParser.T__12)
- pass
- elif token in [RelayParser.T__1, RelayParser.T__2, RelayParser.T__5, RelayParser.T__7, RelayParser.T__9, RelayParser.T__13, RelayParser.T__17, RelayParser.T__23, RelayParser.T__24, RelayParser.T__27, RelayParser.QUOTED_STRING, RelayParser.SUB, RelayParser.BOOL_LIT, RelayParser.CNAME, RelayParser.FLOAT, RelayParser.NAT]:
- self.state = 312
- self.expr(0)
- pass
- else:
- raise NoViableAltException(self)
-
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class MatchTypeContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_matchType
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitMatchType" ):
- return visitor.visitMatchType(self)
- else:
- return visitor.visitChildren(self)
-
-
-
-
- def matchType(self):
-
- localctx = RelayParser.MatchTypeContext(self, self._ctx, self.state)
- self.enterRule(localctx, 30, self.RULE_matchType)
- self._la = 0 # Token type
- try:
- self.enterOuterAlt(localctx, 1)
- self.state = 315
- _la = self._input.LA(1)
- if not(_la==RelayParser.T__23 or _la==RelayParser.T__24):
- self._errHandler.recoverInline(self)
- else:
- self._errHandler.reportMatch(self)
- self.consume()
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class PatternListContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
- def pattern(self, i:int=None):
- if i is None:
- return self.getTypedRuleContexts(RelayParser.PatternContext)
- else:
- return self.getTypedRuleContext(RelayParser.PatternContext,i)
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_patternList
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitPatternList" ):
- return visitor.visitPatternList(self)
- else:
- return visitor.visitChildren(self)
-
-
-
-
- def patternList(self):
-
- localctx = RelayParser.PatternListContext(self, self._ctx, self.state)
- self.enterRule(localctx, 32, self.RULE_patternList)
- self._la = 0 # Token type
- try:
- self.enterOuterAlt(localctx, 1)
- self.state = 317
- self.match(RelayParser.T__5)
- self.state = 318
- self.pattern()
- self.state = 323
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- while _la==RelayParser.T__4:
- self.state = 319
- self.match(RelayParser.T__4)
- self.state = 320
- self.pattern()
- self.state = 325
- self._errHandler.sync(self)
- _la = self._input.LA(1)
-
- self.state = 326
- self.match(RelayParser.T__6)
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class PatternContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_pattern
-
-
- def copyFrom(self, ctx:ParserRuleContext):
- super().copyFrom(ctx)
-
-
-
- class WildcardPatternContext(PatternContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.PatternContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitWildcardPattern" ):
- return visitor.visitWildcardPattern(self)
- else:
- return visitor.visitChildren(self)
-
-
- class ConstructorPatternContext(PatternContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.PatternContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def constructorName(self):
- return self.getTypedRuleContext(RelayParser.ConstructorNameContext,0)
-
- def patternList(self):
- return self.getTypedRuleContext(RelayParser.PatternListContext,0)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitConstructorPattern" ):
- return visitor.visitConstructorPattern(self)
- else:
- return visitor.visitChildren(self)
-
-
- class TuplePatternContext(PatternContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.PatternContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def patternList(self):
- return self.getTypedRuleContext(RelayParser.PatternListContext,0)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitTuplePattern" ):
- return visitor.visitTuplePattern(self)
- else:
- return visitor.visitChildren(self)
-
-
- class VarPatternContext(PatternContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.PatternContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def localVar(self):
- return self.getTypedRuleContext(RelayParser.LocalVarContext,0)
-
- def typeExpr(self):
- return self.getTypedRuleContext(RelayParser.TypeExprContext,0)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitVarPattern" ):
- return visitor.visitVarPattern(self)
- else:
- return visitor.visitChildren(self)
-
-
-
- def pattern(self):
-
- localctx = RelayParser.PatternContext(self, self._ctx, self.state)
- self.enterRule(localctx, 34, self.RULE_pattern)
- self._la = 0 # Token type
- try:
- self.state = 339
- self._errHandler.sync(self)
- token = self._input.LA(1)
- if token in [RelayParser.T__3]:
- localctx = RelayParser.WildcardPatternContext(self, localctx)
- self.enterOuterAlt(localctx, 1)
- self.state = 328
- self.match(RelayParser.T__3)
- pass
- elif token in [RelayParser.T__2]:
- localctx = RelayParser.VarPatternContext(self, localctx)
- self.enterOuterAlt(localctx, 2)
- self.state = 329
- self.localVar()
- self.state = 332
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- if _la==RelayParser.T__25:
- self.state = 330
- self.match(RelayParser.T__25)
- self.state = 331
- self.typeExpr()
-
-
- pass
- elif token in [RelayParser.CNAME]:
- localctx = RelayParser.ConstructorPatternContext(self, localctx)
- self.enterOuterAlt(localctx, 3)
- self.state = 334
- self.constructorName()
- self.state = 336
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- if _la==RelayParser.T__5:
- self.state = 335
- self.patternList()
-
-
- pass
- elif token in [RelayParser.T__5]:
- localctx = RelayParser.TuplePatternContext(self, localctx)
- self.enterOuterAlt(localctx, 4)
- self.state = 338
- self.patternList()
- pass
- else:
- raise NoViableAltException(self)
-
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class AdtConsContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
- def constructorName(self):
- return self.getTypedRuleContext(RelayParser.ConstructorNameContext,0)
-
-
- def adtConsParamList(self):
- return self.getTypedRuleContext(RelayParser.AdtConsParamListContext,0)
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_adtCons
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitAdtCons" ):
- return visitor.visitAdtCons(self)
- else:
- return visitor.visitChildren(self)
-
-
-
-
- def adtCons(self):
-
- localctx = RelayParser.AdtConsContext(self, self._ctx, self.state)
- self.enterRule(localctx, 36, self.RULE_adtCons)
- self._la = 0 # Token type
- try:
- self.enterOuterAlt(localctx, 1)
- self.state = 341
- self.constructorName()
- self.state = 343
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- if _la==RelayParser.T__5:
- self.state = 342
- self.adtConsParamList()
-
-
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class AdtConsParamListContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
- def adtConsParam(self, i:int=None):
- if i is None:
- return self.getTypedRuleContexts(RelayParser.AdtConsParamContext)
- else:
- return self.getTypedRuleContext(RelayParser.AdtConsParamContext,i)
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_adtConsParamList
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitAdtConsParamList" ):
- return visitor.visitAdtConsParamList(self)
- else:
- return visitor.visitChildren(self)
-
-
-
-
- def adtConsParamList(self):
-
- localctx = RelayParser.AdtConsParamListContext(self, self._ctx, self.state)
- self.enterRule(localctx, 38, self.RULE_adtConsParamList)
- self._la = 0 # Token type
- try:
- self.enterOuterAlt(localctx, 1)
- self.state = 345
- self.match(RelayParser.T__5)
- self.state = 346
- self.adtConsParam()
- self.state = 351
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- while _la==RelayParser.T__4:
- self.state = 347
- self.match(RelayParser.T__4)
- self.state = 348
- self.adtConsParam()
- self.state = 353
- self._errHandler.sync(self)
- _la = self._input.LA(1)
-
- self.state = 354
- self.match(RelayParser.T__6)
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class AdtConsParamContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
- def localVar(self):
- return self.getTypedRuleContext(RelayParser.LocalVarContext,0)
-
-
- def constructorName(self):
- return self.getTypedRuleContext(RelayParser.ConstructorNameContext,0)
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_adtConsParam
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitAdtConsParam" ):
- return visitor.visitAdtConsParam(self)
- else:
- return visitor.visitChildren(self)
-
-
-
-
- def adtConsParam(self):
-
- localctx = RelayParser.AdtConsParamContext(self, self._ctx, self.state)
- self.enterRule(localctx, 40, self.RULE_adtConsParam)
- try:
- self.state = 358
- self._errHandler.sync(self)
- token = self._input.LA(1)
- if token in [RelayParser.T__2]:
- self.enterOuterAlt(localctx, 1)
- self.state = 356
- self.localVar()
- pass
- elif token in [RelayParser.CNAME]:
- self.enterOuterAlt(localctx, 2)
- self.state = 357
- self.constructorName()
- pass
- else:
- raise NoViableAltException(self)
-
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class ArgListContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_argList
-
-
- def copyFrom(self, ctx:ParserRuleContext):
- super().copyFrom(ctx)
-
-
-
- class ArgNoAttrContext(ArgListContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ArgListContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def varList(self):
- return self.getTypedRuleContext(RelayParser.VarListContext,0)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitArgNoAttr" ):
- return visitor.visitArgNoAttr(self)
- else:
- return visitor.visitChildren(self)
-
-
- class ArgWithAttrContext(ArgListContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ArgListContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def attrSeq(self):
- return self.getTypedRuleContext(RelayParser.AttrSeqContext,0)
-
- def var(self, i:int=None):
- if i is None:
- return self.getTypedRuleContexts(RelayParser.VarContext)
- else:
- return self.getTypedRuleContext(RelayParser.VarContext,i)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitArgWithAttr" ):
- return visitor.visitArgWithAttr(self)
- else:
- return visitor.visitChildren(self)
-
-
-
- def argList(self):
-
- localctx = RelayParser.ArgListContext(self, self._ctx, self.state)
- self.enterRule(localctx, 42, self.RULE_argList)
- self._la = 0 # Token type
- try:
- self.state = 370
- self._errHandler.sync(self)
- la_ = self._interp.adaptivePredict(self._input,38,self._ctx)
- if la_ == 1:
- localctx = RelayParser.ArgNoAttrContext(self, localctx)
- self.enterOuterAlt(localctx, 1)
- self.state = 360
- self.varList()
- pass
-
- elif la_ == 2:
- localctx = RelayParser.ArgWithAttrContext(self, localctx)
- self.enterOuterAlt(localctx, 2)
- self.state = 366
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- while _la==RelayParser.T__2:
- self.state = 361
- self.var()
- self.state = 362
- self.match(RelayParser.T__4)
- self.state = 368
- self._errHandler.sync(self)
- _la = self._input.LA(1)
-
- self.state = 369
- self.attrSeq()
- pass
-
-
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class VarListContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
- def var(self, i:int=None):
- if i is None:
- return self.getTypedRuleContexts(RelayParser.VarContext)
- else:
- return self.getTypedRuleContext(RelayParser.VarContext,i)
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_varList
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitVarList" ):
- return visitor.visitVarList(self)
- else:
- return visitor.visitChildren(self)
-
-
-
-
- def varList(self):
-
- localctx = RelayParser.VarListContext(self, self._ctx, self.state)
- self.enterRule(localctx, 44, self.RULE_varList)
- self._la = 0 # Token type
- try:
- self.enterOuterAlt(localctx, 1)
- self.state = 380
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- if _la==RelayParser.T__2:
- self.state = 372
- self.var()
- self.state = 377
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- while _la==RelayParser.T__4:
- self.state = 373
- self.match(RelayParser.T__4)
- self.state = 374
- self.var()
- self.state = 379
- self._errHandler.sync(self)
- _la = self._input.LA(1)
-
-
-
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class VarContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
- def localVar(self):
- return self.getTypedRuleContext(RelayParser.LocalVarContext,0)
-
-
- def typeExpr(self):
- return self.getTypedRuleContext(RelayParser.TypeExprContext,0)
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_var
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitVar" ):
- return visitor.visitVar(self)
- else:
- return visitor.visitChildren(self)
-
-
-
-
- def var(self):
-
- localctx = RelayParser.VarContext(self, self._ctx, self.state)
- self.enterRule(localctx, 46, self.RULE_var)
- self._la = 0 # Token type
- try:
- self.enterOuterAlt(localctx, 1)
- self.state = 382
- self.localVar()
- self.state = 385
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- if _la==RelayParser.T__25:
- self.state = 383
- self.match(RelayParser.T__25)
- self.state = 384
- self.typeExpr()
-
-
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class AttrSeqContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
- def attr(self, i:int=None):
- if i is None:
- return self.getTypedRuleContexts(RelayParser.AttrContext)
- else:
- return self.getTypedRuleContext(RelayParser.AttrContext,i)
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_attrSeq
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitAttrSeq" ):
- return visitor.visitAttrSeq(self)
- else:
- return visitor.visitChildren(self)
-
-
-
-
- def attrSeq(self):
-
- localctx = RelayParser.AttrSeqContext(self, self._ctx, self.state)
- self.enterRule(localctx, 48, self.RULE_attrSeq)
- self._la = 0 # Token type
- try:
- self.enterOuterAlt(localctx, 1)
- self.state = 387
- self.attr()
- self.state = 392
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- while _la==RelayParser.T__4:
- self.state = 388
- self.match(RelayParser.T__4)
- self.state = 389
- self.attr()
- self.state = 394
- self._errHandler.sync(self)
- _la = self._input.LA(1)
-
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class AttrContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
- def CNAME(self):
- return self.getToken(RelayParser.CNAME, 0)
-
- def expr(self):
- return self.getTypedRuleContext(RelayParser.ExprContext,0)
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_attr
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitAttr" ):
- return visitor.visitAttr(self)
- else:
- return visitor.visitChildren(self)
-
-
-
-
- def attr(self):
-
- localctx = RelayParser.AttrContext(self, self._ctx, self.state)
- self.enterRule(localctx, 50, self.RULE_attr)
- try:
- self.enterOuterAlt(localctx, 1)
- self.state = 395
- self.match(RelayParser.CNAME)
- self.state = 396
- self.match(RelayParser.T__14)
- self.state = 397
- self.expr(0)
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class TypeExprContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_typeExpr
-
-
- def copyFrom(self, ctx:ParserRuleContext):
- super().copyFrom(ctx)
-
-
-
- class TypeParenContext(TypeExprContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.TypeExprContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def typeExpr(self):
- return self.getTypedRuleContext(RelayParser.TypeExprContext,0)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitTypeParen" ):
- return visitor.visitTypeParen(self)
- else:
- return visitor.visitChildren(self)
-
-
- class TupleTypeContext(TypeExprContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.TypeExprContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def typeExpr(self, i:int=None):
- if i is None:
- return self.getTypedRuleContexts(RelayParser.TypeExprContext)
- else:
- return self.getTypedRuleContext(RelayParser.TypeExprContext,i)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitTupleType" ):
- return visitor.visitTupleType(self)
- else:
- return visitor.visitChildren(self)
-
-
- class TypeCallTypeContext(TypeExprContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.TypeExprContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def generalIdent(self):
- return self.getTypedRuleContext(RelayParser.GeneralIdentContext,0)
-
- def typeParamList(self):
- return self.getTypedRuleContext(RelayParser.TypeParamListContext,0)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitTypeCallType" ):
- return visitor.visitTypeCallType(self)
- else:
- return visitor.visitChildren(self)
-
-
- class TypeIdentTypeContext(TypeExprContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.TypeExprContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def generalIdent(self):
- return self.getTypedRuleContext(RelayParser.GeneralIdentContext,0)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitTypeIdentType" ):
- return visitor.visitTypeIdentType(self)
- else:
- return visitor.visitChildren(self)
-
-
- class IncompleteTypeContext(TypeExprContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.TypeExprContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitIncompleteType" ):
- return visitor.visitIncompleteType(self)
- else:
- return visitor.visitChildren(self)
-
-
- class TensorTypeContext(TypeExprContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.TypeExprContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def shapeList(self):
- return self.getTypedRuleContext(RelayParser.ShapeListContext,0)
-
- def typeExpr(self):
- return self.getTypedRuleContext(RelayParser.TypeExprContext,0)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitTensorType" ):
- return visitor.visitTensorType(self)
- else:
- return visitor.visitChildren(self)
-
-
- class FuncTypeContext(TypeExprContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.TypeExprContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def typeExpr(self, i:int=None):
- if i is None:
- return self.getTypedRuleContexts(RelayParser.TypeExprContext)
- else:
- return self.getTypedRuleContext(RelayParser.TypeExprContext,i)
-
- def typeParamList(self):
- return self.getTypedRuleContext(RelayParser.TypeParamListContext,0)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitFuncType" ):
- return visitor.visitFuncType(self)
- else:
- return visitor.visitChildren(self)
-
-
-
- def typeExpr(self):
-
- localctx = RelayParser.TypeExprContext(self, self._ctx, self.state)
- self.enterRule(localctx, 52, self.RULE_typeExpr)
- self._la = 0 # Token type
- try:
- self.state = 450
- self._errHandler.sync(self)
- la_ = self._interp.adaptivePredict(self._input,47,self._ctx)
- if la_ == 1:
- localctx = RelayParser.TupleTypeContext(self, localctx)
- self.enterOuterAlt(localctx, 1)
- self.state = 399
- self.match(RelayParser.T__5)
- self.state = 400
- self.match(RelayParser.T__6)
- pass
-
- elif la_ == 2:
- localctx = RelayParser.TypeParenContext(self, localctx)
- self.enterOuterAlt(localctx, 2)
- self.state = 401
- self.match(RelayParser.T__5)
- self.state = 402
- self.typeExpr()
- self.state = 403
- self.match(RelayParser.T__6)
- pass
-
- elif la_ == 3:
- localctx = RelayParser.TupleTypeContext(self, localctx)
- self.enterOuterAlt(localctx, 3)
- self.state = 405
- self.match(RelayParser.T__5)
- self.state = 406
- self.typeExpr()
- self.state = 407
- self.match(RelayParser.T__4)
- self.state = 408
- self.match(RelayParser.T__6)
- pass
-
- elif la_ == 4:
- localctx = RelayParser.TupleTypeContext(self, localctx)
- self.enterOuterAlt(localctx, 4)
- self.state = 410
- self.match(RelayParser.T__5)
- self.state = 411
- self.typeExpr()
- self.state = 414
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- while True:
- self.state = 412
- self.match(RelayParser.T__4)
- self.state = 413
- self.typeExpr()
- self.state = 416
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- if not (_la==RelayParser.T__4):
- break
-
- self.state = 418
- self.match(RelayParser.T__6)
- pass
-
- elif la_ == 5:
- localctx = RelayParser.TypeCallTypeContext(self, localctx)
- self.enterOuterAlt(localctx, 5)
- self.state = 420
- self.generalIdent()
- self.state = 421
- self.typeParamList()
- pass
-
- elif la_ == 6:
- localctx = RelayParser.TypeIdentTypeContext(self, localctx)
- self.enterOuterAlt(localctx, 6)
- self.state = 423
- self.generalIdent()
- pass
-
- elif la_ == 7:
- localctx = RelayParser.TensorTypeContext(self, localctx)
- self.enterOuterAlt(localctx, 7)
- self.state = 424
- self.match(RelayParser.T__26)
- self.state = 425
- self.match(RelayParser.T__7)
- self.state = 426
- self.shapeList()
- self.state = 427
- self.match(RelayParser.T__4)
- self.state = 428
- self.typeExpr()
- self.state = 429
- self.match(RelayParser.T__8)
- pass
-
- elif la_ == 8:
- localctx = RelayParser.FuncTypeContext(self, localctx)
- self.enterOuterAlt(localctx, 8)
- self.state = 431
- self.match(RelayParser.T__17)
- self.state = 433
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- if _la==RelayParser.T__7:
- self.state = 432
- self.typeParamList()
-
-
- self.state = 435
- self.match(RelayParser.T__5)
- self.state = 444
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__3) | (1 << RelayParser.T__5) | (1 << RelayParser.T__17) | (1 << RelayParser.T__26) | (1 << RelayParser.CNAME))) != 0):
- self.state = 436
- self.typeExpr()
- self.state = 441
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- while _la==RelayParser.T__4:
- self.state = 437
- self.match(RelayParser.T__4)
- self.state = 438
- self.typeExpr()
- self.state = 443
- self._errHandler.sync(self)
- _la = self._input.LA(1)
-
-
-
- self.state = 446
- self.match(RelayParser.T__6)
- self.state = 447
- self.match(RelayParser.T__18)
- self.state = 448
- self.typeExpr()
- pass
-
- elif la_ == 9:
- localctx = RelayParser.IncompleteTypeContext(self, localctx)
- self.enterOuterAlt(localctx, 9)
- self.state = 449
- self.match(RelayParser.T__3)
- pass
-
-
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class TypeParamListContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
- def typeExpr(self, i:int=None):
- if i is None:
- return self.getTypedRuleContexts(RelayParser.TypeExprContext)
- else:
- return self.getTypedRuleContext(RelayParser.TypeExprContext,i)
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_typeParamList
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitTypeParamList" ):
- return visitor.visitTypeParamList(self)
- else:
- return visitor.visitChildren(self)
-
-
-
-
- def typeParamList(self):
-
- localctx = RelayParser.TypeParamListContext(self, self._ctx, self.state)
- self.enterRule(localctx, 54, self.RULE_typeParamList)
- self._la = 0 # Token type
- try:
- self.enterOuterAlt(localctx, 1)
- self.state = 452
- self.match(RelayParser.T__7)
- self.state = 453
- self.typeExpr()
- self.state = 458
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- while _la==RelayParser.T__4:
- self.state = 454
- self.match(RelayParser.T__4)
- self.state = 455
- self.typeExpr()
- self.state = 460
- self._errHandler.sync(self)
- _la = self._input.LA(1)
-
- self.state = 461
- self.match(RelayParser.T__8)
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class ShapeListContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
- def shape(self, i:int=None):
- if i is None:
- return self.getTypedRuleContexts(RelayParser.ShapeContext)
- else:
- return self.getTypedRuleContext(RelayParser.ShapeContext,i)
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_shapeList
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitShapeList" ):
- return visitor.visitShapeList(self)
- else:
- return visitor.visitChildren(self)
-
-
-
-
- def shapeList(self):
-
- localctx = RelayParser.ShapeListContext(self, self._ctx, self.state)
- self.enterRule(localctx, 56, self.RULE_shapeList)
- self._la = 0 # Token type
- try:
- self.state = 476
- self._errHandler.sync(self)
- la_ = self._interp.adaptivePredict(self._input,50,self._ctx)
- if la_ == 1:
- self.enterOuterAlt(localctx, 1)
- self.state = 463
- self.match(RelayParser.T__5)
- self.state = 464
- self.match(RelayParser.T__6)
- pass
-
- elif la_ == 2:
- self.enterOuterAlt(localctx, 2)
- self.state = 465
- self.match(RelayParser.T__5)
- self.state = 466
- self.shape()
- self.state = 469
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- while True:
- self.state = 467
- self.match(RelayParser.T__4)
- self.state = 468
- self.shape()
- self.state = 471
- self._errHandler.sync(self)
- _la = self._input.LA(1)
- if not (_la==RelayParser.T__4):
- break
-
- self.state = 473
- self.match(RelayParser.T__6)
- pass
-
- elif la_ == 3:
- self.enterOuterAlt(localctx, 3)
- self.state = 475
- self.shape()
- pass
-
-
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class MetaContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
- def CNAME(self):
- return self.getToken(RelayParser.CNAME, 0)
-
- def NAT(self):
- return self.getToken(RelayParser.NAT, 0)
-
- def getRuleIndex(self):
- return RelayParser.RULE_meta
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitMeta" ):
- return visitor.visitMeta(self)
- else:
- return visitor.visitChildren(self)
-
-
-
-
- def meta(self):
-
- localctx = RelayParser.MetaContext(self, self._ctx, self.state)
- self.enterRule(localctx, 58, self.RULE_meta)
- try:
- self.enterOuterAlt(localctx, 1)
- self.state = 478
- self.match(RelayParser.T__27)
- self.state = 479
- self.match(RelayParser.T__7)
- self.state = 480
- self.match(RelayParser.CNAME)
- self.state = 481
- self.match(RelayParser.T__8)
- self.state = 482
- self.match(RelayParser.T__7)
- self.state = 483
- self.match(RelayParser.NAT)
- self.state = 484
- self.match(RelayParser.T__8)
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class ShapeContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_shape
-
-
- def copyFrom(self, ctx:ParserRuleContext):
- super().copyFrom(ctx)
-
-
-
- class ParensShapeContext(ShapeContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ShapeContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def shape(self):
- return self.getTypedRuleContext(RelayParser.ShapeContext,0)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitParensShape" ):
- return visitor.visitParensShape(self)
- else:
- return visitor.visitChildren(self)
-
-
- class MetaShapeContext(ShapeContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ShapeContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def meta(self):
- return self.getTypedRuleContext(RelayParser.MetaContext,0)
-
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitMetaShape" ):
- return visitor.visitMetaShape(self)
- else:
- return visitor.visitChildren(self)
-
-
- class IntShapeContext(ShapeContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ShapeContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def NAT(self):
- return self.getToken(RelayParser.NAT, 0)
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitIntShape" ):
- return visitor.visitIntShape(self)
- else:
- return visitor.visitChildren(self)
-
-
-
- def shape(self):
-
- localctx = RelayParser.ShapeContext(self, self._ctx, self.state)
- self.enterRule(localctx, 60, self.RULE_shape)
- try:
- self.state = 492
- self._errHandler.sync(self)
- token = self._input.LA(1)
- if token in [RelayParser.T__27]:
- localctx = RelayParser.MetaShapeContext(self, localctx)
- self.enterOuterAlt(localctx, 1)
- self.state = 486
- self.meta()
- pass
- elif token in [RelayParser.T__5]:
- localctx = RelayParser.ParensShapeContext(self, localctx)
- self.enterOuterAlt(localctx, 2)
- self.state = 487
- self.match(RelayParser.T__5)
- self.state = 488
- self.shape()
- self.state = 489
- self.match(RelayParser.T__6)
- pass
- elif token in [RelayParser.NAT]:
- localctx = RelayParser.IntShapeContext(self, localctx)
- self.enterOuterAlt(localctx, 3)
- self.state = 491
- self.match(RelayParser.NAT)
- pass
- else:
- raise NoViableAltException(self)
-
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class BodyContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
- def expr(self):
- return self.getTypedRuleContext(RelayParser.ExprContext,0)
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_body
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitBody" ):
- return visitor.visitBody(self)
- else:
- return visitor.visitChildren(self)
-
-
-
-
- def body(self):
-
- localctx = RelayParser.BodyContext(self, self._ctx, self.state)
- self.enterRule(localctx, 62, self.RULE_body)
- try:
- self.enterOuterAlt(localctx, 1)
- self.state = 494
- self.match(RelayParser.T__11)
- self.state = 495
- self.expr(0)
- self.state = 496
- self.match(RelayParser.T__12)
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class ScalarContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_scalar
-
-
- def copyFrom(self, ctx:ParserRuleContext):
- super().copyFrom(ctx)
-
-
-
- class ScalarFloatContext(ScalarContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ScalarContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def FLOAT(self):
- return self.getToken(RelayParser.FLOAT, 0)
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitScalarFloat" ):
- return visitor.visitScalarFloat(self)
- else:
- return visitor.visitChildren(self)
-
-
- class ScalarBoolContext(ScalarContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ScalarContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def BOOL_LIT(self):
- return self.getToken(RelayParser.BOOL_LIT, 0)
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitScalarBool" ):
- return visitor.visitScalarBool(self)
- else:
- return visitor.visitChildren(self)
-
-
- class ScalarIntContext(ScalarContext):
-
- def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ScalarContext
- super().__init__(parser)
- self.copyFrom(ctx)
-
- def NAT(self):
- return self.getToken(RelayParser.NAT, 0)
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitScalarInt" ):
- return visitor.visitScalarInt(self)
- else:
- return visitor.visitChildren(self)
-
-
-
- def scalar(self):
-
- localctx = RelayParser.ScalarContext(self, self._ctx, self.state)
- self.enterRule(localctx, 64, self.RULE_scalar)
- try:
- self.state = 501
- self._errHandler.sync(self)
- token = self._input.LA(1)
- if token in [RelayParser.FLOAT]:
- localctx = RelayParser.ScalarFloatContext(self, localctx)
- self.enterOuterAlt(localctx, 1)
- self.state = 498
- self.match(RelayParser.FLOAT)
- pass
- elif token in [RelayParser.NAT]:
- localctx = RelayParser.ScalarIntContext(self, localctx)
- self.enterOuterAlt(localctx, 2)
- self.state = 499
- self.match(RelayParser.NAT)
- pass
- elif token in [RelayParser.BOOL_LIT]:
- localctx = RelayParser.ScalarBoolContext(self, localctx)
- self.enterOuterAlt(localctx, 3)
- self.state = 500
- self.match(RelayParser.BOOL_LIT)
- pass
- else:
- raise NoViableAltException(self)
-
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
- class IdentContext(ParserRuleContext):
-
- def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1):
- super().__init__(parent, invokingState)
- self.parser = parser
-
- def generalIdent(self):
- return self.getTypedRuleContext(RelayParser.GeneralIdentContext,0)
-
-
- def globalVar(self):
- return self.getTypedRuleContext(RelayParser.GlobalVarContext,0)
-
-
- def localVar(self):
- return self.getTypedRuleContext(RelayParser.LocalVarContext,0)
-
-
- def graphVar(self):
- return self.getTypedRuleContext(RelayParser.GraphVarContext,0)
-
-
- def getRuleIndex(self):
- return RelayParser.RULE_ident
-
- def accept(self, visitor:ParseTreeVisitor):
- if hasattr( visitor, "visitIdent" ):
- return visitor.visitIdent(self)
- else:
- return visitor.visitChildren(self)
-
-
-
-
- def ident(self):
-
- localctx = RelayParser.IdentContext(self, self._ctx, self.state)
- self.enterRule(localctx, 66, self.RULE_ident)
- try:
- self.state = 507
- self._errHandler.sync(self)
- la_ = self._interp.adaptivePredict(self._input,53,self._ctx)
- if la_ == 1:
- self.enterOuterAlt(localctx, 1)
- self.state = 503
- self.generalIdent()
- pass
-
- elif la_ == 2:
- self.enterOuterAlt(localctx, 2)
- self.state = 504
- self.globalVar()
- pass
-
- elif la_ == 3:
- self.enterOuterAlt(localctx, 3)
- self.state = 505
- self.localVar()
- pass
-
- elif la_ == 4:
- self.enterOuterAlt(localctx, 4)
- self.state = 506
- self.graphVar()
- pass
-
-
- except RecognitionException as re:
- localctx.exception = re
- self._errHandler.reportError(self, re)
- self._errHandler.recover(self, re)
- finally:
- self.exitRule()
- return localctx
-
-
-
- def sempred(self, localctx:RuleContext, ruleIndex:int, predIndex:int):
- if self._predicates == None:
- self._predicates = dict()
- self._predicates[7] = self.expr_sempred
- pred = self._predicates.get(ruleIndex, None)
- if pred is None:
- raise Exception("No predicate with index:" + str(ruleIndex))
- else:
- return pred(localctx, predIndex)
-
- def expr_sempred(self, localctx:ExprContext, predIndex:int):
- if predIndex == 0:
- return self.precpred(self._ctx, 19)
-
-
- if predIndex == 1:
- return self.precpred(self._ctx, 18)
-
-
- if predIndex == 2:
- return self.precpred(self._ctx, 17)
-
-
- if predIndex == 3:
- return self.precpred(self._ctx, 16)
-
-
- if predIndex == 4:
- return self.precpred(self._ctx, 6)
-
-
- if predIndex == 5:
- return self.precpred(self._ctx, 21)
-
-
- if predIndex == 6:
- return self.precpred(self._ctx, 8)
-
-
-
-
-
+++ /dev/null
-# Generated from /Users/doobs/Code/repo/sampl/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2
-from antlr4 import *
-if __name__ is not None and "." in __name__:
- from .RelayParser import RelayParser
-else:
- from RelayParser import RelayParser
-
-# This class defines a complete generic visitor for a parse tree produced by RelayParser.
-
-class RelayVisitor(ParseTreeVisitor):
-
- # Visit a parse tree produced by RelayParser#prog.
- def visitProg(self, ctx:RelayParser.ProgContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#generalIdent.
- def visitGeneralIdent(self, ctx:RelayParser.GeneralIdentContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#globalVar.
- def visitGlobalVar(self, ctx:RelayParser.GlobalVarContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#localVar.
- def visitLocalVar(self, ctx:RelayParser.LocalVarContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#graphVar.
- def visitGraphVar(self, ctx:RelayParser.GraphVarContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#exprList.
- def visitExprList(self, ctx:RelayParser.ExprListContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#callNoAttr.
- def visitCallNoAttr(self, ctx:RelayParser.CallNoAttrContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#callWithAttr.
- def visitCallWithAttr(self, ctx:RelayParser.CallWithAttrContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#funcExpr.
- def visitFuncExpr(self, ctx:RelayParser.FuncExprContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#metaExpr.
- def visitMetaExpr(self, ctx:RelayParser.MetaExprContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#match.
- def visitMatch(self, ctx:RelayParser.MatchContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#tensor.
- def visitTensor(self, ctx:RelayParser.TensorContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#graph.
- def visitGraph(self, ctx:RelayParser.GraphContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#identExpr.
- def visitIdentExpr(self, ctx:RelayParser.IdentExprContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#stringExpr.
- def visitStringExpr(self, ctx:RelayParser.StringExprContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#call.
- def visitCall(self, ctx:RelayParser.CallContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#neg.
- def visitNeg(self, ctx:RelayParser.NegContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#tuple.
- def visitTuple(self, ctx:RelayParser.TupleContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#paren.
- def visitParen(self, ctx:RelayParser.ParenContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#scalarExpr.
- def visitScalarExpr(self, ctx:RelayParser.ScalarExprContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#let.
- def visitLet(self, ctx:RelayParser.LetContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#projection.
- def visitProjection(self, ctx:RelayParser.ProjectionContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#ifElse.
- def visitIfElse(self, ctx:RelayParser.IfElseContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#binOp.
- def visitBinOp(self, ctx:RelayParser.BinOpContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#func.
- def visitFunc(self, ctx:RelayParser.FuncContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#funcDefn.
- def visitFuncDefn(self, ctx:RelayParser.FuncDefnContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#externAdtDefn.
- def visitExternAdtDefn(self, ctx:RelayParser.ExternAdtDefnContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#adtDefn.
- def visitAdtDefn(self, ctx:RelayParser.AdtDefnContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#constructorName.
- def visitConstructorName(self, ctx:RelayParser.ConstructorNameContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#adtConsDefnList.
- def visitAdtConsDefnList(self, ctx:RelayParser.AdtConsDefnListContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#adtConsDefn.
- def visitAdtConsDefn(self, ctx:RelayParser.AdtConsDefnContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#matchClauseList.
- def visitMatchClauseList(self, ctx:RelayParser.MatchClauseListContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#matchClause.
- def visitMatchClause(self, ctx:RelayParser.MatchClauseContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#matchType.
- def visitMatchType(self, ctx:RelayParser.MatchTypeContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#patternList.
- def visitPatternList(self, ctx:RelayParser.PatternListContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#wildcardPattern.
- def visitWildcardPattern(self, ctx:RelayParser.WildcardPatternContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#varPattern.
- def visitVarPattern(self, ctx:RelayParser.VarPatternContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#constructorPattern.
- def visitConstructorPattern(self, ctx:RelayParser.ConstructorPatternContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#tuplePattern.
- def visitTuplePattern(self, ctx:RelayParser.TuplePatternContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#adtCons.
- def visitAdtCons(self, ctx:RelayParser.AdtConsContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#adtConsParamList.
- def visitAdtConsParamList(self, ctx:RelayParser.AdtConsParamListContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#adtConsParam.
- def visitAdtConsParam(self, ctx:RelayParser.AdtConsParamContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#argNoAttr.
- def visitArgNoAttr(self, ctx:RelayParser.ArgNoAttrContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#argWithAttr.
- def visitArgWithAttr(self, ctx:RelayParser.ArgWithAttrContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#varList.
- def visitVarList(self, ctx:RelayParser.VarListContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#var.
- def visitVar(self, ctx:RelayParser.VarContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#attrSeq.
- def visitAttrSeq(self, ctx:RelayParser.AttrSeqContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#attr.
- def visitAttr(self, ctx:RelayParser.AttrContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#tupleType.
- def visitTupleType(self, ctx:RelayParser.TupleTypeContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#typeParen.
- def visitTypeParen(self, ctx:RelayParser.TypeParenContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#typeCallType.
- def visitTypeCallType(self, ctx:RelayParser.TypeCallTypeContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#typeIdentType.
- def visitTypeIdentType(self, ctx:RelayParser.TypeIdentTypeContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#tensorType.
- def visitTensorType(self, ctx:RelayParser.TensorTypeContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#funcType.
- def visitFuncType(self, ctx:RelayParser.FuncTypeContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#incompleteType.
- def visitIncompleteType(self, ctx:RelayParser.IncompleteTypeContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#typeParamList.
- def visitTypeParamList(self, ctx:RelayParser.TypeParamListContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#shapeList.
- def visitShapeList(self, ctx:RelayParser.ShapeListContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#meta.
- def visitMeta(self, ctx:RelayParser.MetaContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#metaShape.
- def visitMetaShape(self, ctx:RelayParser.MetaShapeContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#parensShape.
- def visitParensShape(self, ctx:RelayParser.ParensShapeContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#intShape.
- def visitIntShape(self, ctx:RelayParser.IntShapeContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#body.
- def visitBody(self, ctx:RelayParser.BodyContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#scalarFloat.
- def visitScalarFloat(self, ctx:RelayParser.ScalarFloatContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#scalarInt.
- def visitScalarInt(self, ctx:RelayParser.ScalarIntContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#scalarBool.
- def visitScalarBool(self, ctx:RelayParser.ScalarBoolContext):
- return self.visitChildren(ctx)
-
-
- # Visit a parse tree produced by RelayParser#ident.
- def visitIdent(self, ctx:RelayParser.IdentContext):
- return self.visitChildren(ctx)
-
-
-
-del RelayParser
\ No newline at end of file
+++ /dev/null
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
+++ /dev/null
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""A parser for Relay's text format."""
-from __future__ import absolute_import
-from .. import register_func
-
-
-@register_func("relay.fromtext")
-def fromtext(data, source_name=None):
- """Parse a Relay program."""
- # pylint: disable=import-outside-toplevel
- from tvm.relay import _parser
- x = _parser.fromtext(data + "\n", source_name)
- if x is None:
- raise Exception("cannot parse: ", data)
- return x
* specific language governing permissions and limitations
* under the License.
*/
-v0.0.4
+
+#[version = "0.0.5"]
extern type Storage
* specific language governing permissions and limitations
* under the License.
*/
-v0.0.4
+
+#[version = "0.0.5"]
/*
* Store the Gradient Value of a Tensor of type T.
* specific language governing permissions and limitations
* under the License.
*/
-v0.0.4
-
-// TODO(weberlo): should we add sugar for scalar types (e.g., `int32` => `Tensor[(), int32]`)?
+#[version = "0.0.5"]
def @id[A](%x: A) -> A {
%x
* Takes a number n and a function f; returns a closure that takes an argument
* and applies f n times to its argument.
*/
-def @iterate[A](%f: fn(A) -> A, %n: Tensor[(), int32]) -> (fn(A) -> A) {
+def @iterate[A](%f: fn(A) -> A, %n: Tensor[(), int32]) -> fn(A) -> A {
if (%n == 0) {
@id
} else {
// and are only used in minimum cases where they are clearly marked.
//
// Rationale: We calls into relay's analysis module to verify correctness.
+#include <tvm/parser/parser.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
std::unordered_set<String> IRModuleNode::Imports() const { return this->import_set_; }
IRModule IRModule::FromText(const String& text, const String& source_path) {
- auto* f = tvm::runtime::Registry::Get("relay.fromtext");
- CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
- IRModule mod = (*f)(text, source_path);
- return mod;
+ return tvm::parser::ParseModule(source_path, text);
}
TVM_REGISTER_NODE_TYPE(IRModuleNode);
return static_cast<const SourceNameNode*>(n)->name;
});
-Span::Span(SourceName source, int lineno, int col_offset) {
+Span::Span(SourceName source_name, int line, int end_line, int column, int end_column) {
auto n = make_object<SpanNode>();
- n->source = std::move(source);
- n->line = lineno;
- n->column = col_offset;
+ n->source_name = std::move(source_name);
+ n->line = line;
+ n->end_line = end_line;
+ n->column = column;
+ n->end_column = end_column;
data_ = std::move(n);
}
+Span Span::Merge(const Span& other) {
+ CHECK((*this)->source_name == other->source_name);
+ return Span((*this)->source_name, std::min((*this)->line, other->line),
+ std::max((*this)->end_line, other->end_line),
+ std::min((*this)->column, other->column),
+ std::max((*this)->end_column, other->end_column));
+}
+
TVM_REGISTER_NODE_TYPE(SpanNode);
-TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source, int lineno, int col_offset) {
- return Span(source, lineno, col_offset);
+TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source_name, int line, int end_line,
+ int column, int end_column) {
+ return Span(source_name, line, end_line, column, end_column);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SpanNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const SpanNode*>(ref.get());
- p->stream << "Span(" << node->source << ", " << node->line << ", " << node->column << ")";
+ p->stream << "Span(" << node->source_name << ", " << node->line << ", " << node->end_line
+ << ", " << node->column << ", " << node->end_column << ")";
});
} // namespace tvm
#define TVM_PARSER_DIAGNOSTIC_H_
#include <tvm/ir/span.h>
+#include <tvm/parser/source_map.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/object.h>
namespace tvm {
namespace parser {
-/*! \brief A program source in any language.
- *
- * Could represent the source from an ML framework or the internal
- * source of a TVM program.
- */
-struct Source {
- /*! \brief The raw source. */
- std::string source;
- /*! \brief A mapping of line breaks into the raw source. */
- std::vector<std::pair<int, int>> line_map;
-
- /*! \brief An empty source. */
- Source() : source(), line_map() {}
-
- /*! \brief Construct a source from a string. */
- explicit Source(const std::string& source) : source(source) {
- int index = 0;
- int length = 0;
- line_map.push_back({index, length});
- for (auto c : source) {
- if (c == '\n') {
- // Record the length of the line.
- line_map.back().second = length;
- // Bump past the newline.
- index += 1;
- // Record the start of the next line, and put placeholder for length.
- line_map.push_back({index, 0});
- // Reset length to zero.
- length = 0;
- } else {
- length += 1;
- index += 1;
- }
- }
- line_map.back().second = length;
- }
-
- Source(const Source& source) : source(source.source), line_map(source.line_map) {}
-
- /*! \brief Generate an error message at a specific line and column with the
- * annotated message.
- *
- * The error is written directly to the `out` std::ostream.
- *
- * \param out The output ostream.
- * \param line The line at which to report a diagnostic.
- * \param line The column at which to report a diagnostic.
- * \param msg The message to attach.
- */
- void ReportAt(std::ostream& out, int line, int column, const std::string& msg) const {
- CHECK(line - 1 <= static_cast<int64_t>(line_map.size()))
- << "requested line: " << (line - 1) << "line_map size: " << line_map.size()
- << "source: " << source;
-
- // Adjust for zero indexing, now have (line_start, line_length);
- auto range = line_map.at(line - 1);
- int line_start = range.first;
- int line_length = range.second;
- out << "file:" << line << ":" << column << ": parse error: " << msg << std::endl;
- out << " " << source.substr(line_start, line_length) << std::endl;
- out << " ";
- std::stringstream marker;
- for (int i = 1; i <= line_length; i++) {
- if (i == column) {
- marker << "^";
- } else if ((column - i) < 3) {
- marker << "~";
- } else if ((i - column) < 3) {
- marker << "~";
- } else {
- marker << " ";
- }
- }
- out << marker.str();
- out << std::endl;
- }
-};
-
/*! \brief The diagnostic level, controls the printing of the message. */
-enum DiagnosticLevel {
- Bug,
- Error,
- Warning,
- Note,
- Help,
+enum class DiagnosticLevel {
+ kBug,
+ kError,
+ kWarning,
+ kNote,
+ kHelp,
};
+struct DiagnosticBuilder;
+
/*! \brief A diagnostic message. */
struct Diagnostic {
/*! \brief The level. */
/*! \brief The diagnostic message. */
std::string message;
- Diagnostic(int line, int column, const std::string& message)
- : level(DiagnosticLevel::Error), span(SourceName(), line, column), message(message) {}
+ Diagnostic(DiagnosticLevel level, Span span, const std::string& message)
+ : level(level), span(span), message(message) {}
+
+ static DiagnosticBuilder Bug(Span span);
+ static DiagnosticBuilder Error(Span span);
+ static DiagnosticBuilder Warning(Span span);
+ static DiagnosticBuilder Note(Span span);
+ static DiagnosticBuilder Help(Span span);
};
+/*!
+ * \brief A wrapper around std::stringstream to build a diagnostic.
+ *
+ * \code
+ *
+ * void ReportError(const Error& err);
+ *
+ * void Test(int number) {
+ * // Use error reporter to construct an error.
+ * ReportError(ErrorBuilder() << "This is an error number=" << number);
+ * }
+ *
+ * \endcode
+ */
+struct DiagnosticBuilder {
+ public:
+ /*! \brief The level. */
+ DiagnosticLevel level;
+
+ /*! \brief The source name. */
+ SourceName source_name;
+
+ /*! \brief The span of the diagnostic. */
+ Span span;
+
+ template <typename T>
+ DiagnosticBuilder& operator<<(const T& val) { // NOLINT(*)
+ stream_ << val;
+ return *this;
+ }
+
+ DiagnosticBuilder() : level(DiagnosticLevel::kError), source_name(), span(Span()) {}
+
+ DiagnosticBuilder(const DiagnosticBuilder& builder)
+ : level(builder.level), source_name(builder.source_name), span(builder.span) {}
+
+ DiagnosticBuilder(DiagnosticLevel level, Span span) : level(level), span(span) {}
+
+ operator Diagnostic() { return Diagnostic(this->level, this->span, this->stream_.str()); }
+
+ private:
+ std::stringstream stream_;
+ friend struct Diagnostic;
+};
+
+DiagnosticBuilder Diagnostic::Bug(Span span) {
+ return DiagnosticBuilder(DiagnosticLevel::kBug, span);
+}
+
+DiagnosticBuilder Diagnostic::Error(Span span) {
+ return DiagnosticBuilder(DiagnosticLevel::kError, span);
+}
+
+DiagnosticBuilder Diagnostic::Warning(Span span) {
+ return DiagnosticBuilder(DiagnosticLevel::kWarning, span);
+}
+
+DiagnosticBuilder Diagnostic::Note(Span span) {
+ return DiagnosticBuilder(DiagnosticLevel::kNote, span);
+}
+
+DiagnosticBuilder Diagnostic::Help(Span span) {
+ return DiagnosticBuilder(DiagnosticLevel::kHelp, span);
+}
+
/*! \brief A diagnostic context for recording errors against a source file.
* TODO(@jroesch): convert source map and improve in follow up PR, the parser
* assumes a single global file for now.
/*! \brief Emit a diagnostic. */
void Emit(const Diagnostic& diagnostic) { diagnostics.push_back(diagnostic); }
+ /*! \brief Emit a diagnostic. */
+ void EmitFatal(const Diagnostic& diagnostic) {
+ diagnostics.push_back(diagnostic);
+ Render(std::cout);
+ }
+
// TODO(@jroesch): eventually modularize the rendering interface to provide control of how to
// format errors.
void Render(std::ostream& ostream) {
for (auto diagnostic : diagnostics) {
- source.ReportAt(ostream, diagnostic.span->line, diagnostic.span->column, diagnostic.message);
+ source.ReportAt(ostream, diagnostic.span, diagnostic.message);
}
if (diagnostics.size()) {
- LOG(FATAL) << "parse error occured";
+ LOG(FATAL) << "DiagnosticError: one or more error diagnostics were "
+ << "emitted, please check diagnostic render for output.";
}
}
};
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/parser/meta_ref.cc
+ * \brief An operator which allows forward referencing a yet-to-be parsed meta table reference.
+ */
+
+#include "./meta_ref.h"
+
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace parser {
+
+using tvm::relay::transform::CreateFunctionPass;
+using tvm::transform::PassContext;
+
+/* Set to arbitrary high number, since we should never schedule in normal pass manager flow. */
+static int kMetaExpandOptLevel = 1337;
+
+TVM_REGISTER_NODE_TYPE(MetaRefAttrs);
+
+bool MetaRefRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter) {
+ LOG(FATAL) << "need to expand before type checking";
+ return true;
+}
+
+RELAY_REGISTER_OP("parser.MetaRef")
+ .describe(R"code(A reference into the meta table.)code" TVM_ADD_FILELINE)
+ .set_attrs_type<MetaRefAttrs>()
+ .set_num_inputs(0)
+ .set_support_level(10)
+ .add_type_rel("MetaRef", MetaRefRel)
+ .set_attr<TOpIsStateful>("TOpIsStateful", false)
+ .set_attr<TNonComputational>("TNonComputational", true);
+
+Expr MetaRef(std::string type_key, uint64_t node_index) {
+ static const Op& op = Op::Get("parser.MetaRef");
+ auto attrs = make_object<MetaRefAttrs>();
+ attrs->node_type_key = tvm::String(type_key);
+ attrs->node_index = node_index;
+ return Call(op, {}, Attrs(attrs), {});
+}
+
+struct MetaRefExpander : public ExprMutator {
+ MetaTable table;
+
+ explicit MetaRefExpander(const MetaTable& table) : table(table) {}
+
+ Expr VisitExpr_(const CallNode* call) final {
+ if (auto op_node = call->op.as<OpNode>()) {
+ if (op_node->name == "parser.MetaRef") {
+ auto meta_attrs = call->attrs.as<MetaRefAttrs>();
+ CHECK(meta_attrs) << "an internal error has occurred";
+ auto nodes = table.at(meta_attrs->node_type_key);
+ CHECK_LT(meta_attrs->node_index, nodes.size());
+ return Downcast<Expr>(nodes[meta_attrs->node_index]);
+ }
+ }
+
+ return ExprMutator::VisitExpr_(call);
+ }
+};
+
+Function ExpandMetaRefs(const MetaTable& meta_table, const relay::Function& func) {
+ MetaRefExpander expander(meta_table);
+ return Downcast<Function>(expander.VisitExpr(func));
+}
+
+IRModule ExpandMetaRefs(const MetaTable& meta_table, const IRModule& mod) {
+ auto pass = CreateFunctionPass([&](Function func, IRModule module,
+ PassContext ctx) { return ExpandMetaRefs(meta_table, func); },
+ kMetaExpandOptLevel, "ExpandMetaRefs", {});
+
+ return pass(mod, PassContext::Create());
+}
+
+} // namespace parser
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file meta_ref.h
+ * \brief A reference into the metadata section of the Relay text format.
+ */
+
+#ifndef TVM_PARSER_META_REF_H_
+#define TVM_PARSER_META_REF_H_
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/function.h>
+
+#include <string>
+
+namespace tvm {
+namespace parser {
+
+using namespace relay;
+
+using MetaTable = Map<String, Array<ObjectRef>>;
+
+/*!
+ * \brief Options for allocating storage.
+ */
+struct MetaRefAttrs : public tvm::AttrsNode<MetaRefAttrs> {
+ tvm::String node_type_key;
+ uint64_t node_index;
+
+ TVM_DECLARE_ATTRS(MetaRefAttrs, "relay.attrs.MetaRefAttrs") {
+ TVM_ATTR_FIELD(node_type_key)
+ .describe("The type_key representing the type of the node referenced.");
+ TVM_ATTR_FIELD(node_index).describe("The index into the type specific node array.");
+ }
+};
+
+/*! \brief A reference to a "meta-expression".
+ *
+ * In the text format we allow referencing metadata which
+ * uses a compact serialization that proceeds the main
+ * program body.
+ *
+ * We can reference this table using an expression of
+ * the form `meta[Type][index]`.
+ *
+ * We must later resolve these references to actual in-memory
+ * AST nodes but this requires first parsing the full program
+ * then expanding these temporary AST nodes into their corresponding
+ * nodes.
+ *
+ * For example the nth large constant will be pretty-printed as meta[relay.Constant][n]
+ * with its compact binary serialization residing in the metadata section at the end
+ * of the program.
+ *
+ * \param type_key The type key of the object in the meta section.
+ * \param node_index The index into that subfield.
+ * \returns The meta table reference.
+ */
+Expr MetaRef(std::string type_key, uint64_t node_index);
+
+relay::Function ExpandMetaRefs(const MetaTable& meta_table, const relay::Function& func);
+IRModule ExpandMetaRefs(const MetaTable& meta_table, const IRModule& mod);
+
+} // namespace parser
+} // namespace tvm
+
+#endif // TVM_PARSER_META_REF_H_
OperatorTable DefaultOpTable() {
return OperatorTable(
- {Rule({TokenType::Star}, Op::Get("multiply"), 12, 2, true),
- Rule({TokenType::Division}, Op::Get("divide"), 12, 2, true),
- Rule({TokenType::Plus}, Op::Get("add"), 10, 2, true),
- Rule({TokenType::Minus}, Op::Get("subtract"), 10, 2, true),
- Rule({TokenType::LAngle}, Op::Get("less"), 8, 2, true),
- Rule({TokenType::LAngle, TokenType::Equal}, Op::Get("less_equal"), 8, 2, true),
- Rule({TokenType::RAngle}, Op::Get("greater"), 8, 2, true),
- Rule({TokenType::RAngle, TokenType::Equal}, Op::Get("greater_equal"), 8, 2, true),
- Rule({TokenType::Equal, TokenType::Equal}, Op::Get("equal"), 7, 2, true),
- Rule({TokenType::Bang, TokenType::Equal}, Op::Get("not_equal"), 7, 2, true)});
+ {Rule({TokenType::kStar}, Op::Get("multiply"), 12, 2, true),
+ Rule({TokenType::kDivision}, Op::Get("divide"), 12, 2, true),
+ Rule({TokenType::kPlus}, Op::Get("add"), 10, 2, true),
+ Rule({TokenType::kMinus}, Op::Get("subtract"), 10, 2, true),
+ Rule({TokenType::kLAngle}, Op::Get("less"), 8, 2, true),
+ Rule({TokenType::kLAngle, TokenType::kEqual}, Op::Get("less_equal"), 8, 2, true),
+ Rule({TokenType::kRAngle}, Op::Get("greater"), 8, 2, true),
+ Rule({TokenType::kRAngle, TokenType::kEqual}, Op::Get("greater_equal"), 8, 2, true),
+ Rule({TokenType::kEqual, TokenType::kEqual}, Op::Get("equal"), 7, 2, true),
+ Rule({TokenType::kBang, TokenType::kEqual}, Op::Get("not_equal"), 7, 2, true)});
}
} // namespace parser
#include <fstream>
#include "./diagnostic.h"
+#include "./meta_ref.h"
#include "./op_table.h"
#include "./tokenizer.h"
using namespace relay;
using Expr = relay::Expr;
+/*! \brief The meta table maps from type key to a sequence of objects. */
+using MetaTable = Map<String, Array<ObjectRef>>;
+
/*! \brief A wrapper structure for capturing the result of parsing
* a global definition *before* we add it to the IRModule.
*
patch_version(other.patch_version) {}
};
-/*! \brief A reference to a "meta-expression".
- *
- * In the text format we allow referencing metadata which
- * uses a compact serialization that proceeds the main
- * program body.
- *
- * We can reference this table using an expression of
- * the form `meta[Type][index]`.
- *
- * We must later resolve these references to actual in-memory
- * AST nodes but this requires first parsing the full program
- * then expanding these temporary AST nodes into their corresponding
- * nodes.
- *
- * For example the nth large constant will be pretty-printed as meta[relay.Constant][n]
- * with its compact binary serialization residing in the metadata section at the end
- * of the program.
- */
-class MetaRefExprNode : public TempExprNode {
- public:
- /*! \brief The type key of the meta expression. */
- std::string type_key;
- /*! \brief The index into the type key's table. */
- uint64_t node_index;
-
- void VisitAttrs(tvm::AttrVisitor* v) {}
-
- // TODO(@jroesch): we probably will need to manually
- // expand these with a pass.
- Expr Realize() const final { return Expr(); }
-
- static constexpr const char* _type_key = "relay.MetaRefExpr";
- TVM_DECLARE_FINAL_OBJECT_INFO(MetaRefExprNode, TempExprNode);
-};
-
-class MetaRefExpr : public TempExpr {
- public:
- /*!
- * \brief The constructor for MetaRefExpr
- * \param type_key The type key of the object in the meta section.
- * \param kind The index into that subfield.
- */
- TVM_DLL MetaRefExpr(std::string type_key, uint64_t node_index);
-
- TVM_DEFINE_OBJECT_REF_METHODS(MetaRefExpr, TempExpr, MetaRefExprNode);
-};
-
-MetaRefExpr::MetaRefExpr(std::string type_key, uint64_t node_index) {
- auto rnode = make_object<MetaRefExprNode>();
- rnode->type_key = type_key;
- rnode->node_index = node_index;
- data_ = std::move(rnode);
-}
-
/*! \brief A simple wrapper around a mapping from raw string names
* to a TVM variable, type variable or other binder type.
*/
class ScopeStack {
private:
std::vector<Scope<T>> scope_stack;
+ std::unordered_map<std::string, T> free_vars;
public:
/*! \brief Adds a variable binding to the current scope. */
this->scope_stack.back().name_map.insert({name, value});
}
+ void AddFreeVar(const std::string& name, const T& value) { free_vars.insert({name, value}); }
+
/*! \brief Looks up a variable name in the scope stack returning the matching variable
* in most recent scope. */
T Lookup(const std::string& name) {
return it->second;
}
}
+
+ // Check if we bound a free variable declaration.
+ auto it = free_vars.find(name);
+ if (it != free_vars.end()) {
+ return it->second;
+ }
+
return T();
}
void PopStack() { this->scope_stack.pop_back(); }
};
+struct DuplicateKeyError : public dmlc::Error {
+ explicit DuplicateKeyError(const std::string& msg) : dmlc::Error(msg) {}
+};
+
/*! \brief A table of interning strings as global function and type names. */
template <typename T>
struct InternTable {
/*! \brief The internal table mapping strings to a unique allocation. */
std::unordered_map<std::string, T> table;
+ DiagnosticContext* ctx;
/*! \brief Add the unique allocation. */
void Add(const std::string& name, const T& t) {
auto it = table.find(name);
if (it != table.end()) {
- LOG(FATAL) << "duplicate name";
+ throw DuplicateKeyError("duplicate key name in intern table");
} else {
table.insert({name, t});
}
SemVer version;
/*! \brief The diagnostic context used for error reporting. */
- DiagnosticContext diag_ctx;
+ DiagnosticContext* diag_ctx;
+
+ const SourceName& source_name;
/*! \brief The current position in the token stream. */
int pos;
/*! \brief The set of expression scopes used for lexical scope. */
ScopeStack<Var> expr_scopes;
- Parser(std::vector<Token> tokens, OperatorTable op_table, Source source)
- : diag_ctx(source), pos(0), tokens(tokens), op_table(op_table), ignore_whitespace(true) {}
+ /*! \brief The metadata section. */
+ MetaTable meta_table;
+
+ Parser(DiagnosticContext* ctx, const SourceName& source_name, std::vector<Token> tokens,
+ OperatorTable op_table, Source source, MetaTable table)
+ : diag_ctx(ctx),
+ source_name(source_name),
+ pos(0),
+ tokens(tokens),
+ op_table(op_table),
+ ignore_whitespace(true),
+ meta_table(table) {}
/*! \brief Examine the next token in the stream, the current parser is configured to be
* whitespace insensitive so we will skip all whitespace or comment tokens. */
// For now we ignore all whitespace tokens and comments.
// We can tweak this behavior later to enable white space sensitivity in the parser.
while (pos < static_cast<int64_t>(tokens.size()) && ignore_whitespace &&
- (tokens.at(pos)->token_type == TokenType::Whitespace ||
- tokens.at(pos)->token_type == TokenType::Newline ||
- tokens.at(pos)->token_type == TokenType::LineComment ||
- tokens.at(pos)->token_type == TokenType::Comment)) {
+ (tokens.at(pos)->token_type == TokenType::kWhitespace ||
+ tokens.at(pos)->token_type == TokenType::kNewline ||
+ tokens.at(pos)->token_type == TokenType::kLineComment ||
+ tokens.at(pos)->token_type == TokenType::kComment)) {
pos++;
}
*/
void Consume(const TokenType& token_type) {
if (tokens[pos]->token_type != token_type) {
- std::string message =
- "expected a " + Pretty(token_type) + " found " + Pretty(Peek()->token_type);
- this->diag_ctx.Emit({tokens[pos]->line, tokens[pos]->column, message});
- this->diag_ctx.Render(std::cout);
+ this->diag_ctx->EmitFatal(Diagnostic::Error(tokens[pos]->span)
+ << "expected a " << Pretty(token_type) << " found "
+ << Pretty(Peek()->token_type));
}
pos++;
}
return var;
}
+ /*! \brief Bind a local variable in the expression scope.
+ *
+ * "x" -> Var("x"), these are needed to map from the raw string names
+ * to unique variable nodes.
+ */
+ Var BindFreeVar(const std::string& name, const relay::Type& type_annotation) {
+ auto var = Var(name, type_annotation);
+ this->expr_scopes.AddFreeVar(name, var);
+ return var;
+ }
+
/*! \brief Bind a type variable in the type scope.
*
* "A" -> TypeVar("A", ...), these are needed to map from raw string names
Var LookupLocal(const Token& local) {
auto var = this->expr_scopes.Lookup(local.ToString());
if (!var.defined()) {
- diag_ctx.Emit(
- {local->line, local->column, "this local variable has not been previously declared"});
+ diag_ctx->Emit(Diagnostic::Error(local->span)
+ << "this local variable has not been previously declared");
}
return var;
}
TypeVar LookupTypeVar(const Token& ident) {
auto var = this->type_scopes.Lookup(ident.ToString());
if (!var.defined()) {
- diag_ctx.Emit(
- {ident->line, ident->column,
- "this type variable has not been previously declared anywhere, perhaps a typo?"});
+ diag_ctx->Emit(
+ Diagnostic::Error(ident->span)
+ << "this type variable has not been previously declared anywhere, perhaps a typo?");
}
return var;
}
/*! \brief Convert a numeric token to an NDArray for embedding into the Relay program. */
NDArray NumberToNDArray(const Token& token) {
- if (token->token_type == TokenType::Integer) {
+ if (token->token_type == TokenType::kInteger) {
DLContext ctx = {DLDeviceType::kDLCPU, 0};
auto dtype = String2DLDataType("int32");
auto data = NDArray::Empty({}, dtype, ctx);
int64_t value = Downcast<tvm::Integer>(token->data);
array[0] = (int32_t)value;
return data;
- } else if (token->token_type == TokenType::Float) {
+ } else if (token->token_type == TokenType::kFloat) {
DLContext ctx = {DLDeviceType::kDLCPU, 0};
auto dtype = String2DLDataType("float32");
auto data = NDArray::Empty({}, dtype, ctx);
/*! \brief Parse `(` parser() `)`. */
template <typename R>
R Parens(std::function<R()> parser) {
- return Bracket(TokenType::OpenParen, TokenType::CloseParen, parser);
+ return Bracket(TokenType::kOpenParen, TokenType::kCloseParen, parser);
}
/*! \brief Parse `{` parser() `}`. */
template <typename R>
R Block(std::function<R()> parser) {
- return Bracket(TokenType::LCurly, TokenType::RCurly, parser);
+ return Bracket(TokenType::kLCurly, TokenType::kRCurly, parser);
}
+ ObjectRef ParseMetaRef() {
+ auto meta_ref = Match(TokenType::kMetaReference);
+ Call ref = Downcast<Call>(meta_ref->data);
+ auto attrs = ref->attrs.as<MetaRefAttrs>();
+ auto type_key = attrs->node_type_key;
+ auto index = attrs->node_index;
+ auto it = this->meta_table.find(type_key);
+ if (it != this->meta_table.end()) {
+ auto nodes = (*it).second;
+ if (index < nodes.size()) {
+ return nodes[index];
+ } else {
+ this->diag_ctx->Emit(Diagnostic::Error(meta_ref->span)
+ << "the node index `" << index << "` is out of bounds for `"
+ << type_key << "`");
+ return ObjectRef();
+ }
+ } else {
+ this->diag_ctx->Emit(Diagnostic::Error(meta_ref->span)
+ << "no entry in the meta table for `" << type_key << "`");
+ return ObjectRef();
+ }
+ }
/*! \brief Parses a sequence beginning with a start token, seperated by a seperator token, and
* ending with a stop token.
*
*/
template <typename T>
Array<T> ParseSequence(TokenType start, TokenType sep, TokenType stop, std::function<T()> parse,
- std::function<void()> before_stop = nullptr) {
+ std::function<bool()> before_stop = nullptr) {
+ DLOG(INFO) << "Parser::ParseSequence: start=" << ToString(start) << "sep=" << ToString(sep)
+ << "stop=" << ToString(stop);
Match(start);
+
+ // This is for the empty arguments list case, if we have <start> <leftovers> <stop> token stream
+ // we must parse leftovers, then match a stop token.
+ if (before_stop) {
+ auto did_parse = before_stop();
+ if (did_parse) {
+ Match(stop);
+ return {};
+ }
+ }
+
+ // This is the case in which we find an empty arguments lists and no leftovers.
if (WhenMatch(stop)) {
return Array<T>();
} else {
auto data = parse();
Array<T> elements = {data};
- // parse '(' expr ')'
- // if we are at the end invoke leftover parser
- if (Peek()->token_type == stop && before_stop) {
- before_stop();
- }
if (WhenMatch(stop)) {
return elements;
// parse '( expr ',' * ')'
} else if (WhenMatch(sep)) {
- // if we are at the end invoke leftover parser
- if (Peek()->token_type == stop && before_stop) {
- before_stop();
- }
while (true) {
if (WhenMatch(stop)) {
break;
} else {
+ // If before stop is
+ if (before_stop) {
+ auto did_parse = before_stop();
+ if (did_parse) {
+ Match(stop);
+ return elements;
+ }
+ }
auto data = parse();
WhenMatch(sep);
elements.push_back(data);
}
return elements;
} else {
- LOG(FATAL) << "issue";
+ auto next = Peek();
+ this->diag_ctx->EmitFatal(Diagnostic::Error(next->span)
+ << "expected a " << Pretty(stop) << " found "
+ << Pretty(next->token_type));
return Array<T>(nullptr);
}
}
auto defs = ParseDefinitions();
// Parse the metadata section at the end.
auto metadata = ParseMetadata();
- Match(TokenType::EndOfFile);
+
+ Match(TokenType::kEndOfFile);
Map<tvm::GlobalVar, BaseFunc> funcs;
Map<tvm::GlobalTypeVar, TypeData> types;
}
/*! \brief Parse the semantic versioning header. */
- SemVer ParseSemVer() {
- // TODO(@jroesch): convert semver to module level attribute.
- auto id = Peek();
- if (id->token_type == TokenType::Identifier && id.ToString() == "v0") {
- auto id = Match(TokenType::Identifier);
- Consume(TokenType::Period);
- Consume(TokenType::Float);
+ SemVer ParseSemVer(bool required = true) {
+ if (Peek()->token_type == TokenType::kVersion) {
+ auto version = Match(TokenType::kVersion);
+ // TODO(@jroesch): we currently only support 0.0.5.
+ if (version.ToString() != "\"0.0.5\"") {
+ this->diag_ctx->Emit(Diagnostic::Error(version->span)
+ << "invalid semantic version `" << version.ToString() << "`");
+ }
+ } else if (required) {
+ this->diag_ctx->Emit(Diagnostic::Error(Peek()->span)
+ << "expected text format semantic version, found a "
+ << PrettyPrint(Peek())
+ << "you can annotate it as #[version = \"0.0.5\"]");
}
- // TODO(@jroesch): the current lexing makes it hard to parse this
- // in a way that doesnt feel like a hack.
- //
- // We should move to module level attributes instead
- // so we can tag modules with top-level data.
- //
- // #[text_version = "0.0.4"]
- //
- // For now we only support current version.
- return SemVer(0, 0, 4);
+ return SemVer(0, 0, 5);
}
/*! \brief Parse zero or more Relay definitions. */
while (true) {
auto next = Peek();
switch (next->token_type) {
- case TokenType::Defn: {
- Consume(TokenType::Defn);
- auto global_name = Match(TokenType::Global).ToString();
+ case TokenType::kDefn: {
+ Consume(TokenType::kDefn);
+ auto global_tok = Match(TokenType::kGlobal);
+ auto global_name = global_tok.ToString();
auto global = GlobalVar(global_name);
- global_names.Add(global_name, global);
+ try {
+ global_names.Add(global_name, global);
+ } catch (DuplicateKeyError e) {
+ this->diag_ctx->Emit(Diagnostic::Error(global_tok->span) << "a function with the name "
+ << "`@" << global_name << "` "
+ << "was previously defined");
+ }
auto func = ParseFunctionDef();
defs.funcs.push_back(GlobalFunc(global, func));
continue;
}
- case TokenType::TypeDef: {
+ case TokenType::kTypeDef: {
defs.types.push_back(ParseTypeDef());
continue;
}
- case TokenType::Extern: {
- Consume(TokenType::Extern);
+ case TokenType::kExtern: {
+ Consume(TokenType::kExtern);
auto type_def = ParseTypeDef();
if (type_def->constructors.size()) {
- diag_ctx.Emit(
- {next->line, next->column, "an external type may not have any constructors"});
+ diag_ctx->Emit(Diagnostic::Error(next->span)
+ << "an external type may not have any constructors");
}
defs.types.push_back(type_def);
}
/*! \brief Parse zero or more Relay type definitions. */
TypeData ParseTypeDef() {
// Match the `type` keyword.
- Match(TokenType::TypeDef);
+ Match(TokenType::kTypeDef);
// Parse the type's identifier.
- auto type_id = Match(TokenType::Identifier).ToString();
+ auto type_tok = Match(TokenType::kIdentifier);
+ auto type_id = type_tok.ToString();
auto type_global = tvm::GlobalTypeVar(type_id, TypeKind::kAdtHandle);
- type_names.Add(type_id, type_global);
+
+ try {
+ type_names.Add(type_id, type_global);
+ } catch (DuplicateKeyError e) {
+ this->diag_ctx->Emit(Diagnostic::Error(type_tok->span) << "a type definition with the name "
+ << "`" << type_id << "` "
+ << "was previously defined");
+ }
Array<TypeVar> generics;
bool should_pop = false;
- if (Peek()->token_type == TokenType::LSquare) {
+ if (Peek()->token_type == TokenType::kLSquare) {
// If we have generics we need to add a type scope.
PushTypeScope();
should_pop = true;
- generics =
- ParseSequence<TypeVar>(TokenType::LSquare, TokenType::Comma, TokenType::RSquare, [&]() {
- auto type_var_name = Match(TokenType::Identifier).ToString();
+ generics = ParseSequence<TypeVar>(
+ TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() {
+ auto type_var_name = Match(TokenType::kIdentifier).ToString();
return BindTypeVar(type_var_name, TypeKind::kType);
});
}
Array<tvm::Constructor> ctors;
- if (Peek()->token_type == TokenType::LCurly) {
+ if (Peek()->token_type == TokenType::kLCurly) {
// Parse the list of constructors.
ctors = ParseSequence<tvm::Constructor>(
- TokenType::LCurly, TokenType::Comma, TokenType::RCurly, [&]() {
+ TokenType::kLCurly, TokenType::kComma, TokenType::kRCurly, [&]() {
// First match the name of the constructor.
- auto ctor_name = Match(TokenType::Identifier).ToString();
+ auto ctor_tok = Match(TokenType::kIdentifier);
+ auto ctor_name = ctor_tok.ToString();
Constructor ctor;
// Match the optional field list.
- if (Peek()->token_type != TokenType::OpenParen) {
+ if (Peek()->token_type != TokenType::kOpenParen) {
ctor = tvm::Constructor(ctor_name, {}, type_global);
} else {
auto arg_types =
- ParseSequence<Type>(TokenType::OpenParen, TokenType::Comma, TokenType::CloseParen,
- [&]() { return ParseType(); });
+ ParseSequence<Type>(TokenType::kOpenParen, TokenType::kComma,
+ TokenType::kCloseParen, [&]() { return ParseType(); });
ctor = tvm::Constructor(ctor_name, arg_types, type_global);
}
CHECK(ctor.defined());
- this->ctors.Add(ctor_name, ctor);
+ try {
+ this->ctors.Add(ctor_name, ctor);
+ } catch (DuplicateKeyError e) {
+ this->diag_ctx->EmitFatal(Diagnostic::Error(ctor_tok->span)
+ << "a constructor with the name "
+ << "`" << ctor_name << "` "
+ << "was previously defined");
+ }
return ctor;
});
/*! \brief Parse a single Relay expression. */
Expr ParseExpr() {
+ DLOG(INFO) << "Parser::ParseExpr";
return ConsumeWhitespace<Expr>([this] {
std::vector<Expr> exprs;
while (true) {
+ DLOG(INFO) << "Parser::ParseExpr: parsing a single expression";
auto next = Peek();
switch (next->token_type) {
// For graph or let, match first rhs, then invoke ParseBindingExpr
// ParseBindingExpression then parse_lhs() parse_rhs() ';' continue
- case TokenType::LCurly: {
+ case TokenType::kLCurly: {
// NB: Might need to optimize to remove deep recursion.
// Stack should only grow proportionally to the number of
// nested scopes.
- return Bracket<Expr>(TokenType::LCurly, TokenType::RCurly, [&]() {
+ // Parses `{` expression `}`.
+ auto block = Bracket<Expr>(TokenType::kLCurly, TokenType::kRCurly, [&]() {
PushScope();
auto expr = ParseExpr();
PopScopes(1);
return expr;
});
+ exprs.push_back(block);
+ break;
}
- case TokenType::Let:
+ case TokenType::kFreeVar: {
+ Consume(TokenType::kFreeVar);
+ auto var_token = Match(TokenType::kLocal);
+
+ Type type;
+ if (WhenMatch(TokenType::kColon)) {
+ type = ParseType();
+ } else {
+ type = IncompleteType();
+ }
+
+ BindFreeVar(var_token.ToString(), type);
+ break;
+ }
+ // Parses `let ...`;
+ case TokenType::kLet:
exprs.push_back(ParseBindingExpr());
break;
- case TokenType::Match:
- case TokenType::PartialMatch: {
- bool is_total = next->token_type == TokenType::Match;
+ case TokenType::kMatch:
+ case TokenType::kPartialMatch: {
+ bool is_total = next->token_type == TokenType::kMatch;
Consume(next->token_type);
exprs.push_back(ParseMatch(is_total));
break;
}
- case TokenType::If: {
+ case TokenType::kIf: {
exprs.push_back(ParseIf());
break;
}
- case TokenType::Graph:
- if (Lookahead(2)->token_type == TokenType::Equal) {
+ // %x ...
+ case TokenType::kGraph:
+ if (Lookahead(2)->token_type == TokenType::kEqual) {
exprs.push_back(ParseBindingExpr());
break;
}
}
}
- if (!WhenMatch(TokenType::Semicolon)) {
+ if (!WhenMatch(TokenType::kSemicolon)) {
break;
}
}
// This ensures for n sequential bindings
// the call depth will be the same before
// and after parsing the n bindings.
+ DLOG(INFO) << "Parser::ParseBindingExpr";
std::vector<std::pair<Var, Expr>> bindings;
int scopes = 0;
while (true) {
auto next = Peek();
- if (next->token_type == TokenType::Graph && Lookahead(2)->token_type == TokenType::Equal) {
- Match(TokenType::Graph);
- Match(TokenType::Equal);
+ if (next->token_type == TokenType::kGraph && Lookahead(2)->token_type == TokenType::kEqual) {
+ Match(TokenType::kGraph);
+ Match(TokenType::kEqual);
auto val = this->ParseExprBinOp();
- Match(TokenType::Semicolon);
+ Match(TokenType::kSemicolon);
AddGraphBinding(next, val);
- } else if (next->token_type == TokenType::Let) {
+ } else if (next->token_type == TokenType::kLet) {
// Parse the 'let'.
- Consume(TokenType::Let);
+ Consume(TokenType::kLet);
// Parse the local '%<id>'.
- auto local_tok = Match(TokenType::Local);
+ auto local_tok = Match(TokenType::kLocal);
auto string = local_tok.ToString();
// Parse the optional type annotation (':' <type>).
Type type;
- if (WhenMatch(TokenType::Colon)) {
+ if (WhenMatch(TokenType::kColon)) {
type = ParseType();
}
auto var = BindVar(string, type);
// Parse the '=';
- Match(TokenType::Equal);
+ Match(TokenType::kEqual);
// Parse the body, and the ';'.
auto val = this->ParseExprBinOp();
- Consume(TokenType::Semicolon);
+ Consume(TokenType::kSemicolon);
// Add the bindings to the local data structure.
bindings.push_back({var, val});
/*! Parse a function definition without a leading keyword or identifier.
*
- * Handles things of the form [T1, ..., TN](arg1: U1, ..., argN, UN) -> Ret { body }.
+ * Handles things of the form [T1, ..., TN](arg1: U1, ..., argN : UN) -> Ret { body }.
*/
Function ParseFunctionDef() {
+ DLOG(INFO) << "Parser::ParseFunctionDef";
PushScope();
PushTypeScope();
Array<TypeVar> generics;
- if (Peek()->token_type == TokenType::LSquare) {
+ if (Peek()->token_type == TokenType::kLSquare) {
// If we have generics we need to add a type scope.
PushTypeScope();
- generics =
- ParseSequence<TypeVar>(TokenType::LSquare, TokenType::Comma, TokenType::RSquare, [&]() {
- auto type_var_name = Match(TokenType::Identifier).ToString();
+ generics = ParseSequence<TypeVar>(
+ TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() {
+ auto type_var_name = Match(TokenType::kIdentifier).ToString();
return BindTypeVar(type_var_name, TypeKind::kType);
});
}
- auto params =
- ParseSequence<Var>(TokenType::OpenParen, TokenType::Comma, TokenType::CloseParen, [&]() {
- auto token = Match(TokenType::Local);
+ Map<String, ObjectRef> raw_attrs;
+
+ auto params = ParseSequence<Var>(
+ TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen,
+ [&]() {
+ auto token = Match(TokenType::kLocal);
auto string = token.ToString();
Type type;
- if (WhenMatch(TokenType::Colon)) {
+ if (WhenMatch(TokenType::kColon)) {
type = ParseType();
}
return BindVar(string, type);
+ },
+ [&] {
+ auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier;
+ auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual;
+
+ if (is_ident && next_is_equal) {
+ raw_attrs = ParseAttrs();
+ return true;
+ }
+
+ return false;
});
Type ret_type;
- if (WhenMatch(TokenType::Minus)) {
- Match(TokenType::RAngle);
+ if (WhenMatch(TokenType::kMinus)) {
+ Match(TokenType::kRAngle);
ret_type = ParseType();
}
PopTypeScopes(1);
PopScopes(1);
- return relay::Function(params, body, ret_type, generics);
+ // TODO(@jroesch): attributes should never be null, they should always be empty.
+ if (raw_attrs.size()) {
+ return relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs));
+ } else {
+ return relay::Function(params, body, ret_type, generics);
+ }
}
/*! \brief Parse an if-expression. */
Expr ParseIf() {
- Consume(TokenType::If);
+ DLOG(INFO) << "Parser::ParseIf";
+ Consume(TokenType::kIf);
auto guard = Parens<Expr>([&] { return ParseExpr(); });
- auto true_branch = Block<Expr>([&] { return ParseExpr(); });
+ auto true_branch = Block<Expr>([&] {
+ this->PushScope();
+ auto expr = ParseExpr();
+ this->PopScopes(1);
+ return expr;
+ });
- Match(TokenType::Else);
+ Match(TokenType::kElse);
- auto false_branch = Block<Expr>([&] { return ParseExpr(); });
+ auto false_branch = Block<Expr>([&] {
+ this->PushScope();
+ auto expr = ParseExpr();
+ this->PopScopes(1);
+ return expr;
+ });
return relay::If(guard, true_branch, false_branch);
}
/* This factors parsing a list of patterns for both tuples, and constructors. */
Array<Pattern> ParsePatternList() {
- return ParseSequence<Pattern>(TokenType::OpenParen, TokenType::Comma, TokenType::CloseParen,
+ return ParseSequence<Pattern>(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen,
[&] { return ParsePattern(); });
}
* This function recursively parses a pattern.
*/
Pattern ParsePattern() {
+ DLOG(INFO) << "Parser::ParsePattern";
auto next = Peek();
switch (next->token_type) {
- case TokenType::Underscore: {
- Match(TokenType::Underscore);
+ case TokenType::kUnderscore: {
+ Match(TokenType::kUnderscore);
return PatternWildcard();
}
- case TokenType::Local: {
- auto id = Match(TokenType::Local);
+ case TokenType::kLocal: {
+ auto id = Match(TokenType::kLocal);
Type type_annotation;
- if (WhenMatch(TokenType::Colon)) {
+ if (WhenMatch(TokenType::kColon)) {
type_annotation = ParseType();
}
auto var = BindVar(id.ToString(), type_annotation);
return PatternVar(var);
}
- case TokenType::Identifier: {
- auto id = Match(TokenType::Identifier);
+ case TokenType::kIdentifier: {
+ auto id = Match(TokenType::kIdentifier);
auto ctor = ctors.Get(id.ToString());
CHECK(ctor) << "undefined identifier";
- if (Peek()->token_type == TokenType::OpenParen) {
+ if (Peek()->token_type == TokenType::kOpenParen) {
auto fields = ParsePatternList();
return PatternConstructor(ctor.value(), fields);
} else {
Clause ParseMatchArm() {
PushScope();
auto pattern = ParsePattern();
- Match(TokenType::Equal);
- Consume(TokenType::RAngle);
+ Match(TokenType::kEqual);
+ Consume(TokenType::kRAngle);
auto expr = ParseExpr();
PopScopes(1);
return Clause(pattern, expr);
Expr scrutinee = ParseExpr();
Array<Clause> clauses = ParseSequence<Clause>(
- TokenType::LCurly, TokenType::Comma, TokenType::RCurly, [&] { return ParseMatchArm(); });
+ TokenType::kLCurly, TokenType::kComma, TokenType::kRCurly, [&] { return ParseMatchArm(); });
return relay::Match(scrutinee, clauses, is_total);
}
Expr ParseExprBinOp() {
+ DLOG(INFO) << "Parser::ParseExprBinOp";
return ConsumeWhitespace<Expr>([this] {
// We must parse at least one expression, the default
// case is that there is no operator and we will fall
});
}
- Attrs ParseAttrs(const std::string& type_key) {
+ ObjectRef ParseAttributeValue() {
+ DLOG(INFO) << "Parser::ParseAttributeValue";
+ auto next = Peek();
+ switch (next->token_type) {
+ case TokenType::kFloat:
+ case TokenType::kInteger:
+ case TokenType::kBoolean:
+ case TokenType::kStringLiteral:
+ return Match(next->token_type)->data;
+ case TokenType::kLSquare: {
+ return ParseSequence<ObjectRef>(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare,
+ [&]() { return ParseAttributeValue(); });
+ }
+ case TokenType::kOpenParen: {
+ // TODO(@jroesch: need to figure out bracket vs. sequence)
+ // return ParseSequence<ObjectRef>(TokenType::kOpenParen, TokenType::kComma,
+ // TokenType::kCloseParen,
+ // [&]() { return ParseAttributeValue(); });
+ return Bracket<ObjectRef>(TokenType::kOpenParen, TokenType::kCloseParen,
+ [&]() { return ParseAttributeValue(); });
+ }
+ // TODO(@jroesch): not sure about this being the right way to handle nulls.
+ case TokenType::kIdentifier: {
+ if (auto text = next->data.as<tvm::StringObj>()) {
+ std::string id = GetRef<String>(text);
+ if (id == "nullptr") {
+ Match(TokenType::kIdentifier);
+ return ObjectRef();
+ }
+ }
+ }
+ default:
+ return ParseAtomicExpr();
+ }
+ }
+
+ Map<String, ObjectRef> ParseAttrs() {
+ DLOG(INFO) << "Parser::ParseAttrs";
Map<String, ObjectRef> kwargs;
- auto attrs = tvm::ReflectionVTable::Global()->CreateObject(type_key, kwargs);
- LOG(FATAL) << Attrs();
- return Attrs();
+ while (Peek()->token_type == TokenType::kIdentifier) {
+ auto key = Match(TokenType::kIdentifier).ToString();
+ Match(TokenType::kEqual);
+ // TOOD(@jroesch): syntactically what do we allow to appear in attribute right hand side.
+ auto value = ParseAttributeValue();
+ // TODO(@jroesch): we need a robust way to handle this writing dtypes as strings in text
+ // format is bad.
+ kwargs.Set(key, value);
+ WhenMatch(TokenType::kComma);
+ }
+ DLOG(INFO) << "Parser::ParseAttrs: kwargs=" << kwargs;
+ return kwargs;
}
Expr ParseCallArgs(Expr op) {
- Attrs call_attrs;
- if (Peek()->token_type == TokenType::OpenParen) {
- Array<Expr> args = ParseSequence<Expr>(
- TokenType::OpenParen, TokenType::Comma, TokenType::CloseParen,
- [&] { return ParseExpr(); },
- [&] {
- auto is_ident = Lookahead(1)->token_type == TokenType::Identifier;
- auto next_is_equal = Lookahead(2)->token_type == TokenType::Equal;
-
- if (is_ident && next_is_equal) {
- if (auto op_node = op.as<OpNode>()) {
- call_attrs = ParseAttrs(op_node->attrs_type_key);
+ try {
+ DLOG(INFO) << "Parser::ParseCallArgs";
+ Map<String, ObjectRef> raw_attrs;
+ std::string op_key;
+ bool is_op = false;
+
+ if (auto op_node = op.as<OpNode>()) {
+ is_op = true;
+ op_key = op_node->attrs_type_key;
+ }
+
+ if (Peek()->token_type == TokenType::kOpenParen) {
+ Array<Expr> args = ParseSequence<Expr>(
+ TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen,
+ [&] { return ParseExpr(); },
+ [&] {
+ auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier;
+ auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual;
+
+ if (is_op && is_ident && next_is_equal) {
+ raw_attrs = ParseAttrs();
+ return true;
}
- }
- });
- return Expr(Call(op, args, call_attrs, {}));
- } else {
- return Expr();
+
+ return false;
+ });
+
+ Attrs attrs;
+
+ if (is_op && op_key.size()) {
+ auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs);
+ CHECK(attr_obj.defined());
+ attrs = Downcast<Attrs>(attr_obj);
+ }
+
+ return Expr(Call(op, args, attrs, {}));
+ } else {
+ return Expr();
+ }
+ } catch (...) {
+ // TODO(@jroesch): AttrErrors should have fields
+ this->diag_ctx->Emit(Diagnostic::Error(Peek()->span));
+ // << err.what());
}
+
+ return Expr();
}
Expr ParseCallExpr() {
+ DLOG(INFO) << "Parser::ParseCallExpr";
return ConsumeWhitespace<Expr>([this] {
Expr expr = ParseAtomicExpr();
// Parse as many call args as possible, building up expression
//
// NB(@jroesch): this seems like a hack but in order to parse curried functions
// and avoid complex grammar we will parse multiple call lists in a row.
- while (true) {
- auto new_expr = ParseCallArgs(expr);
- if (new_expr.defined()) {
- expr = new_expr;
- } else {
- break;
+ while (Peek()->token_type == TokenType::kOpenParen) {
+ try {
+ auto new_expr = ParseCallArgs(expr);
+
+ if (new_expr.defined()) {
+ expr = new_expr;
+ } else {
+ break;
+ }
+ } catch (...) {
+ // TODO(@jroesch): AttrErrors should have fields
+ this->diag_ctx->EmitFatal(Diagnostic::Error(Peek()->span));
+ // << err.what());
}
}
// We need a zero-arity case for constructors.
- if (expr.as<ConstructorNode>()) {
- return Expr(Call(expr, {}));
- } else {
- return expr;
+ if (auto ctor_node = expr.as<ConstructorNode>()) {
+ if (ctor_node->inputs.size() == 0) {
+ return Expr(Call(expr, {}));
+ }
}
+
+ return expr;
});
}
+ Expr GetOp(const std::string& op_name, const Token& tok) {
+ DLOG(INFO) << "op_name=" << op_name << " token=" << tok;
+ try {
+ return Op::Get(op_name);
+ } catch (dmlc::Error e) {
+ this->diag_ctx->Emit(Diagnostic::Error(tok->span)
+ << "operator `" << op_name
+ << "` not found, perhaps you forgot to register it?");
+ return Expr();
+ }
+ }
+
Expr ParseAtomicExpr() {
- return ConsumeWhitespace<Expr>([this] {
+ DLOG(INFO) << "Parser::ParseAtomicExpr";
+ auto expr = ConsumeWhitespace<Expr>([this] {
auto next = Peek();
switch (next->token_type) {
- case TokenType::Integer:
- case TokenType::Float: {
+ case TokenType::kInteger:
+ case TokenType::kFloat: {
Consume(next->token_type);
auto number = NumberToNDArray(next);
- Expr e = Constant(number);
+ Expr e = Constant(number, next->span);
return e;
}
- case TokenType::Boolean: {
- Consume(TokenType::Boolean);
+ case TokenType::kBoolean: {
+ Consume(TokenType::kBoolean);
int value = Downcast<tvm::Integer>(next->data);
auto boolean = BooleanToNDarray(value);
- Expr e = Constant(boolean);
+ Expr e = Constant(boolean, next->span);
return e;
}
- case TokenType::Local: {
- Consume(TokenType::Local);
+ // Parse a local of the form `%x`.
+ case TokenType::kLocal: {
+ Consume(TokenType::kLocal);
return Expr(LookupLocal(next));
}
- case TokenType::Global: {
+ // Parse a local of the form `@x`.
+ case TokenType::kGlobal: {
auto string = next.ToString();
- Consume(TokenType::Global);
+ Consume(TokenType::kGlobal);
auto global = global_names.Get(string);
if (!global) {
+ // TODO(@jroesch): fix global's needing span information
auto global_var = GlobalVar(string);
global_names.Add(string, global_var);
return Expr(global_var);
return Expr(global.value());
}
}
- case TokenType::Identifier: {
- auto string = next.ToString();
- Consume(TokenType::Identifier);
- auto ctor = ctors.Get(string);
+ // Parse a local of the form `x`.
+ // Right now we fail to parse `x.y`.
+ case TokenType::kIdentifier: {
+ auto ctor = ctors.Get(next.ToString());
if (ctor) {
+ Consume(TokenType::kIdentifier);
return Expr(ctor.value());
} else {
- return Expr(Op::Get(string));
+ auto idents = ParseHierarchicalName();
+ CHECK_NE(idents.size(), 0);
+ std::stringstream op_name;
+ int i = 0;
+ int periods = idents.size() - 1;
+ for (auto ident : idents) {
+ op_name << ident;
+ if (i < periods) {
+ op_name << ".";
+ i++;
+ }
+ }
+ return GetOp(op_name.str(), next);
}
}
- case TokenType::Graph: {
- Consume(TokenType::Graph);
+ case TokenType::kGraph: {
+ Consume(TokenType::kGraph);
return LookupGraphBinding(next);
}
- case TokenType::Fn: {
- Consume(TokenType::Fn);
+ case TokenType::kMetaReference: {
+ return Downcast<Expr>(ParseMetaRef());
+ }
+ case TokenType::kFn: {
+ Consume(TokenType::kFn);
return Expr(ParseFunctionDef());
}
- case TokenType::OpenParen: {
- Consume(TokenType::OpenParen);
+ case TokenType::kOpenParen: {
+ Consume(TokenType::kOpenParen);
// parse '(' ')'
- if (WhenMatch(TokenType::CloseParen)) {
+ if (WhenMatch(TokenType::kCloseParen)) {
return Expr(Tuple(Array<Expr>()));
} else {
auto expr = ParseExpr();
// parse '(' expr ')'
- if (WhenMatch(TokenType::CloseParen)) {
+ if (WhenMatch(TokenType::kCloseParen)) {
return expr;
// parse '( expr ',' * ')'
- } else if (WhenMatch(TokenType::Comma)) {
+ } else if (WhenMatch(TokenType::kComma)) {
Array<Expr> exprs = {expr};
while (true) {
- if (WhenMatch(TokenType::CloseParen)) {
+ if (WhenMatch(TokenType::kCloseParen)) {
break;
} else {
auto expr = ParseExpr();
- WhenMatch(TokenType::Comma);
+ WhenMatch(TokenType::kComma);
exprs.push_back(expr);
}
}
}
}
default: {
- std::stringstream msg;
- msg << "expected an expression found " << Pretty(next->token_type);
- diag_ctx.Emit({next->line, next->column, msg.str()});
- diag_ctx.Render(std::cout);
+ this->diag_ctx->EmitFatal(Diagnostic::Error(next->span)
+ << "expected an expression found "
+ << Pretty(next->token_type));
return Expr();
}
}
});
+
+ if (WhenMatch(TokenType::kPeriod)) {
+ auto index = Match(TokenType::kInteger).ToNumber();
+ expr = relay::TupleGetItem(expr, index);
+ }
+
+ return expr;
+ }
+
+ /*! \brief Parse a hierarchical name.
+ *
+ * The tokenizer produces a token stream of <id1> . <id2>
+ * and so on for names of the form `nn.conv2d`.
+ * Currently we only use string names everywhere instead
+ * of a notion of a hierarchical name.
+ *
+ * The below utility reassembles a token stream into a
+ * single stream inserting the required periods needed
+ * to look up registered names.
+ */
+ Array<String> ParseHierarchicalName() {
+ Array<String> idents;
+ while (Peek()->token_type == TokenType::kIdentifier) {
+ auto name = Peek().ToString();
+ idents.push_back(name);
+ Consume(TokenType::kIdentifier);
+
+ // Keep parsing while we see a trailing period.
+ if (Peek()->token_type == TokenType::kPeriod) {
+ Consume(TokenType::kPeriod);
+ continue;
+ } else {
+ // No more periods means we are done!
+ break;
+ }
+ }
+
+ return idents;
}
/*! \brief Parse a shape. */
Array<tvm::PrimExpr> ParseShape() {
- auto dims = ParseSequence<tvm::PrimExpr>(TokenType::OpenParen, TokenType::Comma,
- TokenType::CloseParen, [&]() {
- auto tok = Match(TokenType::Integer);
- return Downcast<tvm::PrimExpr>(tok->data);
- });
+ auto dims = ParseSequence<tvm::PrimExpr>(
+ TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&]() {
+ tvm::PrimExpr dim;
+ if (Peek()->token_type == TokenType::kMetaReference) {
+ dim = Downcast<tvm::PrimExpr>(ParseMetaRef());
+ } else {
+ dim = Downcast<tvm::PrimExpr>(Match(TokenType::kInteger)->data);
+ }
+
+ return dim;
+ });
return dims;
}
/*! \brief Parse a function type. */
Type ParseFunctionType() {
- auto ty_params = ParseSequence<Type>(TokenType::OpenParen, TokenType::Comma,
- TokenType::CloseParen, [&]() { return ParseType(); });
+ auto ty_params = ParseSequence<Type>(TokenType::kOpenParen, TokenType::kComma,
+ TokenType::kCloseParen, [&]() { return ParseType(); });
- Match(TokenType::Minus);
- Match(TokenType::RAngle);
+ Match(TokenType::kMinus);
+ Match(TokenType::kRAngle);
auto ret_type = ParseType();
return relay::FuncType(ty_params, ret_type, {}, {});
CHECK(head_type.defined()) << "internal error: head type must be defined";
Array<Type> arg_types;
- if (Peek()->token_type == TokenType::LSquare) {
- arg_types = ParseSequence<Type>(TokenType::LSquare, TokenType::Comma, TokenType::RSquare,
+ if (Peek()->token_type == TokenType::kLSquare) {
+ arg_types = ParseSequence<Type>(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare,
[&]() { return ParseType(); });
}
Type ParseType() {
auto tok = Peek();
- if (tok->token_type == TokenType::OpenParen) {
- auto tys = ParseSequence<relay::Type>(TokenType::OpenParen, TokenType::Comma,
- TokenType::CloseParen, [&]() { return ParseType(); });
+ if (tok->token_type == TokenType::kOpenParen) {
+ auto tys = ParseSequence<relay::Type>(TokenType::kOpenParen, TokenType::kComma,
+ TokenType::kCloseParen, [&]() { return ParseType(); });
return relay::TupleType(tys);
- } else if (WhenMatch(TokenType::Fn)) {
+ } else if (WhenMatch(TokenType::kFn)) {
return ParseFunctionType();
- } else if (WhenMatch(TokenType::Identifier)) {
+ } else if (WhenMatch(TokenType::kIdentifier)) {
auto id = tok.ToString();
if (id == "Tensor") {
- Match(TokenType::LSquare);
+ Match(TokenType::kLSquare);
auto shape = ParseShape();
- Match(TokenType::Comma);
- auto dtype_tok = Match(TokenType::Identifier);
+ Match(TokenType::kComma);
+ auto dtype_tok = Match(TokenType::kIdentifier);
auto dtype = DataType(String2DLDataType(dtype_tok.ToString()));
- Match(TokenType::RSquare);
+ Match(TokenType::kRSquare);
return TensorType(shape, dtype);
} else {
auto ty = tok.ToString();
}
}
}
- if (WhenMatch(TokenType::Underscore)) {
+ if (WhenMatch(TokenType::kUnderscore)) {
return IncompleteType();
} else {
- std::stringstream msg;
- msg << "failed to parse type found ";
- msg << tok;
- diag_ctx.Emit({tok->line, tok->column, msg.str()});
- diag_ctx.Render(std::cout);
+ this->diag_ctx->EmitFatal(Diagnostic::Error(tok->span)
+ << "failed to parse type found " << tok);
return Type();
}
}
R ConsumeWhitespace(std::function<R()> func) {
auto old = this->ignore_whitespace;
this->ignore_whitespace = true;
- while (tokens[pos]->token_type == TokenType::Whitespace) {
+ while (tokens[pos]->token_type == TokenType::kWhitespace) {
pos++;
}
auto res = func();
return res;
}
- // TODO(@jroesch): this is the final remaining feature.
- ObjectRef ParseMetadata() { return ObjectRef(); }
+ Map<String, Array<ObjectRef>> ParseMetadata() {
+ if (Peek()->token_type == TokenType::kMetadata) {
+ return Match(TokenType::kMetadata).ToMetadata();
+ } else {
+ return Map<String, Array<ObjectRef>>();
+ }
+ }
/*! \brief A helper for debugging the parser, displays the next N tokens in the token stream. */
void DisplayNextN(int n) {
};
IRModule ParseModule(std::string file_name, std::string file_content) {
- auto tokens = Tokenize(file_content);
- Parser parser(tokens, DefaultOpTable(), Source(file_content));
- return parser.ParseModule();
+ DLOG(INFO) << "ParseModule";
+ SourceName src_name = SourceName::Get(file_name);
+ Source src(src_name, file_content);
+ DiagnosticContext ctx(src);
+ auto tokens_and_table = Tokenize(&ctx, src_name, file_content);
+ auto tokens = tokens_and_table.first;
+ auto meta_data_table = tokens_and_table.second;
+ Parser parser(&ctx, src_name, tokens, DefaultOpTable(), src, meta_data_table.ToMetadata());
+ auto mod = parser.ParseModule();
+ // NB(@jroesch): it is very important that we render any errors before we procede
+ // if there were any errors which allow the parser to procede we must render them
+ // here.
+ parser.diag_ctx->Render(std::cout);
+ return mod;
}
Expr ParseExpr(std::string file_name, std::string file_content) {
- auto tokens = Tokenize(file_content);
- Parser parser(tokens, DefaultOpTable(), Source(file_content));
+ DLOG(INFO) << "ParseExpr";
+ SourceName src_name = SourceName::Get(file_name);
+ Source src(src_name, file_content);
+ DiagnosticContext ctx(src);
+ auto tokens_and_table = Tokenize(&ctx, src_name, file_content);
+ auto tokens = tokens_and_table.first;
+ auto meta_data_table = tokens_and_table.second;
+ Parser parser(&ctx, src_name, tokens, DefaultOpTable(), src, meta_data_table.ToMetadata());
+ parser.ParseSemVer(false);
parser.PushScope();
auto expr = parser.ParseExpr();
- parser.Match(TokenType::EndOfFile);
+ parser.Match(TokenType::kEndOfFile);
+ // NB(@jroesch): it is very important that we render any errors before we procede
+ // if there were any errors which allow the parser to procede we must render them
+ // here.
+ parser.diag_ctx->Render(std::cout);
return expr;
}
TVM_REGISTER_GLOBAL("parser.ParseModule")
- .set_body_typed([](std::string file_name, std::string file_content) {
+ .set_body_typed([](tvm::String file_name, tvm::String file_content) {
return ParseModule(file_name, file_content);
});
TVM_REGISTER_GLOBAL("parser.ParseExpr")
- .set_body_typed([](std::string file_name, std::string file_content) {
+ .set_body_typed([](tvm::String file_name, tvm::String file_content) {
return ParseExpr(file_name, file_content);
});
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file source_map.cc
+ * \brief The implementation of the source map data structure.
+ */
+#include <tvm/parser/source_map.h>
+#include <tvm/runtime/registry.h>
+
+namespace tvm {
+namespace parser {
+
+/*! \brief Construct a source from a string. */
+Source::Source(const SourceName& src_name, const std::string& source)
+ : source_name(src_name), source(source) {
+ int index = 0;
+ int length = 0;
+ line_map.push_back({index, length});
+ for (auto c : source) {
+ if (c == '\n') {
+ // Record the length of the line.
+ line_map.back().second = length;
+ // Bump past the newline.
+ index += 1;
+ // Record the start of the next line, and put placeholder for length.
+ line_map.push_back({index, 0});
+ // Reset length to zero.
+ length = 0;
+ } else {
+ length += 1;
+ index += 1;
+ }
+ }
+ line_map.back().second = length;
+}
+
+/*! \brief Generate an error message at a specific line and column with the
+ * annotated message.
+ *
+ * The error is written directly to the `out` std::ostream.
+ *
+ * \param out The output ostream.
+ * \param line The line at which to report a diagnostic.
+ * \param line The column at which to report a diagnostic.
+ * \param msg The message to attach.
+ */
+void Source::ReportAt(std::ostream& out, const Span& span, const std::string& msg) const {
+ DLOG(INFO) << "Source::ReportAt"
+ << "span = " << span << "msg = " << msg;
+ int line = span->line;
+ int column = span->column;
+
+ CHECK(line - 1 <= static_cast<int64_t>(line_map.size()))
+ << "requested line: " << (line - 1) << "line_map size: " << line_map.size()
+ << "source: " << source;
+
+ // Adjust for zero indexing, now have (line_start, line_length);
+ auto range = line_map.at(line - 1);
+ int line_start = range.first;
+ int line_length = range.second;
+ out << "file:" << line << ":" << column << ": parse error: " << msg << std::endl;
+ out << " " << source.substr(line_start, line_length) << std::endl;
+ out << " ";
+ std::stringstream marker;
+ for (int i = 1; i <= line_length; i++) {
+ if (i == column) {
+ marker << "^";
+ } else if ((column - i) < 3) {
+ marker << "~";
+ } else if ((i - column) < 3) {
+ marker << "~";
+ } else {
+ marker << " ";
+ }
+ }
+ out << marker.str();
+ out << std::endl;
+}
+
+// TVM_REGISTER_GLOBAL("ir.SourceName").set_body_typed(SourceName::Get);
+
+// TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+// .set_dispatch<SourceNameNode>([](const ObjectRef& ref, ReprPrinter* p) {
+// auto* node = static_cast<const SourceNameNode*>(ref.get());
+// p->stream << "SourceName(" << node->name << ", " << node << ")";
+// });
+
+TVM_REGISTER_NODE_TYPE(SourceMapNode);
+
+SourceMap::SourceMap(Map<SourceName, tvm::String> source_map) {
+ auto n = make_object<SourceMapNode>();
+ n->source_map = std::move(source_map);
+ data_ = std::move(n);
+}
+
+} // namespace parser
+} // namespace tvm
#ifndef TVM_PARSER_TOKEN_H_
#define TVM_PARSER_TOKEN_H_
+#include <tvm/ir/span.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/object.h>
using namespace runtime;
-enum TokenType {
- CommentStart,
- CommentEnd,
- LineComment,
- Comment,
- Whitespace,
- Newline,
- StringLiteral,
- Identifier,
- Local,
- Global,
- Op,
- Graph,
- OpenParen,
- CloseParen,
- AtSymbol,
- Percent,
- Comma,
- Period,
- Equal,
- Semicolon,
- Colon,
- Integer,
- Float,
- Division,
- Boolean,
- Plus,
- Star,
- Minus,
- RAngle,
- LAngle,
- RCurly,
- LCurly,
- RSquare,
- LSquare,
- Bang,
- At,
- Question,
- If,
- Else,
- Underscore,
- Let,
- Fn,
- Defn,
- TypeDef,
- Extern,
- Match,
- PartialMatch,
- Unknown,
- EndOfFile,
- Null,
+enum class TokenType {
+ kCommentStart,
+ kCommentEnd,
+ kLineComment,
+ kComment,
+ kWhitespace,
+ kNewline,
+ kStringLiteral,
+ kIdentifier,
+ kLocal,
+ kGlobal,
+ kOp,
+ kGraph,
+ kOpenParen,
+ kCloseParen,
+ kAtSymbol,
+ kPercent,
+ kComma,
+ kPeriod,
+ kEqual,
+ kSemicolon,
+ kColon,
+ kInteger,
+ kFloat,
+ kDivision,
+ kBoolean,
+ kPlus,
+ kStar,
+ kMinus,
+ kRAngle,
+ kLAngle,
+ kRCurly,
+ kLCurly,
+ kRSquare,
+ kLSquare,
+ kBang,
+ kAt,
+ kQuestion,
+ kIf,
+ kElse,
+ kUnderscore,
+ kLet,
+ kFn,
+ kDefn,
+ kTypeDef,
+ kExtern,
+ kMatch,
+ kPartialMatch,
+ kMetadata,
+ kMetaReference,
+ kFreeVar,
+ kVersion,
+ kUnknown,
+ kEndOfFile,
+ kNull,
};
std::string ToString(const TokenType& token_type) {
switch (token_type) {
- case TokenType::CommentStart:
+ case TokenType::kCommentStart:
return "CommentStart";
- case TokenType::CommentEnd:
+ case TokenType::kCommentEnd:
return "CommentEnd";
- case TokenType::LineComment:
+ case TokenType::kLineComment:
return "LineComment";
- case TokenType::Comment:
+ case TokenType::kComment:
return "Comment";
- case TokenType::Whitespace:
+ case TokenType::kWhitespace:
return "WhiteSpace";
- case TokenType::Newline:
+ case TokenType::kNewline:
return "Newline";
- case TokenType::StringLiteral:
+ case TokenType::kStringLiteral:
return "StringLiteral";
- case TokenType::Identifier:
+ case TokenType::kIdentifier:
return "Identifier";
- case TokenType::Local:
+ case TokenType::kLocal:
return "Local";
- case TokenType::Global:
+ case TokenType::kGlobal:
return "Global";
- case TokenType::Graph:
+ case TokenType::kGraph:
return "Graph";
- case TokenType::Op:
+ case TokenType::kOp:
return "Op";
- case TokenType::OpenParen:
+ case TokenType::kOpenParen:
return "OpenParen";
- case TokenType::CloseParen:
+ case TokenType::kCloseParen:
return "CloseParen";
- case TokenType::AtSymbol:
+ case TokenType::kAtSymbol:
return "AtSymbol";
- case TokenType::Percent:
+ case TokenType::kPercent:
return "Percent";
- case TokenType::Comma:
+ case TokenType::kComma:
return "Comma";
- case TokenType::Colon:
+ case TokenType::kColon:
return "Colon";
- case TokenType::Semicolon:
+ case TokenType::kSemicolon:
return "Semicolon";
- case TokenType::Period:
+ case TokenType::kPeriod:
return "Period";
- case TokenType::Equal:
+ case TokenType::kEqual:
return "Equal";
- case TokenType::Integer:
+ case TokenType::kInteger:
return "Integer";
- case TokenType::Float:
+ case TokenType::kFloat:
return "Float";
- case TokenType::Plus:
+ case TokenType::kPlus:
return "Plus";
- case TokenType::Star:
+ case TokenType::kStar:
return "Star";
- case TokenType::Minus:
+ case TokenType::kMinus:
return "Minus";
- case TokenType::Division:
+ case TokenType::kDivision:
return "Division";
- case TokenType::RAngle:
+ case TokenType::kRAngle:
return "RAngle";
- case TokenType::LAngle:
+ case TokenType::kLAngle:
return "LAngle";
- case TokenType::RCurly:
+ case TokenType::kRCurly:
return "RCurly";
- case TokenType::LCurly:
+ case TokenType::kLCurly:
return "LCurly";
- case TokenType::RSquare:
+ case TokenType::kRSquare:
return "RSquare";
- case TokenType::LSquare:
+ case TokenType::kLSquare:
return "LSquare";
- case TokenType::Bang:
+ case TokenType::kBang:
return "Bang";
- case TokenType::Underscore:
+ case TokenType::kUnderscore:
return "Underscore";
- case TokenType::At:
+ case TokenType::kAt:
return "At";
- case TokenType::Let:
+ case TokenType::kLet:
return "Let";
- case TokenType::If:
+ case TokenType::kIf:
return "If";
- case TokenType::Else:
+ case TokenType::kElse:
return "Else";
- case TokenType::Fn:
+ case TokenType::kFn:
return "Fn";
- case TokenType::Defn:
+ case TokenType::kDefn:
return "Defn";
- case TokenType::TypeDef:
+ case TokenType::kTypeDef:
return "TypeDef";
- case TokenType::Extern:
+ case TokenType::kExtern:
return "Extern";
- case TokenType::Match:
+ case TokenType::kMatch:
return "Match";
- case TokenType::PartialMatch:
+ case TokenType::kPartialMatch:
return "PartialMatch";
- case TokenType::Question:
+ case TokenType::kQuestion:
return "Question";
- case TokenType::Boolean:
+ case TokenType::kBoolean:
return "Boolean";
- case TokenType::Unknown:
+ case TokenType::kMetadata:
+ return "Metadata";
+ case TokenType::kMetaReference:
+ return "MetaReference";
+ case TokenType::kFreeVar:
+ return "FreeVar";
+ case TokenType::kVersion:
+ return "Version";
+ case TokenType::kUnknown:
return "Unknown";
- case TokenType::EndOfFile:
+ case TokenType::kEndOfFile:
return "EndOfFile";
- case TokenType::Null:
+ case TokenType::kNull:
return "Null";
// Older compilers warn even though the above code is exhaustive.
default:
std::string Pretty(const TokenType& token_type) {
switch (token_type) {
- case TokenType::CommentStart:
+ case TokenType::kCommentStart:
return "`/*`";
- case TokenType::CommentEnd:
+ case TokenType::kCommentEnd:
return "`*/`";
- case TokenType::LineComment:
+ case TokenType::kLineComment:
return "`//`";
- case TokenType::Comment:
+ case TokenType::kComment:
return "comment";
- case TokenType::Whitespace:
+ case TokenType::kWhitespace:
return "whitespace";
- case TokenType::Newline:
+ case TokenType::kNewline:
return "newline";
- case TokenType::StringLiteral:
+ case TokenType::kStringLiteral:
return "string literal";
- case TokenType::Identifier:
+ case TokenType::kIdentifier:
return "identifier";
- case TokenType::Local:
+ case TokenType::kLocal:
return "local variable";
- case TokenType::Global:
+ case TokenType::kGlobal:
return "global variable";
- case TokenType::Graph:
+ case TokenType::kGraph:
return "graph variable";
- case TokenType::Op:
+ case TokenType::kOp:
return "operator";
- case TokenType::OpenParen:
+ case TokenType::kOpenParen:
return "`(`";
- case TokenType::CloseParen:
+ case TokenType::kCloseParen:
return "`)`";
- case TokenType::AtSymbol:
+ case TokenType::kAtSymbol:
return "`@`";
- case TokenType::Percent:
+ case TokenType::kPercent:
return "`%`";
- case TokenType::Comma:
+ case TokenType::kComma:
return "`,`";
- case TokenType::Colon:
+ case TokenType::kColon:
return "`:`";
- case TokenType::Semicolon:
+ case TokenType::kSemicolon:
return "`;`";
- case TokenType::Period:
+ case TokenType::kPeriod:
return "`.`";
- case TokenType::Equal:
+ case TokenType::kEqual:
return "`=`";
- case TokenType::Integer:
+ case TokenType::kInteger:
return "integer";
- case TokenType::Float:
+ case TokenType::kFloat:
return "float";
- case TokenType::Plus:
+ case TokenType::kPlus:
return "`+`";
- case TokenType::Star:
+ case TokenType::kStar:
return "`*`";
- case TokenType::Minus:
+ case TokenType::kMinus:
return "`-`";
- case TokenType::Division:
+ case TokenType::kDivision:
return "`/`";
- case TokenType::RAngle:
+ case TokenType::kRAngle:
return "`<`";
- case TokenType::LAngle:
+ case TokenType::kLAngle:
return "`>`";
- case TokenType::RCurly:
+ case TokenType::kRCurly:
return "`}`";
- case TokenType::LCurly:
+ case TokenType::kLCurly:
return "`{`";
- case TokenType::RSquare:
+ case TokenType::kRSquare:
return "`]`";
- case TokenType::LSquare:
+ case TokenType::kLSquare:
return "`[`";
- case TokenType::Bang:
+ case TokenType::kBang:
return "`!`";
- case TokenType::Underscore:
+ case TokenType::kUnderscore:
return "`_`";
- case TokenType::At:
+ case TokenType::kAt:
return "`@`";
- case TokenType::Let:
+ case TokenType::kLet:
return "`let`";
- case TokenType::If:
+ case TokenType::kIf:
return "`if`";
- case TokenType::Else:
+ case TokenType::kElse:
return "`else`";
- case TokenType::Fn:
+ case TokenType::kFn:
return "`fn`";
- case TokenType::Defn:
+ case TokenType::kDefn:
return "`def`";
- case TokenType::TypeDef:
+ case TokenType::kTypeDef:
return "`type`";
- case TokenType::Extern:
+ case TokenType::kExtern:
return "`extern`";
- case TokenType::Boolean:
+ case TokenType::kBoolean:
return "boolean";
- case TokenType::Match:
+ case TokenType::kMetadata:
+ return "metadata section";
+ case TokenType::kMetaReference:
+ return "`meta`";
+ case TokenType::kFreeVar:
+ return "`free_var`";
+ case TokenType::kMatch:
return "`match`";
- case TokenType::PartialMatch:
+ case TokenType::kPartialMatch:
return "`match?`";
- case TokenType::Question:
+ case TokenType::kQuestion:
return "`?`";
- case TokenType::Unknown:
+ case TokenType::kUnknown:
return "unknown";
- case TokenType::EndOfFile:
+ case TokenType::kEndOfFile:
return "end of file";
- case TokenType::Null:
+ case TokenType::kNull:
return "null";
+ case TokenType::kVersion:
+ return "version attribute";
// Older compilers warn even though the above code is exhaustive.
default:
LOG(FATAL) << "unreachable code";
class TokenNode : public Object {
public:
- int line;
- int column;
+ Span span;
TokenType token_type;
mutable runtime::ObjectRef data;
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TokenNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const TokenNode*>(ref.get());
- p->stream << "Token(line=" << node->line << ", column=" << node->column
- << ", token_type=" << ToString(node->token_type) << ", data=" << node->data << ")";
+ p->stream << "Token(span=" << node->span << ", token_type=" << ToString(node->token_type)
+ << ", data=" << node->data << ")";
});
TVM_REGISTER_NODE_TYPE(TokenNode);
class Token : public ObjectRef {
public:
- TVM_DLL explicit Token(int line, int column, TokenType token_type, ObjectRef data = ObjectRef());
+ TVM_DLL explicit Token(Span span, TokenType token_type, ObjectRef data = ObjectRef());
static Token Null();
int64_t ToNumber() const;
std::string ToString() const;
+ Map<String, Array<ObjectRef>> ToMetadata() const;
TVM_DEFINE_OBJECT_REF_METHODS(Token, ObjectRef, TokenNode);
};
-Token::Token(int line, int column, TokenType token_type, ObjectRef data) {
+Token::Token(Span span, TokenType token_type, ObjectRef data) {
ObjectPtr<TokenNode> n = make_object<TokenNode>();
- n->line = line;
- n->column = column;
+ n->span = span;
n->token_type = token_type;
n->data = data;
data_ = std::move(n);
}
-Token Token::Null() { return Token(0, 0, TokenType::Null); }
+Token Token::Null() { return Token(Span(SourceName(), 0, 0, 0, 0), TokenType::kNull); }
int64_t Token::ToNumber() const { return Downcast<tvm::Integer>(this->operator->()->data); }
std::string Token::ToString() const { return Downcast<tvm::String>(this->operator->()->data); }
+Map<String, Array<ObjectRef>> Token::ToMetadata() const {
+ ObjectRef data = this->operator->()->data;
+ if (data.defined()) {
+ return Downcast<Map<String, Array<ObjectRef>>>(data);
+ } else {
+ return Map<String, Array<ObjectRef>>({});
+ }
+}
+
} // namespace parser
} // namespace tvm
#endif // TVM_PARSER_TOKEN_H_
#ifndef TVM_PARSER_TOKENIZER_H_
#define TVM_PARSER_TOKENIZER_H_
+#include <tvm/node/serialization.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/object.h>
#include <fstream>
#include <string>
#include <unordered_map>
+#include <utility>
#include <vector>
+#include "./meta_ref.h"
#include "./token.h"
namespace tvm {
using namespace runtime;
+// trim from start (in place)
+static inline void ltrim(std::string& s) { // NOLINT(*)
+ s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { return !std::isspace(ch); }));
+}
+
+// trim from end (in place)
+static inline void rtrim(std::string& s) { // NOLINT(*)
+ s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) { return !std::isspace(ch); }).base(),
+ s.end());
+}
+
bool IsDigit(char c) { return '0' <= c && c <= '9'; }
bool IsWhitespace(char c) { return ' ' == c || c == '\t' || c == '\n'; }
bool IsIdent(char c) { return IsIdentLetter(c) || IsDigit(c); }
static std::unordered_map<std::string, TokenType> KEYWORD_TABLE = {
- {"let", TokenType::Let}, {"fn", TokenType::Fn}, {"def", TokenType::Defn},
- {"if", TokenType::If}, {"else", TokenType::Else}, {"type", TokenType::TypeDef},
- {"match", TokenType::Match}, {"extern", TokenType::Extern}};
+ {"let", TokenType::kLet}, {"fn", TokenType::kFn},
+ {"def", TokenType::kDefn}, {"if", TokenType::kIf},
+ {"else", TokenType::kElse}, {"type", TokenType::kTypeDef},
+ {"match", TokenType::kMatch}, {"extern", TokenType::kExtern},
+ {"free_var", TokenType::kFreeVar}};
struct Tokenizer {
+ DiagnosticContext* diag_ctx;
+ const SourceName& source_name;
+
size_t pos;
int col;
int line;
return this->source.at(this->pos);
}
- Token NewToken(TokenType token_type, ObjectRef data = ObjectRef()) {
- return Token(this->line, this->col, token_type, data);
+ Token NewToken(TokenType token_type, ObjectRef data = ObjectRef(), int lines = 0, int cols = 1) {
+ auto span =
+ Span(this->source_name, this->line, this->line + lines, this->col, this->col + cols);
+ return Token(span, token_type, data);
+ }
+
+ Span SpanFrom(int line, int column) {
+ int end_line = this->line;
+ int end_column = this->col;
+ return Span(this->source_name, line, end_line, column, end_column);
}
enum CommentParserState {
CommentParserState state = CommentParserState::Proceed;
int nesting = 1;
- while (true) {
+ while (More()) {
switch (state) {
case CommentParserState::Proceed: {
if (Peek() == '/') {
Next();
buffer->pop_back();
return;
- } else {
- buffer->operator+=(Next());
- state = CommentParserState::Proceed;
}
}
+
+ buffer->operator+=(Next());
+ state = CommentParserState::Proceed;
continue;
}
}
if (is_float) {
throw std::invalid_argument("is_float");
}
- auto token = NewToken(TokenType::Integer);
+ auto token = NewToken(TokenType::kInteger);
size_t index = 0;
int value = std::stoi(number, &index);
if (number.size() > index) {
token->data = tvm::Integer(value);
return token;
} catch (const std::invalid_argument& ia) {
- auto token = NewToken(TokenType::Float);
+ auto token = NewToken(TokenType::kFloat);
if (number.back() == 'f') {
number.pop_back();
}
}
+ bool MatchString(const std::string& string) {
+ int start = this->pos;
+
+ for (auto c : string) {
+ if (Peek() != c) {
+ this->pos = start;
+ return false;
+ } else {
+ Next();
+ }
+ }
+
+ return true;
+ }
+
+ Token TokenizeMetaRef() {
+ int line = this->line;
+ int column = this->col;
+
+ CHECK_EQ(Peek(), '[');
+ Next();
+ std::stringstream type_key;
+ while (More() && Peek() != ']') {
+ type_key << Next();
+ }
+ CHECK_EQ(Peek(), ']');
+ Next();
+
+ CHECK_EQ(Peek(), '[');
+ Next();
+ std::stringstream str_index;
+ while (More() && Peek() != ']') {
+ str_index << Next();
+ }
+ CHECK_EQ(Peek(), ']');
+ Next();
+ // todo: add error handling around bad indices
+ auto index = ParseNumber(true, false, str_index.str()).ToNumber();
+ auto span = SpanFrom(line, column);
+ return Token(span, TokenType::kMetaReference, MetaRef(type_key.str(), index));
+ }
+
+ Token TokenizeAttr() {
+ int line = this->line;
+ int column = this->col;
+ Next();
+ if (Peek() == '[') {
+ Next();
+ std::stringstream raw_attribute;
+
+ while (More() && Peek() != ']') {
+ raw_attribute << Next();
+ }
+
+ CHECK_EQ(Next(), ']');
+
+ auto attribute = raw_attribute.str();
+ // Clean up the white-space on both sides.
+ ltrim(attribute);
+ rtrim(attribute);
+
+ // Metadata can only appear at the bottom of a file and goes to EOF.
+ if (attribute == "metadata") {
+ std::stringstream metadata;
+ while (More()) {
+ metadata << Next();
+ }
+ ObjectRef metadata_map = tvm::LoadJSON(metadata.str());
+ auto span = SpanFrom(line, column);
+ return Token(span, TokenType::kMetadata, metadata_map);
+ }
+ if (attribute.rfind("version", 0) == 0) {
+ std::string version = attribute.substr(attribute.find("=") + 1);
+ ltrim(version);
+ rtrim(version);
+ auto span = SpanFrom(line, column);
+ return Token(span, TokenType::kVersion, tvm::String(version));
+ } else {
+ // TOOD(@jroesch): maybe make this a warning an continue parsing?
+ auto span = SpanFrom(line, column);
+ this->diag_ctx->EmitFatal(Diagnostic::Error(span) << "unsupported attribute " << attribute);
+ return Token();
+ }
+ } else {
+ auto span = SpanFrom(line, column);
+ this->diag_ctx
+ ->EmitFatal(Diagnostic::Error(span)
+ << "`#` denotes the start of an attribute can only be followed by `[`"
+ << " found `" << Peek() << "`");
+ return Token();
+ }
+ }
+
inline Token TokenizeOnce() {
+ int line = this->line;
+ int col = this->col;
auto next = Peek();
+ DLOG(INFO) << "tvm::parser::TokenizeOnce: next=" << next;
if (next == '\n') {
- auto token = NewToken(TokenType::Newline);
+ auto token = NewToken(TokenType::kNewline);
Next();
return token;
} else if (next == '\r') {
Next();
if (More() && Peek() == '\n') {
- auto token = NewToken(TokenType::Newline);
+ auto token = NewToken(TokenType::kNewline);
return token;
} else {
- // TODO(@jroesch): have lexer use diagnostic context too.
- LOG(FATAL) << "lexer error";
+ auto span = SpanFrom(line, col);
+ this->diag_ctx->EmitFatal(
+ Diagnostic::Error(span)
+ << "\\r carriage returns must be followed by a \\n in the TVM text format");
return Token();
}
} else if (next == '"') {
- LOG(FATAL) << "string not working yet";
- return NewToken(TokenType::Unknown);
+ // TODO(@jroesch): Properly tokenize escape sequences in strings.
+ // see https://github.com/apache/incubator-tvm/issues/6153.
+ Next();
+ std::stringstream string_content;
+ while (More() && Peek() != '"') {
+ string_content << Next();
+ }
+ Next();
+ return NewToken(TokenType::kStringLiteral, tvm::String(string_content.str()));
} else if (IsWhitespace(next)) {
- auto token = NewToken(TokenType::Whitespace);
+ auto token = NewToken(TokenType::kWhitespace);
Next();
return token;
} else if (IsDigit(next) || next == '-') {
// with multi-token return or something.
if (negs && !IsDigit(Peek())) {
pos = pos - (negs - 1);
- return NewToken(TokenType::Minus);
+ return NewToken(TokenType::kMinus);
}
bool is_neg = negs % 2 == 1;
return ParseNumber(!is_neg, is_float, ss.str());
} else if (next == '.') {
- auto token = NewToken(TokenType::Period);
+ auto token = NewToken(TokenType::kPeriod);
Next();
return token;
} else if (next == ',') {
- auto token = NewToken(TokenType::Comma);
+ auto token = NewToken(TokenType::kComma);
Next();
return token;
} else if (next == '=') {
- auto token = NewToken(TokenType::Equal);
+ auto token = NewToken(TokenType::kEqual);
Next();
return token;
} else if (next == ';') {
- auto token = NewToken(TokenType::Semicolon);
+ auto token = NewToken(TokenType::kSemicolon);
Next();
return token;
} else if (next == ':') {
- auto token = NewToken(TokenType::Colon);
+ auto token = NewToken(TokenType::kColon);
Next();
return token;
} else if (next == '(') {
- auto token = NewToken(TokenType::OpenParen);
+ auto token = NewToken(TokenType::kOpenParen);
Next();
return token;
} else if (next == ')') {
- auto token = NewToken(TokenType::CloseParen);
+ auto token = NewToken(TokenType::kCloseParen);
Next();
return token;
} else if (next == '+') {
- auto token = NewToken(TokenType::Plus);
+ auto token = NewToken(TokenType::kPlus);
Next();
return token;
} else if (next == '-') {
- auto token = NewToken(TokenType::Minus);
+ auto token = NewToken(TokenType::kMinus);
Next();
return token;
} else if (next == '*') {
- auto token = NewToken(TokenType::Star);
+ auto token = NewToken(TokenType::kStar);
Next();
return token;
} else if (next == '<') {
- auto token = NewToken(TokenType::LAngle);
+ auto token = NewToken(TokenType::kLAngle);
Next();
return token;
} else if (next == '>') {
- auto token = NewToken(TokenType::RAngle);
+ auto token = NewToken(TokenType::kRAngle);
Next();
return token;
} else if (next == '{') {
- auto token = NewToken(TokenType::LCurly);
+ auto token = NewToken(TokenType::kLCurly);
Next();
return token;
} else if (next == '}') {
- auto token = NewToken(TokenType::RCurly);
+ auto token = NewToken(TokenType::kRCurly);
Next();
return token;
} else if (next == '[') {
- auto token = NewToken(TokenType::LSquare);
+ auto token = NewToken(TokenType::kLSquare);
Next();
return token;
} else if (next == ']') {
- auto token = NewToken(TokenType::RSquare);
+ auto token = NewToken(TokenType::kRSquare);
Next();
return token;
} else if (next == '!') {
- auto token = NewToken(TokenType::Bang);
+ auto token = NewToken(TokenType::kBang);
Next();
return token;
} else if (next == '@') {
- auto token = NewToken(TokenType::At);
+ auto token = NewToken(TokenType::kAt);
Next();
return token;
} else if (next == '?') {
- auto token = NewToken(TokenType::Question);
+ auto token = NewToken(TokenType::kQuestion);
Next();
return token;
+ } else if (MatchString("meta")) {
+ return TokenizeMetaRef();
+ } else if (next == '#') {
+ return TokenizeAttr();
} else if (next == '%') {
- auto token = NewToken(TokenType::Percent);
+ auto token = NewToken(TokenType::kPercent);
Next();
+
+ std::stringstream number;
+ while (More() && IsDigit(Peek())) {
+ number << Next();
+ }
+
+ auto number_str = number.str();
+ if (number_str.size()) {
+ auto num_tok = ParseNumber(true, false, number_str);
+ auto span = SpanFrom(token->span->line, token->span->column);
+ token = Token(span, TokenType::kGraph, num_tok->data);
+ }
+
return token;
} else if (next == '/') {
Next();
if (Peek() == '/') {
- auto token = NewToken(TokenType::LineComment);
+ auto token = NewToken(TokenType::kLineComment);
// Consume the /
Next();
std::stringstream comment;
Next();
std::string comment;
MatchComment(&comment);
- auto token = NewToken(TokenType::Comment, tvm::String(comment));
+ auto token = NewToken(TokenType::kComment, tvm::String(comment));
return token;
} else {
- return NewToken(TokenType::Division);
+ return NewToken(TokenType::kDivision);
}
} else if (IsIdentLetter(next)) {
std::stringstream ss;
if (it != KEYWORD_TABLE.end()) {
token_type = it->second;
- if (token_type == TokenType::Match) {
+ if (token_type == TokenType::kMatch) {
if (More() && Peek() == '?') {
Next();
- token_type = TokenType::PartialMatch;
+ token_type = TokenType::kPartialMatch;
}
}
} else {
- token_type = TokenType::Identifier;
+ token_type = TokenType::kIdentifier;
}
- return Token(line, col, token_type, tvm::String(ss.str()));
+ auto span = SpanFrom(line, col);
+ return Token(span, token_type, tvm::String(ss.str()));
} else {
std::stringstream ss;
while (More() && !IsWhitespace(Peek())) {
ss << Next();
}
- auto token = NewToken(TokenType::Unknown);
+ auto token = NewToken(TokenType::kUnknown);
token->data = tvm::String(ss.str());
return token;
}
}
void Tokenize() {
+ DLOG(INFO) << "tvm::parser::Tokenize";
while (this->More()) {
auto token = TokenizeOnce();
CHECK(token.defined());
this->tokens.push_back(token);
}
- this->tokens.push_back(NewToken(TokenType::EndOfFile));
+ this->tokens.push_back(NewToken(TokenType::kEndOfFile));
}
- explicit Tokenizer(std::string& source) : pos(0), col(1), line(1), source(source), tokens() {}
+ explicit Tokenizer(DiagnosticContext* ctx, const SourceName& source_name,
+ const std::string& source)
+ : diag_ctx(ctx),
+ source_name(source_name),
+ pos(0),
+ col(1),
+ line(1),
+ source(source),
+ tokens() {}
};
-std::vector<Token> Condense(const std::vector<Token>& tokens) {
+std::vector<Token> Condense(const std::vector<Token>& tokens, Token* table) {
std::vector<Token> out;
+ bool found_metadata = false;
for (size_t i = 0; i < tokens.size(); i++) {
auto current = tokens.at(i);
switch (current->token_type) {
- case TokenType::Percent: {
+ case TokenType::kMetadata: {
+ if (!found_metadata) {
+ found_metadata = true;
+ *table = current;
+ } else {
+ LOG(FATAL) << "duplicate metadata section";
+ }
+ continue;
+ }
+ case TokenType::kPercent: {
auto next = tokens.at(i + 1);
- if (next->token_type == TokenType::Identifier) {
+ if (next->token_type == TokenType::kIdentifier) {
// Match this token.
i += 1;
- auto tok = Token(current->line, current->column, TokenType::Local, next->data);
+ // TODO(@jroesch): merge spans
+ auto tok = Token(current->span, TokenType::kLocal, next->data);
CHECK(tok.defined());
out.push_back(tok);
- } else if (next->token_type == TokenType::Integer) {
+ } else if (next->token_type == TokenType::kInteger) {
i += 1;
- auto tok = Token(current->line, current->column, TokenType::Graph, next->data);
+ auto tok = Token(current->span, TokenType::kGraph, next->data);
CHECK(tok.defined());
out.push_back(tok);
} else {
}
continue;
}
- case TokenType::At: {
+ case TokenType::kAt: {
auto next = tokens.at(i + 1);
- if (next->token_type == TokenType::Identifier) {
+ if (next->token_type == TokenType::kIdentifier) {
// Match this token.
i += 1;
- auto tok = Token(current->line, current->column, TokenType::Global, next->data);
+ // TODO(@jroesch): merge spans
+ auto tok = Token(current->span, TokenType::kGlobal, next->data);
CHECK(tok.defined());
out.push_back(tok);
} else {
}
continue;
}
- case TokenType::Identifier: {
+ case TokenType::kIdentifier: {
std::string str = Downcast<tvm::String>(current->data);
Token tok;
+ // TODO(@jroesch): merge spans
if (str == "True") {
auto data = tvm::Integer(1);
- tok = Token(current->line, current->column, TokenType::Boolean, data);
+ tok = Token(current->span, TokenType::kBoolean, data);
} else if (str == "False") {
auto data = tvm::Integer(0);
- tok = Token(current->line, current->column, TokenType::Boolean, data);
+ tok = Token(current->span, TokenType::kBoolean, data);
} else if (str == "_") {
- tok = Token(current->line, current->column, TokenType::Underscore);
+ tok = Token(current->span, TokenType::kUnderscore);
} else {
tok = current;
}
return out;
}
-std::vector<Token> Tokenize(std::string source) {
- auto tokenizer = Tokenizer(source);
+std::pair<std::vector<Token>, Token> Tokenize(DiagnosticContext* ctx, const SourceName& source_name,
+ const std::string& source) {
+ auto tokenizer = Tokenizer(ctx, source_name, source);
tokenizer.Tokenize();
- auto tokens = Condense(tokenizer.tokens);
+ Token meta_table(Span(), TokenType::kUnknown, ObjectRef());
+ auto tokens = Condense(tokenizer.tokens, &meta_table);
for (auto token : tokens) {
CHECK(token.defined());
}
- return tokens;
+ return {tokens, meta_table};
}
} // namespace parser
#include <tvm/tir/function.h>
#include "../ir/attr_functor.h"
+#include "../parser/meta_ref.h"
#include "../relay/analysis/dependency_graph.h"
#include "doc.h"
#include "meta_data.h"
// determine whether to inline
bool inline_expr = AlwaysInline(expr);
+
if (try_inline) {
inline_expr |= IsUnique(expr);
}
if (it != memo_.end()) return it->second;
Doc printed_expr;
+
if (meta) {
printed_expr = meta_->GetMetaNode(GetRef<ObjectRef>(expr.get()));
} else if (!inline_expr && expr.as<LetNode>()) {
if (expr.as<VarNode>()) {
// This is our first time visiting the var and we hit the VarNode case
// in the visitor. Thus the variable is free.
- doc_stack_.back() << "free_var " << printed_expr << Doc::NewLine();
+ doc_stack_.back() << "free_var " << printed_expr << ";" << Doc::NewLine();
// Memoization is done in AllocVar.
return memo_[expr];
} else if (inline_expr) {
Doc printed_attr;
if (value.as<tvm::tir::AnyNode>()) {
printed_attr << "?";
+ } else if (auto str_obj = value.as<tvm::StringObj>()) {
+ printed_attr << Doc::StrLiteral(GetRef<String>(str_obj));
} else if (meta) {
printed_attr = meta_->GetMetaNode(Downcast<ObjectRef>(value));
} else {
namespace tvm {
-static const char* kSemVer = "v0.0.4";
-
-// TODO(tvm-team): split into files, related: arith/analyzer.h
-//
-// - text_printer.h (common header)
-// - text_printer.cc (prints modules dispatch into relay and tir files)
-// - type_text_printer.cc(specific printing logics for types,
-// can also consider put under type_text_printer)
-// - Implements AsText
-// - relay_text_printer.cc (specific printing logics for relay)
-// - tir_text_printer.cc (specific printing logics for TIR)
+static const char* kSemVer = "0.0.5";
Doc TextPrinter::PrintMod(const IRModule& mod) {
Doc doc;
String PrettyPrint(const ObjectRef& node) {
Doc doc;
- doc << TextPrinter(false, nullptr).PrintFinal(node);
+ doc << TextPrinter(false, nullptr, false).PrintFinal(node);
return doc.str();
}
String AsText(const ObjectRef& node, bool show_meta_data,
runtime::TypedPackedFunc<String(ObjectRef)> annotate) {
Doc doc;
- doc << kSemVer << Doc::NewLine();
+ doc << "#[version = \"" << kSemVer << "\"]" << Doc::NewLine();
runtime::TypedPackedFunc<std::string(ObjectRef)> ftyped = nullptr;
if (annotate != nullptr) {
ftyped = runtime::TypedPackedFunc<std::string(ObjectRef)>(
class TextPrinter {
public:
explicit TextPrinter(bool show_meta_data,
- const runtime::TypedPackedFunc<std::string(ObjectRef)>& annotate)
+ const runtime::TypedPackedFunc<std::string(ObjectRef)>& annotate,
+ bool show_warning = true)
: show_meta_data_(show_meta_data),
+ show_warning_(show_warning),
annotate_(annotate),
relay_text_printer_(show_meta_data, &meta_, annotate),
tir_text_printer_(show_meta_data, &meta_) {}
/*! \brief whether show meta data */
bool show_meta_data_;
+
+ /*! \brief whether show the meta data warning message */
+ bool show_warning_;
+
/*! \brief meta data context */
TextMetaDataContext meta_;
/*! \brief additional comment function */
if (!meta_.empty()) {
doc << Doc::NewLine();
if (show_meta_data_) {
- // append meta data in the end.
- doc << "METADATA:" << Doc::NewLine() << meta_.GetMetaSection();
- } else {
- doc << "// meta data omitted. you can use show_meta_data=True to include meta data";
+ doc << "#[metadata]" << Doc::NewLine() << meta_.GetMetaSection();
+ } else if (show_warning_) {
+ doc << "/* For debugging purposes the metadata section has been omitted." << Doc::NewLine()
+ << " * If you would like to see the full metadata section you can set the "
+ << Doc::NewLine() << " * option to `True` when invoking `astext`. " << Doc::NewLine()
+ << " */";
}
}
return doc;
p->stream << "ClauseNode(" << node->lhs << ", " << node->rhs << ")";
});
-Match::Match(Expr data, tvm::Array<Clause> clauses, bool complete) {
+Match::Match(Expr data, tvm::Array<Clause> clauses, bool complete, Span span) {
ObjectPtr<MatchNode> n = make_object<MatchNode>();
n->data = std::move(data);
n->clauses = std::move(clauses);
n->complete = complete;
+ n->span = std::move(span);
data_ = std::move(n);
}
using tvm::ReprPrinter;
using namespace tvm::runtime;
-Constant::Constant(runtime::NDArray data) {
+Constant::Constant(runtime::NDArray data, Span span) {
ObjectPtr<ConstantNode> n = make_object<ConstantNode>();
n->data = std::move(data);
+ n->span = std::move(span);
data_ = std::move(n);
}
return TensorType(shape, dtype);
}
-Tuple::Tuple(tvm::Array<relay::Expr> fields) {
+Tuple::Tuple(tvm::Array<relay::Expr> fields, Span span) {
ObjectPtr<TupleNode> n = make_object<TupleNode>();
n->fields = std::move(fields);
+ n->span = std::move(span);
data_ = std::move(n);
}
p->stream << "Tuple(" << node->fields << ")";
});
-Var::Var(Id vid, Type type_annotation) {
+Var::Var(Id vid, Type type_annotation, Span span) {
ObjectPtr<VarNode> n = make_object<VarNode>();
n->vid = std::move(vid);
n->type_annotation = std::move(type_annotation);
+ n->span = std::move(span);
data_ = std::move(n);
}
p->stream << ")";
});
-Call::Call(Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args) {
+Call::Call(Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args, Span span) {
ObjectPtr<CallNode> n = make_object<CallNode>();
n->op = std::move(op);
n->args = std::move(args);
n->attrs = std::move(attrs);
n->type_args = std::move(type_args);
+ n->span = std::move(span);
data_ = std::move(n);
}
<< node->type_args << ")";
});
-Let::Let(Var var, Expr value, Expr body) {
+Let::Let(Var var, Expr value, Expr body, Span span) {
ObjectPtr<LetNode> n = make_object<LetNode>();
n->var = std::move(var);
n->value = std::move(value);
n->body = std::move(body);
+ n->span = std::move(span);
data_ = std::move(n);
}
p->stream << "LetNode(" << node->var << ", " << node->value << ", " << node->body << ")";
});
-If::If(Expr cond, Expr true_branch, Expr false_branch) {
+If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) {
ObjectPtr<IfNode> n = make_object<IfNode>();
n->cond = std::move(cond);
n->true_branch = std::move(true_branch);
n->false_branch = std::move(false_branch);
+ n->span = std::move(span);
data_ = std::move(n);
}
<< node->false_branch << ")";
});
-TupleGetItem::TupleGetItem(Expr tuple, int index) {
+TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) {
ObjectPtr<TupleGetItemNode> n = make_object<TupleGetItemNode>();
n->tuple = std::move(tuple);
n->index = index;
+ n->span = std::move(span);
data_ = std::move(n);
}
p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")";
});
-RefCreate::RefCreate(Expr value) {
+RefCreate::RefCreate(Expr value, Span span) {
ObjectPtr<RefCreateNode> n = make_object<RefCreateNode>();
n->value = std::move(value);
+ n->span = std::move(span);
data_ = std::move(n);
}
p->stream << "RefCreateNode(" << node->value << ")";
});
-RefRead::RefRead(Expr ref) {
+RefRead::RefRead(Expr ref, Span span) {
ObjectPtr<RefReadNode> n = make_object<RefReadNode>();
n->ref = std::move(ref);
+ n->span = std::move(span);
data_ = std::move(n);
}
p->stream << "RefReadNode(" << node->ref << ")";
});
-RefWrite::RefWrite(Expr ref, Expr value) {
+RefWrite::RefWrite(Expr ref, Expr value, Span span) {
ObjectPtr<RefWriteNode> n = make_object<RefWriteNode>();
n->ref = std::move(ref);
n->value = std::move(value);
+ n->span = std::move(span);
data_ = std::move(n);
}
if (op->type_annotation.defined()) {
auto type = this->VisitType(op->type_annotation);
if (!op->type_annotation.same_as(type)) {
- return Var(op->vid, type);
+ return Var(op->vid, type, op->span);
}
}
// default case return self.
if (all_fields_unchanged) {
return GetRef<Expr>(op);
} else {
- return Tuple(fields);
+ return Tuple(fields, op->span);
}
}
body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
- return Function(params, body, ret_type, ty_params, op->attrs);
+ return Function(params, body, ret_type, ty_params, op->attrs, op->span);
}
}
if (unchanged) {
return GetRef<Expr>(call_node);
} else {
- return Call(new_op, call_args, call_node->attrs, ty_args);
+ return Call(new_op, call_args, call_node->attrs, ty_args, call_node->span);
}
}
if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
- return Let(var, value, body);
+ return Let(var, value, body, op->span);
}
}
op->false_branch.same_as(false_b)) {
return GetRef<Expr>(op);
} else {
- return If(guard, true_b, false_b);
+ return If(guard, true_b, false_b, op->span);
}
}
-Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
- auto t = this->Mutate(g->tuple);
- if (g->tuple == t) {
- return GetRef<Expr>(g);
+Expr ExprMutator::VisitExpr_(const TupleGetItemNode* get_item) {
+ auto t = this->Mutate(get_item->tuple);
+ if (get_item->tuple == t) {
+ return GetRef<Expr>(get_item);
} else {
- return TupleGetItem(t, g->index);
+ return TupleGetItem(t, get_item->index, get_item->span);
}
}
if (value.same_as(op->value)) {
return GetRef<Expr>(op);
} else {
- return RefCreate(value);
+ return RefCreate(value, op->span);
}
}
if (ref.same_as(op->ref)) {
return GetRef<Expr>(op);
} else {
- return RefRead(ref);
+ return RefRead(ref, op->span);
}
}
if (ref.same_as(op->ref) && value.same_as(op->value)) {
return GetRef<Expr>(op);
} else {
- return RefWrite(ref, value);
+ return RefWrite(ref, value, op->span);
}
}
}
Expr data = Mutate(m->data);
unchanged &= data.same_as(m->data);
+
if (unchanged) {
return GetRef<Expr>(m);
}
- return Match(data, clauses, m->complete);
+ return Match(data, clauses, m->complete, m->span);
}
Clause ExprMutator::VisitClause(const Clause& c) {
}
void ExprVisitor::VisitExpr_(const VarNode* op) {
+ this->VisitSpan(op->span);
if (op->type_annotation.defined()) {
this->VisitType(op->type_annotation);
}
}
-void ExprVisitor::VisitExpr_(const GlobalVarNode* op) {}
+void ExprVisitor::VisitExpr_(const GlobalVarNode* op) { this->VisitSpan(op->span); }
-void ExprVisitor::VisitExpr_(const ConstantNode* op) {}
+void ExprVisitor::VisitExpr_(const ConstantNode* op) { this->VisitSpan(op->span); }
void ExprVisitor::VisitExpr_(const TupleNode* op) {
+ this->VisitSpan(op->span);
for (auto field : op->fields) {
this->VisitExpr(field);
}
}
void ExprVisitor::VisitExpr_(const FunctionNode* op) {
+ this->VisitSpan(op->span);
for (auto param : op->params) {
this->VisitExpr(param);
}
}
void ExprVisitor::VisitExpr_(const CallNode* op) {
+ this->VisitSpan(op->span);
this->VisitExpr(op->op);
for (auto ty_arg : op->type_args) {
}
void ExprVisitor::VisitExpr_(const LetNode* op) {
+ this->VisitSpan(op->span);
this->VisitExpr(op->value);
this->VisitExpr(op->var);
this->VisitExpr(op->body);
}
void ExprVisitor::VisitExpr_(const IfNode* op) {
+ this->VisitSpan(op->span);
this->VisitExpr(op->cond);
this->VisitExpr(op->true_branch);
this->VisitExpr(op->false_branch);
void ExprVisitor::VisitExpr_(const OpNode* op) { return; }
-void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { this->VisitExpr(op->tuple); }
+void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) {
+ this->VisitSpan(op->span);
+ this->VisitExpr(op->tuple);
+}
-void ExprVisitor::VisitExpr_(const RefCreateNode* op) { this->VisitExpr(op->value); }
+void ExprVisitor::VisitExpr_(const RefCreateNode* op) {
+ this->VisitSpan(op->span);
+ this->VisitExpr(op->value);
+}
-void ExprVisitor::VisitExpr_(const RefReadNode* op) { this->VisitExpr(op->ref); }
+void ExprVisitor::VisitExpr_(const RefReadNode* op) {
+ this->VisitSpan(op->span);
+ this->VisitExpr(op->ref);
+}
void ExprVisitor::VisitExpr_(const RefWriteNode* op) {
+ this->VisitSpan(op->span);
this->VisitExpr(op->ref);
this->VisitExpr(op->value);
}
void ExprVisitor::VisitExpr_(const ConstructorNode* op) {
+ // TODO(@jroesch): visit spans
for (const Type& t : op->inputs) {
this->VisitType(t);
}
}
void ExprVisitor::VisitExpr_(const MatchNode* op) {
+ this->VisitSpan(op->span);
this->VisitExpr(op->data);
for (const Clause& c : op->clauses) {
this->VisitClause(c);
}
void ExprVisitor::VisitClause(const Clause& op) {
+ // TODO(@jroesch): visit spans
this->VisitPattern(op->lhs);
this->VisitExpr(op->rhs);
}
void ExprVisitor::VisitType(const Type& t) { return; }
+void ExprVisitor::VisitSpan(const Span& span) { return; }
+
// visitor to implement apply
class ExprApplyVisit : public ExprVisitor {
public:
namespace relay {
Function::Function(tvm::Array<Var> params, Expr body, Type ret_type,
- tvm::Array<TypeVar> type_params, DictAttrs attrs) {
+ tvm::Array<TypeVar> type_params, DictAttrs attrs, Span span) {
ObjectPtr<FunctionNode> n = make_object<FunctionNode>();
CHECK(params.defined());
CHECK(type_params.defined());
n->ret_type = std::move(ret_type);
n->type_params = std::move(type_params);
n->attrs = std::move(attrs);
+ n->span = std::move(span);
data_ = std::move(n);
}
* most efficient code we need to obtain type information for the
* IR.
*
- * Like computation graphs the IR leaves most type information
- * implicit and relies performing analysis of the program to
- * generate this information.
+ * Similar to previous computation graph based IRs, the Relay IR leaves
+ * type information implicit and computes types by performing program
+ * analysis.
*
- * This pass given an expression `e` will infer a type `t` for
- * the expression simultaneous checking the property `e : t`
- * (i.e we can show e has type t).
+ * Given an expression `e` this pass infers a type `t` for
+ * the expression as well as simultaneously checking the property `e : t`
+ * (i.e., we can show e has type t).
*
- * If we can not infer a type or there are conflicting typing
- * constraints we will trigger an error.
+ * If we can not infer a type or there is a conflicting
+ * constraint it will emit errors.
*/
#include <tvm/ir/error.h>
#include <tvm/ir/type_functor.h>
std::vector<NodeEntry> inputs;
// control deps
std::vector<uint32_t> control_deps;
+
// JSON Loader
void LoadAttrs(dmlc::JSONReader* reader, TVMOpParam* param) {
int bitmask = 0;
.node_repl_history
node_modules
-# Relay parser: they are generated by ANTLR.
-RelayLexer.py
-RelayParser.py
-RelayVisitor.py
-
# Specific files
package-list
MANIFEST
# Span
def test_span():
- span = relay.Span(None, 1, 1)
- assert span.source == None
+ span = relay.Span(None, 1, 2, 3, 4)
+ assert span.source_name == None
assert span.line == 1
- assert span.column == 1
+ assert span.end_line == 2
+ assert span.column == 3
+ assert span.end_column == 4
assert span.same_as(span)
assert span == span
assert isinstance(span, relay.base.Span)
# span is not a node so we can't use graph_equal
# to test the round trip
back = tvm.ir.load_json(tvm.ir.save_json(span))
- assert back.source == span.source
+ assert back.source_name == span.source_name
assert back.line == span.line
+ assert back.end_line == span.end_line
assert back.column == span.column
+ assert back.end_column == span.end_column
def test_constant():
import tvm
from tvm import te
from tvm import relay
+import tvm.relay.testing
import pytest
from numpy import isclose
from typing import Union
from functools import wraps
-raises_parse_error = pytest.mark.xfail(raises=tvm._ffi.base.TVMError)
-SEMVER = "v0.0.4"
+
+
+SEMVER = "#[version = \"0.0.5\"]\n"
BINARY_OPS = {
"*": relay.multiply,
def graph_equal(lhs, rhs):
return tvm.ir.structural_equal(lhs, rhs, map_free_vars=True)
+def roundtrip_expr(expr):
+ text = tvm.relay.Expr.astext(expr, show_meta_data=False)
+ x = tvm.parser.parse_expr(text)
+ assert_graph_equal(x, expr)
+# Testing Utilities for expressions.
def roundtrip(expr):
- x = relay.fromtext(expr.astext())
+ x = tvm.parser.fromtext(expr.astext())
assert_graph_equal(x, expr)
-
def parse_text(code):
- expr = relay.fromtext(SEMVER + "\n" + code)
- roundtrip(expr)
+ expr = tvm.parser.parse_expr(code)
+ roundtrip_expr(expr)
return expr
-
def parses_as(code, expr):
# type: (str, relay.Expr) -> bool
parsed = parse_text(code)
result = graph_equal(parsed, expr)
return result
+# Testing Utilities for full modules.
+def parse_module(code):
+ mod = tvm.parser.parse(SEMVER + code)
+ roundtrip(mod)
+ return mod
def assert_parses_as(code, expr):
parsed = parse_text(code)
assert_graph_equal(parsed, expr)
+def assert_parse_module_as(code, mod):
+ parsed = parse_module(code)
+ assert_graph_equal(parsed, mod)
def get_scalar(x):
# type: (relay.Constant) -> (Union[float, int, bool])
def test_negative():
- assert isinstance(parse_text("let %x = 1; -%x").body, relay.Call)
+ # need to handle parsing non-literal operations
+ # assert isinstance(parse_text("let %x = 1; -%x").body, relay.Call)
assert get_scalar(parse_text("--10")) == 10
assert get_scalar(parse_text("---10")) == -10
assert graph_equal(parse_text("1 * 1 + 1 < 1 == 1"), parse_text("(((1 * 1) + 1) < 1) == 1"))
assert graph_equal(parse_text("1 == 1 < 1 + 1 * 1"), parse_text("1 == (1 < (1 + (1 * 1)))"))
-
-@pytest.mark.skip
def test_vars():
- # temp vars won't work b/c they start with a digit
- # # temp var
- # temp_var = parse_text("%1")
- # assert isinstance(temp_var, relay.Var)
- # assert temp_var.name == "1"
-
# var
var = parse_text("let %foo = (); %foo")
assert isinstance(var.body, relay.Var)
assert global_var.name_hint == "foo"
# operator id
- op = parse_text("foo")
+ op = parse_text("add")
assert isinstance(op, tvm.ir.Op)
- assert op.name == "foo"
+ assert op.name == "add"
+
+ # operator id with prefix
+ op = parse_text("nn.global_avg_pool2d")
+ assert isinstance(op, tvm.ir.Op)
+ assert op.name == "nn.global_avg_pool2d"
+
+def test_meta_ref():
+ with pytest.raises(tvm.error.DiagnosticError):
+ meta_op = parse_text("meta[type_key][1337]")
+ assert meta_op.attrs.node_type_key == "type_key"
+ assert meta_op.attrs.node_index == 1337
def test_let():
def test_seq():
assert_parses_as(
- "();; ()",
+ "(); ()",
relay.Let(
_,
UNIT,
)
-@raises_parse_error
-def test_graph_wrong_order():
- parse_text("%1 = (); %1")
+def test_graph_single():
+ assert_parses_as("%1 = (); %1", relay.Tuple([]))
-
-@raises_parse_error
def test_let_global_var():
- parse_text("let @x = 1; ()")
+ with pytest.raises(tvm.error.DiagnosticError):
+ parse_text("let @x = 1; ()")
-@raises_parse_error
def test_let_op():
- parse_text("let x = 1; ()")
+ with pytest.raises(tvm.error.DiagnosticError):
+ parse_text("let x = 1; ()")
def test_tuple():
)
)
- # attributes
- assert_parses_as(
- "fn (n=5) { () }",
- relay.Function([], UNIT, None, None, tvm.ir.make_node("DictAttrs", n=relay.const(5)))
- )
+ # Refactor the attribute syntax and printing.
+ #
+ # # attributes
+ # assert_parses_as(
+ # "fn (n=5) { () }",
+ # relay.Function([], UNIT, None, None, tvm.ir.make_node("DictAttrs", n=relay.const(5)))
+ # )
# TODO(@jmp): Crashes if %x isn't annnotated.
def test_defn():
- id_defn = parse_text(
+ id_defn = parse_module(
"""
def @id(%x: int32) -> int32 {
%x
def test_recursive_call():
- id_defn = parse_text(
+ id_defn = parse_module(
"""
def @id(%x: int32) -> int32 {
@id(%x)
)
-@raises_parse_error
def test_ifelse_scope():
- parse_text(
- """
- if (True) {
- let %x = ();
- ()
- } else {
- %x
- }
- """
- )
+ with pytest.raises(tvm.error.DiagnosticError):
+ parse_text(
+ """
+ if (True) {
+ let %x = ();
+ ()
+ } else {
+ %x
+ }
+ """
+ )
def test_call():
)
)
- # TODO(@jmp): re-enable after sequence parsing improvements
# curried function
- # curried_mult = relay.Var("curried_mult")
- # assert_parses_as(
- # """
- # let %curried_mult =
- # fn (%x) {
- # fn (%y) {
- # %x * %y
- # }
- # };
- # %curried_mult(0);
- # %curried_mult(0)(0)
- # """,
- # relay.Let(
- # curried_mult,
- # relay.Function(
- # [X],
- # relay.Function(
- # [Y],
- # relay.multiply(X, Y),
- # None,
- # []
- # ),
- # None,
- # []
- # ),
- # relay.Let(
- # _,
- # relay.Call(curried_mult, [relay.const(0)], None, None),
- # relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None)
- # )
- # )
- # )
+ curried_mult = relay.Var("curried_mult")
+ assert_parses_as(
+ """
+ let %curried_mult =
+ fn (%x) {
+ fn (%y) {
+ %x * %y
+ }
+ };
+ %curried_mult(0);
+ %curried_mult(0)(0)
+ """,
+ relay.Let(
+ curried_mult,
+ relay.Function(
+ [X],
+ relay.Function(
+ [Y],
+ relay.multiply(X, Y),
+ None,
+ []
+ ),
+ None,
+ []
+ ),
+ relay.Let(
+ _,
+ relay.Call(curried_mult, [relay.const(0)], None, None),
+ relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None)
+ )
+ )
+ )
# op
assert_parses_as(
[],
[relay.Constructor("Nil", [], glob_typ_var)])
mod[glob_typ_var] = prog
- assert_parses_as(
+ assert_parse_module_as(
"""
type Ayy { Nil }
""",
glob_typ_var = relay.GlobalTypeVar("Ayy")
prog = relay.TypeData(glob_typ_var, [], [])
mod[glob_typ_var] = prog
- assert_parses_as(
+ assert_parse_module_as(
"""
type Ayy { }
""",
relay.Constructor("Nil", [], list_var),
])
mod[list_var] = prog
- assert_parses_as(LIST_DEFN, mod)
+ assert_parse_module_as(LIST_DEFN, mod)
def test_multiple_type_param_defn():
])
mod = tvm.IRModule()
mod[glob_typ_var] = prog
- assert_parses_as(
+ assert_parse_module_as(
"""
type Either[A, B] {
Left(A),
input_var = relay.Var("xs", input_type)
rest_var = relay.Var("rest")
cons_case = relay.Let(
- _,
+ relay.var("", type_annotation=None),
UNIT,
relay.add(relay.const(1), relay.Call(length_var, [rest_var])))
body = relay.Match(input_var,
)
mod[length_var] = length_func
- assert_parses_as(
+ assert_parse_module_as(
"""
%s
def @length[A](%%xs: List[A]) -> int32 {
%s (%%xs) {
- Cons(_, %%rest) => {
- ();;
+ Cons(_, %%rest : List[A]) => {
+ ();
1 + @length(%%rest)
},
Nil => 0,
)
mod[make_singleton_var] = make_singleton_func
- assert_parses_as(
+ assert_parse_module_as(
"""
%s
)
-@raises_parse_error
def test_duplicate_adt_defn():
- parse_text(
- """
- %s
+ with pytest.raises(tvm.error.DiagnosticError):
+ parse_module(
+ """
+ %s
- type List[A] {
- Cons(A, List[A]),
- Nil,
- }
- """ % LIST_DEFN
- )
+ type List[A] {
+ Cons(A, List[A]),
+ Nil,
+ }
+ """ % LIST_DEFN
+ )
-@raises_parse_error
def test_duplicate_adt_cons():
- parse_text(
- """
- type Ayy { Lmao }
- type Haha { Lmao }
- """
- )
+ with pytest.raises(tvm.error.DiagnosticError):
+ parse_text(
+ """
+ type Ayy { Lmao }
+ type Haha { Lmao }
+ """
+ )
-@raises_parse_error
def test_duplicate_adt_cons_defn():
- parse_text(
- """
- type Ayy { Lmao }
- type Lmao { Ayy }
- """
- )
+ with pytest.raises(tvm.error.DiagnosticError):
+ parse_text(
+ """
+ type Ayy { Lmao }
+ type Lmao { Ayy }
+ """
+ )
-@raises_parse_error
def test_duplicate_global_var():
- parse_text(
- """
- def @id[A](%%x: A) -> A { x }
- def @id[A](%%x: A) -> A { x }
- """
- )
+ with pytest.raises(tvm.error.DiagnosticError):
+ parse_text(
+ """
+ def @id[A](%x: A) -> A { x }
+ def @id[A](%x: A) -> A { x }
+ """
+ )
def test_extern_adt_defn():
- # TODO(weberlo): update this test once extern is implemented
mod = tvm.IRModule()
extern_var = relay.GlobalTypeVar("T")
extern_def = relay.TypeData(extern_var, [typ_var], [])
mod[extern_var] = extern_def
- assert_parses_as(
+ assert_parse_module_as(
"""
extern type T[A]
""",
mod
)
+
def test_import_grad():
mod = tvm.IRModule()
mod.import_from_std("gradient.rly")
+def test_resnet():
+ mod, _ = relay.testing.resnet.get_workload()
+ text = mod.astext()
+ parsed_mod = tvm.parser.parse(text)
+ tvm.ir.assert_structural_equal(mod, parsed_mod)
+
+def inline_params(mod, params):
+ main_fn = mod["main"]
+ str_to_var = {}
+ for param in main_fn.params:
+ str_to_var[param.name_hint] = param
+
+ bind_map = {}
+ for param in params:
+ bind_map[str_to_var[param]] = relay.const(params[param])
+
+ body = relay.bind(main_fn.body, bind_map)
+ main_fn = relay.Function(relay.analysis.free_vars(body), body)
+ mod["main_fn"] = main_fn
+ return mod
+
+def test_resnet_inlined_params():
+ mod, params = relay.testing.resnet.get_workload()
+ mod = inline_params(mod, params)
+ text = mod.astext()
+ parsed_mod = tvm.parser.parse(text)
+ tvm.ir.assert_structural_equal(mod, parsed_mod)
+
if __name__ == "__main__":
- test_graph()
- test_comments()
- test_int_literal()
- test_float_literal()
- test_bool_literal()
- test_negative()
- test_bin_op()
- test_parens()
- test_op_assoc()
- test_let()
- test_seq()
- test_tuple()
- test_func()
- test_defn()
- test_recursive_call()
- test_ifelse()
- test_call()
- test_incomplete_type()
- test_builtin_types()
- test_tensor_type()
- test_function_type()
- test_tuple_type()
- test_adt_defn()
- test_empty_adt_defn()
- test_multiple_cons_defn()
- test_multiple_type_param_defn()
- test_match()
- test_adt_cons_expr()
- test_duplicate_adt_defn()
- test_duplicate_adt_cons()
- test_duplicate_adt_cons_defn()
- test_duplicate_global_var()
- test_extern_adt_defn()
- test_import_grad()
+ import sys
+ pytest.main(sys.argv)
+++ /dev/null
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import tvm
-from tvm import te
-from tvm import relay
-import pytest
-from numpy import isclose
-from typing import Union
-from functools import wraps
-raises_parse_error = pytest.mark.xfail(raises=tvm._ffi.base.TVMError)
-
-SEMVER = "v0.0.4"
-
-BINARY_OPS = {
- "*": relay.multiply,
- "/": relay.divide,
- "+": relay.add,
- "-": relay.subtract,
- "<": relay.less,
- ">": relay.greater,
- "<=": relay.less_equal,
- ">=": relay.greater_equal,
- "==": relay.equal,
- "!=": relay.not_equal,
-}
-
-TYPES = {
- "int8",
- "int16",
- "int32",
- "int64",
-
- "uint8",
- "uint16",
- "uint32",
- "uint64",
-
- "float16",
- "float32",
- "float64",
-
- "bool",
-
- "int8x4",
- "uint1x4",
- "float16x4",
-}
-
-LIST_DEFN = """
-type List[A] {
- Cons(A, List[A]),
- Nil,
-}
-"""
-
-def assert_graph_equal(lhs, rhs):
- tvm.ir.assert_structural_equal(lhs, rhs, map_free_vars=True)
-
-def graph_equal(lhs, rhs):
- return tvm.ir.structural_equal(lhs, rhs, map_free_vars=True)
-
-
-def roundtrip_expr(expr):
- x = tvm.parser.parse_expr(str(str(expr)))
- assert_graph_equal(x, expr)
-
-def roundtrip(expr):
- x = tvm.parser.fromtext(expr.astext())
- assert_graph_equal(x, expr)
-
-def parse_text(code):
- expr = tvm.parser.parse_expr(code)
- roundtrip_expr(expr)
- return expr
-
-
-def parses_as(code, expr):
- # type: (str, relay.Expr) -> bool
- parsed = parse_text(code)
- result = graph_equal(parsed, expr)
- return result
-
-def parse_module(code):
- mod = tvm.parser.parse(code)
- roundtrip(mod)
- return mod
-
-
-def assert_parses_as(code, expr):
- parsed = parse_text(code)
- assert_graph_equal(parsed, expr)
-
-def assert_parse_module_as(code, mod):
- parsed = parse_module(code)
- assert_graph_equal(parsed, mod)
-
-def get_scalar(x):
- # type: (relay.Constant) -> (Union[float, int, bool])
- return x.data.asnumpy().item()
-
-int32 = relay.scalar_type("int32")
-
-_ = relay.Var("_")
-X = relay.Var("x")
-Y = relay.Var("y")
-X_ANNO = relay.Var("x", int32)
-Y_ANNO = relay.Var("y", int32)
-
-UNIT = relay.Tuple([])
-
-
-def test_comments():
- assert_parses_as(
- """
- // This is a line comment!
- ()
- """,
- UNIT
- )
-
- assert_parses_as(
- """
- /* This is a block comment!
- This is still a block comment!
- */
- ()
- """,
- UNIT
- )
-
- assert_parses_as(
- """
- /* This is a block comment!
- /*Block comment is recursive!*/
- */
- ()
- """,
- UNIT
- )
-
-
-def test_int_literal():
- assert isinstance(parse_text("1"), relay.Constant)
- assert isinstance(parse_text("1").data, tvm.nd.NDArray)
-
- assert get_scalar(parse_text("1")) == 1
- assert get_scalar(parse_text("10")) == 10
- assert get_scalar(parse_text("0")) == 0
- assert get_scalar(parse_text("-100")) == -100
- assert get_scalar(parse_text("-05")) == -5
-
-
-def test_float_literal():
- assert get_scalar(parse_text("1.0f")) == 1.0
- assert isclose(get_scalar(parse_text("1.56667f")), 1.56667)
- assert get_scalar(parse_text("0.0f")) == 0.0
- assert get_scalar(parse_text("-10.0f")) == -10.0
-
- # scientific notation
- assert isclose(get_scalar(parse_text("1e-1f")), 1e-1)
- assert get_scalar(parse_text("1e+1f")) == 1e+1
- assert isclose(get_scalar(parse_text("1E-1f")), 1E-1)
- assert get_scalar(parse_text("1E+1f")) == 1E+1
- assert isclose(get_scalar(parse_text("1.0e-1f")), 1.0e-1)
- assert get_scalar(parse_text("1.0e+1f")) == 1.0e+1
- assert isclose(get_scalar(parse_text("1.0E-1f")), 1.0E-1)
- assert get_scalar(parse_text("1.0E+1f")) == 1.0E+1
-
-
-def test_bool_literal():
- assert get_scalar(parse_text("True")) == True
- assert get_scalar(parse_text("False")) == False
-
-
-def test_negative():
- # need to handle parsing non-literal operations
- # assert isinstance(parse_text("let %x = 1; -%x").body, relay.Call)
- assert get_scalar(parse_text("--10")) == 10
- assert get_scalar(parse_text("---10")) == -10
-
-
-def test_bin_op():
- for bin_op in BINARY_OPS.keys():
- assert_parses_as(
- "1 {} 1".format(bin_op),
- BINARY_OPS.get(bin_op)(relay.const(1), relay.const(1))
- )
-
-
-def test_parens():
- assert graph_equal(parse_text("1 * 1 + 1"), parse_text("(1 * 1) + 1"))
- assert not graph_equal(parse_text("1 * 1 + 1"), parse_text("1 * (1 + 1)"))
-
-
-def test_op_assoc():
- assert graph_equal(parse_text("1 * 1 + 1 < 1 == 1"), parse_text("(((1 * 1) + 1) < 1) == 1"))
- assert graph_equal(parse_text("1 == 1 < 1 + 1 * 1"), parse_text("1 == (1 < (1 + (1 * 1)))"))
-
-
-def test_vars():
- # var
- var = parse_text("let %foo = (); %foo")
- assert isinstance(var.body, relay.Var)
- assert var.body.name_hint == "foo"
-
- # global var
- global_var = parse_text("@foo")
- assert isinstance(global_var, relay.GlobalVar)
- assert global_var.name_hint == "foo"
-
- # operator id
- op = parse_text("add")
- assert isinstance(op, tvm.ir.Op)
- assert op.name == "add"
-
-
-def test_let():
- assert_parses_as(
- "let %x = 1; ()",
- relay.Let(
- X,
- relay.const(1),
- UNIT
- )
- )
-
- assert_parses_as(
- """
- let %x = 1;
- let %y = 2;
- ()
- """,
- relay.Let(
- X,
- relay.const(1),
- relay.Let(
- Y,
- relay.const(2),
- UNIT
- )
- )
- )
-
-
-def test_seq():
- assert_parses_as(
- "(); ()",
- relay.Let(
- _,
- UNIT,
- UNIT)
- )
-
- assert_parses_as(
- "let %_ = 1; ()",
- relay.Let(
- X,
- relay.const(1),
- UNIT
- )
- )
-
-
-def test_graph():
- code = "%0 = (); %1 = 1; (%0, %0, %1)"
- assert_parses_as(
- code,
- relay.Tuple([UNIT, UNIT, relay.const(1)])
- )
-
-
-@raises_parse_error
-def test_graph_wrong_order():
- parse_text("%1 = (); %1")
-
-
-@raises_parse_error
-def test_let_global_var():
- parse_text("let @x = 1; ()")
-
-
-@raises_parse_error
-def test_let_op():
- parse_text("let x = 1; ()")
-
-
-def test_tuple():
- assert_parses_as("()", relay.Tuple([]))
-
- assert_parses_as("(0,)", relay.Tuple([relay.const(0)]))
-
- assert_parses_as("(0, 1)", relay.Tuple([relay.const(0), relay.const(1)]))
-
- assert_parses_as("(0, 1, 2)", relay.Tuple([relay.const(0), relay.const(1), relay.const(2)]))
-
-
-def test_func():
- # 0 args
- assert_parses_as(
- "fn () { 0 }",
- relay.Function(
- [],
- relay.const(0),
- None,
- []
- )
- )
-
- # 1 arg
- assert_parses_as(
- "fn (%x) { %x }",
- relay.Function(
- [X],
- X,
- None,
- []
- )
- )
-
- # 2 args
- assert_parses_as(
- "fn (%x, %y) { %x + %y }",
- relay.Function(
- [X, Y],
- relay.add(X, Y),
- None,
- []
- )
- )
-
- # annotations
- assert_parses_as(
- "fn (%x: int32) -> int32 { %x }",
- relay.Function(
- [X_ANNO],
- X_ANNO,
- int32,
- []
- )
- )
-
- # Refactor the attribute syntax and printing.
- #
- # # attributes
- # assert_parses_as(
- # "fn (n=5) { () }",
- # relay.Function([], UNIT, None, None, tvm.ir.make_node("DictAttrs", n=relay.const(5)))
- # )
-
-
-# TODO(@jmp): Crashes if %x isn't annnotated.
-def test_defn():
- id_defn = parse_module(
- """
- def @id(%x: int32) -> int32 {
- %x
- }
- """)
- assert isinstance(id_defn, tvm.IRModule)
-
-
-def test_recursive_call():
- id_defn = parse_module(
- """
- def @id(%x: int32) -> int32 {
- @id(%x)
- }
- """)
- assert isinstance(id_defn, tvm.IRModule)
-
-
-def test_ifelse():
- assert_parses_as(
- """
- if (True) {
- 0
- } else {
- 1
- }
- """,
- relay.If(
- relay.const(True),
- relay.const(0),
- relay.const(1)
- )
- )
-
-
-@raises_parse_error
-def test_ifelse_scope():
- parse_text(
- """
- if (True) {
- let %x = ();
- ()
- } else {
- %x
- }
- """
- )
-
-
-def test_call():
- # select right function to call: simple ident case
- id_func = relay.Var("id")
- assert_parses_as(
- """
- let %id = fn (%x) { %x };
- 10 * %id(10)
- """,
- relay.Let(
- id_func,
- relay.Function([X], X, None, []),
- relay.multiply(relay.const(10), relay.Call(id_func, [relay.const(10)]))
- )
- )
-
- # 0 args
- constant = relay.Var("constant")
- assert_parses_as(
- """
- let %constant = fn () { 0 };
- %constant()
- """,
- relay.Let(
- constant,
- relay.Function([], relay.const(0), None, []),
- relay.Call(constant, [], None, None)
- )
- )
-
- # 1 arg
- id_var = relay.Var("id")
- assert_parses_as(
- """
- let %id = fn (%x) { %x };
- %id(1)
- """,
- relay.Let(
- id_var,
- relay.Function([X], X, None, []),
- relay.Call(id_var, [relay.const(1)], None, None)
- )
- )
-
- # 2 args
- multiply = relay.Var("multiply")
- assert_parses_as(
- """
- let %multiply = fn (%x, %y) { %x * %y };
- %multiply(0, 0)
- """,
- relay.Let(
- multiply,
- relay.Function(
- [X, Y],
- relay.multiply(X, Y),
- None,
- []
- ),
- relay.Call(multiply, [relay.const(0), relay.const(0)], None, None)
- )
- )
-
- # anonymous function
- assert_parses_as(
- """
- (fn (%x) { %x })(0)
- """,
- relay.Call(
- relay.Function(
- [X],
- X,
- None,
- []
- ),
- [relay.const(0)],
- None,
- None
- )
- )
-
- # curried function
- curried_mult = relay.Var("curried_mult")
- assert_parses_as(
- """
- let %curried_mult =
- fn (%x) {
- fn (%y) {
- %x * %y
- }
- };
- %curried_mult(0);
- %curried_mult(0)(0)
- """,
- relay.Let(
- curried_mult,
- relay.Function(
- [X],
- relay.Function(
- [Y],
- relay.multiply(X, Y),
- None,
- []
- ),
- None,
- []
- ),
- relay.Let(
- _,
- relay.Call(curried_mult, [relay.const(0)], None, None),
- relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None)
- )
- )
- )
-
- # op
- assert_parses_as(
- "abs(1)",
- relay.Call(relay.op.get("abs"), [relay.const(1)], None, None)
- )
-
-# Types
-
-
-def test_incomplete_type():
- assert_parses_as(
- "let %_ : _ = (); ()",
- relay.Let(
- _,
- UNIT,
- UNIT
- )
- )
-
-
-def test_builtin_types():
- for builtin_type in TYPES:
- parse_text("let %_ : {} = (); ()".format(builtin_type))
-
-
-def test_tensor_type():
- assert_parses_as(
- "let %_ : Tensor[(), float32] = (); ()",
- relay.Let(
- relay.Var("_", relay.TensorType((), "float32")),
- UNIT,
- UNIT
- )
- )
-
- assert_parses_as(
- "let %_ : Tensor[(1), float32] = (); ()",
- relay.Let(
- relay.Var("_", relay.TensorType((1,), "float32")),
- UNIT,
- UNIT
- )
- )
-
- assert_parses_as(
- "let %_ : Tensor[(1, 1), float32] = (); ()",
- relay.Let(
- relay.Var("_", relay.TensorType((1, 1), "float32")),
- UNIT,
- UNIT
- )
- )
-
-
-def test_function_type():
- assert_parses_as(
- """
- let %_: fn () -> int32 = fn () -> int32 { 0 }; ()
- """,
- relay.Let(
- relay.Var("_", relay.FuncType([], int32, [], [])),
- relay.Function([], relay.const(0), int32, []),
- UNIT
- )
- )
-
- assert_parses_as(
- """
- let %_: fn (int32) -> int32 = fn (%x: int32) -> int32 { 0 }; ()
- """,
- relay.Let(
- relay.Var("_", relay.FuncType([int32], int32, [], [])),
- relay.Function([relay.Var("x", int32)], relay.const(0), int32, []),
- UNIT
- )
- )
-
- assert_parses_as(
- """
- let %_: fn (int32, int32) -> int32 = fn (%x: int32, %y: int32) -> int32 { 0 }; ()
- """,
- relay.Let(
- relay.Var("_", relay.FuncType([int32, int32], int32, [], [])),
- relay.Function([relay.Var("x", int32), relay.Var("y", int32)], relay.const(0), int32, []),
- UNIT
- )
- )
-
-
-def test_tuple_type():
- assert_parses_as(
- """
- let %_: () = (); ()
- """,
- relay.Let(
- relay.Var("_", relay.TupleType([])),
- UNIT,
- UNIT
- )
- )
-
- assert_parses_as(
- """
- let %_: (int32,) = (0,); ()
- """,
- relay.Let(
- relay.Var("_", relay.TupleType([int32])),
- relay.Tuple([relay.const(0)]),
- UNIT
- )
- )
-
- assert_parses_as(
- """
- let %_: (int32, int32) = (0, 1); ()
- """,
- relay.Let(
- relay.Var("_", relay.TupleType([int32, int32])),
- relay.Tuple([relay.const(0), relay.const(1)]),
- UNIT
- )
- )
-
-
-def test_adt_defn():
- mod = tvm.IRModule()
-
- glob_typ_var = relay.GlobalTypeVar("Ayy")
- prog = relay.TypeData(
- glob_typ_var,
- [],
- [relay.Constructor("Nil", [], glob_typ_var)])
- mod[glob_typ_var] = prog
- assert_parse_module_as(
- """
- type Ayy { Nil }
- """,
- mod
- )
-
-
-def test_empty_adt_defn():
- mod = tvm.IRModule()
-
- glob_typ_var = relay.GlobalTypeVar("Ayy")
- prog = relay.TypeData(glob_typ_var, [], [])
- mod[glob_typ_var] = prog
- assert_parse_module_as(
- """
- type Ayy { }
- """,
- mod
- )
-
-
-def test_multiple_cons_defn():
- mod = tvm.IRModule()
-
- list_var = relay.GlobalTypeVar("List")
- typ_var = relay.TypeVar("A")
- prog = relay.TypeData(
- list_var,
- [typ_var],
- [
- relay.Constructor("Cons", [typ_var, list_var(typ_var)], list_var),
- relay.Constructor("Nil", [], list_var),
- ])
- mod[list_var] = prog
- assert_parse_module_as(LIST_DEFN, mod)
-
-
-def test_multiple_type_param_defn():
- glob_typ_var = relay.GlobalTypeVar("Either")
- typ_var_a = relay.TypeVar("A")
- typ_var_b = relay.TypeVar("B")
- prog = relay.TypeData(
- glob_typ_var,
- [typ_var_a, typ_var_b],
- [
- relay.Constructor("Left", [typ_var_a], glob_typ_var),
- relay.Constructor("Right", [typ_var_b], glob_typ_var),
- ])
- mod = tvm.IRModule()
- mod[glob_typ_var] = prog
- assert_parse_module_as(
- """
- type Either[A, B] {
- Left(A),
- Right(B),
- }
- """,
- mod
- )
-
-
-def test_match():
- # pair each match keyword with whether it specifies a complete match or not
- match_keywords = [("match", True), ("match?", False)]
- for (match_keyword, is_complete) in match_keywords:
- mod = tvm.IRModule()
-
- list_var = relay.GlobalTypeVar("List")
- typ_var = relay.TypeVar("A")
- cons_constructor = relay.Constructor(
- "Cons", [typ_var, list_var(typ_var)], list_var)
- nil_constructor = relay.Constructor("Nil", [], list_var)
- list_def = relay.TypeData(
- list_var,
- [typ_var],
- [cons_constructor, nil_constructor])
- mod[list_var] = list_def
-
- length_var = relay.GlobalVar("length")
- typ_var = relay.TypeVar("A")
- input_type = list_var(typ_var)
- input_var = relay.Var("xs", input_type)
- rest_var = relay.Var("rest")
- cons_case = relay.Let(
- relay.var("", type_annotation=None),
- UNIT,
- relay.add(relay.const(1), relay.Call(length_var, [rest_var])))
- body = relay.Match(input_var,
- [relay.Clause(
- relay.PatternConstructor(
- cons_constructor,
- [relay.PatternWildcard(), relay.PatternVar(rest_var)]),
- cons_case),
- relay.Clause(
- relay.PatternConstructor(nil_constructor, []),
- relay.const(0))],
- complete=is_complete
- )
- length_func = relay.Function(
- [input_var],
- body,
- int32,
- [typ_var]
- )
- mod[length_var] = length_func
-
- assert_parse_module_as(
- """
- %s
-
- def @length[A](%%xs: List[A]) -> int32 {
- %s (%%xs) {
- Cons(_, %%rest : List[A]) => {
- ();
- 1 + @length(%%rest)
- },
- Nil => 0,
- }
- }
- """ % (LIST_DEFN, match_keyword),
- mod
- )
-
-
-def test_adt_cons_expr():
- mod = tvm.IRModule()
-
- list_var = relay.GlobalTypeVar("List")
- typ_var = relay.TypeVar("A")
- cons_constructor = relay.Constructor(
- "Cons", [typ_var, list_var(typ_var)], list_var)
- nil_constructor = relay.Constructor("Nil", [], list_var)
- list_def = relay.TypeData(
- list_var,
- [typ_var],
- [cons_constructor, nil_constructor])
- mod[list_var] = list_def
-
- make_singleton_var = relay.GlobalVar("make_singleton")
- input_var = relay.Var("x", int32)
- make_singleton_func = relay.Function(
- [input_var],
- cons_constructor(input_var, nil_constructor()),
- list_var(int32)
- )
- mod[make_singleton_var] = make_singleton_func
-
- assert_parse_module_as(
- """
- %s
-
- def @make_singleton(%%x: int32) -> List[int32] {
- Cons(%%x, Nil)
- }
- """ % LIST_DEFN,
- mod
- )
-
-
-@raises_parse_error
-def test_duplicate_adt_defn():
- parse_module(
- """
- %s
-
- type List[A] {
- Cons(A, List[A]),
- Nil,
- }
- """ % LIST_DEFN
- )
-
-
-@raises_parse_error
-def test_duplicate_adt_cons():
- parse_text(
- """
- type Ayy { Lmao }
- type Haha { Lmao }
- """
- )
-
-
-@raises_parse_error
-def test_duplicate_adt_cons_defn():
- parse_text(
- """
- type Ayy { Lmao }
- type Lmao { Ayy }
- """
- )
-
-
-@raises_parse_error
-def test_duplicate_global_var():
- parse_text(
- """
- def @id[A](%x: A) -> A { x }
- def @id[A](%x: A) -> A { x }
- """
- )
-
-
-def test_extern_adt_defn():
- # TODO(weberlo): update this test once extern is implemented
- mod = tvm.IRModule()
-
- extern_var = relay.GlobalTypeVar("T")
- typ_var = relay.TypeVar("A")
- extern_def = relay.TypeData(extern_var, [typ_var], [])
- mod[extern_var] = extern_def
-
- assert_parse_module_as(
- """
- extern type T[A]
- """,
- mod
- )
-
-@pytest.mark.skip("not yet tested on parser 2.0")
-def test_import_grad():
- mod = tvm.IRModule()
- mod.import_from_std("gradient.rly")
-
-if __name__ == "__main__":
- import sys
- pytest.main(sys.argv)
import tvm
from tvm import te
from tvm import relay
-import tvm.relay.testing
+from tvm.relay import testing
import numpy as np
from tvm.relay import Expr
from tvm.relay.analysis import free_vars
-do_print = [False]
+DEBUG_PRINT = False
-SEMVER = "v0.0.4\n"
+SEMVER = "#[version = \"0.0.5\"]\n"
-def astext(p, unify_free_vars=False):
- txt = p.astext()
- if isinstance(p, Expr) and free_vars(p):
- return txt
- x = relay.fromtext(txt)
- if unify_free_vars:
- tvm.ir.assert_structural_equal(x, p, map_free_vars=True)
+def astext(program, unify_free_vars=False):
+ text = program.astext()
+ print(text)
+ if isinstance(program, Expr):
+ roundtrip_program = tvm.parser.parse_expr(text)
else:
- tvm.ir.assert_structural_equal(x, p)
- return txt
+ roundtrip_program = tvm.parser.fromtext(text)
+
+ tvm.ir.assert_structural_equal(roundtrip_program, program, map_free_vars=True)
+
+ return text
def show(text):
- if do_print[0]:
+ if DEBUG_PRINT:
print("---------------------------")
print(text)
def test_mlp():
- net, params = tvm.relay.testing.mlp.get_workload(batch_size=1)
+ net, _ = tvm.relay.testing.mlp.get_workload(batch_size=1)
astext(net)
def test_resnet():
- net, params = tvm.relay.testing.resnet.get_workload(batch_size=1)
+ net, _ = tvm.relay.testing.resnet.get_workload(batch_size=1)
astext(net)
def test_mobilenet():
- net, params = tvm.relay.testing.mobilenet.get_workload(batch_size=1)
+ net, _ = tvm.relay.testing.mobilenet.get_workload(batch_size=1)
astext(net)
def test_dqn():
- net, params = tvm.relay.testing.dqn.get_workload(batch_size=1)
+ net, _ = tvm.relay.testing.dqn.get_workload(batch_size=1)
astext(net)
def test_dcgan():
- net, params = tvm.relay.testing.dcgan.get_workload(batch_size=1)
+ net, _ = tvm.relay.testing.dcgan.get_workload(batch_size=1)
astext(net)
def test_lstm():
- net, params = tvm.relay.testing.lstm.get_workload(1, 1)
+ net, _ = tvm.relay.testing.lstm.get_workload(1, 1)
astext(net)
- net, params = tvm.relay.testing.lstm.get_workload(4, 4)
+ net, _ = tvm.relay.testing.lstm.get_workload(4, 4)
astext(net)
def test_inception_v3():
- net, params = tvm.relay.testing.inception_v3.get_workload(batch_size=1)
+ net, _ = tvm.relay.testing.inception_v3.get_workload(batch_size=1)
astext(net)
def test_squeezenet():
for version in ['1.0', '1.1']:
- net, params = tvm.relay.testing.squeezenet.get_workload(batch_size=1, version=version)
+ net, _ = tvm.relay.testing.squeezenet.get_workload(batch_size=1, version=version)
astext(net)
def test_vgg():
- net, params = tvm.relay.testing.vgg.get_workload(batch_size=1)
+ net, _ = tvm.relay.testing.vgg.get_workload(batch_size=1)
astext(net)
def test_densenet():
- net, params = tvm.relay.testing.densenet.get_workload(batch_size=1)
+ net, _ = tvm.relay.testing.densenet.get_workload(batch_size=1)
astext(net)
Cons
}
"""
- mod = relay.fromtext(SEMVER + type_def_str + main_def_str)
+ mod = tvm.parser.parse(SEMVER + type_def_str + main_def_str)
mod_str = str(mod)
# ensure constructors are printed correctly in type definitions (with their
# signature) and as exprs (without their signature)
if __name__ == "__main__":
- do_print[0] = True
- test_lstm()
- test_zeros()
- test_meta_data()
- test_let_inlining()
- test_resnet()
- test_mobilenet()
- test_mlp()
- test_dqn()
- test_dcgan()
- test_squeezenet()
- test_inception_v3()
- test_vgg()
- test_densenet()
- test_func()
- test_env()
- test_call_attrs()
- test_let_if_scope()
- test_variable_name()
- test_call_node_order()
- test_unapplied_constructor()
- test_null_attribute()
+ import sys
+ pytext.argv(sys.argv)
mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df))
df = mod["main"]
- df_parsed = relay.parser.fromtext(
+ df_parsed = tvm.parser.parse_expr(
"""
- v0.0.4
+ #[version = "0.0.5"]
fn (%x: Tensor[(1), float32], %y: Tensor[(1), float32],
%z: Tensor[(1), float32], %w: Tensor[(1), float32])
-> (Tensor[(1), float32],
mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df))
df = mod["main"]
- df_parsed = relay.parser.fromtext(
+ df_parsed = tvm.parser.parse_expr(
"""
- v0.0.4
+ #[version = "0.0.5"]
fn (%x: Tensor[(1), float32], %y: Tensor[(1), float32],
%z: Tensor[(1), float32], %w: Tensor[(1), float32])
-> ((Tensor[(1), float32], Tensor[(1), float32]),
import tvm.relay.transform as _transform
def test_eta_expand_global_var():
- mod = relay.fromtext(r"""
- v0.0.4
+ mod = tvm.parser.fromtext(r"""
+ #[version = "0.0.5"]
def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] {
%x
}
- def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) {
+ def @main() -> fn(Tensor[(), int32]) -> Tensor[(), int32] {
@aux
}
""")
seq = tvm.transform.Sequential([_transform.EtaExpand(expand_global_var=True)])
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
- expected = relay.fromtext(r"""
- v0.0.4
+ expected = tvm.parser.fromtext(r"""
+ #[version = "0.0.5"]
def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] {
%x
}
- def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) {
+ def @main() -> fn(Tensor[(), int32]) -> Tensor[(), int32] {
fn (%x: Tensor[(), int32]) -> Tensor[(), int32] {
@aux(%x)
}
def test_eta_expand_constructor():
- mod = relay.fromtext(r"""
- v0.0.4
+ mod = tvm.parser.fromtext(r"""
+ #[version = "0.0.5"]
type List[A] {
Cons(A, List[A]),
Nil,
}
- def @main[A]() -> (fn(A, List[A]) -> List[A]) {
+ def @main[A]() -> fn(A, List[A]) -> List[A] {
Cons
}
""")
seq = tvm.transform.Sequential([_transform.EtaExpand(expand_constructor=True)])
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
- expected = relay.fromtext(r"""
- v0.0.4
+ expected = tvm.parser.fromtext(r"""
+ #[version = "0.0.5"]
type List[A] {
Cons(A, List[A]),
Nil,
}
- def @main[A]() -> (fn(A, List[A]) -> List[A]) {
+ def @main[A]() -> fn(A, List[A]) -> List[A] {
fn [A](%x: A, %xs: List[A]) -> List[A] {
Cons(%x, %xs)
}
def test_inf_loop_case():
code = """
-v0.0.4
+#[version = "0.0.5"]
type Arith[A] {
Zero,
Const(A),
}
}
"""
- relay.fromtext(code)
+ tvm.parser.fromtext(code)
# fromtext parse the module, then checked it (which include strictness checking).
if __name__ == "__main__":