# Source file lists
file(GLOB_RECURSE COMPILER_SRCS
+ src/auto_schedule/*.cc
src/node/*.cc
src/ir/*.cc
src/arith/*.cc
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-import, redefined-builtin
+""" Namespace for TVM Auto-scheduler. """
+
+from . import compute_dag
+from . import measure
+from . import measure_record
+from . import loop_state
+from . import utils
+from . import workload_registry
+
+# Shortcut
+from .compute_dag import ComputeDAG
+from .auto_schedule import SearchTask, TuningOptions, HardwareParams, \
+ auto_schedule, EmptyPolicy
+from .measure import MeasureInput, LocalBuilder, LocalRunner
+from .measure_record import RecordToFile, RecordReader, load_best, \
+ load_records, save_records
+from .workload_registry import register_workload, make_workload_key
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Register FFI APIs from C++ for the namespace tvm.auto_schedule. """
+import tvm._ffi
+
+
+tvm._ffi._init_api("auto_schedule", __name__)
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+User interface for TVM Auto-scheduler.
+
+The basic schedule search process for TVM Auto-scheduler is designed to be:
+`Program sampling` -> `Performance Tuning`.
+
+In `Program sampling`, we use some predefined precise or heuristic rules to generate several
+initial schedules. Based on these initial starting points, we perform `Performance Tuning` which
+uses cost model based evolutionary search to select schedules with the best performance.
+
+Candidate schedules are measured against the specific hardware target.
+"""
+
+import tvm._ffi
+from tvm.runtime import Object
+from .measure import LocalBuilder, LocalRunner
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("auto_schedule.HardwareParams")
+class HardwareParams(Object):
+ """ The parameters of target hardware used to guide the search policy
+
+ TODO(jcf94): This is considered to be merged with the new Target specification:
+ https://discuss.tvm.ai/t/rfc-tvm-target-specification/6844
+
+ Parameters
+ ----------
+ num_cores : int
+ The number of device cores.
+ vector_unit_bytes : int
+ The width of vector units in bytes.
+ cache_line_bytes : int
+ The size of cache line in bytes.
+ """
+ def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes):
+ self.__init_handle_by_constructor__(_ffi_api.HardwareParams, num_cores,
+ vector_unit_bytes, cache_line_bytes)
+
+
+@tvm._ffi.register_object("auto_schedule.SearchTask")
+class SearchTask(Object):
+ """ The computation information and hardware parameters for a specific schedule search task.
+
+ Parameters
+ ----------
+ dag : ComputeDAG
+ The ComputeDAG for the corresponding compute declaration.
+ workload_key : str
+ The workload key for the corresponding compute declaration.
+ target : tvm.target.Target
+ The target device of this search task.
+ target_host : Optional[tvm.target.Target]
+ The target host device of this search task.
+ hardware_params : Optional[HardwareParams]
+ Hardware parameters used in this search task.
+ """
+ def __init__(self, dag, workload_key, target, target_host=None,
+ hardware_params=None):
+ self.__init_handle_by_constructor__(_ffi_api.SearchTask, dag,
+ workload_key, target, target_host,
+ hardware_params)
+
+
+@tvm._ffi.register_object("auto_schedule.SearchPolicy")
+class SearchPolicy(Object):
+ """ The base class of search policies. """
+
+
+@tvm._ffi.register_object("auto_schedule.EmptyPolicy")
+class EmptyPolicy(SearchPolicy):
+ """ This is an example empty search policy which will always generate
+ the init state of ComputeDAG.
+ """
+ def __init__(self):
+ self.__init_handle_by_constructor__(_ffi_api.EmptyPolicy)
+
+
+@tvm._ffi.register_object("auto_schedule.TuningOptions")
+class TuningOptions(Object):
+ """ This controls the options of performance tuning.
+
+ Parameters
+ ----------
+ num_measure_trials: int = 0
+ The number of measurement trials.
+ The search policy measures `num_measure_trials` schedules in total and returns the best one
+ among them.
+ With `num_measure_trials` == 0, the policy will do the schedule search but won't involve
+ measurement. This can be used to get a runnable schedule quickly without auto-tuning.
+ early_stopping: Optional[int]
+ Stop the tuning early if getting no improvement after n measurements.
+ num_measures_per_round: int = 64
+ The number of schedules to be measured at each search round.
+ The whole schedule search process will try a total number of `num_measure_trials` in several
+ rounds.
+ verbose: int = 1
+ Verbosity level. 0 for silent, 1 to output information during schedule search.
+ builder: Union[ProgramBuilder, str] = 'local'
+ ProgramBuilder which builds the program.
+ runner: Union[ProgramRunner, str] = 'local'
+ ProgramRunner which runs the program and measures time costs.
+ measure_callbacks: Optional[List[MeasureCallback]]
+ Callback functions called after each measurement.
+ Candidates:
+ - auto_schedule.RecordToFile
+ pre_search_callbacks: Optional[List[SearchCallback]]
+ Callback functions called before the search process.
+ Candidates:
+ - auto_schedule.PreloadMeasuredStates
+ - auto_schedule.PreloadCustomSketchRule
+ TODO(jcf94): Add these implementation in later PRs.
+ """
+ def __init__(self, num_measure_trials=0, early_stopping=None, num_measures_per_round=64,
+ verbose=1, builder='local', runner='local', measure_callbacks=None,
+ pre_search_callbacks=None):
+ if isinstance(builder, str):
+ if builder == 'local':
+ builder = LocalBuilder()
+ else:
+ raise ValueError("Invalid builder: " + builder)
+ elif not isinstance(builder, tvm.auto_schedule.measure.ProgramBuilder):
+ raise ValueError("Invalid builder: " + builder +
+ " . TuningOptions expects a ProgramBuilder or string.")
+
+ if isinstance(runner, str):
+ if runner == 'local':
+ runner = LocalRunner()
+ else:
+ raise ValueError("Invalid runner: " + runner)
+ elif not isinstance(runner, tvm.auto_schedule.measure.ProgramRunner):
+ raise ValueError("Invalid runner: " + runner +
+ " . TuningOptions expects a ProgramRunner or string.")
+
+ self.__init_handle_by_constructor__(
+ _ffi_api.TuningOptions, num_measure_trials, early_stopping if early_stopping else -1,
+ num_measures_per_round, verbose, builder, runner, measure_callbacks,
+ pre_search_callbacks)
+
+
+def auto_schedule(task, search_policy='default', tuning_options=None):
+ """ Do auto scheduling for a computation declaration.
+
+ The task parameter can be a `string` as workload_key, or directly
+ passing a `SearchTask` as input.
+
+ Parameters
+ ----------
+ task : SearchTask
+ The SearchTask for the computation declaration.
+ search_policy : Union[SearchPolicy, str] = 'default'
+ The search policy to be used for schedule search.
+ tuning_options : Optional[TuningOptions]
+ Tuning and measurement options.
+
+ Returns
+ -------
+ A `te.schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`.
+ """
+ if not isinstance(task, SearchTask):
+ raise ValueError("Invalid task: " + task +
+ " . `auto_schedule.auto_schedule` expects a SearchTask.")
+
+ if isinstance(search_policy, str):
+ if search_policy == 'default':
+ # TODO(jcf94): This is an example policy for minimum system, will be upgrated to
+ # formal search policy later.
+ search_policy = EmptyPolicy()
+ else:
+ raise ValueError("Invalid search policy: " + search_policy)
+ elif not isinstance(search_policy, SearchPolicy):
+ raise ValueError("Invalid search policy: " + search_policy +
+ " . `auto_schedule.auto_schedule` expects a SearchPolicy or a string.")
+
+ sch, tensors = _ffi_api.AutoSchedule(task, search_policy,
+ tuning_options if tuning_options else TuningOptions())
+ return sch, tensors
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" The TVM Auto-scheduler computational graph and related program analyses. """
+
+import hashlib
+
+import tvm._ffi
+from tvm.runtime import Object
+from tvm.te import PlaceholderOp, ComputeOp
+
+from .loop_state import State, StateObject
+from .utils import get_const_tuple
+from .workload_registry import workload_key_to_tensors
+
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("auto_schedule.ComputeDAG")
+class ComputeDAG(Object):
+ """
+ The TVM Auto-scheduler computational graph and related program analyses.
+
+ We convert a compute declaration described by `tvm.compute` (could be a single operator or a
+ subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
+ a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the
+ total float operation count, consumer/producer relations of each operation stage, whether an
+ operation stage should be tiled/compute inlined ...). These analyses can help the search policy
+ to make decisions during search process.
+ ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and
+ TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
+ `LoopState` with extra information got from TVM schedule ...).
+
+ Parameters
+ ----------
+ compute : Union[List[Tensor], str]
+ `Tensor`s or workload key for a compute declaration.
+ """
+ def __init__(self, compute):
+ if isinstance(compute, str):
+ compute = workload_key_to_tensors(compute)
+ elif isinstance(compute, list):
+ for item in compute:
+ if not isinstance(item, tvm.te.Tensor):
+ raise ValueError("The input of ComputeDAG should be a list of Tensor")
+ else:
+ raise ValueError("Invalid compute: " + compute +
+ " . ComputeDAG expects a string or list of Tensor")
+ self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, compute)
+
+ def get_init_state(self):
+ """ Get the init state of this ComputeDAG.
+
+ Returns
+ -------
+ state : State
+ The initial State without any transform steps.
+ """
+ return State(self.init_state, self)
+
+ def apply_steps_from_state(self, state):
+ """
+ Apply the history transform steps from a State to get a TVM schedule.
+
+ Parameters
+ ----------
+ state : Union[State, StateObject]
+ The state from which we get transform steps.
+
+ Returns
+ -------
+ A `te.schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`.
+ """
+ state_obj = state if isinstance(state, StateObject) else state.state_object
+ return _ffi_api.ComputeDAGApplyStepsFromState(self, state_obj)
+
+ def print_python_code_from_state(self, state):
+ """
+ Print transform steps in the history of a State as TVM's python schedule primitive.
+
+ This is used to print transformation steps for debugging.
+ Use `apply_steps_from_state` if you want to get a schedule for code generation.
+
+ Parameters
+ ----------
+ state : Union[State, StateObject]
+ The state from which we get transform steps.
+
+ Returns
+ -------
+ str : Str
+ The Python schedule code.
+ """
+ state_obj = state if isinstance(state, StateObject) else state.state_object
+ return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state_obj)
+
+ def infer_bound_from_state(self, state):
+ """
+ Infer and fill the bound of all iterators of a state.
+
+ The states may lose complete bound information after some transform steps
+ (e.g., compute_at).
+ We can call this function to infer and fill all the bound information.
+ This function calls TVM InferBound pass internally to get the bound.
+ The returned state of this function is guaranteed to have complete iterator extent
+ information.
+
+ Parameters
+ ----------
+ state : Union[State, StateObject]
+ The state from which we get transform steps.
+
+ Returns
+ -------
+ state : State
+ The State with complete bound information.
+ """
+ state_obj = state if isinstance(state, StateObject) else state.state_object
+ return State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self)
+
+ def __hash__(self):
+ # TODO(merrymercy): Implement this more carefully and move this to c++ as a member function
+ # of ComputeDAG
+ str_key = ''
+ for op in self.ops:
+ t = op.output(0)
+ if isinstance(op, PlaceholderOp):
+ str_key += 'placeholder,'
+ str_key += str(get_const_tuple(t.shape)) + ','
+ str_key += t.dtype + ';'
+ elif isinstance(op, ComputeOp):
+ str_key += str(t.op.body) + ','
+ str_key += str(get_const_tuple(t.shape)) + ','
+ str_key += t.dtype + ';'
+ else:
+ raise ValueError("Invalid op: " + op)
+
+ str_key = str_key.encode(encoding='utf-8')
+ return hashlib.md5(str_key).hexdigest()
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-import
+
+"""
+The definition of the "state" in search.
+
+Each LoopState corresponds to a schedule for its ComputeDAG.
+A LoopState consists of: 1. a current loop structure; 2. a list of transformation steps used to
+construct the loop structure.
+The loop structure keeps a preview of how the schedule will finally look like after lowering the
+current state (e.g. number of iterators, the extent of each iterator, the compute_at locations ...).
+During the schedule search process, the loop structure can provide search policy with necessary
+information on how to manipulate the current state.
+The transform history is a sequence of `TransformStep` which will finally be mapped to TVM schedule
+primitives. The steps can also be used for the serialization of a state.
+
+The LoopState can be seen as a lightweight loop structure IR specifically for schedule search.
+We don't use the existing TVM IR but to extend a new structure on it is because:
+1. We want fast incremental change to the loop structures. The search policy needs to get the
+immediate loop structures update rather than after TVM lowering;
+2. We want serializable transform history for replay, backtracking, and mutation;
+3. We may create some macro schedule primitives that represent the combination of several
+TVM schedule primitives.
+
+When the search is complete, we will lower the state to TVM IR with TVM's schedule primitives.
+Since we share a lot of common objects during search, the transformation is implemented in
+copy on write style. All objects are immutable, which is similar to TVM IR.
+"""
+
+import tvm._ffi
+from tvm.te.tensor import Operation, Tensor
+from tvm.runtime import Object
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("auto_schedule.Iterator")
+class Iterator(Object):
+ """ A loop iterator structure. """
+
+
+@tvm._ffi.register_object("auto_schedule.Stage")
+class Stage(Object):
+ """ A stage in the compute declaration. Similar to tvm.te.schedule.Stage. """
+
+
+@tvm._ffi.register_object("auto_schedule.State")
+class StateObject(Object):
+ """ The internal State object """
+ def __eq__(self, other):
+ return _ffi_api.StateEqual(self, other)
+
+
+class State:
+ """
+ A state in the search process. It consists of the current loop structure
+ and a list of transformation steps used to construct it.
+
+ Each State corresponds to a specific schedule for its ComputeDAG.
+
+ Parameters
+ ----------
+ state_object : StateObject
+ The StateObject corresponding to C++ internal State object.
+ dag : ComputeDAG
+ The original ComputeDAG of this State.
+
+ Notes
+ -----
+ This is a wrapper class of StateObject to deal with copy-on-write property
+ """
+ def __init__(self, state_object, dag):
+ self.state_object = state_object
+ self.compute_dag = dag
+
+ self.stage_id_map = {} # A dict maps operation to stage id
+ self._update_stage_id_map()
+
+ @property
+ def stages(self):
+ """
+ Returns
+ -------
+ stages : List[Stage]
+ """
+ return self.state_object.stages
+
+ @property
+ def stage_ops(self):
+ """
+ Returns
+ -------
+ ops: List[Operation]
+ """
+ return [stage.op for stage in self.stages]
+
+ def reorder(self, stage, order):
+ """ Schedule primitive corresponds to te.reorder.
+
+ Parameters
+ ----------
+ stage : Union[int, Operation, Tensor]
+ The Stage to be reordered, can be a Stage order index, Stage operation or stage
+ output tensor.
+ order : List[Iterator]
+ Iterators in the expected order.
+ """
+ stage_id = self._resolve_stage_id(stage)
+
+ self.state_object = _ffi_api.StateReorder(self.state_object, stage_id, order)
+
+ def split(self, stage, iterator, lengths, inner_to_outer=True):
+ """ Schedule primitive corresponds to te.split.
+
+ This API supports multiple split factors. (e.g. with 2 split factors, the original iterator
+ will be split to 3 parts, use `inner_to_outer` to control the split order)
+
+ Parameters
+ ----------
+ stage : Union[int, Operation, Tensor]
+ The Stage to be split, can be a Stage order index, Stage operation or stage
+ output tensor.
+ iterator : Iterator
+ The iterator to be split.
+ lengths: List[int]
+ The multiple split factors. Can be None to be filled by search policy.
+ inner_to_outer: boolean = True
+ Whether the factor go from inner to outer, or from outer to inner.
+
+ Returns
+ -------
+ res_its : List[Iterator]
+ The splitted new Iterators
+ """
+ stage_id = self._resolve_stage_id(stage)
+
+ self.state_object, res = _ffi_api.StateSplit(self.state_object, stage_id, iterator, lengths,
+ inner_to_outer)
+ return res
+
+ def fuse(self, stage, iters):
+ """ Schedule primitive corresponds to te.fuse.
+
+ Parameters
+ ----------
+ stage : Union[int, Operation, Tensor]
+ The Stage to be fused, can be a Stage order index, Stage operation or stage
+ output tensor.
+ iters : List[Iterator]
+ The iterators to be fused
+
+ Returns
+ -------
+ res_it : Iterator
+ The fused Iterator
+ """
+ stage_id = self._resolve_stage_id(stage)
+
+ self.state_object, res = _ffi_api.StateFuse(self.state_object, stage_id, iters)
+ return res
+
+ def copy(self):
+ """ Do deep copy of this State. """
+ state = State(self.state_object, self.compute_dag)
+ state.stage_id_map = self.stage_id_map.copy()
+ return state
+
+ def _resolve_stage_id(self, stage_id):
+ if isinstance(stage_id, Operation):
+ return self.stage_id_map[stage_id]
+ if isinstance(stage_id, Tensor):
+ return self.stage_id_map[stage_id.op]
+ if isinstance(stage_id, int):
+ return stage_id
+ raise ValueError("Invalid stage: " + stage_id +
+ " . Expect to be a int, Operation or Tensor")
+
+ def _update_stage_id_map(self):
+ for index, stage in enumerate(self.stages):
+ self.stage_id_map[stage.op] = index
+
+ def __getitem__(self, key):
+ if isinstance(key, Tensor):
+ key = key.op
+ if isinstance(key, Operation):
+ return self.stages[self.stage_id_map[key]]
+ raise ValueError("Invalid item: " + key +
+ " . Expect to be a Operation or Tensor")
+
+ def __str__(self):
+ return str(self.state_object)
+
+ def __eq__(self, other):
+ return _ffi_api.StateEqual(self.state_object, other.state_object)
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Distributed measurement infrastructure to measure the runtime costs of tensor programs.
+
+These functions are responsible for building the tvm module, uploading it to
+remote devices, recording the running time costs, and checking the correctness of the output.
+
+We separate the measurement into two steps: build and run.
+A builder builds the executable binary files and a runner runs the binary files to
+get the measurement results. The flow of data structures is
+
+ `ProgramBuilder` `ProgramRunner`
+`MeasureInput` -----------------> `BuildResult` ----------------> `MeasureResult`
+
+We implement these in python to utilize python's multiprocessing and error handling.
+"""
+
+import os
+import time
+import shutil
+import traceback
+import tempfile
+import multiprocessing
+
+import tvm._ffi
+from tvm.runtime import Object, module, ndarray
+from tvm.driver import build_module
+from tvm.ir import transform
+from tvm.contrib import tar, ndk
+
+from . import _ffi_api
+from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout
+
+# The maximum length of error message
+MAX_ERROR_MSG_LEN = 512
+
+# We use fork and a global variable to copy arguments between processings.
+# This can avoid expensive serialization of TVM IR when using multiprocessing.Pool
+GLOBAL_BUILD_ARGUMENTS = None
+
+@tvm._ffi.register_object("auto_schedule.MeasureCallback")
+class MeasureCallback(Object):
+ """ The base class of measurement callback functions. """
+
+
+@tvm._ffi.register_object("auto_schedule.MeasureInput")
+class MeasureInput(Object):
+ """ Store the input of a measurement.
+
+ Parameters
+ ----------
+ task : SearchTask
+ The SearchTask of this measure.
+ state : State
+ The State to be measured.
+ """
+ def __init__(self, task, state):
+ self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state.state_object)
+
+
+@tvm._ffi.register_object("auto_schedule.BuildResult")
+class BuildResult(Object):
+ """ Store the result of a build.
+
+ Parameters
+ ----------
+ filename : Optional[str]
+ The filename of built binary file.
+ args : List[Tensor]
+ The arguments.
+ error_no : int
+ The error code.
+ error_msg : Optional[str]
+ The error message if there is any error.
+ time_cost : float
+ The time cost of build.
+ """
+ def __init__(self, filename, args, error_no, error_msg, time_cost):
+ filename = filename if filename else ""
+ error_msg = error_msg if error_msg else ""
+
+ self.__init_handle_by_constructor__(
+ _ffi_api.BuildResult, filename, args, error_no, error_msg, time_cost)
+
+
+@tvm._ffi.register_object("auto_schedule.MeasureResult")
+class MeasureResult(Object):
+ """ Store the results of a measurement.
+
+ Parameters
+ ----------
+ costs : List[float]
+ The time costs of execution.
+ error_no : int
+ The error code.
+ error_msg : Optional[str]
+ The error message if there is any error.
+ all_cost : float
+ The time cost of build and run.
+ timestamp : float
+ The time stamps of this measurement.
+ """
+ def __init__(self, costs, error_no, error_msg, all_cost, timestamp):
+ error_msg = error_msg if error_msg else ""
+
+ self.__init_handle_by_constructor__(
+ _ffi_api.MeasureResult, costs, error_no,
+ error_msg, all_cost, timestamp)
+
+
+@tvm._ffi.register_object("auto_schedule.ProgramBuilder")
+class ProgramBuilder(Object):
+ """ The base class of ProgramBuilders. """
+
+ def build(self, measure_inputs, verbose=1):
+ """ Build programs and return results.
+
+ Parameters
+ ----------
+ measure_inputs : List[MeasureInput]
+ A List of MeasureInput.
+ verbose: int = 1
+ Verbosity level. 0 for silent, 1 to output information during program building.
+
+ Returns
+ -------
+ res : List[BuildResult]
+ """
+ return _ffi_api.ProgramBuilderBuild(self, measure_inputs, verbose)
+
+
+@tvm._ffi.register_object("auto_schedule.ProgramRunner")
+class ProgramRunner(Object):
+ """ The base class of ProgramRunners. """
+
+ def run(self, measure_inputs, build_results, verbose=1):
+ """ Run measurement and return results.
+
+ Parameters
+ ----------
+ measure_inputs : List[MeasureInput]
+ A List of MeasureInput.
+ build_results : List[BuildResult]
+ A List of BuildResult to be ran.
+ verbose: int = 1
+ Verbosity level. 0 for silent, 1 to output information during program running.
+
+ Returns
+ -------
+ res : List[MeasureResult]
+ """
+ return _ffi_api.ProgramRunnerRun(self, measure_inputs, build_results, verbose)
+
+
+@tvm._ffi.register_object("auto_schedule.LocalBuilder")
+class LocalBuilder(ProgramBuilder):
+ """ LocalBuilder use local CPU cores to build programs in parallel.
+
+ Parameters
+ ----------
+ timeout : int = 15
+ The timeout limit (in second) for each build thread.
+ This is used in a wrapper of the multiprocessing.Process.join().
+ n_parallel : int = multiprocessing.cpu_count()
+ Number of threads used to build in parallel.
+ build_func : str = 'default'
+ The name of registered build function.
+ """
+
+ def __init__(self,
+ timeout=15,
+ n_parallel=multiprocessing.cpu_count(),
+ build_func='default'):
+ self.__init_handle_by_constructor__(
+ _ffi_api.LocalBuilder, timeout, n_parallel, build_func)
+
+
+@tvm._ffi.register_object("auto_schedule.LocalRunner")
+class LocalRunner(ProgramRunner):
+ """ LocalRunner that uses local CPU/GPU to measures the time cost of programs.
+
+ Parameters
+ ----------
+ timeout : int = 10
+ The timeout limit (in second) for each run.
+ This is used in a wrapper of the multiprocessing.Process.join().
+ number : int = 3
+ The number of times to run the generated code for taking average.
+ We call these runs as one `repeat` of measurement.
+ repeat : int = 1
+ The number of times to repeat the measurement.
+ In total, the generated code will be run (1 + number x repeat) times,
+ where the first "1" is warm up and will be discarded.
+ The returned result contains `repeat` costs,
+ each of which is an average of `number` costs.
+ min_repeat_ms : int = 0
+ The minimum duration of one `repeat` in milliseconds.
+ By default, one `repeat` contains `number` runs. If this parameter is set,
+ the parameters `number` will be dynamically adjusted to meet the
+ minimum duration requirement of one `repeat`.
+ i.e., When the run time of one `repeat` falls below this time, the `number` parameter
+ will be automatically increased.
+ cooldown_interval : float = 0.0
+ The cool down interval between two measurements.
+ """
+
+ def __init__(self,
+ timeout=10,
+ number=3,
+ repeat=1,
+ min_repeat_ms=0,
+ cooldown_interval=0.0):
+ self.__init_handle_by_constructor__(
+ _ffi_api.LocalRunner, timeout, number, repeat, min_repeat_ms, cooldown_interval)
+
+
+class MeasureErrorNo(object):
+ """ Error type for MeasureResult. """
+ NO_ERROR = 0 # No error
+ INSTANTIATION_ERROR = 1 # Errors happen when apply transform steps from init state
+ # Errors happen when compiling code on host (e.g. tvm.build)
+ COMPILE_HOST = 2
+ COMPILE_DEVICE = 3 # Errors happen when compiling code on device
+ # (e.g. OpenCL JIT on the device)
+ RUNTIME_DEVICE = 4 # Errors happen when run program on device
+ WRONG_ANSWER = 5 # Answer is wrong when compared to a reference output
+ BUILD_TIMEOUT = 6 # Timeout during compilation
+ RUN_TIMEOUT = 7 # Timeout during run
+ UNKNOWN_ERROR = 8 # Unknown error
+
+
+def make_error_msg():
+ """ Get the error message from traceback. """
+ error_msg = str(traceback.format_exc())
+ if len(error_msg) > MAX_ERROR_MSG_LEN:
+ error_msg = error_msg[:MAX_ERROR_MSG_LEN//2] + \
+ "\n...\n" + error_msg[-MAX_ERROR_MSG_LEN//2:]
+ return error_msg
+
+
+def local_build_worker(index):
+ """
+ Build function of LocalBuilder to be ran in the Builder thread pool.
+
+ Parameters
+ ----------
+ index : int
+ The MeasureInput index to be processed by the current Builder thread.
+
+ Returns
+ -------
+ res : BuildResult
+ The build result of this Builder thread.
+ """
+ global GLOBAL_BUILD_ARGUMENTS
+
+ # We use fork and a global variable to copy arguments between processings.
+ # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool
+ if not GLOBAL_BUILD_ARGUMENTS:
+ raise ValueError("GLOBAL_BUILD_ARGUMENTS not found")
+ measure_inputs, build_func, timeout, verbose = GLOBAL_BUILD_ARGUMENTS
+ assert isinstance(build_func, str)
+
+ if build_func == 'default':
+ build_func = tar.tar
+ elif build_func == 'ndk':
+ build_func = ndk.create_shared
+ else:
+ raise ValueError("Invalid build_func" + build_func)
+
+ def timed_func():
+ tic = time.time()
+ inp = measure_inputs[index]
+ task = inp.task
+
+ error_no = MeasureErrorNo.NO_ERROR
+ error_msg = None
+ args = []
+
+ try:
+ sch, args = task.compute_dag.apply_steps_from_state(
+ inp.state)
+ # pylint: disable=broad-except
+ except Exception:
+ error_no = MeasureErrorNo.INSTANTIATION_ERROR
+ error_msg = make_error_msg()
+
+ if error_no == 0:
+ dirname = tempfile.mkdtemp()
+ filename = os.path.join(
+ dirname, "tmp_func." + build_func.output_format)
+
+ try:
+ with transform.PassContext(): # todo(lmzheng): port the unroll pass
+ func = build_module.build(
+ sch, args, target=task.target, target_host=task.target_host)
+ func.export_library(filename, build_func)
+ # pylint: disable=broad-except
+ except Exception:
+ error_no = MeasureErrorNo.COMPILE_HOST
+ error_msg = make_error_msg()
+ else:
+ filename = ""
+
+ if verbose >= 1:
+ if error_no == MeasureErrorNo.NO_ERROR:
+ print(".", end="")
+ else:
+ print(".E", end="") # Build error
+ return filename, args, error_no, error_msg, time.time() - tic
+
+ res = call_func_with_timeout(timeout, timed_func)
+ if isinstance(res, TimeoutError):
+ if verbose >= 1:
+ print(".T", end="") # Build timeout
+ res = None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout
+
+ return res
+
+
+@tvm._ffi.register_func("auto_schedule.local_builder.build")
+def local_builder_build(inputs, timeout, n_parallel, build_func='default', verbose=1):
+ """
+ Build function of LocalBuilder to build the MeasureInputs to runnable modules.
+
+ Parameters
+ ----------
+ inputs : List[MeasureInput]
+ The MeasureInputs to be built.
+ timeout : int
+ The timeout limit (in second) for each build thread.
+ This is used in a wrapper of the multiprocessing.Process.join().
+ n_parallel : int
+ Number of threads used to build in parallel.
+ build_func : str = 'default'
+ The name of build function to process the built module.
+ verbose: int = 1
+ Verbosity level. 0 for silent, 1 to output information during program building.
+
+ Returns
+ -------
+ res : List[BuildResult]
+ The build results of these MeasureInputs.
+ """
+ # We use fork and a global variable to copy arguments between processings.
+ # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool
+ global GLOBAL_BUILD_ARGUMENTS
+
+ GLOBAL_BUILD_ARGUMENTS = (inputs, build_func, timeout, verbose)
+
+ pool = NoDaemonPool(n_parallel)
+ tuple_res = pool.map(local_build_worker, range(len(inputs)))
+ pool.terminate()
+ pool.join()
+ del pool
+
+ results = []
+ for res in tuple_res:
+ results.append(BuildResult(*res))
+
+ return results
+
+@tvm._ffi.register_func("auto_schedule.local_runner.run")
+def local_run(inputs, build_results, timeout, number, repeat, min_repeat_ms, cooldown_interval,
+ verbose=1):
+ """
+ Run function of LocalRunner to test the performance of the input BuildResults.
+
+ Parameters
+ ----------
+ inputs : List[MeasureInput]
+ The MeasureInputs to be measured.
+ build_results : List[BuildResult]
+ The BuildResults to be measured.
+ timeout : int
+ The timeout limit (in second) for each run.
+ This is used in a wrapper of the multiprocessing.Process.join().
+ number : int = 3
+ The number of times to run the generated code for taking average.
+ We call these runs as one `repeat` of measurement.
+ repeat : int = 1
+ The number of times to repeat the measurement.
+ In total, the generated code will be run (1 + number x repeat) times,
+ where the first "1" is warm up and will be discarded.
+ The returned result contains `repeat` costs,
+ each of which is an average of `number` costs.
+ min_repeat_ms : int = 0
+ The minimum duration of one `repeat` in milliseconds.
+ By default, one `repeat` contains `number` runs. If this parameter is set,
+ the parameters `number` will be dynamically adjusted to meet the
+ minimum duration requirement of one `repeat`.
+ i.e., When the run time of one `repeat` falls below this time, the `number` parameter
+ will be automatically increased.
+ cooldown_interval : float = 0.0
+ The cool down interval between two measurements.
+ verbose: int = 1
+ Verbosity level. 0 for silent, 1 to output information during program measuring.
+
+ Returns
+ -------
+ res : List[MeasureResult]
+ The measure results of these MeasureInputs.
+ """
+ max_float = 1e10 # We use 1e10 instead of sys.float_info.max for better readability in log
+
+ def timed_func(inp, build_res):
+ tic = time.time()
+ error_no = 0
+ error_msg = None
+ try:
+ func = module.load_module(build_res.filename)
+ ctx = ndarray.context(str(inp.task.target), 0)
+ time_f = func.time_evaluator(
+ func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms)
+ # pylint: disable=broad-except
+ except Exception:
+ costs = (max_float,)
+ error_no = MeasureErrorNo.COMPILE_DEVICE
+ error_msg = make_error_msg()
+
+ if error_no == 0:
+ try:
+ args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in
+ build_res.args]
+ ctx.sync()
+ costs = time_f(*args).results
+ # pylint: disable=broad-except
+ except Exception:
+ costs = (max_float,)
+ error_no = MeasureErrorNo.RUNTIME_DEVICE
+ error_msg = make_error_msg()
+
+ shutil.rmtree(os.path.dirname(build_res.filename))
+ toc = time.time()
+ time.sleep(cooldown_interval)
+
+ if verbose >= 1:
+ if error_no == MeasureErrorNo.NO_ERROR:
+ print("*", end="")
+ else:
+ print("*E", end="") # Run error
+ return costs, error_no, error_msg, toc - tic + build_res.time_cost, toc
+
+ measure_results = []
+ assert len(inputs) == len(build_results), \
+ "Measure input size should be equal to build results"
+ for inp, build_res in zip(inputs, build_results):
+ if build_res.error_no != 0:
+ res = (max_float,), build_res.error_no, build_res.error_msg, build_res.time_cost, \
+ time.time()
+ else:
+ res = call_func_with_timeout(
+ timeout, timed_func, args=(inp, build_res))
+ if isinstance(res, TimeoutError):
+ if verbose >= 1:
+ print("*T", end="") # Run timeout
+ res = (max_float,), MeasureErrorNo.RUN_TIMEOUT, None, \
+ build_res.time_cost + timeout, time.time()
+ measure_results.append(MeasureResult(*res))
+
+ if verbose >= 1:
+ print("")
+
+ return measure_results
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Serialization and other I/O support for measurement records (tuning logs). """
+
+import numpy as np
+
+import tvm._ffi
+from tvm.runtime import Object
+from .measure import MeasureCallback, MeasureErrorNo
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("auto_schedule.RecordToFile")
+class RecordToFile(MeasureCallback):
+ """
+ A measurement callback that writes measurement records into a file.
+
+ Parameters
+ ----------
+ filename : str
+ File name for this callback to write log to.
+ """
+ def __init__(self, filename="auto_schedule_tuning.json"):
+ self.__init_handle_by_constructor__(_ffi_api.RecordToFile, filename)
+
+
+@tvm._ffi.register_object("auto_schedule.RecordReader")
+class RecordReader(Object):
+ """
+ Reader of the json log file.
+
+ Parameters
+ ----------
+ filename : str = "auto_schedule_tuning.json"
+ File name for this reader to load log from.
+ """
+ def __init__(self, filename="auto_schedule_tuning.json"):
+ self.__init_handle_by_constructor__(_ffi_api.RecordReader, filename)
+
+ def read_lines(self, max_lines=None, skip_lines=0):
+ """ Read multiple lines from the log file.
+
+ Parameters
+ ----------
+ max_lines : Optional[int]
+ The maximum number of lines. None to read all lines.
+ skip_lines : int = 0
+ Skip the first n lines.
+
+ Returns
+ -------
+ inputs : List[MeasureInput]
+ The MeasureInputs loaded from the log file.
+ results : List[MeasureResult]
+ The MeasureResults loaded from the log file.
+ """
+ inputs, results = _ffi_api.RecordReaderReadLines(self, max_lines if max_lines else -1,
+ skip_lines)
+ return inputs, results
+
+ def __iter__(self):
+ while True:
+ ret = _ffi_api.RecordReaderReadNext(self)
+ if not ret:
+ break
+ yield ret[0], ret[1] # (input, result)
+
+
+def load_records(filename):
+ """
+ Load measurement records from a file.
+
+ Parameters
+ ----------
+ filename : str
+ File name to load log from.
+
+ Returns
+ -------
+ logs : List[MeasureInput, MeasureResult]
+ """
+ return zip(*RecordReader(filename).read_lines())
+
+
+def save_records(filename, inputs, results):
+ """
+ Append measure records to file.
+
+ Parameters
+ ----------
+ filename : str
+ File name to write log to.
+ inputs: List[MeasureInputs]
+ The MeasureInputs to be written.
+ results: List[MeasureResults]
+ The MeasureResults to be written.
+ """
+ _ffi_api.SaveRecords(filename, inputs, results)
+
+def load_best(filename, workload_key=None, target=None):
+ """ Return the best measurement pair form a log file. This may return none results if
+ there is no legal measure pair with the specified workload_key/target found from the log file.
+
+ Parameters
+ ----------
+ filename : str
+ File name to load log from.
+ workload_key : Optional[str]
+ The workload key of the compute declaration.
+ With `None`, this retuns the best measure pair of all workloads.
+ target : Optional[tvm.target.Target]
+ The target device.
+ With `None`, this retuns the best measure pair of all target devices.
+
+ Returns
+ -------
+ input : MeasureInput
+ The best State's MeasureInput from this log fine.
+ result : MeasureResult
+ The best State's MeasureResult from this log fine.
+ """
+ log_reader = RecordReader(filename)
+ best_cost = 1e30
+ best_inp = None
+ best_res = None
+
+ for inp, res in log_reader:
+ if res.error_no != MeasureErrorNo.NO_ERROR:
+ continue
+ if workload_key and inp.task.workload_key != workload_key:
+ continue
+ if target and inp.task.target.id.name != target.id.name:
+ continue
+
+ costs = [v.value for v in res.costs]
+ cost = np.mean(costs)
+ if cost < best_cost:
+ best_cost = cost
+ best_inp = inp
+ best_res = res
+
+ return best_inp, best_res
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Common utilities for auto_schedule. """
+
+from typing import Hashable
+import multiprocessing
+import multiprocessing.pool
+import queue
+import signal
+
+try:
+ import psutil
+except ImportError:
+ raise ImportError("psutil not found, try `pip install psutil` to fix this")
+
+from tvm.tir import expr
+from tvm.tir.transform import Simplify
+from tvm.ir.transform import Sequential
+from ..te import Tensor, placeholder
+
+
+def get_func_name(func):
+ """Get name of a function.
+
+ Parameters
+ ----------
+ func: Function
+ The input function.
+
+ Returns
+ -------
+ name: str
+ The function name.
+ """
+ return func.func_name if hasattr(func, 'func_name') else func.__qualname__
+
+
+def get_const_int(exp):
+ """Verifies expr is integer and get the constant value.
+
+ Parameters
+ ----------
+ exp : Union[tvm.tir.expr, int]
+ The input expression.
+
+ Returns
+ -------
+ out_value : int
+ The output.
+ """
+ if isinstance(exp, int):
+ return exp
+ if not isinstance(exp, expr.IntImm):
+ opt = Sequential([Simplify()])
+ exp = opt(exp)
+ if not isinstance(exp, expr.IntImm):
+ raise ValueError("Expect value to be constant int")
+ return exp.value
+
+
+def get_const_tuple(in_tuple):
+ """Verifies input tuple is IntImm, returns tuple of int.
+
+ Parameters
+ ----------
+ in_tuple : Tuple[tvm.tir.expr]
+ The input.
+
+ Returns
+ -------
+ out_tuple : Tuple[int]
+ The output.
+ """
+ return tuple(get_const_int(x) for x in in_tuple)
+
+
+
+def list_to_tuple(x):
+ """ Convert a list to a tuple recursively. """
+ assert isinstance(x, list)
+ return tuple(list_to_tuple(y) if isinstance(y, list) else y for y in x)
+
+
+def serialize_args(args):
+ """
+ Serialize arguments of a function to a hashable and jsonable tuple.
+ Currently this is mainly used for tvm.tensor.Tensor
+ """
+ ret = []
+ for t in args:
+ if isinstance(t, Tensor):
+ t = ('TENSOR', get_const_tuple(t.shape), t.dtype)
+ elif isinstance(t, list):
+ t = list_to_tuple(t)
+
+ assert isinstance(t, Hashable), str(t) + " is not hashable"
+ ret.append(t)
+
+ return tuple(ret)
+
+
+def deserialize_args(args):
+ """The inverse function of :code:`serialize_args`"""
+ ret = []
+ for t in args:
+ if isinstance(t, (tuple, list)) and t[0] == 'TENSOR':
+ ret.append(placeholder(shape=t[1], dtype=t[2]))
+ else:
+ ret.append(t)
+ return ret
+
+
+class NoDaemonProcess(multiprocessing.Process):
+ @property
+ def daemon(self):
+ return False
+
+ @daemon.setter
+ def daemon(self, value):
+ pass
+
+
+class NoDaemonContext(type(multiprocessing.get_context())):
+ Process = NoDaemonProcess
+
+
+class NoDaemonPool(multiprocessing.pool.Pool):
+ """A no daemon pool version of multiprocessing.Pool.
+ This allows us to start new processings inside the worker function"""
+
+ def __init__(self, *args, **kwargs):
+ kwargs['context'] = NoDaemonContext()
+ super().__init__(*args, **kwargs)
+
+ def __reduce__(self):
+ pass
+
+
+def kill_child_processes(parent_pid, sig=signal.SIGTERM):
+ """kill all child processes recursively"""
+ try:
+ parent = psutil.Process(parent_pid)
+ except psutil.NoSuchProcess:
+ return
+ children = parent.children(recursive=True)
+ for process in children:
+ try:
+ process.send_signal(sig)
+ except psutil.NoSuchProcess:
+ return
+
+
+def call_func_with_timeout(timeout, func, args=(), kwargs=None):
+ """Call a function with timeout"""
+ def func_wrapper(que):
+ if kwargs:
+ que.put(func(*args, **kwargs))
+ else:
+ que.put(func(*args))
+
+ que = multiprocessing.Queue(2)
+ process = multiprocessing.Process(target=func_wrapper, args=(que,))
+ process.start()
+ process.join(timeout)
+
+ try:
+ res = que.get(block=False)
+ except queue.Empty:
+ res = TimeoutError()
+
+ # clean queue and process
+ kill_child_processes(process.pid)
+ process.terminate()
+ process.join()
+ que.close()
+ que.join_thread()
+ del process
+ del que
+
+ return res
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Workload registration and serialization.
+
+We use a json string to represent a workload (a computation graph).
+The format of the string is `[func_name, [args...]]`.
+The dag should be the return value of this `func_name(*args)`.
+
+Rationale: The workload is actually a compute dag defined by tvm dsl. But serializing compute dags
+and matching them efficiently is not easy. Therefore, we use the above string to encode a compute
+dag.
+These strings are efficient for serialization/matching and won't be too long.
+When we need the dag, we decode the string and call the function, which will return the dag.
+"""
+
+import pickle
+import json
+
+import tvm._ffi
+from .utils import serialize_args, deserialize_args, get_func_name
+
+WORKLOAD_FUNC_REGISTRY = {}
+
+
+def register_workload(func_name, f=None, override=False):
+ """ Register a function that generates a certain workload.
+
+ The input function should take hashable and jsonable arguments
+ (int, float, tuple of int, tvm.tensor.Tensor, ...) and return a list of tvm.tensor.Tensor.
+
+ Parameters
+ ----------
+ func_name : Union[Function, str]
+ The generation function that returns the compute declaration Tensors or its function name.
+ f : Optional[Function]
+ The generation function to be registered.
+ override : boolean = False
+ Whether override existing entry.
+
+ Examples
+ --------
+ @auto_schedule.register_workload
+ def matmul(N, M, K):
+ A = te.placeholder((N, K), name='A')
+ B = te.placeholder((K, M), name='B')
+ k = te.reduce_axis((0, K), name='k')
+ C = te.compute((N, M), lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=[k]), name='C')
+ return [A, B, C]
+ """
+ global WORKLOAD_FUNC_REGISTRY
+
+ if callable(func_name):
+ f = func_name
+ func_name = get_func_name(f)
+ if not isinstance(func_name, str):
+ raise ValueError("expect string function name")
+
+ def register(myf):
+ """internal register function"""
+ if func_name in WORKLOAD_FUNC_REGISTRY and not override:
+ raise RuntimeError('%s has been registered already' % func_name)
+ WORKLOAD_FUNC_REGISTRY[func_name] = myf
+ return myf
+ if f:
+ return register(f)
+ return register
+
+
+def make_workload_key(func, args):
+ """ Make a workload key by function and arguments.
+
+ Parameters
+ ----------
+ func : Union[Function, str]
+ The function that returns the compute declaration Tensors.
+ Can be the a function or the function name.
+ args : Args
+ The args of the function.
+
+ Returns
+ -------
+ workload_key : Str
+ The workload key of the function.
+ """
+ global WORKLOAD_FUNC_REGISTRY
+
+ if callable(func):
+ func_name = get_func_name(func)
+ elif isinstance(func, str):
+ func_name = func
+ else:
+ raise ValueError("Invalid function: " + str(func) +
+ " . `make_workload_key` expects a callable function or its function name")
+
+ if not func_name in WORKLOAD_FUNC_REGISTRY:
+ raise ValueError("%s is not registered. " % func,
+ "Please register it with @auto_schedule.register_workload")
+
+ args = serialize_args(args)
+
+ return json.dumps((func_name,) + args)
+
+
+def decode_workload_key_to_func_args(workload_key):
+ """ Decode a workload key to the registerd function name and its corresponding args.
+
+ Parameters
+ ----------
+ workload_key : str
+ The input workload key.
+
+ Returns
+ -------
+ name : str
+ The function name of this workload key.
+ args : List[Tensor]
+ The args of the generation function.
+ """
+ global WORKLOAD_FUNC_REGISTRY
+
+ workload = json.loads(workload_key)
+ if not workload[0] in WORKLOAD_FUNC_REGISTRY:
+ raise ValueError("%s is not registered. " % workload[0] +
+ "Please register it with @auto_schedule.register_workload")
+ return workload[0], deserialize_args(workload[1:])
+
+
+@tvm._ffi.register_func("auto_schedule.workload_key_to_tensors")
+def workload_key_to_tensors(workload_key):
+ """ Get the input/output tensors from the workload key.
+
+ This method is usually used to create a ComputeDAG by workload key.
+
+ Parameters
+ ----------
+ workload_key : str
+ The input workload key.
+
+ Returns
+ -------
+ tensors : List[Tensor]
+ The registered compute declaration Tensors.
+ """
+ global WORKLOAD_FUNC_REGISTRY
+
+ name, args = decode_workload_key_to_func_args(workload_key)
+ lookup = WORKLOAD_FUNC_REGISTRY[name]
+ assert callable(lookup)
+ return lookup(*args)
+
+
+def save_workload_func_registry(filename):
+ """ Dump workload function registry to a pickle binary file.
+
+ Parameters
+ ----------
+ filename : str
+ The filename to dump workload function registry to.
+ """
+ global WORKLOAD_FUNC_REGISTRY
+
+ pickle.dump(WORKLOAD_FUNC_REGISTRY, open(filename, 'wb'))
+
+
+def load_workload_func_registry(filename):
+ """ Load workload function registry from a pickle binary file.
+
+ Parameters
+ ----------
+ filename : str
+ The filename to load workload function registry from.
+ """
+ global WORKLOAD_FUNC_REGISTRY
+
+ WORKLOAD_FUNC_REGISTRY = pickle.load(open(filename, 'rb'))
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_schedule/auto_schedule.cc
+ * \brief The user interface of the TVM Auto-scheduler.
+ */
+
+#include "auto_schedule.h"
+
+#include <tvm/runtime/registry.h>
+
+namespace tvm {
+namespace auto_schedule {
+
+TVM_REGISTER_NODE_TYPE(TuningOptionsNode);
+
+TuningOptions::TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round,
+ int verbose, ProgramBuilder builder, ProgramRunner runner,
+ Optional<Array<MeasureCallback>> measure_callbacks,
+ Optional<Array<SearchCallback>> pre_search_callbacks) {
+ auto node = make_object<TuningOptionsNode>();
+ node->num_measure_trials = num_measure_trials;
+ node->early_stopping = early_stopping;
+ node->num_measures_per_round = num_measures_per_round;
+ node->verbose = verbose;
+ node->builder = std::move(builder);
+ node->runner = std::move(runner);
+ node->measure_callbacks = std::move(measure_callbacks);
+ node->pre_search_callbacks = std::move(pre_search_callbacks);
+ data_ = std::move(node);
+}
+
+std::pair<te::Schedule, Array<te::Tensor>> AutoSchedule(SearchTask task, SearchPolicy search_policy,
+ TuningOptions tuning_options) {
+ // Create a ProgramMeasurer to handle the schedule build and performance measure
+ ProgramMeasurer measurer =
+ ProgramMeasurer(tuning_options->builder, tuning_options->runner,
+ tuning_options->measure_callbacks, tuning_options->verbose);
+ // Search for the best schedule
+ State state = search_policy->Search(
+ task, tuning_options->num_measure_trials, tuning_options->early_stopping,
+ tuning_options->num_measures_per_round, tuning_options->verbose, measurer,
+ tuning_options->pre_search_callbacks);
+ return task->compute_dag.ApplySteps(state->transform_steps);
+}
+
+TVM_REGISTER_GLOBAL("auto_schedule.TuningOptions")
+ .set_body_typed([](int num_measure_trials, int early_stopping, int num_measures_per_round,
+ int verbose, ProgramBuilder builder, ProgramRunner runner,
+ Optional<Array<MeasureCallback>> measure_callbacks,
+ Optional<Array<SearchCallback>> pre_search_callbacks) {
+ return TuningOptions(num_measure_trials, early_stopping, num_measures_per_round, verbose,
+ builder, runner, measure_callbacks, pre_search_callbacks);
+ });
+
+TVM_REGISTER_GLOBAL("auto_schedule.AutoSchedule")
+ .set_body_typed([](SearchTask task, SearchPolicy search_policy, TuningOptions tuning_options) {
+ te::Schedule sch;
+ Array<te::Tensor> return_tensors;
+ std::tie(sch, return_tensors) = AutoSchedule(task, search_policy, tuning_options);
+ return Array<ObjectRef>{sch, return_tensors};
+ });
+} // namespace auto_schedule
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_schedule/auto_schedule.h
+ * \brief The user interface of the TVM Auto-scheduler. This is the entry structure to get
+ * schedule search requirements from upper level (Python API), and returns a high performance
+ * schedule after search process.
+ */
+
+#ifndef TVM_AUTO_SCHEDULE_AUTO_SCHEDULE_H_
+#define TVM_AUTO_SCHEDULE_AUTO_SCHEDULE_H_
+
+#include <utility>
+
+#include "measure.h"
+#include "search_policy/search_policy.h"
+
+namespace tvm {
+namespace auto_schedule {
+
+/*! \brief Tuning and measurement options. */
+class TuningOptionsNode : public Object {
+ public:
+ /*! \brief Number of total measurement trials. */
+ int num_measure_trials;
+ /*! \brief Stops early the tuning if no improvement after n measurements. */
+ int early_stopping;
+ /*! \brief The number of programs to be measured at each search round. */
+ int num_measures_per_round;
+ /*!
+ * \brief Verbosity level.
+ * 0 for silent, 1 to output information during schedule searching.
+ */
+ int verbose;
+ /*! \brief ProgramBuilder which builds the program */
+ ProgramBuilder builder;
+ /*! \brief ProgramRunner which runs the program and measure time costs */
+ ProgramRunner runner;
+ /*! \brief MeasureCallback functions to be called after each measure batch */
+ Optional<Array<MeasureCallback>> measure_callbacks;
+ /*! \brief SearchCallback functions to be called before schedule search */
+ Optional<Array<SearchCallback>> pre_search_callbacks;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("num_measure_trials", &num_measure_trials);
+ v->Visit("early_stopping", &early_stopping);
+ v->Visit("num_measures_per_round", &num_measures_per_round);
+ v->Visit("verbose", &verbose);
+ v->Visit("builder", &builder);
+ v->Visit("runner", &runner);
+ v->Visit("measure_callbacks", &measure_callbacks);
+ v->Visit("pre_search_callbacks", &pre_search_callbacks);
+ }
+
+ static constexpr const char* _type_key = "auto_schedule.TuningOptions";
+ TVM_DECLARE_FINAL_OBJECT_INFO(TuningOptionsNode, Object);
+};
+
+/*!
+ * \brief Managed reference to TuningOptionsNode.
+ * \sa TuningOptionsNode
+ */
+class TuningOptions : public ObjectRef {
+ public:
+ /*!
+ * \brief The constructor
+ * \param num_measure_trials Number of total measurement trials.
+ * \param early_stopping Stops early the tuning if no improvement after n measurements.
+ * \param num_measures_per_round The number of programs to be measured at each search round.
+ * \param verbose Verbosity level. 0 for silent, 1 to output information during schedule
+ * search.
+ * \param builder ProgramBuilder which builds the program.
+ * \param runner ProgramRunner which runs the program and measure time costs.
+ * \param measure_callbacks MeasureCallback functions to be called after each measure batch.
+ * \param pre_search_callbacks SearchCallback functions to be called before schedule search.
+ */
+ TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round, int verbose,
+ ProgramBuilder builder, ProgramRunner runner,
+ Optional<Array<MeasureCallback>> measure_callbacks,
+ Optional<Array<SearchCallback>> pre_search_callbacks);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(TuningOptions, ObjectRef, TuningOptionsNode);
+};
+
+/*!
+ * \brief Auto schedule search for a given compute declaration.
+ * \param task The search task of the compute declaration.
+ * \param search_policy The search policy to be used for schedule search.
+ * \param tuning_options Tuning and measurement options.
+ * \return A `te::schedule` and the a Array of `te::Tensor` to be used in `tvm.lower` or
+ * `tvm.build`.
+ */
+TVM_DLL std::pair<te::Schedule, Array<te::Tensor>> AutoSchedule(SearchTask task,
+ SearchPolicy search_policy,
+ TuningOptions tuning_options);
+} // namespace auto_schedule
+} // namespace tvm
+
+#endif // TVM_AUTO_SCHEDULE_AUTO_SCHEDULE_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_schedule/compute_dag.cc
+ * \brief Compute declaration graph and its related analysis tools.
+ */
+
+#include "compute_dag.h"
+
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "loop_state.h"
+#include "utils.h"
+
+namespace tvm {
+namespace auto_schedule {
+
+using namespace tvm::tir;
+
+TVM_REGISTER_NODE_TYPE(ComputeDAGNode);
+
+// Topo-sort ops from tensors according to their read-write relations.
+Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
+ std::unordered_map<const te::OperationNode*, int> degree;
+ std::unordered_map<const te::OperationNode*, std::vector<const te::OperationNode*>> edge_set;
+ std::unordered_map<const te::OperationNode*, int> priority;
+ std::unordered_set<const te::OperationNode*> visited;
+
+ // traverse to build edge_set and count degree
+ std::vector<const te::OperationNode*> stack;
+ stack.reserve(tensors.size());
+ for (const auto& x : tensors) {
+ stack.push_back(x->op.operator->());
+ }
+
+ int ct = 0;
+ while (!stack.empty()) {
+ const te::OperationNode* op = stack.back();
+ stack.pop_back();
+ if (visited.count(op)) {
+ continue;
+ }
+
+ priority[op] = ct;
+ ct++;
+ visited.insert(op);
+
+ if (op->IsInstance<te::PlaceholderOpNode>()) {
+ degree[op] = 0;
+ } else if (auto cop = GetRef<te::Operation>(op).as<te::ComputeOpNode>()) {
+ const Array<te::Tensor>& input_tensors = cop->InputTensors();
+ degree[op] = input_tensors.size();
+ for (const auto& ten : input_tensors) {
+ edge_set[ten->op.operator->()].push_back(op);
+ stack.push_back(ten->op.operator->());
+ }
+ } else {
+ LOG(FATAL) << "Unsupported op " << GetRef<te::Operation>(op);
+ }
+ }
+
+ // topo sort
+ Array<te::Operation> ops;
+
+ using Item = std::pair<const te::OperationNode*, int>;
+ auto cmp = [](const Item& left, const Item& right) { return left.second < right.second; };
+ std::priority_queue<Item, std::vector<Item>, decltype(cmp)> queue(cmp);
+ for (const auto& iter : degree) {
+ if (iter.second == 0) {
+ queue.push(Item(iter.first, priority[iter.first]));
+ }
+ }
+
+ ops.reserve(degree.size());
+ while (!queue.empty()) {
+ Item item = queue.top();
+ queue.pop();
+ ops.push_back(GetRef<te::Operation>(item.first));
+ for (const auto& dst : edge_set[item.first]) {
+ degree[dst] -= 1;
+ if (degree[dst] == 0) {
+ queue.push(Item(dst, priority[dst]));
+ }
+ }
+ }
+
+ return ops;
+}
+
+// Estimate number of float operations in an expression
+class FlopEstimator : public ExprFunctor<double(const PrimExpr& n)> {
+ public:
+ double EstimateFlop(const Array<te::Operation>& ops) {
+ double ret = 0;
+ for (const auto& op : ops) {
+ if (auto pop = op.as<te::ComputeOpNode>()) {
+ double num_element = AxisLengthProd(pop->axis);
+ if (num_element == -1) {
+ fail_ = true;
+ break;
+ }
+ double op_per_element = 0;
+ for (const auto& x : pop->body) {
+ op_per_element += VisitExpr(x);
+ }
+ ret += num_element * op_per_element;
+ } else if (op->IsInstance<te::PlaceholderOpNode>()) {
+ {} // do nothing
+ } else {
+ LOG(FATAL) << "Invalid op type " << op;
+ }
+ }
+
+ return fail_ ? -1 : ret;
+ }
+
+ double VisitExpr_(const ReduceNode* op) final {
+ uint64_t num_iter = 1;
+ for (const auto& x : op->axis) {
+ if (auto imm = x->dom->extent.as<IntImmNode>()) {
+ num_iter *= imm->value;
+ } else {
+ fail_ = true;
+ num_iter = -1;
+ }
+ }
+ double body_flop = 0;
+ for (size_t i = 0; i < op->combiner->result.size(); ++i) {
+ body_flop += VisitExpr(op->combiner->result[i]);
+ body_flop += VisitExpr(op->source[i]);
+ }
+ return num_iter * body_flop;
+ }
+
+ double VisitExpr_(const FloatImmNode* op) final { return 0.0; }
+ double VisitExpr_(const IntImmNode* op) final { return 0.0; }
+ double VisitExpr_(const ProducerLoadNode* op) final { return 0.0; }
+
+ double VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); }
+ double VisitExpr_(const VarNode* op) final { return 0.0; }
+
+ double VisitExpr_(const SelectNode* op) final {
+ return VisitExpr(op->condition) +
+ std::max(VisitExpr(op->true_value), VisitExpr(op->false_value));
+ }
+
+#define VisitBinary(Node) \
+ double VisitExpr_(const Node* op) final { return 1.0 + VisitExpr(op->a) + VisitExpr(op->b); }
+#define VisitUnary(Node) \
+ double VisitExpr_(const Node* op) final { return 1.0 + VisitExpr(op->a); }
+
+ VisitBinary(AddNode);
+ VisitBinary(SubNode);
+ VisitBinary(MulNode);
+ VisitBinary(DivNode);
+ VisitBinary(ModNode);
+ VisitBinary(FloorDivNode);
+ VisitBinary(FloorModNode);
+ VisitBinary(MaxNode);
+ VisitBinary(MinNode);
+ VisitBinary(EQNode);
+ VisitBinary(NENode);
+ VisitBinary(LTNode);
+ VisitBinary(LENode);
+ VisitBinary(GTNode);
+ VisitBinary(GENode);
+ VisitBinary(AndNode);
+ VisitBinary(OrNode);
+ VisitUnary(NotNode);
+
+ double VisitExpr_(const CallNode* op) final {
+ double ret = 0.0;
+ for (const auto& x : op->args) {
+ ret += VisitExpr(x);
+ }
+ return ret;
+ }
+
+ double VisitExprDefault_(const Object* op) final {
+ fail_ = true;
+ return -1.0;
+ }
+
+ private:
+ bool fail_{false};
+};
+
+ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
+ auto node = make_object<ComputeDAGNode>();
+ node->tensors = std::move(tensors);
+ node->ops = TopoSortOps(node->tensors);
+ node->flop_ct = FlopEstimator().EstimateFlop(node->ops);
+ node->init_state = State(node->ops);
+ data_ = std::move(node);
+}
+
+// Update the te::stage to tir::IterVar axis mapping
+void UpdateStageToAxesMap(const te::Stage& stage, StageToAxesMap* stage_to_axes) {
+ if (auto pop = stage->op.as<te::ComputeOpNode>()) {
+ Array<IterVar> axes;
+ for (const auto& axis : pop->axis) {
+ axes.push_back(axis);
+ }
+ for (const auto& axis : pop->reduce_axis) {
+ axes.push_back(axis);
+ }
+ stage_to_axes->Set(stage, std::move(axes));
+ } else if (stage->op->IsInstance<te::PlaceholderOpNode>()) {
+ {} // do nothing on Placeholder
+ } else {
+ LOG(FATAL) << "Invalid op " << stage->op;
+ }
+}
+
+std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
+ const Array<Step>& transform_steps, Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes) const {
+ // Temporal object to be used if the input pointer is nullptr
+ Array<te::Stage> temp_stages;
+ StageToAxesMap temp_stage_to_axes;
+ if (stages == nullptr) {
+ stages = &temp_stages;
+ }
+ if (stage_to_axes == nullptr) {
+ stage_to_axes = &temp_stage_to_axes;
+ }
+ Array<te::Operation> ops;
+ for (const auto& op : operator->()->ops) {
+ if (!op->IsInstance<te::PlaceholderOpNode>()) {
+ ops.push_back(op);
+ }
+ }
+ // Create the initial schedule
+ // TODO(jcf94): Currently we only checked single output dag for TVM Auto-scheduler,
+ // update this after testing with multiple outputs.
+ te::Schedule schedule = te::create_schedule({ops.back()});
+
+ // init axes
+ for (const auto& x : operator->()->ops) {
+ const te::Stage& stage = schedule[x];
+ stages->push_back(stage);
+ UpdateStageToAxesMap(stage, stage_to_axes);
+ }
+
+ // Apply the history steps to TVM schedule
+ for (const auto& step : transform_steps) {
+ // Call each step's ApplyToSchedule method
+ // Note: some steps have extra parameters that must be passed and they may need different
+ // return value, so the ApplyToSchedule is not able to be merged to single interface
+ if (auto ps = step.as<ReorderStepNode>()) {
+ ps->ApplyToSchedule(stages, stage_to_axes);
+ } else if (auto ps = step.as<SplitStepNode>()) {
+ ps->ApplyToSchedule(stages, stage_to_axes);
+ } else if (auto ps = step.as<FuseStepNode>()) {
+ ps->ApplyToSchedule(stages, stage_to_axes);
+ } else {
+ LOG(FATAL) << "Invalid Step";
+ }
+ }
+
+ return std::make_pair(schedule, operator->()->tensors);
+}
+
+String ComputeDAG::PrintStepsAsPython(const Array<Step>& transform_steps) const {
+ Array<te::Stage> stages;
+ StageToAxesMap stage_to_axes;
+ Array<te::Operation> ops;
+ for (const auto& op : operator->()->ops) {
+ if (!op->IsInstance<te::PlaceholderOpNode>()) {
+ ops.push_back(op);
+ }
+ }
+ // Create the initial schedule
+ // TODO(jcf94): Currently we only checked single output dag for TVM Auto-scheduler,
+ // update this after testing with multiple outputs.
+ te::Schedule schedule = te::create_schedule({ops.back()});
+
+ // init axes
+ for (const auto& x : operator->()->ops) {
+ const te::Stage& stage = schedule[x];
+ stages.push_back(stage);
+ UpdateStageToAxesMap(stage, &stage_to_axes);
+ }
+
+ std::stringstream ss;
+ for (const auto& stage : stages) {
+ if (stage->op->IsInstance<te::ComputeOpNode>()) {
+ for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
+ ss << stage->leaf_iter_vars[i]->var->name_hint;
+ if (i != stage->leaf_iter_vars.size() - 1) {
+ ss << ", ";
+ }
+ }
+ ss << " = "
+ << "tuple(" << stage->op->name << ".op.axis)"
+ << " + "
+ << "tuple(" << stage->op->name << ".op.reduce_axis)\n";
+ }
+ }
+ // Call each step's PrintAsPythonAPI method
+ for (const auto& step : transform_steps) {
+ if (auto ps = step.as<ReorderStepNode>()) {
+ ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes);
+ } else if (auto ps = step.as<SplitStepNode>()) {
+ ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes);
+ } else if (auto ps = step.as<FuseStepNode>()) {
+ ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes);
+ } else {
+ LOG(FATAL) << "Invalid Step";
+ }
+ }
+
+ return ss.str();
+}
+
+State ComputeDAG::InferBound(const State& state) const {
+ CHECK(state->concrete) << "Only concrete state can be processed to get bound info.";
+
+ State ret_state;
+ StateNode* pstate;
+
+ if (state->stages.empty()) {
+ // If the input state is incomplete with empty operation stage
+ // create a new state from init_state and update it first
+ ret_state = operator->()->init_state;
+ pstate = ret_state.CopyOnWrite();
+ pstate->transform_steps = state->transform_steps;
+ ret_state.DoSteps(*this);
+ } else {
+ ret_state = state;
+ pstate = ret_state.CopyOnWrite();
+ }
+
+ Array<te::Stage> stages;
+ StageToAxesMap stage_to_axes;
+ te::Schedule sch;
+ Array<te::Tensor> tensors;
+ // Replay steps to tvm::Schedule
+ std::tie(sch, tensors) = ApplySteps(pstate->transform_steps, &stages, &stage_to_axes);
+ sch = sch.normalize();
+ // Get bound information from TVM schedule
+ Map<IterVar, Range> bounds = te::InferBound(sch);
+
+ // Update the state bound information
+ for (size_t i = 0; i < pstate->stages.size(); ++i) {
+ const Stage& stage = pstate->stages[i];
+
+ if (stage->compute_at == ComputeAtKind::kInlined) {
+ continue;
+ }
+
+ Array<Iterator> new_iters;
+ new_iters.reserve(stage->iters.size());
+ // Get bound information from schedule
+ // the StageToAxesMap is used to find the corresponding IterVar in TVM schedule result
+ for (size_t j = 0; j < stage->iters.size(); ++j) {
+ const Iterator& iter = stage->iters[j];
+ const IterVar& axis = stage_to_axes.at(stages[i])[j];
+
+ auto find_res = bounds.find(axis);
+ if (find_res != bounds.end()) {
+ new_iters.push_back(
+ Iterator(iter->name, (*find_res).second, iter->iter_kind, iter->annotation));
+ } else {
+ LOG(FATAL) << "Infer bound fails";
+ }
+ }
+
+ pstate->stages.Set(
+ i, Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs));
+ }
+
+ return ret_state;
+}
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<ComputeDAGNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const ComputeDAGNode*>(ref.get());
+ std::stringstream ss;
+
+ for (const auto& op : node->ops) {
+ if (op->IsInstance<te::PlaceholderOpNode>()) {
+ ss << op->name << " = PLACEHOLDER " << op.output(0)->shape << "\n";
+ } else if (auto pop = op.as<te::ComputeOpNode>()) {
+ for (size_t k = 0; k < pop->body.size(); ++k) {
+ ss << op->name << "(";
+ for (size_t i = 0; i < pop->axis.size(); i++) {
+ ss << pop->axis[i]->var->name_hint;
+ if (i != pop->axis.size() - 1) {
+ ss << ", ";
+ }
+ }
+ ss << ")";
+ if (pop->body.size() > 1) {
+ ss << ".v" << k;
+ }
+ if (auto preduce = pop->body[k].as<ReduceNode>()) {
+ CHECK_LT(k, preduce->combiner->result.size());
+ PrimExpr combiner = preduce->combiner->result[k];
+ if (combiner->IsInstance<AddNode>()) {
+ ss << " += " << preduce->source[0] << "\n";
+ } else if (combiner->IsInstance<MaxNode>()) {
+ ss << " max= " << preduce->source[0] << "\n";
+ } else if (combiner->IsInstance<MinNode>()) {
+ ss << " min= " << preduce->source[0] << "\n";
+ } else if (combiner->IsInstance<SelectNode>()) {
+ const auto& select = combiner.as<SelectNode>();
+ ss << " select(" << select->condition << ", " << select->true_value << ", "
+ << select->false_value << ")= " << '(' << preduce->source[0] << ','
+ << preduce->source[1] << ")\n";
+ } else {
+ LOG(FATAL) << "Unsupported reduction operator" << combiner;
+ }
+ } else {
+ ss << " = " << pop->body[k] << "\n";
+ }
+ }
+ } else {
+ LOG(FATAL) << "Invalid op";
+ }
+ }
+
+ p->stream << ss.str();
+ });
+
+TVM_REGISTER_GLOBAL("auto_schedule.ComputeDAG").set_body_typed([](Array<te::Tensor> tensors) {
+ return ComputeDAG(tensors);
+});
+
+TVM_REGISTER_GLOBAL("auto_schedule.ComputeDAGApplyStepsFromState")
+ .set_body_typed([](const ComputeDAG& dag, const State& state) {
+ te::Schedule sch;
+ Array<te::Tensor> return_tensors;
+ std::tie(sch, return_tensors) = dag.ApplySteps(state->transform_steps);
+ return Array<ObjectRef>{sch, return_tensors};
+ });
+
+TVM_REGISTER_GLOBAL("auto_schedule.ComputeDAGPrintPythonCodeFromState")
+ .set_body_typed([](const ComputeDAG& dag, const State& state) {
+ return dag.PrintStepsAsPython(state->transform_steps);
+ });
+
+TVM_REGISTER_GLOBAL("auto_schedule.ComputeDAGInferBoundFromState")
+ .set_body_typed([](const ComputeDAG& dag, const State& state) {
+ return dag.InferBound(state);
+ });
+
+} // namespace auto_schedule
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_schedule/compute_dag.h
+ * \brief The TVM Auto-scheduler computational graph and related program analyses.
+ *
+ * We convert a compute declaration described by `tvm.compute` (could be a single operator or a
+ * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
+ * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the
+ * total float operation count, consumer/producer relations of each operation stage, whether an
+ * operation stage should be tiled/compute inlined ...). These analyses can help the search policy
+ * to make decisions during search process.
+ * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and
+ * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
+ * `LoopState` with extra information got from TVM schedule ...).
+ */
+
+#ifndef TVM_AUTO_SCHEDULE_COMPUTE_DAG_H_
+#define TVM_AUTO_SCHEDULE_COMPUTE_DAG_H_
+
+#include <tvm/te/schedule.h>
+
+#include <utility>
+
+#include "loop_state.h"
+
+namespace tvm {
+namespace auto_schedule {
+
+/*! \brief The TVM Auto-scheduler computational graph and related program analyses. */
+class ComputeDAGNode : public Object {
+ public:
+ /*!
+ * \brief Input and output tensors.
+ * This is used as the input of `tvm.lower` or `tvm.build`.
+ */
+ Array<te::Tensor> tensors;
+ /*! \brief All related operations in topo order. */
+ Array<te::Operation> ops;
+ /*! \brief Number of total float operations for this ComputeDAG. */
+ double flop_ct;
+ /*! \brief The initial state without any transform steps. */
+ State init_state;
+ // TODO(merrymercy): Add more analyses later.
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("tensors", &tensors);
+ v->Visit("ops", &ops);
+ v->Visit("flop_ct", &flop_ct);
+ v->Visit("init_state", &init_state);
+ }
+
+ static constexpr const char* _type_key = "auto_schedule.ComputeDAG";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ComputeDAGNode, Object);
+};
+
+/*!
+ * \brief Managed reference to ComputeDAGNode.
+ * \sa ComputeDAGNode
+ */
+class ComputeDAG : public ObjectRef {
+ public:
+ /*! \brief The constructor.
+ * \param tensors `te::Tensor`s for a compute declaration.
+ */
+ explicit ComputeDAG(Array<te::Tensor> tensors);
+
+ /*!
+ * \brief Apply the history transform steps from a State to get a TVM schedule.
+ * \param transform_steps Transform steps of a state.
+ * \param stages A pointer to a `te::Stage` Array, default to be nullptr.
+ * Pass a valid pointer if these information needs to be used outside this function.
+ * \param stage_to_axes A pointer to a StageToAxesMap, default to be nullptr.
+ * Pass a valid pointer if these information needs to be used outside this function.
+ * \return A `te.schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`.
+ */
+ std::pair<te::Schedule, Array<te::Tensor>> ApplySteps(
+ const Array<Step>& transform_steps, Array<te::Stage>* stages = nullptr,
+ StageToAxesMap* stage_to_axes = nullptr) const;
+
+ /*!
+ * \brief Print transform steps as equivalent python schedule API.
+ * This can be used for debugging.
+ * \param transform_steps Transform steps of a state.
+ * \return The Python schedule code.
+ */
+ String PrintStepsAsPython(const Array<Step>& transform_steps) const;
+
+ /*!
+ * \brief Fill the correct bound information for a given state by calling ir_pass::InferBound.
+ * The states can lose complete bound information after some transform steps (e.g., compute_at).
+ * We can call this function to infer and fill all the bound information.
+ * This function calls TVM InferBound pass internally to get the bound.
+ * The returned state of this function is guaranteed to have complete iterator extent information.
+ * \param state The state to.
+ * \return The State after inferbound.
+ */
+ State InferBound(const State& state) const;
+
+ TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode);
+ TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode);
+};
+
+} // namespace auto_schedule
+} // namespace tvm
+
+#endif // TVM_AUTO_SCHEDULE_COMPUTE_DAG_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_schedule/loop_state.cc
+ * \brief An lightweight IR (intermediate representation) for loop structures.
+ * see auto_schedule/loop_state.h for more explanation.
+ */
+
+#include "loop_state.h"
+
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+
+#include <utility>
+
+#include "transform_step.h"
+#include "utils.h"
+
+namespace tvm {
+namespace auto_schedule {
+
+TVM_REGISTER_OBJECT_TYPE(StepNode);
+TVM_REGISTER_NODE_TYPE(StageNode);
+TVM_REGISTER_NODE_TYPE(StateNode);
+TVM_REGISTER_NODE_TYPE(IteratorNode);
+
+/********** Iterator **********/
+Iterator::Iterator(String name, Range range, IteratorKind iter_kind,
+ IteratorAnnotation annotation) {
+ auto node = make_object<IteratorNode>();
+ node->name = std::move(name);
+ node->range = std::move(range);
+ node->iter_kind = iter_kind;
+ node->annotation = annotation;
+ data_ = std::move(node);
+}
+
+/********** Stage **********/
+Stage::Stage(te::Operation op) {
+ auto node = make_object<StageNode>();
+ if (op->IsInstance<te::ComputeOpNode>()) {
+ node->op_type = StageKind::kCompute;
+ auto* pop = op.as<te::ComputeOpNode>();
+ for (const auto& axis : pop->axis) {
+ node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom,
+ IteratorKind::kSpatial, IteratorAnnotation::kNone));
+ }
+ for (const auto& axis : pop->reduce_axis) {
+ node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom,
+ IteratorKind::kReduction, IteratorAnnotation::kNone));
+ }
+ } else if (op->IsInstance<te::PlaceholderOpNode>()) {
+ node->op_type = StageKind::kPlaceholder;
+ } else {
+ LOG(FATAL) << "Unsupported operator type" << op->_type_key;
+ }
+
+ node->compute_at = ComputeAtKind::kRoot;
+ node->op = std::move(op);
+ node->attrs.auto_unroll_max_step = 0;
+ node->attrs.storage_offset = 0;
+ data_ = std::move(node);
+}
+
+Stage::Stage(te::Operation op, StageKind op_type, const Array<Iterator>& iters,
+ ComputeAtKind compute_at, StageAttributes attrs) {
+ auto node = make_object<StageNode>();
+ node->op = std::move(op);
+ node->op_type = op_type;
+ node->iters = iters;
+ node->compute_at = compute_at;
+ node->attrs = attrs;
+ data_ = std::move(node);
+}
+
+/********** State **********/
+State::State(const Array<te::Operation>& ops) {
+ auto node = make_object<StateNode>();
+ for (const auto& op : ops) {
+ node->stages.push_back(Stage(op));
+ }
+ node->concrete = true;
+ data_ = std::move(node);
+}
+
+/********** Schedule primitives apis for state **********/
+void State::reorder(int stage_id, const Array<Iterator>& order) {
+ const Stage& stage = operator->()->stages[stage_id];
+ CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators "
+ << "should be specified";
+ Array<Integer> after_ids;
+ GetIndices(stage->iters, order, &after_ids);
+ ReorderStep step = ReorderStep(stage_id, after_ids);
+ CopyOnWrite()->transform_steps.push_back(step);
+ DoReorderStep(step);
+}
+
+Array<Iterator> State::split(int stage_id, const Iterator& it,
+ const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
+ const Stage& stage = operator->()->stages[stage_id];
+ SplitStep step =
+ SplitStep(stage_id, GetIndex(stage->iters, it),
+ it->range.defined() ? it->range->extent : PrimExpr(), lengths, inner_to_outer);
+ CopyOnWrite()->transform_steps.push_back(step);
+ return DoSplitStep(step);
+}
+
+Iterator State::fuse(int stage_id, const Array<Iterator>& iters) {
+ const Stage& stage = operator->()->stages[stage_id];
+ Array<Integer> indices;
+ GetIndices(stage->iters, iters, &indices);
+ FuseStep step = FuseStep(stage_id, indices);
+ CopyOnWrite()->transform_steps.push_back(step);
+ return DoFuseStep(step);
+}
+
+/********** Step implementations for state **********/
+void State::DoReorderStep(const ReorderStep& step) {
+ const Stage& stage = operator->()->stages[step->stage_id];
+ Array<Iterator> iters;
+ for (auto x : step->after_ids) {
+ iters.push_back(stage->iters[x]);
+ }
+ StateNode* pstate = CopyOnWrite();
+ pstate->stages.Set(step->stage_id,
+ Stage(stage->op, stage->op_type, iters, stage->compute_at, stage->attrs));
+}
+
+// common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep
+Array<Iterator> State::DoSplitStepCommon(int stage_id, int iter_id,
+ const Array<Optional<Integer>>& lengths,
+ bool inner_to_outer) {
+ const Stage& stage = operator->()->stages[stage_id];
+ const Iterator& it = stage->iters[iter_id];
+ bool concrete = true;
+
+ Optional<PrimExpr> tosplit_min, tosplit_extent;
+ if (it->range.defined()) {
+ tosplit_min = it->range->min;
+ tosplit_extent = it->range->extent;
+ } else {
+ tosplit_min = NullOpt;
+ tosplit_extent = NullOpt;
+ }
+
+ Array<Iterator> outs;
+ for (size_t i = 0; i < lengths.size(); ++i) {
+ Optional<Integer> l;
+ String name;
+ if (inner_to_outer) {
+ l = lengths[lengths.size() - i - 1];
+ name = it->name + "." + std::to_string(lengths.size() - i);
+ } else {
+ l = lengths[i];
+ name = it->name + "." + std::to_string(i);
+ }
+ Iterator res;
+ if (l && tosplit_min && tosplit_extent) {
+ res = Iterator(name, Range::FromMinExtent(tosplit_min.value(), l.value()), it->iter_kind,
+ IteratorAnnotation::kNone);
+ tosplit_min = Integer(0);
+ tosplit_extent = indexdiv(tosplit_extent.value() + l.value() - 1, l.value());
+ } else {
+ res = Iterator(name, Range(), it->iter_kind, IteratorAnnotation::kNone);
+ tosplit_min = NullOpt;
+ tosplit_extent = NullOpt;
+ concrete = false;
+ }
+ outs.push_back(std::move(res));
+ }
+
+ Range range;
+ if (tosplit_min && tosplit_extent) {
+ range = Range::FromMinExtent(tosplit_min.value(), tosplit_extent.value());
+ }
+ if (inner_to_outer) {
+ outs.push_back(Iterator(it->name + ".0", range, it->iter_kind, IteratorAnnotation::kNone));
+ // Reverse the Iterator array
+ Array<Iterator> temp(outs.rbegin(), outs.rend());
+ outs = std::move(temp);
+ } else {
+ outs.push_back(Iterator(it->name + "." + std::to_string(lengths.size()), range, it->iter_kind,
+ IteratorAnnotation::kNone));
+ }
+
+ Array<Iterator> new_iters;
+ new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + iter_id);
+ new_iters.insert(new_iters.end(), outs.begin(), outs.end());
+ new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id + 1, stage->iters.end());
+
+ StateNode* pstate = CopyOnWrite();
+ pstate->stages.Set(stage_id,
+ Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs));
+ pstate->concrete &= concrete;
+
+ return outs;
+}
+
+Array<Iterator> State::DoSplitStep(const SplitStep& step) {
+ return DoSplitStepCommon(step->stage_id, step->iter_id, step->lengths, step->inner_to_outer);
+}
+
+Iterator State::DoFuseStep(const FuseStep& step) {
+ int stage_id = step->stage_id;
+ const Stage& stage = operator->()->stages[stage_id];
+
+ String new_name;
+ PrimExpr new_extent = 1;
+ IteratorKind new_iter_kind = IteratorKind::kSpecial;
+
+ for (size_t i = 0; i < step->fused_ids.size(); ++i) {
+ if (i > 0) {
+ CHECK_EQ(step->fused_ids[i]->value, step->fused_ids[i - 1]->value + 1);
+ }
+
+ const Iterator& it = stage->iters[step->fused_ids[i]];
+ new_name = new_name + it->name + "@";
+
+ if (it->range.defined() && new_extent.defined()) {
+ new_extent = new_extent * it->range->extent;
+ } else {
+ new_extent = PrimExpr();
+ }
+
+ if (i == 0) {
+ new_iter_kind = it->iter_kind;
+ } else {
+ if (new_iter_kind != it->iter_kind) {
+ new_iter_kind = IteratorKind::kMixed;
+ }
+ }
+ }
+
+ Range range;
+ if (new_extent.defined()) {
+ range = Range::FromMinExtent(0, new_extent);
+ }
+ Iterator new_it = Iterator(new_name, range, new_iter_kind, IteratorAnnotation::kNone);
+ Array<Iterator> new_iters;
+ new_iters.insert(new_iters.end(), stage->iters.begin(),
+ stage->iters.begin() + step->fused_ids.front());
+ new_iters.push_back(new_it);
+ new_iters.insert(new_iters.end(), stage->iters.begin() + step->fused_ids.back() + 1,
+ stage->iters.end());
+
+ StateNode* pstate = CopyOnWrite();
+ pstate->stages.Set(stage_id,
+ Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs));
+
+ return new_it;
+}
+
+void State::DoSteps(const ComputeDAG& dag) {
+ CHECK(operator->()->stages.size()) << "Invalid State with empty operation stages.";
+
+ for (const auto& step : operator->()->transform_steps) {
+ if (auto ps = step.as<ReorderStepNode>()) {
+ DoReorderStep(GetRef<ReorderStep>(ps));
+ } else if (auto ps = step.as<SplitStepNode>()) {
+ DoSplitStep(GetRef<SplitStep>(ps));
+ } else if (auto ps = step.as<FuseStepNode>()) {
+ DoFuseStep(GetRef<FuseStep>(ps));
+ } else {
+ LOG(FATAL) << "Invalid step: " << step;
+ }
+ }
+}
+
+static const char* IteratorAnnotationString[] = {
+ "for", // kNone = 0
+ "unroll", // kUnroll = 1
+ "vectorize", // kVectorize = 2
+ "parallel", // kParallel = 3
+ "vthread", // kVThread = 4
+ "gpu.blockIdx.x", // kBlockX = 5
+ "gpu.threadIdx.x", // kThreadX = 6
+ "gpu.blockIdx.y", // kBlockY = 7
+ "gpu.threadIdx.y", // kThreadY = 8
+ "tensorize" // kTensorized = 9
+};
+
+// Print stage to ostream
+void PrintStage(std::ostream* os, int stage_id, const State& state, size_t base_indent,
+ bool delete_trivial_loop) {
+ const Stage& stage = state->stages[stage_id];
+
+ if (stage->attrs.auto_unroll_max_step != 0) {
+ for (size_t j = 0; j < base_indent; ++j) {
+ *os << " ";
+ }
+ *os << stage->op->name << " auto_unroll: " << stage->attrs.auto_unroll_max_step << "\n";
+ }
+ if (stage->attrs.storage_offset != 0) {
+ for (size_t j = 0; j < base_indent; ++j) {
+ *os << " ";
+ }
+ *os << stage->op->name << " storage_offset: " << stage->attrs.storage_offset << "\n";
+ }
+
+ size_t indent = 0;
+ for (size_t i = 0; i < stage->iters.size(); ++i) {
+ const Iterator& iter = stage->iters[i];
+
+ if (!(delete_trivial_loop && iter->range.defined() && is_one(iter->range->extent))) {
+ for (size_t j = 0; j < base_indent + indent; ++j) {
+ *os << " ";
+ }
+ *os << IteratorAnnotationString[static_cast<int>(iter->annotation)] << " ";
+ if (iter->range.defined()) {
+ *os << iter->name << " (" << iter->range->min << "," << iter->range->extent << ")";
+ } else {
+ *os << iter->name << " (None)";
+ }
+ *os << "\n";
+
+ indent += 2;
+ }
+ }
+
+ for (size_t j = 0; j < base_indent + indent; ++j) {
+ *os << " ";
+ }
+ *os << stage->op->name << " = ...\n";
+}
+
+// Print state to ostream
+void PrintState(std::ostream* os, const State& state, bool delete_trivial_loop) {
+ // Gather placeholders
+ Array<String> placeholders;
+ for (const auto& stage : state->stages) {
+ if (stage->op_type == StageKind::kPlaceholder) {
+ placeholders.push_back(stage->op->name);
+ }
+ }
+
+ *os << "Placeholder: ";
+ for (size_t i = 0; i < placeholders.size(); ++i) {
+ *os << placeholders[i];
+ if (i != placeholders.size() - 1) {
+ *os << ", ";
+ }
+ }
+ *os << "\n";
+
+ // Print all stages
+ for (size_t i = 0; i < state->stages.size(); ++i) {
+ const Stage& stage = state->stages[i];
+ if (stage->op_type == StageKind::kPlaceholder) {
+ continue;
+ } else if (stage->op_type == StageKind::kCompute) {
+ if (stage->compute_at == ComputeAtKind::kRoot) {
+ PrintStage(os, i, state, 0, delete_trivial_loop);
+ }
+ } else {
+ LOG(FATAL) << "Invalid op type";
+ }
+ }
+}
+
+String State::ToStr(bool delete_trivial_loop) const {
+ std::ostringstream os;
+ PrintState(&os, (*this), delete_trivial_loop);
+ return os.str();
+}
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<StateNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ PrintState(&p->stream, tvm::Downcast<State>(ref), true);
+ });
+
+/********** State interface API for ffi **********/
+TVM_REGISTER_GLOBAL("auto_schedule.StateReorder")
+ .set_body_typed([](State state, int stage_id, const Array<Iterator>& order) {
+ state.reorder(stage_id, order);
+ return state;
+ });
+
+TVM_REGISTER_GLOBAL("auto_schedule.StateSplit")
+ .set_body_typed([](State state, int stage_id, const Iterator& it,
+ const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
+ const auto& res = state.split(stage_id, it, lengths, inner_to_outer);
+ return Array<ObjectRef>{state, res};
+ });
+
+TVM_REGISTER_GLOBAL("auto_schedule.StateFuse")
+ .set_body_typed([](State state, int stage_id, const Array<Iterator>& iters) {
+ const auto& res = state.fuse(stage_id, iters);
+ return Array<ObjectRef>{state, res};
+ });
+
+TVM_REGISTER_GLOBAL("auto_schedule.StateEqual").set_body_typed([](State state1, State state2) {
+ return std::equal_to<State>()(state1, state2);
+});
+
+} // namespace auto_schedule
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_schedule/loop_state.h
+ * \brief The definition of the "state" in search.
+ *
+ * Each LoopState corresponds to a schedule for its ComputeDAG.
+ * A LoopState consists of: 1. a current loop structure; 2. a list of transformation steps used to
+ * construct the loop structure.
+ * The loop structure keeps a preview of how the schedule will finally look like after lowering the
+ * current state (e.g. number of iterators, the extent of each iterator, the compute_at locations
+ * ...).
+ * During the schedule search process, the loop structure can provide search policy with necessary
+ * information on how to manipulate the current state.
+ * The transform history is a sequence of `TransformStep` which will finally be mapped to TVM
+ * schedule primitives. The steps can also be used for the serialization of a state.
+ *
+ * The LoopState can be seen as a lightweight loop structure IR specifically for schedule search.
+ * We don't use the existing TVM IR but to extend a new structure on it is because:
+ * 1. We want fast incremental change to the loop structures. The search policy needs to get the
+ * immediate loop structures update rather than after TVM lowering;
+ * 2. We want serializable transform history for replay, backtracking, and mutation;
+ * 3. We may create some macro schedule primitives that represent the combination of several
+ * TVM schedule primitives.
+ *
+ * When the search is complete, we will lower the state to TVM IR with TVM's schedule primitives.
+ * Since we share a lot of common objects during search, the transformation is implemented in
+ * copy on write style. All objects are immutable, which is similar to TVM IR.
+ */
+
+#ifndef TVM_AUTO_SCHEDULE_LOOP_STATE_H_
+#define TVM_AUTO_SCHEDULE_LOOP_STATE_H_
+
+#include <tvm/runtime/container.h>
+
+#include <functional>
+
+#include "transform_step.h"
+
+namespace tvm {
+namespace auto_schedule {
+
+using namespace tvm::tir;
+
+class ComputeDAG;
+
+/*! \brief The type of a stage. */
+enum class StageKind : int {
+ /*! \brief A placeholder stage. */
+ kPlaceholder = 0,
+ /*! \brief A compute stage. */
+ kCompute = 1
+};
+
+/*! \brief The type of compute location. */
+enum class ComputeAtKind : int {
+ /*! \brief Compute at root. */
+ kRoot = 0,
+ /*! \brief Compute inlined. */
+ kInlined = 1,
+ /*! \brief Compute at some iterator. */
+ kIter = 2,
+};
+
+/*! \brief The type of an iterator. */
+enum class IteratorKind : int {
+ /*! \brief Spatial iterator. */
+ kSpatial = 0,
+ /*! \brief Reduction iterator. */
+ kReduction = 1,
+ /*! \brief Fused spatial and reduction iterator. */
+ kMixed = 2,
+ /*! \brief Special iterator. (e.g. virtual root iterator) */
+ kSpecial = 3
+};
+
+/*! \brief The type of an iterator's annotation. */
+enum class IteratorAnnotation : int {
+ /*! \brief This iterator has no annotation. */
+ kNone = 0,
+ /*! \brief This iterator has been unrolled. */
+ kUnroll = 1,
+ /*! \brief This iterator has been vectorized. */
+ kVectorize = 2,
+ /*! \brief This iterator has been paralleld. */
+ kParallel = 3,
+ /*! \brief This iterator has been bind to vthread. */
+ kVThread = 4,
+ /*! \brief This iterator has been bind to blockIdx.x. */
+ kBlockX = 5,
+ /*! \brief This iterator has been bind to threadIdx.x. */
+ kThreadX = 6,
+ /*! \brief This iterator has been bind to blockIdx.y. */
+ kBlockY = 7,
+ /*! \brief This iterator has been bind to threadIdx.y. */
+ kThreadY = 8,
+ /*! \brief This iterator has been mapped with a tensorize intrinsic. */
+ kTensorized = 9
+};
+
+/*!
+ * \brief A for loop iterator
+ * Similar to tvm::IterVar in `include/tvm/tir/expr.h`
+ */
+class IteratorNode : public Object {
+ public:
+ /*! \brief The name of this iterator. */
+ String name;
+ /*! \brief The range of this iterator. */
+ Range range;
+ /*! \brief The iterator type of this iterator. */
+ IteratorKind iter_kind;
+ /*! \brief The annotation type of this iterator. */
+ IteratorAnnotation annotation;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("name", &name);
+ v->Visit("range", &range);
+ }
+
+ static constexpr const char* _type_key = "auto_schedule.Iterator";
+ TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object);
+};
+
+/*!
+ * \brief Managed reference to IteratorNode.
+ * \sa IteratorNode
+ */
+class Iterator : public ObjectRef {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param name The name of this iterator.
+ * \param range The range of this iterator.
+ * \param iter_kind The iterator type of this iterator.
+ * \param annotation The annotation type of this iterator.
+ */
+ Iterator(String name, Range range, IteratorKind iter_kind, IteratorAnnotation annotation);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode);
+};
+
+/*! \brief Stage-level attributes. */
+struct StageAttributes {
+ /*! \brief The maximum steps for the pragma `auto_unroll_max_step`. */
+ int auto_unroll_max_step;
+ /*! \brief The storage offset for the schedule primitive `storage_align`. */
+ int storage_offset;
+};
+
+/*!
+ * \brief A op stage in the compute declaration.
+ * Similar to te::Stage in `include/schedule.h`.
+ */
+class StageNode : public Object {
+ public:
+ /*! \brief The operator of this stage */
+ te::Operation op;
+ /*! \brief The type of this stage. */
+ StageKind op_type;
+ /*! \brief The iterators in this stage. */
+ Array<Iterator> iters;
+ /*! \brief The compute location of this stage. */
+ ComputeAtKind compute_at;
+ /*! \brief Other stage-level attributes. */
+ StageAttributes attrs;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("op", &op);
+ v->Visit("iters", &iters);
+ }
+
+ static constexpr const char* _type_key = "auto_schedule.Stage";
+ TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object);
+};
+
+/*!
+ * \brief Managed reference to StageNode.
+ * \sa StageNode
+ */
+class Stage : public ObjectRef {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param op A `te::Operation`.
+ */
+ explicit Stage(te::Operation op);
+ /*!
+ * \brief The constructor.
+ * \param op A `te::Operation`.
+ * \param op_type The stage type of this op.
+ * \param iters The iterators of this op.
+ * \param compute_at The compute at type of this op.
+ * \param attrs Other stage-level attributes.
+ */
+ Stage(te::Operation op, StageKind op_type, const Array<Iterator>& iters, ComputeAtKind compute_at,
+ StageAttributes attrs);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(Stage, ObjectRef, StageNode);
+ TVM_DEFINE_OBJECT_REF_COW_METHOD(StageNode);
+};
+
+/*!
+ * \brief A state in the search process.
+ * It consists of the current loop structure and a list of transformation steps used to construct
+ * it.
+ * Each State corresponds to a specific schedule for its ComputeDAG.
+ */
+class StateNode : public Object {
+ public:
+ /*! \brief Current stages and loop structures. */
+ Array<Stage> stages;
+ /*! \brief History transformation steps. */
+ Array<Step> transform_steps;
+ /*!
+ * \brief Indicate whether this state has unfilled tile sizes. A concrete state means that all
+ * tile sizes of the state is filled. Only concrete state can be apply to TVM schedule.
+ */
+ bool concrete;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("stages", &stages);
+ v->Visit("transform_steps", &transform_steps);
+ v->Visit("concrete", &concrete);
+ }
+
+ static constexpr const char* _type_key = "auto_schedule.State";
+ TVM_DECLARE_FINAL_OBJECT_INFO(StateNode, Object);
+
+ private:
+ /*!
+ * \brief The up-to-date ComputeDAG of this state, used for some steps that may change the
+ * stage structure of the ComputeDAG (e.g. CacheReadStep/CacheWriteStep which Will be added
+ * later).
+ * The default value is an empty ObjectRef. (means no modification to the original DAG)
+ */
+ ObjectRef current_compute_dag;
+};
+
+/*!
+ * \brief Managed reference to StateNode.
+ * \sa StateNode
+ */
+class State : public ObjectRef {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param ops `te::Operation`s for a compute declaration.
+ */
+ explicit State(const Array<te::Operation>& ops);
+
+ /*!
+ * \brief Print the state to a human readable string.
+ * \param delete_trivial_loop True for skipping the trivial loops.
+ * (undefined or extent == 1, default set to True)
+ * \return The human readable state structure.
+ */
+ String ToStr(bool delete_trivial_loop = true) const;
+
+ /*!
+ * \brief General do step functions with a runtime dynamic dispatcher. This will re-apply all the
+ * transform steps with the initial state.
+ * \param dag The original ComputeDAG of this state.
+ * \note This is different from the class member `current_compute_dag`, for some transform step
+ * may change the op stage structure of the ComputeDAG.
+ */
+ void DoSteps(const ComputeDAG& dag);
+
+ /* Step APIs for State. */
+
+ /*!
+ * \brief Schedule primitive corresponds to te.reorder.
+ * \param stage_id The index of the stage to be reordered.
+ * \param order The expected iterator order.
+ */
+ void reorder(int stage_id, const Array<Iterator>& order);
+ /*!
+ * \brief Schedule primitive corresponds to te.split.
+ * \param stage_id The index of the stage to be split.
+ * \param it The iterator the be split.
+ * \param lengths The multiple split factors. Can be None to be filled by search policy.
+ * \param inner_to_outer Whether the factor go from inner to outer, or from outer to inner.
+ * \return The iterator results after split.
+ */
+ Array<Iterator> split(int stage_id, const Iterator& it, const Array<Optional<Integer>>& lengths,
+ bool inner_to_outer = true);
+ /*!
+ * \brief Schedule primitive corresponds to te.fuse.
+ * \param stage_id The index of the stage to be fused.
+ * \param iters The iterators to be fused.
+ * \return The iterator result after fuse.
+ */
+ Iterator fuse(int stage_id, const Array<Iterator>& iters);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
+ TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode);
+
+ private:
+ /* Do transform steps
+ * Note: The following functions only change loop state but do not change transform_history.
+ * We separate these functions out, so you can call them for replay easily given history steps */
+
+ /*!
+ * \brief Apply reorder step to current state.
+ * \param step A ReorderStep.
+ */
+ void DoReorderStep(const ReorderStep& step);
+ /*!
+ * \brief Apply split step to current state.
+ * \param step A SplitStep.
+ * \return The iterator results after split.
+ */
+ Array<Iterator> DoSplitStep(const SplitStep& step);
+ /*!
+ * \brief Apply fuse step to current state.
+ * \param step A FuseStep.
+ * \return The iterator result after fuse.
+ */
+ Iterator DoFuseStep(const FuseStep& step);
+
+ /*!
+ * \brief Common function for DoSplitStep and DoFollowSplitStep(Will be added later).
+ * \param stage_id The index of the stage to be split.
+ * \param iter_id The index of the iterator to be split.
+ * \param lengths The multiple split factors.
+ * \param inner_to_outer The split direction.
+ * \return The iterator results after split.
+ */
+ Array<Iterator> DoSplitStepCommon(int stage_id, int iter_id,
+ const Array<Optional<Integer>>& lengths, bool inner_to_outer);
+};
+
+} // namespace auto_schedule
+} // namespace tvm
+
+// Hash and equal function for State
+namespace std {
+
+/*! \brief The hash function for auto_schedule::State. */
+template <>
+struct hash<::tvm::auto_schedule::State> {
+ std::size_t operator()(const ::tvm::auto_schedule::State& state) const {
+ return tvm::runtime::ObjectHash()(state.ToStr());
+ }
+};
+
+/*!
+ * \brief The equal_to function for auto_schedule::State.
+ * We use the schedule result(its string format) of a state to check if two states are `euqal`.
+ * Equal States: 1. the transform steps are totally the same; 2. even with different steps, two
+ * states may still result in a same schedule. e.g. To split a axis with extent 512 to 3 parts
+ * [8, 16, 4]. We can split from inner to outter by factors [16, 4], while we can get a same result
+ * to split from outter to inner by factors [8, 16])
+ */
+template <>
+struct equal_to<::tvm::auto_schedule::State> {
+ bool operator()(const ::tvm::auto_schedule::State& lhs,
+ const ::tvm::auto_schedule::State& rhs) const {
+ return lhs.ToStr() == rhs.ToStr();
+ }
+};
+
+} // namespace std
+
+#endif // TVM_AUTO_SCHEDULE_LOOP_STATE_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_schedule/measure.cc
+ * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs.
+ */
+
+#include "measure.h"
+
+#include <tvm/runtime/registry.h>
+
+#include <algorithm>
+
+#include "utils.h"
+
+namespace tvm {
+namespace auto_schedule {
+
+TVM_REGISTER_NODE_TYPE(MeasureInputNode);
+TVM_REGISTER_NODE_TYPE(BuildResultNode);
+TVM_REGISTER_NODE_TYPE(MeasureResultNode);
+TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode);
+TVM_REGISTER_OBJECT_TYPE(ProgramRunnerNode);
+TVM_REGISTER_OBJECT_TYPE(ProgramBuilderNode);
+TVM_REGISTER_OBJECT_TYPE(LocalBuilderNode);
+TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode);
+
+static const char* ErrorNoToStr[] = {
+ "NoError",
+ "InstantiationError",
+ "CompileHostError",
+ "CompileDeviceError",
+ "RuntimeDeviceError",
+ "WrongAnswerError",
+ "BuildTimeoutError",
+ "RunTimeoutError",
+ "UnknownError",
+};
+
+/********** Measure input and result **********/
+MeasureInput::MeasureInput(SearchTask task, State state) {
+ auto node = make_object<MeasureInputNode>();
+ node->task = std::move(task);
+ node->state = std::move(state);
+ data_ = std::move(node);
+}
+
+MeasureInput MeasureInputNode::copy() const {
+ auto node = make_object<MeasureInputNode>();
+ node->task = task;
+ node->state = state;
+ return MeasureInput(node);
+}
+
+BuildResult::BuildResult(String filename, Array<te::Tensor> args, int error_no, String error_msg,
+ double time_cost) {
+ auto node = make_object<BuildResultNode>();
+ node->filename = std::move(filename);
+ node->args = std::move(args);
+ node->error_no = error_no;
+ node->error_msg = std::move(error_msg);
+ node->time_cost = time_cost;
+ data_ = std::move(node);
+}
+
+MeasureResult::MeasureResult(Array<PrimExpr> costs, int error_no, String error_msg, double all_cost,
+ double timestamp) {
+ auto node = make_object<MeasureResultNode>();
+ node->costs = std::move(costs);
+ node->error_no = error_no;
+ node->error_msg = std::move(error_msg);
+ node->all_cost = all_cost;
+ node->timestamp = timestamp;
+ data_ = std::move(node);
+}
+
+MeasureResult MeasureResultNode::copy() const {
+ auto node = make_object<MeasureResultNode>();
+ node->costs = costs;
+ node->error_no = error_no;
+ node->error_msg = error_msg;
+ node->all_cost = all_cost;
+ node->timestamp = timestamp;
+ return MeasureResult(node);
+}
+
+/********** LocalBuilder **********/
+LocalBuilder::LocalBuilder(int timeout, int n_parallel, const String& build_func) {
+ auto node = make_object<LocalBuilderNode>();
+ node->timeout = timeout;
+ node->n_parallel = n_parallel;
+ node->build_func = build_func;
+ data_ = std::move(node);
+}
+
+Array<BuildResult> LocalBuilderNode::Build(const Array<MeasureInput>& inputs, int verbose) {
+ if (const auto* f = runtime::Registry::Get("auto_schedule.local_builder.build")) {
+ Array<BuildResult> results = (*f)(inputs, timeout, n_parallel, build_func, verbose);
+ return results;
+ }
+ LOG(FATAL) << "auto_schedule.local_builder.build is not registered. "
+ << "This is a function registered in Python, "
+ << "make sure the TVM Python runtime has been loaded successfully.";
+ throw;
+}
+
+/********** LocalRunner **********/
+LocalRunner::LocalRunner(int timeout, int number, int repeat, int min_repeat_ms,
+ double cooldown_interval) {
+ ObjectPtr<LocalRunnerNode> node = make_object<LocalRunnerNode>();
+ node->timeout = timeout;
+ node->number = number;
+ node->repeat = repeat;
+ node->min_repeat_ms = min_repeat_ms;
+ node->cooldown_interval = cooldown_interval;
+ data_ = std::move(node);
+}
+
+Array<MeasureResult> LocalRunnerNode::Run(const Array<MeasureInput>& inputs,
+ const Array<BuildResult>& build_results, int verbose) {
+ if (const auto* f = runtime::Registry::Get("auto_schedule.local_runner.run")) {
+ Array<MeasureResult> results = (*f)(inputs, build_results, timeout, number, repeat,
+ min_repeat_ms, cooldown_interval, verbose);
+ return results;
+ }
+ LOG(FATAL) << "auto_schedule.local_runner.run is not registered. "
+ << "This is a function registered in Python, "
+ << "make sure the TVM Python runtime has been loaded successfully.";
+ throw;
+}
+
+/********** ProgramMeasurer **********/
+ProgramMeasurer::ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner,
+ Optional<Array<MeasureCallback>> callbacks, int verbose,
+ int max_continous_error) {
+ auto node = make_object<ProgramMeasurerNode>();
+ node->builder = std::move(builder);
+ node->runner = std::move(runner);
+ node->callbacks = std::move(callbacks);
+ node->verbose = verbose;
+ node->max_continous_error = max_continous_error < 0
+ ? ProgramMeasurerNode::DEFAULT_MAX_CONTINOUS_ERROR
+ : max_continous_error;
+ data_ = std::move(node);
+}
+
+void ProgramMeasurerNode::Reset() {
+ ct = error_ct = 0;
+ best_flops.clear();
+ best_ct.clear();
+ best_state.clear();
+}
+
+void ProgramMeasurerNode::Measure(const SearchTask& task, const SearchPolicy& policy,
+ const Array<MeasureInput>& inputs, Array<MeasureResult>* results,
+ int batch_size) {
+ results->clear();
+ results->reserve(inputs.size());
+
+ if (batch_size == -1) {
+ // set default batch size
+ batch_size = builder->n_parallel * 2;
+ }
+
+ StdCout(verbose) << "Get " << inputs.size() << " programs for measure. (This may take a while)"
+ << std::endl;
+
+ for (size_t i = 0; i < inputs.size(); i += batch_size) {
+ Array<MeasureInput> input_batch(inputs.begin() + i,
+ inputs.begin() + std::min(i + batch_size, inputs.size()));
+ Array<MeasureResult> result_batch;
+
+ // build and run
+ SilentMeasure(task, input_batch, &result_batch);
+
+ // update current best state according to the new measure result
+ for (size_t j = 0; j < input_batch.size(); ++j) {
+ double flops;
+ if (result_batch[j]->error_no == 0) {
+ flops = task->compute_dag->flop_ct / FloatArrayMean(result_batch[j]->costs);
+ error_ct = 0;
+ } else {
+ flops = 0.0;
+ error_ct++;
+ }
+
+ const String& workload_key = input_batch[j]->task->workload_key;
+ if (flops > best_flops[workload_key]) {
+ best_flops[workload_key] = flops;
+ best_state[workload_key] = input_batch[j]->state;
+ best_ct[workload_key] = ct;
+ }
+
+ ct++;
+ StdCout(verbose) << std::fixed << std::setprecision(2) << Chars('=', 50) << "\n"
+ << "No: " << ct << "\tGFLOPS: " << flops / 1e9 << " / "
+ << best_flops[workload_key] / 1e9 << "\tresults: " << result_batch[j] << "\n"
+ << Chars('=', 50) << "\n"
+ << input_batch[j]->state << "\n";
+ }
+
+ // Call callback functions
+ if (callbacks) {
+ for (const auto& callback : callbacks.value()) {
+ callback->Callback(policy, input_batch, result_batch);
+ }
+ }
+
+ // Store result batch
+ for (auto& res : result_batch) {
+ results->push_back(res);
+ }
+
+ if (error_ct > max_continous_error) {
+ LOG(FATAL) << "Too many errors happened during tuning";
+ }
+ }
+}
+
+void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, const Array<MeasureInput>& inputs,
+ Array<MeasureResult>* results) {
+ results->clear();
+ results->reserve(inputs.size());
+
+ // Call builder and runner
+ Array<BuildResult> build_res_batch = builder->Build(inputs, verbose);
+ Array<MeasureResult> result_batch = runner->Run(inputs, build_res_batch, verbose);
+
+ // Store result batch
+ for (auto& res : result_batch) {
+ results->push_back(res);
+ }
+}
+
+/********** Printing functions **********/
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<MeasureInputNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ p->stream << "MeasureInput()";
+ });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<MeasureResultNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const MeasureResultNode*>(ref.get());
+ if (node->error_no == static_cast<int>(MeasureErrorNO::kNoError)) {
+ p->stream << "MeasureResult(cost:[";
+ auto old_config = p->stream.precision(4);
+ for (size_t i = 0; i < node->costs.size(); ++i) {
+ auto pf = node->costs[i].as<FloatImmNode>();
+ CHECK(pf != nullptr);
+ p->stream << pf->value;
+ if (i != node->costs.size() - 1) {
+ p->stream << ",";
+ }
+ }
+ p->stream.precision(old_config);
+ p->stream << "], ";
+ p->stream << "error_no:" << 0 << ", "
+ << "all_cost:" << node->all_cost << ", "
+ << "Tstamp:" << node->timestamp << ")";
+ } else {
+ p->stream << "MeasureResult("
+ << "error_type:" << ErrorNoToStr[node->error_no] << ", "
+ << "error_msg:" << node->error_msg << ", "
+ << "all_cost:" << node->all_cost << ", "
+ << "Tstamp:" << node->timestamp << ")";
+ }
+ });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<BuildResultNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const BuildResultNode*>(ref.get());
+ p->stream << "BuildResult(" << node->filename << ", " << node->error_no << ", "
+ << node->time_cost << ")";
+ });
+
+/********** Measure interface API for ffi **********/
+TVM_REGISTER_GLOBAL("auto_schedule.MeasureInput").set_body_typed([](SearchTask task, State state) {
+ return MeasureInput(task, state);
+});
+
+TVM_REGISTER_GLOBAL("auto_schedule.BuildResult")
+ .set_body_typed([](String filename, Array<te::Tensor> args, int error_no, String error_msg,
+ double time_cost) {
+ return BuildResult(filename, args, error_no, error_msg, time_cost);
+ });
+
+TVM_REGISTER_GLOBAL("auto_schedule.MeasureResult")
+ .set_body_typed([](Array<PrimExpr> costs, int error_no, String error_msg, double all_cost,
+ double timestamp) {
+ return MeasureResult(costs, error_no, error_msg, all_cost, timestamp);
+ });
+
+TVM_REGISTER_GLOBAL("auto_schedule.ProgramBuilderBuild")
+ .set_body_typed([](const ProgramBuilder& builder, const Array<MeasureInput>& inputs,
+ int verbose) { return builder->Build(inputs, verbose); });
+
+TVM_REGISTER_GLOBAL("auto_schedule.ProgramRunnerRun")
+ .set_body_typed([](const ProgramRunner& runner, const Array<MeasureInput>& inputs,
+ const Array<BuildResult>& build_results,
+ int verbose) { return runner->Run(inputs, build_results, verbose); });
+
+TVM_REGISTER_GLOBAL("auto_schedule.LocalBuilder")
+ .set_body_typed([](int timeout, int n_parallel, const String& build_func) {
+ return LocalBuilder(timeout, n_parallel, build_func);
+ });
+
+TVM_REGISTER_GLOBAL("auto_schedule.LocalRunner")
+ .set_body_typed([](int timeout, int number, int repeat, int min_repeat_ms,
+ double cooldown_interval) {
+ return LocalRunner(timeout, number, repeat, min_repeat_ms, cooldown_interval);
+ });
+
+} // namespace auto_schedule
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_schedule/measure.h
+ * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs.
+ * These functions are responsible for building the tvm module, uploading it to remote devices,
+ * recording the running time costs, and checking the correctness of the output.
+ *
+ * We separate the measurement into two steps: build and run.
+ * A builder builds the executable binary files and a runner runs the binary files to get the
+ * measurement results. The flow of data structures is
+ *
+ * `ProgramBuilder` `ProgramRunner`
+ * `MeasureInput` -----------------> `BuildResult` ----------------> `MeasureResult`
+ *
+ * We implement these in python to utilize python's multiprocessing and error handling.
+ */
+
+#ifndef TVM_AUTO_SCHEDULE_MEASURE_H_
+#define TVM_AUTO_SCHEDULE_MEASURE_H_
+
+#include <string>
+#include <unordered_map>
+#include <utility>
+
+#include "loop_state.h"
+#include "search_task.h"
+
+namespace tvm {
+namespace auto_schedule {
+
+class SearchPolicy;
+class MeasureInput;
+class MeasureResult;
+
+/*! \brief The error code of one measurement */
+enum class MeasureErrorNO : int {
+ /*! \brief No error. */
+ kNoError = 0,
+ /*! \brief Errors happen when apply transform steps from init state. */
+ kInstantiationError = 1,
+ /*! \brief Errors happen when compiling code on host. (when build module) */
+ kCompileHostError = 2,
+ /*! \brief Errors happen when compiling code on device. (when load module) */
+ kCompileDeviceError = 3,
+ /*! \brief Errors happen when run program on device. */
+ kRuntimeDeviceError = 4,
+ /*! \brief Answer is wrong when compared to a reference output. */
+ kWrongAnswerError = 5,
+ /*! \brief Timeout during compilation. */
+ kBuildTimeoutError = 6,
+ /*! \brief Timeout during run. */
+ kRunTimeoutError = 7,
+ /*! \brief Unknown error. */
+ kUnknonwError = 8,
+};
+
+// Inputs and results of one measurement
+
+/*! \brief Store the input of a measurement */
+class MeasureInputNode : public Object {
+ public:
+ /*! \brief The search task. */
+ SearchTask task;
+ /*! \brief The program state to be measured. */
+ State state;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("task", &task);
+ v->Visit("state", &state);
+ }
+
+ /*! \brief Do shallow copy. */
+ MeasureInput copy() const;
+
+ static constexpr const char* _type_key = "auto_schedule.MeasureInput";
+ TVM_DECLARE_FINAL_OBJECT_INFO(MeasureInputNode, Object);
+};
+
+/*!
+ * \brief Managed reference to MeasureInputNode.
+ * \sa MeasureInputNode
+ */
+class MeasureInput : public ObjectRef {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param task The SearchTeask of this measure.
+ * \param state The State to be measured.
+ */
+ MeasureInput(SearchTask task, State state);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(MeasureInput, ObjectRef, MeasureInputNode);
+};
+
+/*! \brief Store the result of a build. */
+class BuildResultNode : public Object {
+ public:
+ /*! \brief The filename of built binary file. */
+ String filename;
+ /*! \brief The arguments. */
+ Array<te::Tensor> args;
+ /*! \brief The error code. (0 means no error, see MeasureErrorNO) */
+ int error_no;
+ /*! \brief The error message if there is any error. */
+ String error_msg;
+ /*! \brief The time cost of build. */
+ double time_cost;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("filename", &filename);
+ v->Visit("args", &args);
+ v->Visit("error_no", &error_no);
+ v->Visit("error_msg", &error_msg);
+ v->Visit("time_cost", &time_cost);
+ }
+
+ static constexpr const char* _type_key = "auto_schedule.BuildResult";
+ TVM_DECLARE_FINAL_OBJECT_INFO(BuildResultNode, Object);
+};
+
+/*!
+ * \brief Managed reference to BuildResultNode.
+ * \sa BuildResultNode
+ */
+class BuildResult : public ObjectRef {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param filename The filename of built binary file.
+ * \param args The arguments.
+ * \param error_no The error code.
+ * \param error_msg The error message if there is any error.
+ * \param time_cost The time cost of build.
+ */
+ BuildResult(String filename, Array<te::Tensor> args, int error_no, String error_msg,
+ double time_cost);
+ TVM_DEFINE_OBJECT_REF_METHODS(BuildResult, ObjectRef, BuildResultNode);
+};
+
+/*! \brief Store the results of a measurement. */
+class MeasureResultNode : public Object {
+ public:
+ /*! \brief The time costs of execution. */
+ Array<PrimExpr> costs;
+ /*! \brief The error code. (0 means no error, see MeasureErrorNO) */
+ int error_no;
+ /*! \brief The error message if there is any error. */
+ String error_msg;
+ /*! \brief The time cost of build and run. */
+ double all_cost;
+ /*! \brief The time stamps of this measurement. */
+ double timestamp;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("costs", &costs);
+ v->Visit("error_no", &error_no);
+ v->Visit("error_msg", &error_msg);
+ v->Visit("all_cost", &all_cost);
+ v->Visit("timestamp", ×tamp);
+ }
+
+ /*! \brief Do shallow copy. */
+ MeasureResult copy() const;
+
+ static constexpr const char* _type_key = "auto_schedule.MeasureResult";
+ TVM_DECLARE_FINAL_OBJECT_INFO(MeasureResultNode, Object);
+};
+
+/*!
+ * \brief Managed reference to MeasureResultNode.
+ * \sa MeasureResultNode
+ */
+class MeasureResult : public ObjectRef {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param costs The time costs of execution.
+ * \param error_no The error code.
+ * \param error_msg The error message if there is any error.
+ * \param all_cost The time cost of build and run.
+ * \param timestamp The time stamps of this measurement.
+ */
+ MeasureResult(Array<PrimExpr> costs, int error_no, String error_msg, double all_cost,
+ double timestamp);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(MeasureResult, ObjectRef, MeasureResultNode);
+};
+
+/*! \brief Bass class of measurement callbacks */
+class MeasureCallbackNode : public Object {
+ public:
+ /*!
+ * \brief Callback function that will be called on measurement input/result pairs
+ * after measurement.
+ * \param policy The current search policy.
+ * \param inputs An Array of MeasureInput.
+ * \param results An Array of MeasureResult.
+ */
+ virtual void Callback(const SearchPolicy& policy, const Array<MeasureInput>& inputs,
+ const Array<MeasureResult>& results) = 0;
+ static constexpr const char* _type_key = "auto_schedule.MeasureCallback";
+ TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object);
+};
+
+/*!
+ * \brief Managed reference to MeasureCallbackNode.
+ * \sa MeasureCallbackNode
+ */
+class MeasureCallback : public ObjectRef {
+ public:
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode);
+};
+
+// The base class of ProgramBuilders and ProgramRunners.
+
+/*! \brief ProgramBuilder that builds the programs */
+class ProgramBuilderNode : public Object {
+ public:
+ /*! \brief The number of tasks to run in parallel */
+ int n_parallel;
+ /*! \brief Timeout of a build */
+ int timeout;
+
+ /*!
+ * \brief Build programs and return results.
+ * \param inputs An Array of MeasureInput.
+ * \param verbose Verbosity level. 0 for silent, 1 to output information during program
+ * building.
+ * \return An Array of MeasureResult.
+ */
+ virtual Array<BuildResult> Build(const Array<MeasureInput>& inputs, int verbose) = 0;
+
+ static constexpr const char* _type_key = "auto_schedule.ProgramBuilder";
+ TVM_DECLARE_BASE_OBJECT_INFO(ProgramBuilderNode, Object);
+};
+
+/*!
+ * \brief Managed reference to ProgramBuilderNode.
+ * \sa ProgramBuilderNode
+ */
+class ProgramBuilder : public ObjectRef {
+ public:
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ProgramBuilder, ObjectRef, ProgramBuilderNode);
+};
+
+/*! \brief ProgramRunner that runs the built programs and measure the time cost. */
+class ProgramRunnerNode : public Object {
+ public:
+ /*! \brief Timeout of a run. */
+ int timeout;
+
+ /*!
+ * \brief Run measurement and return results.
+ * \param inputs An Array of MeasureInput.
+ * \param build_results An Array of BuildResult.
+ * \param verbose Verbosity level. 0 for silent, 1 to output information during program
+ * running.
+ * \return An Array of MeasureResult.
+ */
+ virtual Array<MeasureResult> Run(const Array<MeasureInput>& inputs,
+ const Array<BuildResult>& build_results, int verbose) = 0;
+
+ static constexpr const char* _type_key = "auto_schedule.ProgramRunner";
+ TVM_DECLARE_BASE_OBJECT_INFO(ProgramRunnerNode, Object);
+};
+
+/*!
+ * \brief Managed reference to ProgramRunnerNode.
+ * \sa ProgramRunnerNode
+ */
+class ProgramRunner : public ObjectRef {
+ public:
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ProgramRunner, ObjectRef, ProgramRunnerNode);
+};
+
+// Implementation of various builders and runners
+
+/*! \brief LocalBuilder use local CPU cores to build programs in parallel */
+class LocalBuilderNode : public ProgramBuilderNode {
+ public:
+ /*! \brief Build function. */
+ String build_func;
+
+ Array<BuildResult> Build(const Array<MeasureInput>& inputs, int verbose) final;
+
+ static constexpr const char* _type_key = "auto_schedule.LocalBuilder";
+ TVM_DECLARE_FINAL_OBJECT_INFO(LocalBuilderNode, ProgramBuilderNode);
+};
+
+/*!
+ * \brief Managed reference to LocalBuilderNode.
+ * \sa LocalBuilderNode
+ */
+class LocalBuilder : public ProgramBuilder {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param timeout The timeout limit (in second) for each build thread.
+ * This will be used in a wrapper of the multiprocessing.Process.join().
+ * \param n_parallel Number of threads used to build in parallel.
+ * \param build_func The name of registered build function.
+ */
+ LocalBuilder(int timeout, int n_parallel, const String& build_func);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(LocalBuilder, ProgramBuilder, LocalBuilderNode);
+};
+
+/*! \brief LocalRunner that uses local CPU/GPU to measures the time cost of programs */
+class LocalRunnerNode : public ProgramRunnerNode {
+ public:
+ /*! \brief Number of measure times. */
+ int number;
+ /*! \brief Number of repeat times in each measure. */
+ int repeat;
+ /*! \brief The minimum duration of one repeat in milliseconds. */
+ int min_repeat_ms;
+ /*! \brief The cool down interval between two measurements. */
+ double cooldown_interval;
+
+ Array<MeasureResult> Run(const Array<MeasureInput>& inputs,
+ const Array<BuildResult>& build_results, int verbose) final;
+
+ static constexpr const char* _type_key = "auto_schedule.LocalRunner";
+ TVM_DECLARE_FINAL_OBJECT_INFO(LocalRunnerNode, ProgramRunnerNode);
+};
+
+/*!
+ * \brief Managed reference to LocalRunnerNode.
+ * \sa LocalRunnerNode
+ */
+class LocalRunner : public ProgramRunner {
+ public:
+ /*!
+ * \brief The constructor. See the corresponding class in python/tvm/auto_schedule/measure.py
+ * for more detailed parameter explaination.
+ * \param timeout The timeout limit (in second) for each run.
+ * This is used in a wrapper of the multiprocessing.Process.join().
+ * \param number Number of measure times.
+ * \param repeat Number of repeat times in each measure.
+ * \param min_repeat_ms The minimum duration of one repeat in milliseconds.
+ * \param cooldown_interval The cool down interval between two measurements.
+ */
+ LocalRunner(int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval);
+
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LocalRunner, ProgramRunner, LocalRunnerNode);
+};
+
+/*!
+ * \brief Measurer that measures the time costs of tvm programs
+ * This class combines ProgramBuilder and ProgramRunner, and provides a simpler API */
+class ProgramMeasurerNode : public Object {
+ public:
+ /*! \brief Measured programs counter. */
+ int ct;
+ /*! \brief Continuous error counter. */
+ int error_ct;
+ /*! \brief Workload key to best flops map. */
+ std::unordered_map<std::string, double> best_flops;
+ /*! \brief Workload key to best state map. */
+ std::unordered_map<std::string, State> best_state;
+ /*! \brief Workload key to best state's count index map. */
+ std::unordered_map<std::string, int> best_ct;
+ /*! \brief The ProgramBuilder to build each program. */
+ ProgramBuilder builder;
+ /*! \brief The ProgramRunner to measure each program. */
+ ProgramRunner runner;
+ /*! \brief MeasureCallback to be called after each measure batch. */
+ Optional<Array<MeasureCallback>> callbacks;
+ /*! \brief Verbosity level. 0 for silent, 1 to output information during program measuring. */
+ int verbose;
+ /*! \brief The number of max continuous error. */
+ int max_continous_error;
+
+ /*! \brief Reset book keeping variables */
+ void Reset();
+
+ /*!
+ * \brief Do measurement.
+ * \param task The current SearchTask.
+ * \param policy The current SearchPolicy.
+ * \param inputs The MeasureInputs.
+ * \param results A pointer to a MeasureResult Array, this is used as output.
+ * \param batch_size Number of programs to be measured in one batch.
+ */
+ void Measure(const SearchTask& task, const SearchPolicy& policy,
+ const Array<MeasureInput>& inputs, Array<MeasureResult>* results,
+ int batch_size = -1);
+ /*!
+ * \brief Do measurement silently.
+ * This API will not print the measure results to screen.
+ * \param task The current SearchTask.
+ * \param inputs The MeasureInputs.
+ * \param results A pointer to a MeasureResult Array, this is used as output.
+ */
+ void SilentMeasure(const SearchTask& task, const Array<MeasureInput>& inputs,
+ Array<MeasureResult>* results);
+
+ /*! \brief The default max continuous error setting. */
+ static const int DEFAULT_MAX_CONTINOUS_ERROR = 150;
+
+ static constexpr const char* _type_key = "auto_schedule.ProgramMeasurer";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ProgramMeasurerNode, Object);
+};
+
+/*!
+ * \brief Managed reference to ProgramMeasurerNode.
+ * \sa ProgramMeasurerNode
+ */
+class ProgramMeasurer : public ObjectRef {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param builder The ProgramBuilder to build each program.
+ * \param runner The ProgramRunner to measure each program.
+ * \param callbacks MeasureCallback to be called after each measure batch.
+ * \param verbose Verbosity level. 0 for silent, 1 to output information during program
+ * measuring.
+ * \param max_continous_error The number of max continuous error.
+ */
+ ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner,
+ Optional<Array<MeasureCallback>> callbacks, int verbose,
+ int max_continous_error = -1);
+
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ProgramMeasurer, ObjectRef, ProgramMeasurerNode);
+};
+
+} // namespace auto_schedule
+} // namespace tvm
+
+#endif // TVM_AUTO_SCHEDULE_MEASURE_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_schedule/measure_record.cc
+ * \brief Json serialization format for dumping and loading tuning records.
+ */
+
+#include "measure_record.h"
+
+#include <dmlc/json.h>
+#include <tvm/runtime/registry.h>
+
+#include <fstream>
+#include <sstream>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "loop_state.h"
+#include "transform_step.h"
+#include "utils.h"
+
+// Json serialization handler for MeasureInput, MeasureResult
+// (and recursively for SearchTask, State, Step, ...)
+namespace dmlc {
+namespace json {
+
+inline std::vector<int> IntArrayToVector(const ::tvm::Array<::tvm::Integer>& data) {
+ std::vector<int> out;
+ for (const auto& x : data) {
+ CHECK(x.defined());
+ out.push_back(x);
+ }
+ return out;
+}
+
+inline std::vector<int> IntArrayToVector(
+ const ::tvm::Array<::tvm::Optional<::tvm::Integer>>& data) {
+ std::vector<int> out;
+ for (const auto& x : data) {
+ CHECK(x);
+ out.push_back(x.value());
+ }
+ return out;
+}
+
+template <>
+struct Handler<::tvm::Array<::tvm::auto_schedule::Stage>> {
+ inline static void Write(dmlc::JSONWriter* writer,
+ const ::tvm::Array<::tvm::auto_schedule::Stage>& data) {
+ writer->BeginArray(false);
+ writer->EndArray();
+ }
+ inline static void Read(dmlc::JSONReader* reader,
+ ::tvm::Array<::tvm::auto_schedule::Stage>* data) {
+ bool s;
+ reader->BeginArray();
+ s = reader->NextArrayItem();
+ CHECK(!s);
+ }
+};
+
+template <>
+struct Handler<::tvm::Array<::tvm::auto_schedule::Step>> {
+ inline static void Write(dmlc::JSONWriter* writer,
+ const ::tvm::Array<::tvm::auto_schedule::Step>& data) {
+ writer->BeginArray(false);
+ for (size_t i = 0; i < data.size(); ++i) {
+ writer->WriteArraySeperator();
+ writer->BeginArray(false);
+ if (auto ps = data[i].as<::tvm::auto_schedule::ReorderStepNode>()) {
+ writer->WriteArrayItem(std::string("RE"));
+ writer->WriteArrayItem(ps->stage_id);
+ writer->WriteArrayItem(IntArrayToVector(ps->after_ids));
+ } else if (auto ps = data[i].as<::tvm::auto_schedule::SplitStepNode>()) {
+ writer->WriteArrayItem(std::string("SP"));
+ writer->WriteArrayItem(ps->stage_id);
+ writer->WriteArrayItem(ps->iter_id);
+ writer->WriteArrayItem(ps->extent ? ::tvm::auto_schedule::GetIntImm(ps->extent.value())
+ : 0);
+ writer->WriteArrayItem(IntArrayToVector(ps->lengths));
+ writer->WriteArrayItem(static_cast<int>(ps->inner_to_outer));
+ } else if (auto ps = data[i].as<::tvm::auto_schedule::FuseStepNode>()) {
+ writer->WriteArrayItem(std::string("FU"));
+ writer->WriteArrayItem(ps->stage_id);
+ writer->WriteArrayItem(IntArrayToVector(ps->fused_ids));
+ } else {
+ LOG(FATAL) << "Invalid step: " << data[i];
+ }
+ writer->EndArray();
+ }
+ writer->EndArray();
+ }
+
+ inline static void Read(dmlc::JSONReader* reader,
+ ::tvm::Array<::tvm::auto_schedule::Step>* data) {
+ std::vector<int> int_list;
+ bool s, inner_to_outer;
+ std::string name, scope_name, pragma_type, ti_func_name;
+ int stage_id, iter_id, extent;
+
+ reader->BeginArray();
+ data->clear();
+ while (reader->NextArrayItem()) {
+ reader->BeginArray();
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&name);
+ if (name == "RE") {
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&stage_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&int_list);
+ ::tvm::Array<::tvm::Integer> after_ids;
+ for (const auto& i : int_list) {
+ after_ids.push_back(i);
+ }
+ data->push_back(::tvm::auto_schedule::ReorderStep(stage_id, after_ids));
+ } else if (name == "SP") {
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&stage_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&iter_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&extent);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&int_list);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&inner_to_outer);
+ ::tvm::Array<::tvm::Optional<::tvm::Integer>> lengths;
+ for (const auto& i : int_list) {
+ lengths.push_back(::tvm::Integer(i));
+ }
+ data->push_back(::tvm::auto_schedule::SplitStep(
+ stage_id, iter_id, extent == 0 ? ::tvm::PrimExpr() : extent, lengths, inner_to_outer));
+ } else if (name == "FU") {
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&stage_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&int_list);
+ ::tvm::Array<::tvm::Integer> fused_ids;
+ for (const auto& i : int_list) {
+ fused_ids.push_back(i);
+ }
+ data->push_back(::tvm::auto_schedule::FuseStep(stage_id, fused_ids));
+ } else {
+ LOG(FATAL) << "Invalid step format";
+ }
+ s = reader->NextArrayItem();
+ CHECK(!s);
+ }
+ }
+};
+
+template <>
+struct Handler<::tvm::auto_schedule::StateNode> {
+ inline static void Write(dmlc::JSONWriter* writer, const ::tvm::auto_schedule::StateNode& data) {
+ writer->BeginArray(false);
+ writer->WriteArrayItem(data.stages);
+ writer->WriteArrayItem(data.transform_steps);
+ writer->EndArray();
+ }
+ inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_schedule::StateNode* data) {
+ reader->BeginArray();
+ bool s;
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&data->stages);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&data->transform_steps);
+ s = reader->NextArrayItem();
+ CHECK(!s);
+ }
+};
+
+template <>
+struct Handler<::tvm::auto_schedule::SearchTaskNode> {
+ inline static void Write(dmlc::JSONWriter* writer,
+ const ::tvm::auto_schedule::SearchTaskNode& data) {
+ writer->BeginArray(false);
+ writer->WriteArrayItem(std::string(data.workload_key));
+ writer->WriteArrayItem(data.target->str());
+ writer->EndArray();
+ }
+ inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_schedule::SearchTaskNode* data) {
+ std::string target_str;
+ bool s;
+
+ reader->BeginArray();
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&target_str);
+ data->workload_key = std::move(target_str);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&target_str);
+ data->target = ::tvm::Target::Create(target_str);
+ s = reader->NextArrayItem();
+ CHECK(!s);
+ }
+};
+
+template <>
+struct Handler<::tvm::auto_schedule::MeasureInputNode> {
+ inline static void Write(dmlc::JSONWriter* writer,
+ const ::tvm::auto_schedule::MeasureInputNode& data) {
+ writer->BeginArray(false);
+ writer->WriteArrayItem(*data.task.operator->());
+ writer->WriteArrayItem(*data.state.operator->());
+ writer->EndArray();
+ }
+ inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_schedule::MeasureInputNode* data) {
+ bool s;
+ auto task_node = ::tvm::make_object<::tvm::auto_schedule::SearchTaskNode>();
+ auto state_node = ::tvm::make_object<::tvm::auto_schedule::StateNode>();
+ state_node->concrete = true;
+
+ reader->BeginArray();
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(task_node.get());
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(state_node.get());
+ s = reader->NextArrayItem();
+ CHECK(!s);
+
+ data->task = ::tvm::auto_schedule::SearchTask(task_node);
+ data->state = ::tvm::auto_schedule::State(state_node);
+ }
+};
+
+template <>
+struct Handler<::tvm::auto_schedule::MeasureResultNode> {
+ inline static void Write(dmlc::JSONWriter* writer,
+ const ::tvm::auto_schedule::MeasureResultNode& data) {
+ writer->BeginArray(false);
+ writer->WriteArraySeperator();
+ writer->BeginArray(false);
+ for (const auto& x : data.costs) {
+ auto pf = x.as<::tvm::tir::FloatImmNode>();
+ CHECK(pf != nullptr) << "Cost can only contain float values";
+ writer->WriteArrayItem(pf->value);
+ }
+ writer->EndArray();
+ writer->WriteArrayItem(data.error_no);
+ writer->WriteArrayItem(data.all_cost);
+ writer->WriteArrayItem(static_cast<int>((data.timestamp)));
+ writer->EndArray();
+ }
+ inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_schedule::MeasureResultNode* data) {
+ bool s;
+ std::vector<double> tmp;
+
+ reader->BeginArray();
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&tmp);
+ data->costs.clear();
+ for (const auto& i : tmp) {
+ data->costs.push_back(::tvm::FloatImm(::tvm::DataType::Float(64), i));
+ }
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&data->error_no);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&data->all_cost);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&data->timestamp);
+ s = reader->NextArrayItem();
+ CHECK(!s);
+ }
+};
+
+} // namespace json
+} // namespace dmlc
+
+namespace tvm {
+namespace auto_schedule {
+
+TVM_REGISTER_OBJECT_TYPE(RecordToFileNode);
+TVM_REGISTER_OBJECT_TYPE(RecordReaderNode);
+
+const std::string AUTO_SCHEDULE_LOG_VERSION = "v0.2"; // NOLINT(*)
+
+RecordToFile::RecordToFile(String filename) {
+ auto node = make_object<RecordToFileNode>();
+ node->filename = std::move(filename);
+ data_ = std::move(node);
+}
+
+void WriteMeasureRecords(std::ostream* os, const Array<MeasureInput>& inputs,
+ const Array<MeasureResult>& results) {
+ dmlc::JSONWriter writer(os);
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ writer.BeginObject(false);
+ writer.WriteObjectKeyValue("i", *inputs[i].operator->());
+ writer.WriteObjectKeyValue("r", *results[i].operator->());
+ writer.WriteObjectKeyValue("v", AUTO_SCHEDULE_LOG_VERSION);
+ writer.EndObject();
+ *os << "\n";
+ }
+}
+
+void ReadMeasureRecord(const std::string& str, MeasureInputNode* inp, MeasureResultNode* res,
+ std::string* log_version) {
+ std::istringstream ss(str);
+ dmlc::JSONReader reader(&ss);
+ std::string key;
+
+ reader.BeginObject();
+ while (reader.NextObjectItem(&key)) {
+ if (key == "i") {
+ reader.Read(inp);
+ } else if (key == "r") {
+ reader.Read(res);
+ } else if (key == "v") {
+ reader.Read(log_version);
+ } else {
+ LOG(FATAL) << "Invalid key in json log: " << key;
+ }
+ }
+}
+
+void RecordToFileNode::Callback(const SearchPolicy& policy, const Array<MeasureInput>& inputs,
+ const Array<MeasureResult>& results) {
+ std::ofstream ofs(filename, std::ofstream::app);
+ WriteMeasureRecords(&ofs, inputs, results);
+}
+
+RecordReader::RecordReader(String filename) {
+ auto node = make_object<RecordReaderNode>();
+ node->filename = filename;
+ node->infile.open(filename, std::ifstream::in);
+ data_ = std::move(node);
+}
+
+RecordReaderNode::~RecordReaderNode() { infile.close(); }
+
+bool RecordReaderNode::ReadNext(MeasureInputNode* inp, MeasureResultNode* res) {
+ std::string log_version;
+
+ while (std::getline(infile, cur_line_)) {
+ if (cur_line_[0] == '#' || cur_line_[0] == ' ') {
+ // skip comment lines begin with '#' or ' '
+ continue;
+ }
+ ReadMeasureRecord(cur_line_, inp, res, &log_version);
+ return true;
+ }
+
+ return false;
+}
+
+std::pair<Array<MeasureInput>, Array<MeasureResult>> RecordReaderNode::ReadLines(int max_size,
+ int skip_size) {
+ auto inp = make_object<MeasureInputNode>();
+ auto res = make_object<MeasureResultNode>();
+ Array<MeasureInput> inputs;
+ Array<MeasureResult> results;
+
+ while (ReadNext(inp.get(), res.get())) {
+ if (skip_size > 0) {
+ skip_size--;
+ continue;
+ }
+
+ inputs.push_back(inp->copy());
+ results.push_back(res->copy());
+
+ if (max_size > 0 && static_cast<int>(inputs.size()) >= max_size) {
+ break;
+ }
+ }
+
+ return std::make_pair(inputs, results);
+}
+
+TVM_REGISTER_GLOBAL("auto_schedule.RecordToFile").set_body_typed([](const String& filename) {
+ return RecordToFile(filename);
+});
+
+TVM_REGISTER_GLOBAL("auto_schedule.RecordReader").set_body_typed([](const String& filename) {
+ return RecordReader(filename);
+});
+
+TVM_REGISTER_GLOBAL("auto_schedule.RecordReaderReadLines")
+ .set_body_typed([](RecordReader reader, int size, int skip_size) {
+ const auto& res = reader->ReadLines(size, skip_size);
+ return Array<ObjectRef>{res.first, res.second};
+ });
+
+TVM_REGISTER_GLOBAL("auto_schedule.RecordReaderReadNext").set_body_typed([](RecordReader reader) {
+ auto inp = make_object<MeasureInputNode>();
+ auto res = make_object<MeasureResultNode>();
+ if (reader->ReadNext(inp.get(), res.get())) {
+ return Array<ObjectRef>{ObjectRef(inp), ObjectRef(res)};
+ } else {
+ return Array<ObjectRef>();
+ }
+});
+
+TVM_REGISTER_GLOBAL("auto_schedule.SaveRecords")
+ .set_body_typed([](String filename, Array<MeasureInput> in, Array<MeasureResult> res) {
+ std::ofstream ofs(filename, std::ofstream::app);
+ WriteMeasureRecords(&ofs, in, res);
+ });
+} // namespace auto_schedule
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_schedule/measure_record.h
+ * \brief Json serialization format for dumping and loading tuning records.
+ */
+
+#ifndef TVM_AUTO_SCHEDULE_MEASURE_RECORD_H_
+#define TVM_AUTO_SCHEDULE_MEASURE_RECORD_H_
+
+#include <fstream>
+#include <string>
+#include <utility>
+
+#include "measure.h"
+
+namespace tvm {
+namespace auto_schedule {
+
+/*! \brief Callback for logging the input and results of measurements to file */
+class RecordToFileNode : public MeasureCallbackNode {
+ public:
+ /*! \brief File name for this callback to write log to. */
+ String filename;
+
+ void Callback(const SearchPolicy& policy, const Array<MeasureInput>& inputs,
+ const Array<MeasureResult>& results) final;
+
+ static constexpr const char* _type_key = "auto_schedule.RecordToFile";
+ TVM_DECLARE_FINAL_OBJECT_INFO(RecordToFileNode, MeasureCallbackNode);
+};
+
+/*!
+ * \brief Managed reference to RecordToFileNode.
+ * \sa RecordToFileNode
+ */
+class RecordToFile : public MeasureCallback {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param filename File name for this callback to write log.
+ */
+ explicit RecordToFile(String filename);
+
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RecordToFile, MeasureCallback, RecordToFileNode);
+};
+
+/*! \brief Log reader to load step logs from a file.*/
+class RecordReaderNode : public Object {
+ public:
+ /*! \brief File name for this reader to load log from. */
+ String filename;
+ /*! \brief The reading file stream. */
+ std::ifstream infile;
+
+ ~RecordReaderNode();
+
+ /*!
+ * \brief Read next line in the log file.
+ * \param inp A pointer to a MeasureInputNode, this is used as output.
+ * \param res A pointer to a MeasureResultNode, this is used as output.
+ * \return Whether the read is successful. */
+ bool ReadNext(MeasureInputNode* inp, MeasureResultNode* res);
+
+ /*!
+ * \brief Read multiple lines from the log file.
+ * \param max_size The maximum number of lines. -1 means read all lines.
+ * \param skip_size Skip the first n lines.
+ * \return The MeasureInputs and MeasureResults loaded from the log file.
+ */
+ std::pair<Array<MeasureInput>, Array<MeasureResult>> ReadLines(int max_size = -1,
+ int skip_size = 0);
+
+ static constexpr const char* _type_key = "auto_schedule.RecordReader";
+ TVM_DECLARE_FINAL_OBJECT_INFO(RecordReaderNode, Object);
+
+ private:
+ /*! \brief A string object to store the next line. */
+ std::string cur_line_;
+};
+
+/*!
+ * \brief Managed reference to RecordReaderNode.
+ * \sa RecordReaderNode
+ */
+class RecordReader : public ObjectRef {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param filename File name for this callback to write log.
+ */
+ explicit RecordReader(String filename);
+
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RecordReader, ObjectRef, RecordReaderNode);
+};
+
+/*!
+ * \brief Write measure records to an output stream.
+ * \param os A pointer to a output stream.
+ * \param inputs The MeasureInputs to be written.
+ * \param results The MeasureResults to be written.
+ */
+void WriteMeasureRecords(std::ostream* os, const Array<MeasureInput>& inputs,
+ const Array<MeasureResult>& results);
+
+/*!
+ * \brief Read one measure record from a string.
+ * \param str The record string to be extract.
+ * \param inp A pointer to a MeasureInputNode, this is used as output.
+ * \param res A pointer to a MeasureResultNode, this is used as output.
+ * \param log_version A pointer to a log version string.
+ */
+void ReadMeasureRecord(const std::string& str, MeasureInputNode* inp, MeasureResultNode* res,
+ std::string* log_version);
+
+} // namespace auto_schedule
+} // namespace tvm
+
+#endif // TVM_AUTO_SCHEDULE_MEASURE_RECORD_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_schedule/search_policy/empty_policy.cc
+ * \brief This is an brief example of search policy.
+ */
+
+#include "empty_policy.h"
+
+#include <tvm/runtime/registry.h>
+
+#include "../measure.h"
+
+namespace tvm {
+namespace auto_schedule {
+
+TVM_REGISTER_NODE_TYPE(EmptyPolicyNode);
+
+State EmptyPolicyNode::Search(SearchTask task, int num_measure_trials, int early_stopping,
+ int num_measures_per_round, int verbose, ProgramMeasurer measurer,
+ Optional<Array<SearchCallback>> pre_search_callbacks) {
+ cur_task = task;
+
+ // Run pre_search_callbacks before the search process
+ // This Interface is usually used to set some init status
+ RunCallbacks(pre_search_callbacks);
+
+ // Basic design principe: `SearchOneRound()` several times to get candidate states,
+ // measure them and return the best one
+ // Measure is disabled if num_measure_trials <= 1
+ if (num_measure_trials <= 1) {
+ const auto& res = SearchOneRound();
+ CHECK_GT(res.size(), 0);
+
+ return res[0];
+ } else {
+ Array<MeasureInput> inputs;
+ Array<MeasureResult> results;
+
+ measurer->Reset();
+ int ct = 0;
+ // In each round, we call SearchOneRound to get several candidate states,
+ // then use ProgramMeasurer to test their performance
+ while (ct < num_measure_trials) {
+ const auto& res = SearchOneRound();
+ ct += res.size();
+ // Build MeasureInputs for measuring
+ inputs.clear();
+ for (const auto& state : res) {
+ // The class members measured_states_set_ provided by SearchPolicy can be used to filter
+ // out the already measured states
+ inputs.push_back(MeasureInput(cur_task, state));
+ }
+ // ProgramMeasurer will record the state with best performance during measure process
+ measurer->Measure(cur_task, GetRef<SearchPolicy>(this), inputs, &results);
+ }
+
+ // Return a state with best measured performance
+ return measurer->best_state[cur_task->workload_key];
+ }
+}
+
+// As an example policy, EmptyPolicy always returns a init state
+Array<State> EmptyPolicyNode::SearchOneRound() {
+ Array<State> res;
+
+ // 1. We will process `Program sampling` first to generate several initial schedules
+ res.push_back(cur_task->compute_dag->init_state);
+
+ // 2. Then `Performance Tuning`: use cost model and evolutionary search to seek for the schedule
+ // with best performance
+ // Note: This example policy does not include this part
+
+ // 3. The returned candidate schedules will be measured in hardware
+ return res;
+}
+
+TVM_REGISTER_GLOBAL("auto_schedule.EmptyPolicy").set_body_typed([]() {
+ return EmptyPolicy(make_object<EmptyPolicyNode>());
+});
+
+} // namespace auto_schedule
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_schedule/search_policy/empty_policy.h
+ * \brief A brief example of the search policy which always returns the initial naive schedule
+ * (state).
+ */
+
+#ifndef TVM_AUTO_SCHEDULE_SEARCH_POLICY_EMPTY_POLICY_H_
+#define TVM_AUTO_SCHEDULE_SEARCH_POLICY_EMPTY_POLICY_H_
+
+#include "../loop_state.h"
+#include "search_policy.h"
+
+namespace tvm {
+namespace auto_schedule {
+
+/*!
+ * \brief A brief example of the search policy which always returns the initial naive schedule
+ * (state), the formal search policy will continue to follow its design.
+ * The key implementation for this structure is `Search()`, check `empty_policy.cc` for more
+ * details.
+ */
+class EmptyPolicyNode : public SearchPolicyNode {
+ public:
+ State Search(SearchTask task, int num_measure_trials, int early_stopping,
+ int num_measures_per_round, int verbose, ProgramMeasurer measurer,
+ Optional<Array<SearchCallback>> pre_search_callbacks) final;
+
+ static constexpr const char* _type_key = "auto_schedule.EmptyPolicy";
+ TVM_DECLARE_FINAL_OBJECT_INFO(EmptyPolicyNode, SearchPolicyNode);
+
+ private:
+ /*!
+ * \brief Use a sub function to generate several candidate states in each search round.
+ * \returns Several generated states
+ */
+ Array<State> SearchOneRound();
+};
+
+/*!
+ * \brief Managed reference to EmptyPolicyNode.
+ * \sa EmptyPolicyNode
+ */
+class EmptyPolicy : public SearchPolicy {
+ public:
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EmptyPolicy, SearchPolicy, EmptyPolicyNode);
+};
+
+} // namespace auto_schedule
+} // namespace tvm
+
+#endif // TVM_AUTO_SCHEDULE_SEARCH_POLICY_EMPTY_POLICY_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_schedule/search_policy/search_policy.cc
+ * \brief The base class of search policies.
+ */
+
+#include "search_policy.h"
+
+#include <tvm/runtime/registry.h>
+
+namespace tvm {
+namespace auto_schedule {
+
+TVM_REGISTER_OBJECT_TYPE(SearchCallbackNode);
+TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode);
+
+void SearchPolicyNode::RunCallbacks(const Optional<Array<SearchCallback>>& callbacks) {
+ if (callbacks) {
+ for (const auto& callback : callbacks.value()) {
+ callback->Callback(this);
+ }
+ }
+}
+
+TVM_REGISTER_GLOBAL("auto_schedule.SearchPolicyRunCallbacks")
+ .set_body_typed([](SearchPolicy policy, Optional<Array<SearchCallback>> callbacks) {
+ policy->RunCallbacks(callbacks);
+ });
+
+TVM_REGISTER_GLOBAL("auto_schedule.SearchPolicySetTask")
+ .set_body_typed([](SearchPolicy policy, SearchTask task) { policy->cur_task = task; });
+
+TVM_REGISTER_GLOBAL("auto_schedule.SearchPolicySetVerbose")
+ .set_body_typed([](SearchPolicy policy, int verbose) { policy->verbose = verbose; });
+
+} // namespace auto_schedule
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_schedule/search_policy/search_policy.h
+ * \brief The base class of search policies, including the abstract definition of search policy and
+ * other supporting data structures.
+ *
+ * The basic schedule search process for TVM Auto-scheduler is design to be:
+ * `Program sampling` -> `Performance Tuning`.
+ *
+ * In `Program sampling`, we use some predefined precise or heuristic rules to generate several
+ * initial schedules. Based on these initial starting points, we perform `Performance Tuning` which
+ * uses cost model based evolutionary search to select schedules with the best performance.
+ *
+ * Candidate schedules are measured against the specific hardware target.
+ *
+ * \note Adding a new search policy.
+ * In design, there's no need for users to implement their own search policy, our formal search
+ * policy(will be brought later) should be enough to cover most use cases. Meanwhile, a custom rule
+ * mechanism will be provided to enable user-defined template search to serve the same functionality
+ * as the current AutoTVM template.
+ *
+ * This guide is for advanced uses who have special requirements.
+ * 1. The only function that must be implemented is Search(), which takes a task as input and
+ * returns the best states found.
+ * 2. Information about the compute declaration of ops/subgraphs can be acquired from SearchTask.
+ * This structure also contains some information about the target device. (e.g. knowing the width
+ * of the device vector unit, we can limit the max vectorize size during schedule search)
+ * 3. SearchCallback provides more flexibility to do extra affairs before/after the search process.
+ * 4. ProgramMeasurer provides a simple but useful api to help check the performance of states got
+ * during the search process.
+ */
+
+#ifndef TVM_AUTO_SCHEDULE_SEARCH_POLICY_SEARCH_POLICY_H_
+#define TVM_AUTO_SCHEDULE_SEARCH_POLICY_SEARCH_POLICY_H_
+
+#include <tvm/node/node.h>
+
+#include <unordered_set>
+#include <vector>
+
+#include "../search_task.h"
+
+namespace tvm {
+namespace auto_schedule {
+
+class ProgramMeasurer;
+class SearchPolicyNode;
+
+/*!
+ * \brief Callback function to be called by the search process.
+ * This interface allows to do extra initializations before schedule search or extra
+ * check during/after the schedule search.
+ */
+class SearchCallbackNode : public Object {
+ public:
+ /*!
+ * \brief Run the registered callback function.
+ * \param policy A pointer to a SearchPolicyNode.
+ */
+ virtual void Callback(SearchPolicyNode* policy) = 0;
+
+ static constexpr const char* _type_key = "auto_schedule.SearchCallback";
+ TVM_DECLARE_BASE_OBJECT_INFO(SearchCallbackNode, Object);
+};
+
+/*!
+ * \brief Managed reference to SearchCallbackNode.
+ * \sa SearchCallbackNode
+ */
+class SearchCallback : public ObjectRef {
+ public:
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchCallback, ObjectRef, SearchCallbackNode);
+};
+
+/*!
+ * \brief The base class of search policies.
+ */
+class SearchPolicyNode : public Object {
+ public:
+ /*! \brief The current search task. */
+ SearchTask cur_task;
+ /*!
+ * \brief Verbose level to control the screen output during schedule search.
+ * 0 for silent, 1 to output state & measure information during search process.
+ */
+ int verbose;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("cur_task", &cur_task);
+ v->Visit("verbose", &verbose);
+ }
+
+ /*!
+ * \brief Do schedule search for a task. Takes the SearchTask as input and returns the best state
+ * get during the search process.
+ * \param task The SearchTask or workload key for the computation declaration
+ * \param num_measure_trials Total schedules to be tried during this search.
+ * \param early_stopping Early stop if no better schedule is found.
+ * \param num_measures_per_round Max measure batch in one search round.
+ * \param verbose Verbose level. 0 for silent, 1 to output information during schedule
+ * search.
+ * \param measurer A ProgramMeasurer which packs ProgramBuilder & ProgramRunner inside.
+ * \param pre_search_callbacks SearchCallback to be called before schedule search.
+ * \return The best state get.
+ */
+ virtual State Search(SearchTask task, int num_measure_trials, int early_stopping,
+ int num_measures_per_round, int verbose, ProgramMeasurer measurer,
+ Optional<Array<SearchCallback>> pre_search_callbacks) = 0;
+
+ /*!
+ * \brief Call SearchCallback with the current SearchPolicyNode
+ * \param callbacks SearchCallback to be called.
+ */
+ void RunCallbacks(const Optional<Array<SearchCallback>>& callbacks);
+
+ static constexpr const char* _type_key = "auto_schedule.SearchPolicy";
+ TVM_DECLARE_BASE_OBJECT_INFO(SearchPolicyNode, Object);
+
+ protected:
+ /*!
+ * \brief The set of already measured states.
+ * During the schedule search process, we may generate `equal states` through different search
+ * branches. (Equal States: 1. the transform steps are totally the same; 2. even with different
+ * steps, two states may still result in a same schedule. e.g. To split a axis with extent 512
+ * to 3 parts [8, 16, 4]. We can split from inner to outter by factors [16, 4], while we can
+ * get a same result to split from outter to inner by factors [8, 16])
+ * We store the string format of a state for redundancy check. This is used to make sure a
+ * measured state will never be measured again.
+ */
+ std::unordered_set<String> measured_states_set_;
+ /*! \brief The array of already measured states. This can be used in evolutionary search. */
+ std::vector<State> measured_states_vector_;
+ /*! \brief The throughputs of already measured states */
+ std::vector<float> measured_states_throughputs_;
+};
+
+/*!
+ * \brief Managed reference to SearchPolicyNode.
+ * \sa SearchPolicyNode
+ */
+class SearchPolicy : public ObjectRef {
+ public:
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchPolicy, ObjectRef, SearchPolicyNode);
+};
+
+} // namespace auto_schedule
+} // namespace tvm
+
+#endif // TVM_AUTO_SCHEDULE_SEARCH_POLICY_SEARCH_POLICY_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_schedule/search_task.cc
+ * \brief Meta information and hardware parameters for a search task.
+ */
+
+#include "search_task.h"
+
+#include <tvm/runtime/registry.h>
+#include <tvm/runtime/threading_backend.h>
+
+#include <utility>
+
+namespace tvm {
+namespace auto_schedule {
+
+TVM_REGISTER_NODE_TYPE(HardwareParamsNode);
+TVM_REGISTER_NODE_TYPE(SearchTaskNode);
+
+HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes) {
+ auto node = make_object<HardwareParamsNode>();
+ node->num_cores = num_cores;
+ node->vector_unit_bytes = vector_unit_bytes;
+ node->cache_line_bytes = cache_line_bytes;
+ data_ = std::move(node);
+}
+
+HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target,
+ const Target& target_host) {
+ if (target->id->name == "llvm") {
+ return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64);
+ } else {
+ LOG(FATAL) << "No default hardware parameters for target: " << target;
+ }
+ return HardwareParams();
+}
+
+SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target target,
+ Target target_host, Optional<HardwareParams> hardware_params) {
+ auto node = make_object<SearchTaskNode>();
+ node->compute_dag = std::move(compute_dag);
+ node->workload_key = std::move(workload_key);
+ node->target = std::move(target);
+ node->target_host = std::move(target_host);
+ if (hardware_params) {
+ node->hardware_params = hardware_params.value();
+ } else {
+ node->hardware_params =
+ HardwareParamsNode::GetDefaultHardwareParams(node->target, node->target_host);
+ }
+ data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("auto_schedule.HardwareParams")
+ .set_body_typed([](int num_cores, int vector_unit_bytes, int cache_line_bytes) {
+ return HardwareParams(num_cores, vector_unit_bytes, cache_line_bytes);
+ });
+
+TVM_REGISTER_GLOBAL("auto_schedule.SearchTask")
+ .set_body_typed([](ComputeDAG compute_dag, String workload_key, Target target,
+ Target target_host, Optional<HardwareParams> hardware_params) {
+ return SearchTask(compute_dag, workload_key, target, target_host, hardware_params);
+ });
+
+} // namespace auto_schedule
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_schedule/search_task.h
+ * \brief Meta information and hardware parameters for a search task.
+ */
+
+#ifndef TVM_AUTO_SCHEDULE_SEARCH_TASK_H_
+#define TVM_AUTO_SCHEDULE_SEARCH_TASK_H_
+
+#include <tvm/target/target.h>
+
+#include "compute_dag.h"
+
+namespace tvm {
+namespace auto_schedule {
+
+class HardwareParams;
+
+/*! \brief The parameters of target hardware used to guide the search process of SearchPolicy. */
+class HardwareParamsNode : public Object {
+ public:
+ /*! \brief The number of cores. */
+ int num_cores;
+ /*! \brief The width of vector units in bytes. */
+ int vector_unit_bytes;
+ /*! \brief The size of cache line in bytes. */
+ int cache_line_bytes;
+
+ // GPU related parameters got from device query API
+
+ /*! \brief The max shared memory per block. */
+ int max_shared_memory_per_block{INT32_MAX};
+ /*! \brief The max register memory per block. */
+ int max_registers_per_block{INT32_MAX};
+ /*! \brief The max threads per block. */
+ int max_threads_per_block{INT32_MAX};
+ /*! \brief The max vthread extent. */
+ int max_vthread_extent{INT32_MAX};
+ /*! \brief The thread numbers of a warp. */
+ int warp_size{INT32_MAX};
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("num_cores", &num_cores);
+ v->Visit("vector_unit_bytes", &vector_unit_bytes);
+ v->Visit("cache_line_bytes", &cache_line_bytes);
+ v->Visit("max_shared_memory_per_block", &max_shared_memory_per_block);
+ v->Visit("max_registers_per_block", &max_registers_per_block);
+ v->Visit("max_threads_per_block", &max_threads_per_block);
+ v->Visit("max_vthread_extent", &max_vthread_extent);
+ v->Visit("warp_size", &warp_size);
+ }
+
+ /*!
+ * \brief Get the default hardware params.
+ * \param target A `tvm.target`.
+ * \param target_host A `tvm.target` for host device.
+ * \return A HardwareParams object.
+ */
+ static HardwareParams GetDefaultHardwareParams(const Target& target, const Target& target_host);
+
+ static constexpr const char* _type_key = "auto_schedule.HardwareParams";
+ TVM_DECLARE_FINAL_OBJECT_INFO(HardwareParamsNode, Object);
+};
+
+/*!
+ * \brief Managed reference to HardwareParamsNode.
+ * \sa HardwareParamsNode
+ */
+class HardwareParams : public ObjectRef {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param num_cores The number of cores.
+ * \param vector_unit_bytes The width of vector units in bytes.
+ * \param cache_line_bytes The size of cache line in bytes.
+ */
+ HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(HardwareParams, ObjectRef, HardwareParamsNode);
+ TVM_DEFINE_OBJECT_REF_COW_METHOD(HardwareParamsNode);
+};
+
+/*!
+ * \brief The computation information and hardware parameters for a specific schedule search task.
+ */
+class SearchTaskNode : public Object {
+ public:
+ /*! \brief The ComputeDAG for the compute declaration. */
+ ComputeDAG compute_dag;
+ /*! \brief The workload key for the compute declaration. */
+ String workload_key;
+ /*! \brief The target device of this search task. */
+ Target target;
+ /*! \brief The target host device of this search task. */
+ Target target_host;
+ /*! \brief Hardware parameters used in this search task. */
+ HardwareParams hardware_params;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("compute_dag", &compute_dag);
+ v->Visit("workload_key", &workload_key);
+ v->Visit("target", &target);
+ v->Visit("target_host", &target_host);
+ v->Visit("hardware_params", &hardware_params);
+ }
+
+ static constexpr const char* _type_key = "auto_schedule.SearchTask";
+ TVM_DECLARE_FINAL_OBJECT_INFO(SearchTaskNode, Object);
+};
+
+/*!
+ * \brief Managed reference to SearchTaskNode.
+ * \sa SearchTaskNode
+ */
+class SearchTask : public ObjectRef {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param compute_dag The ComputeDAG for the compute declaration.
+ * \param workload_key The workload key for the compute declaration.
+ * \param target The target device of this search task.
+ * \param target_host The target host device of this search task.
+ * \param hardware_params Hardware parameters used in this search task.
+ */
+ SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host,
+ Optional<HardwareParams> hardware_params);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode);
+};
+
+} // namespace auto_schedule
+} // namespace tvm
+
+#endif // TVM_AUTO_SCHEDULE_SEARCH_TASK_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_schedule/transform_step.cc
+ * \brief Transformation steps. For each schedule primitive, there is a corresponding transform
+ * step.
+ */
+
+#include "transform_step.h"
+
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+
+#include <utility>
+
+#include "loop_state.h"
+#include "utils.h"
+
+namespace tvm {
+namespace auto_schedule {
+
+/********** Reorder **********/
+ReorderStep::ReorderStep(int stage_id, const Array<Integer>& after_ids) {
+ auto node = make_object<ReorderStepNode>();
+ node->stage_id = stage_id;
+ for (const auto& x : after_ids) {
+ CHECK(x->IsInstance<IntImmNode>());
+ }
+ node->after_ids = after_ids;
+ data_ = std::move(node);
+}
+
+void ReorderStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes) const {
+ auto stage = (*stages)[stage_id];
+ const Array<IterVar>& axes = stage_to_axes->at(stage);
+ CHECK_EQ(after_ids.size(), axes.size());
+
+ Array<IterVar> new_axes;
+ new_axes.reserve(axes.size());
+ for (auto i : after_ids) {
+ new_axes.push_back(axes[i]);
+ }
+ stage.reorder(new_axes);
+
+ stage_to_axes->Set(stage, std::move(new_axes));
+ stages->Set(stage_id, std::move(stage));
+}
+
+String ReorderStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes) const {
+ const auto& stage = (*stages)[stage_id];
+ std::stringstream ss;
+
+ ss << "s[" << CleanName(stage->op->name) << "].reorder(";
+ for (size_t i = 0; i < after_ids.size(); ++i) {
+ ss << CleanName((*stage_to_axes)[stage][after_ids[i]]->var->name_hint);
+ if (i != after_ids.size() - 1) {
+ ss << ", ";
+ }
+ }
+ ss << ")\n";
+
+ ApplyToSchedule(stages, stage_to_axes);
+ return ss.str();
+}
+
+/********** Split **********/
+Array<IterVar> ApplySplitToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+ int stage_id, int iter_id,
+ const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
+ auto stage = (*stages)[stage_id];
+ const Array<IterVar>& axes = stage_to_axes->at(stage);
+
+ Array<IterVar> outs;
+ if (inner_to_outer) {
+ IterVar outer = axes[iter_id], inner;
+ for (int i = static_cast<int>(lengths.size()) - 1; i >= 0; i--) {
+ IterVar to_split = outer;
+ stage.split(to_split, lengths[i].value(), &outer, &inner);
+ outs.push_back(inner);
+ }
+ outs.push_back(outer);
+ } else {
+ IterVar outer, inner = axes[iter_id];
+ for (size_t i = 0; i < lengths.size(); i++) {
+ IterVar to_split = inner;
+ stage.split_by_nparts(to_split, lengths[i].value(), &outer, &inner);
+ outs.push_back(outer);
+ }
+ outs.push_back(inner);
+ }
+
+ Array<IterVar> new_axes;
+ new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + iter_id);
+ if (inner_to_outer) {
+ for (auto x = outs.rbegin(); x != outs.rend(); ++x) {
+ new_axes.push_back((*x));
+ }
+ } else {
+ for (const auto& x : outs) {
+ new_axes.push_back(x);
+ }
+ }
+ new_axes.insert(new_axes.end(), axes.begin() + iter_id + 1, axes.end());
+
+ stage_to_axes->Set(stage, std::move(new_axes));
+ stages->Set(stage_id, std::move(stage));
+ return outs;
+}
+
+String PrintSplitAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes, int stage_id,
+ int iter_id, const Array<Optional<Integer>>& lengths,
+ bool inner_to_outer) {
+ const auto& stage = (*stages)[stage_id];
+ auto to_split = stage_to_axes->at(stage)[iter_id];
+ const auto& func_name = CleanName(stage->op->name);
+ const auto& outs =
+ ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer);
+ CHECK_EQ(outs.size(), lengths.size() + 1);
+
+ std::stringstream ss;
+ int size = static_cast<int>(lengths.size());
+ if (inner_to_outer) {
+ for (int i = size - 1; i >= 0; i--) {
+ ss << CleanName(outs[size - i]->var->name_hint) << ", "
+ << CleanName(outs[size - i - 1]->var->name_hint) << " = s[" << func_name << "].split("
+ << CleanName(to_split->var->name_hint) << ", factor=" << lengths[i] << ")\n";
+ to_split = outs[size - i];
+ }
+ } else {
+ for (int i = 0; i < size; i++) {
+ ss << CleanName(outs[i]->var->name_hint) << ", " << CleanName(outs[i + 1]->var->name_hint)
+ << " = s[" << func_name << "].split(" << CleanName(to_split->var->name_hint)
+ << ", nparts=" << lengths[i] << ")\n";
+ to_split = outs[i + 1];
+ }
+ }
+
+ return ss.str();
+}
+
+SplitStep::SplitStep(int stage_id, int iter_id, Optional<PrimExpr> extent,
+ const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
+ auto node = make_object<SplitStepNode>();
+ node->stage_id = stage_id;
+ // Extent can be a unreducible expression in some special cases
+ if (extent && extent.value()->IsInstance<IntImmNode>()) {
+ node->extent = tvm::Downcast<Integer>(extent.value());
+ }
+ node->iter_id = iter_id;
+ node->lengths = lengths;
+ node->inner_to_outer = inner_to_outer;
+ data_ = std::move(node);
+}
+
+Array<IterVar> SplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes) const {
+ return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer);
+}
+
+String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes) const {
+ return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer);
+}
+
+/********** Fuse **********/
+FuseStep::FuseStep(int stage_id, const Array<Integer>& fused_ids) {
+ auto node = make_object<FuseStepNode>();
+ node->stage_id = stage_id;
+ for (const auto& x : fused_ids) {
+ CHECK(x->IsInstance<IntImmNode>());
+ }
+ node->fused_ids = fused_ids;
+ data_ = std::move(node);
+}
+
+IterVar FuseStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes) const {
+ auto stage = (*stages)[stage_id];
+ const Array<IterVar>& axes = stage_to_axes->at(stage);
+
+ Array<IterVar> to_fuse;
+ for (const auto& i : fused_ids) {
+ to_fuse.push_back(axes[i]);
+ }
+ IterVar fused_axis;
+ stage.fuse(to_fuse, &fused_axis);
+
+ Array<IterVar> new_axes;
+ new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids.front());
+ new_axes.push_back(fused_axis);
+ new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, axes.end());
+
+ stage_to_axes->Set(stage, std::move(new_axes));
+ stages->Set(stage_id, std::move(stage));
+ return fused_axis;
+}
+
+String FuseStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes) const {
+ const auto& stage = (*stages)[stage_id];
+ std::stringstream to_fuse;
+
+ for (size_t i = 0; i < fused_ids.size(); ++i) {
+ to_fuse << CleanName(stage_to_axes->at(stage)[fused_ids[i]]->var->name_hint);
+ if (i != fused_ids.size() - 1) {
+ to_fuse << ", ";
+ }
+ }
+
+ std::stringstream ss;
+ const auto& fused = ApplyToSchedule(stages, stage_to_axes);
+
+ ss << CleanName(fused->var->name_hint) << " = s[" << CleanName(stage->op->name) << "].fuse("
+ << to_fuse.str() << ")\n";
+
+ return ss.str();
+}
+
+} // namespace auto_schedule
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_schedule/transform_step.h
+ * \brief Transformation steps. For each schedule primitive, there is a corresponding transform
+ * step. The implementation of each step consists of 2 parts:
+ * - transform_step.cc: How each step interacts with TE and TE's schedule primitives
+ * - loop_state.cc: How each step updates LoopState
+ *
+ * \note To add a new transform step:
+ * Take fuse step for example:
+ * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its construction
+ * function `FuseStep::FuseStep(...)` in `transform_steps.cc`
+ * 2. Implement `FuseStepNode::ApplyToSchedule` and `FuseStepNode::PrintAsPythonAPI`.
+ * - In these two functions you need to lower this step with tvm's te schedule API
+ * 3. Implement `State::fuse` and `State::DoFuseStep`.
+ * - In these two functions you need to incrementally update all data structures in State with
+ * CopyOnWrite style
+ * 4. Add you step to `ComputeDAG::ApplySteps` and make sure it works.
+ * 5. Add log record serialization support in `struct Handler<Array<::tvm::auto_schedule::Step>>`
+ * in `record.cc`.
+ * 6. Add its corresponding Python API to `loop_state.py` and necessary unit test.
+ */
+
+#ifndef TVM_AUTO_SCHEDULE_TRANSFORM_STEP_H_
+#define TVM_AUTO_SCHEDULE_TRANSFORM_STEP_H_
+
+#include <dmlc/common.h>
+#include <tvm/node/node.h>
+#include <tvm/te/schedule.h>
+
+#include "utils.h"
+
+namespace tvm {
+namespace auto_schedule {
+
+typedef Map<tvm::te::Stage, Array<tir::IterVar>, ObjectHash, ObjectEqual> StageToAxesMap;
+
+/*!
+ * \brief The base class of transformation steps. Each step has its corresponding tvm.te
+ * schedule primitives.
+ */
+class StepNode : public Object {
+ public:
+ /*! \brief The index of the stage. */
+ int stage_id;
+
+ static constexpr const char* _type_key = "auto_schedule.Step";
+ TVM_DECLARE_BASE_OBJECT_INFO(StepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to StepNode.
+ * \sa StepNode
+ */
+class Step : public ObjectRef {
+ public:
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Step, ObjectRef, StepNode);
+};
+
+/*! \brief Reorder step that corresponds to te::Stage::reorder */
+class ReorderStepNode : public StepNode {
+ public:
+ /*!
+ * \brief The iterator ids after reorder.
+ * This array should specify the order of all iterators.
+ */
+ Array<Integer> after_ids;
+
+ /*!
+ * \brief Apply the current state to tvm.schedule
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ */
+ void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+ /*!
+ * \brief Print step as equivalent python schedule API.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \return Python schedule code.
+ */
+ String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+ static constexpr const char* _type_key = "auto_schedule.ReorderStep";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to ReorderStepNode.
+ * \sa ReorderStepNode
+ */
+class ReorderStep : public Step {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param stage_id The index of the stage to be reordered.
+ * \param after_ids The expected indexes of the iterators after reorder.
+ */
+ ReorderStep(int stage_id, const Array<Integer>& after_ids);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(ReorderStep, Step, ReorderStepNode);
+};
+
+/*!
+ * \brief Split step that corresponds to te::Stage::split with additional
+ * support of multiple-level of factors
+ */
+class SplitStepNode : public StepNode {
+ public:
+ /*! \brief The id of the iter to split. */
+ int iter_id;
+ /*! \brief The extent length of the axis to split. */
+ Optional<Integer> extent;
+ /*! \brief The split factors. */
+ Array<Optional<Integer>> lengths;
+ /*!
+ * \brief If true, the `lengths` denote the lengths of iterators
+ * from inner level to outer level
+ */
+ bool inner_to_outer;
+
+ /*!
+ * \brief Apply the current state to tvm.schedule
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \return The iterator results after split.
+ */
+ Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes) const;
+
+ /*!
+ * \brief Print step as equivalent python schedule API.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \return Python schedule code.
+ */
+ String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+ static constexpr const char* _type_key = "auto_schedule.SplitStep";
+ TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to SplitStepNode.
+ * \sa SplitStepNode
+ */
+class SplitStep : public Step {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param stage_id The index of the stage to be split.
+ * \param iter_id The index of the iterator to be split.
+ * \param extent The extent length of the axis to split.
+ * \param lengths The multiple split factors. Can be None to be filled by search policy.
+ * \param inner_to_outer The split direction.
+ */
+ SplitStep(int stage_id, int iter_id, Optional<PrimExpr> extent,
+ const Array<Optional<Integer>>& lengths, bool inner_to_outer);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode);
+};
+
+/*! \brief Fuse step that corresponds to te::Stage::fuse */
+class FuseStepNode : public StepNode {
+ public:
+ /*! \brief The ids of iterators to fuse. */
+ Array<Integer> fused_ids;
+
+ /*!
+ * \brief Apply the current state to tvm.schedule
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \return The iterator result after fuse.
+ */
+ tir::IterVar ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+ /*!
+ * \brief Print step as equivalent python schedule API.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \return Python schedule code.
+ */
+ String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+ static constexpr const char* _type_key = "auto_schedule.FuseStep";
+ TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to FuseStepNode.
+ * \sa FuseStepNode
+ */
+class FuseStep : public Step {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param stage_id The index of the stage to be fused.
+ * \param fused_ids The index of the iterators to be fused.
+ */
+ FuseStep(int stage_id, const Array<Integer>& fused_ids);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode);
+};
+
+} // namespace auto_schedule
+} // namespace tvm
+
+#endif // TVM_AUTO_SCHEDULE_TRANSFORM_STEP_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_schedule/utils.cc
+ * \brief Common utilities.
+ */
+
+#include "utils.h"
+
+namespace tvm {
+namespace auto_schedule {
+
+NullStream& NullStream::Global() {
+ static NullStream stream;
+ return stream;
+}
+
+} // namespace auto_schedule
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_schedule/utils.h
+ * \brief Common utilities.
+ */
+
+#ifndef TVM_AUTO_SCHEDULE_UTILS_H_
+#define TVM_AUTO_SCHEDULE_UTILS_H_
+
+#include <dmlc/common.h>
+#include <tvm/tir/expr.h>
+
+#include <algorithm>
+#include <deque>
+#include <exception>
+#include <future>
+#include <string>
+#include <thread>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+namespace std {
+
+/*! \brief Hash function for std::pair */
+template <typename T1, typename T2>
+struct hash<std::pair<T1, T2>> {
+ std::size_t operator()(const std::pair<T1, T2>& k) const {
+ return ::dmlc::HashCombine(std::hash<T1>()(k.first), std::hash<T2>()(k.second));
+ }
+};
+
+/*! \brief Hash function for std::tuple */
+template <typename T1, typename T2, typename T3>
+struct hash<std::tuple<T1, T2, T3>> {
+ std::size_t operator()(const std::tuple<T1, T2, T3>& k) const {
+ return ::dmlc::HashCombine(
+ ::dmlc::HashCombine(std::hash<T1>()(std::get<0>(k)), std::hash<T2>()(std::get<1>(k))),
+ std::hash<T3>()(std::get<2>(k)));
+ }
+};
+
+} // namespace std
+
+namespace tvm {
+namespace auto_schedule {
+
+/********** Utilities for Array, std::string **********/
+/*! \brief Get the first appearance index of elements in an Array */
+template <typename T>
+inline void GetIndices(const Array<T>& array, const Array<T>& to_locate, Array<Integer>* indices) {
+ for (const auto& v : to_locate) {
+ auto it = std::find(array.begin(), array.end(), v);
+ if (it != array.end()) {
+ indices->push_back(it - array.begin());
+ } else {
+ LOG(FATAL) << "Cannot find the item";
+ }
+ }
+}
+
+/*! \brief Get the first appearance index of an element in an Array */
+template <typename T>
+inline int GetIndex(const Array<T>& array, const T& to_locate) {
+ for (size_t i = 0; i < array.size(); ++i) {
+ if (array[i] == to_locate) {
+ return i;
+ }
+ }
+ LOG(FATAL) << "Cannot find the item";
+ return -1;
+}
+
+/*! \brief Replace a sub-string to another sub-string in a string */
+inline void StrReplace(std::string* base, const std::string& from, const std::string& to) {
+ auto pos = base->find(from);
+ while (pos != std::string::npos) {
+ base->replace(pos, from.size(), to);
+ pos = base->find(from, pos + to.size());
+ }
+}
+
+/********** Utilities for TVM Containers / ByteArray **********/
+/*! \brief Compute mean of a FloatImm array */
+inline double FloatArrayMean(const Array<PrimExpr>& float_array) {
+ double sum = 0;
+ if (float_array.empty()) {
+ return 0.0;
+ }
+
+ for (const auto& x : float_array) {
+ auto floatimm = x.as<tir::FloatImmNode>();
+ CHECK(floatimm != nullptr);
+ sum += floatimm->value;
+ }
+ return sum / float_array.size();
+}
+
+/********** Other Utilities **********/
+/*! \brief Get an int value from an Expr */
+inline int64_t GetIntImm(const PrimExpr& expr) {
+ auto pint = expr.as<IntImmNode>();
+ CHECK(pint != nullptr);
+ return pint->value;
+}
+
+/*! \brief Compute the product of the lengths of axes */
+inline int64_t AxisLengthProd(const Array<tir::IterVar>& axes) {
+ int64_t ret = 1.0;
+ for (const auto& x : axes) {
+ if (const IntImmNode* imm = x->dom->extent.as<IntImmNode>()) {
+ ret *= imm->value;
+ } else {
+ return -1.0;
+ }
+ }
+ return ret;
+}
+
+/*!
+ * \brief Clean the name of an iterator to make it valid in python code.
+ * \param str The original name.
+ * \return The cleaned name.
+ */
+inline std::string CleanName(const std::string& str) {
+ std::string ret = str;
+ StrReplace(&ret, ".", "_");
+ StrReplace(&ret, "@", "_");
+ StrReplace(&ret, "outer", "o");
+ StrReplace(&ret, "inner", "i");
+ return ret;
+}
+
+/*! \brief An empty output stream */
+class NullStream : public std::ostream {
+ public:
+ NullStream() : std::ostream(nullptr) {}
+ NullStream(const NullStream&) : std::ostream(nullptr) {}
+ static NullStream& Global();
+};
+
+template <class T>
+NullStream& operator<<(NullStream& os, const T& value) {
+ return os;
+}
+
+/*! \brief Get std cout with verbose control */
+inline std::ostream& StdCout(int verbose, int setting = 1) {
+ return verbose >= setting ? std::cout : NullStream::Global();
+}
+
+/*! \brief Print multiple chars */
+inline std::string Chars(const char& str, int times) {
+ std::stringstream ret;
+ for (int i = 0; i < times; ++i) {
+ ret << str;
+ }
+ return ret.str();
+}
+
+/*! \brief Print a title */
+inline void PrintTitle(const std::string& title, int verbose) {
+ StdCout(verbose) << Chars('-', 60) << "\n"
+ << Chars('-', 25) << " [ " << title << " ]\n"
+ << Chars('-', 60) << std::endl;
+}
+
+} // namespace auto_schedule
+} // namespace tvm
+
+#endif // TVM_AUTO_SCHEDULE_UTILS_H_
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Common functions for auto_schedule test cases"""
+
+import threading
+
+from tvm import te, auto_schedule
+import topi
+
+
+@auto_schedule.register_workload
+def matmul_auto_schedule_test(N, M, K):
+ A = te.placeholder((N, K), name='A')
+ B = te.placeholder((K, M), name='B')
+ k = te.reduce_axis((0, K), name='k')
+ C = te.compute((N, M), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C')
+ return [A, B, C]
+
+
+@auto_schedule.register_workload("matmul_auto_schedule_test_rename_1")
+def matmul_auto_schedule_test_rename_0(N, M, K):
+ A = te.placeholder((N, K), name='A')
+ B = te.placeholder((K, M), name='B')
+ k = te.reduce_axis((0, K), name='k')
+ C = te.compute((N, M), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C')
+ return [A, B, C]
+
+
+def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1):
+ data = te.placeholder((N, CI, H, W), name='Data')
+ kernel = te.placeholder((CO, CI, kernel_size, kernel_size), name='Kernel')
+ bias = te.placeholder((CO, 1, 1), name='Bias')
+ bn_scale = te.placeholder((CO, 1, 1), name='Bn_scale')
+ bn_offset = te.placeholder((CO, 1, 1), name='Bn_offset')
+
+ OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1
+ OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1
+
+ conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation)
+ conv = te.compute((N, CO, OH, OW),
+ lambda i, j, k, l: conv[i, j, k, l] + bias[j, 0, 0],
+ name='Bias_add')
+ conv = te.compute((N, CO, OH, OW),
+ lambda i, j, k, l: conv[i, j, k, l] * bn_scale[j, 0, 0],
+ name='Bn_mul')
+ conv = te.compute((N, CO, OH, OW),
+ lambda i, j, k, l: conv[i, j, k, l] + bn_offset[j, 0, 0],
+ name='Bn_add')
+ out = topi.nn.relu(conv)
+
+ return [data, kernel, bias, bn_offset, bn_scale, out]
+
+
+def get_tiled_matmul():
+ A, B, C = matmul_auto_schedule_test(512, 512, 512)
+ dag = auto_schedule.ComputeDAG([A, B, C])
+
+ s0 = dag.get_init_state()
+ its0 = s0.split(C, s0[C].iters[0], [4, 8, 8])
+ its1 = s0.split(C, s0[C].iters[4], [8, 4, 4])
+ s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], its1[3],
+ s0[C].iters[8]])
+
+ return dag, s0
+
+
+class PropagatingThread(threading.Thread):
+ def run(self):
+ self.exc = None
+ try:
+ self.ret = self._target(*self._args, **self._kwargs)
+ except BaseException as e:
+ self.exc = e
+
+ def join(self):
+ super(PropagatingThread, self).join()
+ if self.exc:
+ raise self.exc
+ return self.ret
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Test ComputeDAG (replay, infer bound)"""
+
+import tvm
+from tvm import auto_schedule, te
+
+from test_auto_schedule_common import get_tiled_matmul
+
+
+def test_apply_steps():
+ dag, s = get_tiled_matmul()
+ dag.print_python_code_from_state(s)
+ sch, tensors = dag.apply_steps_from_state(s)
+ stmt = tvm.lower(sch, tensors, simple_mode=True)
+
+
+def test_infer_bound():
+ dag, s = get_tiled_matmul()
+ s = dag.infer_bound_from_state(s)
+
+
+def test_estimate_flop():
+ dag, s = get_tiled_matmul()
+ assert abs(dag.flop_ct - 2 * 512 ** 3) < 0.5
+
+
+if __name__ == "__main__":
+ test_apply_steps()
+ test_infer_bound()
+ test_estimate_flop()
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Test loop state and schedule primitives"""
+
+import numpy as np
+
+import tvm
+from tvm import auto_schedule, te
+import topi
+
+from test_auto_schedule_common import matmul_auto_schedule_test, conv2d_nchw_bn_relu
+
+
+def test_split_fuse_reorder():
+ A, B, C = matmul_auto_schedule_test(512, 512, 512)
+ dag = auto_schedule.ComputeDAG([A, B, C])
+ s0 = dag.get_init_state()
+ i, j, k = s0[C].iters
+
+ assert i.range.extent == 512
+
+ io, ii = s0.split(C, i, [16])
+ assert s0[C].iters[0] == io
+ assert s0[C].iters[1] == ii
+ assert io.range.extent == 32
+ assert ii.range.extent == 16
+
+ jo, ji = s0.split(C, j, [8])
+ assert jo.range.extent == 64
+ assert ji.range.extent == 8
+
+ s0.reorder(C, [io, jo, k, ji, ii])
+ assert s0[C].iters[2].range.extent == 512
+
+ fused_it = s0.fuse(C, [io, jo])
+ assert fused_it.range.extent == 2048
+
+ s1 = dag.get_init_state()
+ i, j, _ = s1[C].iters
+ i1, i2, i3 = s1.split(C, i, [8, 2])
+ j1, j2, j3 = s1.split(C, j, [32, 8], False)
+ assert s1[C].iters[0].range.extent == 32
+ assert s1[C].iters[1].range.extent == 8
+ assert s1[C].iters[2].range.extent == 2
+ assert s1[C].iters[3].range.extent == 32
+ assert s1[C].iters[4].range.extent == 8
+ assert s1[C].iters[5].range.extent == 2
+
+if __name__ == "__main__":
+ test_split_fuse_reorder()
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Test measurement and log serialization. """
+
+import tvm
+from tvm import auto_schedule
+import tempfile
+
+from test_auto_schedule_common import get_tiled_matmul
+
+
+def test_record():
+ dag, s = get_tiled_matmul()
+
+ if not tvm.runtime.enabled("llvm"):
+ return
+ target = tvm.target.create("llvm")
+ task = auto_schedule.SearchTask(dag, "test", target)
+
+ inp = auto_schedule.measure.MeasureInput(task, s)
+ res = auto_schedule.measure.MeasureResult([0.1], 0, "", 0.2, 1)
+
+ with tempfile.NamedTemporaryFile() as fp:
+ auto_schedule.save_records(fp.name, [inp], [res])
+
+ log_reader = auto_schedule.RecordReader(fp.name)
+ inputs, results = log_reader.read_lines()
+ assert len(inputs) == 1
+
+ s1 = dag.infer_bound_from_state(s)
+ s2 = dag.infer_bound_from_state(inputs[0].state)
+
+ assert s1 == s2
+ assert not (s1 == dag.get_init_state())
+
+
+def test_measure_local_builder_runner():
+ dag, s0 = get_tiled_matmul()
+
+ if not tvm.runtime.enabled("llvm"):
+ return
+ tgt = tvm.target.create("llvm")
+ task = auto_schedule.SearchTask(dag, "test", tgt)
+
+ minp = auto_schedule.MeasureInput(task, s0)
+ local_builder = auto_schedule.LocalBuilder()
+ local_runner = auto_schedule.LocalRunner()
+
+ bress = local_builder.build([minp])
+ assert bress[0].error_no == 0
+ mress = local_runner.run([minp], bress)
+ assert mress[0].error_no == 0
+
+
+if __name__ == "__main__":
+ test_record()
+ test_measure_local_builder_runner()
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Test search policy"""
+
+import random
+import numpy as np
+import tempfile
+
+import tvm
+from tvm import auto_schedule
+
+from test_auto_schedule_common import matmul_auto_schedule_test, PropagatingThread
+
+def search_common(workload=matmul_auto_schedule_test, target="llvm", search_policy = auto_schedule.EmptyPolicy(),
+ seed=random.randint(1, 1 << 30), runner='local', cost_model=None,
+ num_measure_trials=2, params=None, pre_search_callbacks=None):
+ print("Test %s schedule search with the default search policy" % (target))
+
+ random.seed(seed)
+ N = 128
+ workload_key = auto_schedule.make_workload_key(workload, (N, N, N))
+ dag = auto_schedule.ComputeDAG(workload_key)
+ target = tvm.target.create(target)
+ task = auto_schedule.SearchTask(dag, workload_key, target)
+
+ with tempfile.NamedTemporaryFile() as fp:
+ log_file = fp.name
+
+ tuning_options = auto_schedule.TuningOptions(num_measure_trials=num_measure_trials, runner=runner,
+ verbose=0,
+ measure_callbacks=[auto_schedule.RecordToFile(log_file)],
+ pre_search_callbacks=pre_search_callbacks)
+ sch, args = auto_schedule.auto_schedule(task, search_policy, tuning_options)
+ inp, res = auto_schedule.load_best(log_file, workload_key, target)
+
+ print("==== Python Code ====")
+ print(dag.print_python_code_from_state(inp.state))
+
+ try:
+ print("==== Lowered Stmt ====")
+ print(tvm.lower(sch, args, simple_mode=True))
+ mod = tvm.build(sch, args, target)
+
+ ctx = tvm.context(str(target), 0)
+ dtype = dag.tensors[0].dtype
+ a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx)
+ b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx)
+ c = tvm.nd.array(np.zeros((N, N), dtype=dtype), ctx)
+ mod(a, b, c)
+ tvm.testing.assert_allclose(c.asnumpy(), np.dot(
+ a.asnumpy(), b.asnumpy()), rtol=1e-5)
+ print("==== Verification passed ====")
+ except Exception:
+ raise Exception("Error encountered with seed: %d" % (seed))
+ print()
+
+
+def test_workload_registry_search_basic():
+ if not tvm.runtime.enabled("llvm"):
+ return
+ # wrap the search in a new thread to avoid the conflict
+ # between python's multiprocessing and tvm's thread pool
+ t = PropagatingThread(target=search_common, kwargs={'seed': 944563397})
+ t.start()
+ t.join()
+ t = PropagatingThread(target=search_common,
+ kwargs={'seed': 944563397, 'workload': "matmul_auto_schedule_test"})
+ t.start()
+ t.join()
+ t = PropagatingThread(target=search_common,
+ kwargs={'seed': 944563397, 'workload': "matmul_auto_schedule_test_rename_1"})
+ t.start()
+ t.join()
+
+if __name__ == "__main__":
+ test_workload_registry_search_basic()