--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/auto_scheduler/cost_model.h
+ * \brief Cost models that estimate the performance of programs
+ */
+
+#ifndef TVM_AUTO_SCHEDULER_COST_MODEL_H_
+#define TVM_AUTO_SCHEDULER_COST_MODEL_H_
+
+#include <tvm/auto_scheduler/compute_dag.h>
+#include <tvm/auto_scheduler/measure.h>
+#include <tvm/node/node.h>
+#include <tvm/runtime/packed_func.h>
+
+#include <vector>
+
+namespace tvm {
+namespace auto_scheduler {
+
+using runtime::PackedFunc;
+using runtime::TypedPackedFunc;
+
+/*! \brief The base class for cost model */
+class CostModelNode : public Object {
+ public:
+ /*!
+ * \brief Update the cost model according to new measurement results (training data).
+ * \param inputs The measure inputs
+ * \param results The measure results
+ */
+ virtual void Update(const Array<MeasureInput>& inputs, const Array<MeasureResult>& results) = 0;
+
+ /*!
+ * \brief Predict the scores of states
+ * \param task The search task of states
+ * \param states The input states
+ * \param scores The predicted scores for all states
+ */
+ virtual void Predict(const SearchTask& task, const std::vector<State>& states,
+ std::vector<float>* scores) = 0;
+
+ /*!
+ * \brief Predict the scores of all stages in states. This is the breakdown version of `Predict`
+ * \param task The search task
+ * \param states The input states
+ * \param state_scores The predicted scores for all states
+ * \param stage_scores The predicted scores for all stages in all stages
+ */
+ virtual void PredictStages(const SearchTask& task, const std::vector<State>& states,
+ std::vector<float>* state_scores,
+ std::vector<std::vector<float>>* stage_scores) {
+ LOG(FATAL) << "Not implemented";
+ }
+
+ static constexpr const char* _type_key = "auto_scheduler.CostModel";
+ TVM_DECLARE_BASE_OBJECT_INFO(CostModelNode, Object);
+};
+
+/*!
+ * \brief Managed reference to CostModelNode.
+ * \sa CostModelNode
+ */
+class CostModel : public ObjectRef {
+ public:
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CostModel, ObjectRef, CostModelNode);
+};
+
+/*! \brief The cost model returning random value for all predictions */
+class RandomModelNode : public CostModelNode {
+ public:
+ /*! \brief Pointer to a random number generator function */
+ const TypedPackedFunc<void(size_t, void*)>* random_number_func;
+
+ void Update(const Array<MeasureInput>& inputs, const Array<MeasureResult>& results) final;
+
+ void Predict(const SearchTask& task, const std::vector<State>& states,
+ std::vector<float>* scores) final;
+
+ static constexpr const char* _type_key = "auto_scheduler.RandomModel";
+ TVM_DECLARE_FINAL_OBJECT_INFO(RandomModelNode, CostModelNode);
+};
+
+/*!
+ * \brief Managed reference to RandomModelNode.
+ * \sa RandomModelNode
+ */
+class RandomModel : public CostModel {
+ public:
+ RandomModel();
+ explicit RandomModel(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : CostModel(n) {}
+
+ RandomModelNode* operator->() const { return static_cast<RandomModelNode*>(data_.get()); }
+
+ TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(RandomModel);
+ using ContainerType = RandomModelNode;
+};
+
+/*! \brief A wrapper for cost model defined by python code
+ * This class will call functions defined in the python */
+class PythonBasedModelNode : public CostModelNode {
+ public:
+ /*! \brief Pointer to the update funcion in python */
+ PackedFunc update_func;
+ /*! \brief Pointer to the predict funcion in python */
+ PackedFunc predict_func;
+ /*! \brief Pointer to the predict funcion in python */
+ PackedFunc predict_stage_func;
+
+ void Update(const Array<MeasureInput>& inputs, const Array<MeasureResult>& results) final;
+
+ void Predict(const SearchTask& task, const std::vector<State>& states,
+ std::vector<float>* scores) final;
+
+ void PredictStages(const SearchTask& task, const std::vector<State>& states,
+ std::vector<float>* state_scores,
+ std::vector<std::vector<float>>* stage_scores) final;
+
+ static constexpr const char* _type_key = "auto_scheduler.PythonBasedModel";
+ TVM_DECLARE_FINAL_OBJECT_INFO(PythonBasedModelNode, CostModelNode);
+};
+
+/*!
+ * \brief Managed reference to PythonBasedModelNode.
+ * \sa PythonBasedModelNode
+ */
+class PythonBasedModel : public CostModel {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param update_func The pointer to the update function defined in python
+ * \param predict_func The pointer to the prediction function defined in python
+ * \param predict_stage_func The pointer to the prediction function defined in python
+ */
+ PythonBasedModel(PackedFunc update_func, PackedFunc predict_func, PackedFunc predict_stage_func);
+
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PythonBasedModel, CostModel, PythonBasedModelNode);
+};
+
+} // namespace auto_scheduler
+} // namespace tvm
+
+#endif // TVM_AUTO_SCHEDULER_COST_MODEL_H_
from . import workload_registry
# Shortcut
-from .compute_dag import ComputeDAG
from .auto_schedule import SearchTask, TuningOptions, HardwareParams, \
auto_schedule, EmptyPolicy
+from .compute_dag import ComputeDAG
+from .cost_model import RandomModel
from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, \
LocalRPCMeasureContext
from .measure_record import RecordToFile, RecordReader, load_best, \
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-import, redefined-builtin
+""" Cost model that estimates the performance of programs """
+
+from .cost_model import RandomModel
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Cost model that estimates the performance of programs """
+import ctypes
+import numpy as np
+
+import tvm._ffi
+from tvm.runtime import Object
+from .. import _ffi_api
+
+
+@tvm._ffi.register_object("auto_scheduler.CostModel")
+class CostModel(Object):
+ """The base class for cost model"""
+
+@tvm._ffi.register_object("auto_scheduler.RandomModel")
+class RandomModel(CostModel):
+ """A model returns random estimation for all inputs"""
+ def __init__(self):
+ self.__init_handle_by_constructor__(_ffi_api.RandomModel)
+
+ def update(self, inputs, results):
+ """Update the cost model according to new measurement results (training data).
+
+ Parameters
+ ----------
+ inputs : List[MeasureInput]
+ The measurement inputs
+ results : List[MeasureResult]
+ The measurement results
+ """
+ _ffi_api.CostModelUpdate(self, inputs, results)
+
+ def predict(self, search_task, states):
+ """Predict the scores of states
+
+ Parameters
+ ----------
+ search_task : SearchTask
+ The search task of states
+ statse : List[State]
+ The input states
+
+ Returns
+ -------
+ scores: List[float]
+ The predicted scores for all states
+ """
+ return [x.value for x in _ffi_api.CostModelPredict(self, search_task, states)]
+
+
+@tvm._ffi.register_func("auto_scheduler.cost_model.random_fill_float")
+def random_fill_float(size, return_ptr):
+ """Fills a c++ float array with random numbers in [0, 1]
+
+ Parameters
+ ----------
+ size: int
+ The size of the array
+ return_ptr:
+ A pointer to a c++ float array
+ """
+ if size == 0:
+ return
+ return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float))
+ array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(size,))
+ array_wrapper[:] = np.random.uniform(0, 1, (size,))
+
+
+@tvm._ffi.register_object("auto_scheduler.PythonBasedModel")
+class PythonBasedModel(CostModel):
+ """Base class for cost models implemented in python"""
+ def __init__(self):
+ def update_func(inputs, results):
+ self.update(inputs, results)
+
+ def predict_func(task, states, return_ptr):
+ return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float))
+ array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(len(states),))
+ array_wrapper[:] = self.predict(task, states)
+
+ def predict_stage_func(task, states, return_ptr):
+ ret = self.predict_stages(task, states)
+ return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float))
+ array_wrapper = np.ctypeslib.as_array(return_ptr, shape=ret.shape)
+ array_wrapper[:] = ret
+
+ self.__init_handle_by_constructor__(_ffi_api.PythonBasedModel, update_func,
+ predict_func, predict_stage_func)
+
+ def update(self, inputs, results):
+ """Update the cost model according to new measurement results (training data).
+
+ Parameters
+ ----------
+ inputs : List[MeasureInput]
+ The measurement inputs
+ results : List[MeasureResult]
+ The measurement results
+ """
+ raise NotImplementedError
+
+ def predict(self, task, states):
+ """Predict the scores of states
+
+ Parameters
+ ----------
+ search_task : SearchTask
+ The search task of states
+ statse : List[State]
+ The input states
+
+ Returns
+ -------
+ scores: List[float]
+ The predicted scores for all states
+ """
+ raise NotImplementedError
+
+ def predict_stages(self, task, states):
+ """Predict the scores of all stages in states. This is the breakdown version of `predict`.
+
+ Parameters
+ ----------
+ search_task : SearchTask
+ The search task of states
+ statse : List[State]
+ The input states
+
+ Returns
+ -------
+ scores: List[float]
+ The predicted scores for all stages in all states in the packed format
+ """
+ raise NotImplementedError
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_scheduler/cost_model.cc
+ * \brief Cost models that estimate the performance of programs
+ */
+
+#include <tvm/auto_scheduler/cost_model.h>
+
+namespace tvm {
+namespace auto_scheduler {
+
+TVM_REGISTER_OBJECT_TYPE(CostModelNode);
+TVM_REGISTER_OBJECT_TYPE(RandomModelNode);
+TVM_REGISTER_OBJECT_TYPE(PythonBasedModelNode);
+
+RandomModel::RandomModel() {
+ ObjectPtr<RandomModelNode> node = make_object<RandomModelNode>();
+ const auto* f = runtime::Registry::Get("auto_scheduler.cost_model.random_fill_float");
+ CHECK(f != nullptr);
+ node->random_number_func = reinterpret_cast<const TypedPackedFunc<void(size_t, void*)>*>(f);
+ data_ = std::move(node);
+}
+
+void RandomModelNode::Update(const Array<MeasureInput>& inputs,
+ const Array<MeasureResult>& results) {}
+
+void RandomModelNode::Predict(const SearchTask& task, const std::vector<State>& states,
+ std::vector<float>* scores) {
+ scores->resize(states.size());
+ (*random_number_func)(states.size(), static_cast<void*>(scores->data()));
+}
+
+PythonBasedModel::PythonBasedModel(PackedFunc update_func, PackedFunc predict_func,
+ PackedFunc predict_stage_func) {
+ auto node = make_object<PythonBasedModelNode>();
+ node->update_func = std::move(update_func);
+ node->predict_func = std::move(predict_func);
+ node->predict_stage_func = std::move(predict_stage_func);
+ data_ = std::move(node);
+}
+
+void PythonBasedModelNode::Update(const Array<MeasureInput>& inputs,
+ const Array<MeasureResult>& results) {
+ update_func(inputs, results);
+}
+
+void PythonBasedModelNode::Predict(const SearchTask& task, const std::vector<State>& states,
+ std::vector<float>* scores) {
+ scores->resize(states.size());
+ predict_func(task, Array<State>(states.begin(), states.end()),
+ static_cast<void*>(scores->data()));
+}
+
+void PythonBasedModelNode::PredictStages(const SearchTask& task, const std::vector<State>& states,
+ std::vector<float>* state_scores,
+ std::vector<std::vector<float>>* stage_scores) {
+ size_t n_states = states.size();
+ size_t n_stages = task->compute_dag->init_state->stages.size();
+ std::vector<float> flatten_scores;
+ // Allocate sufficient spaces.
+ flatten_scores.resize(n_states * n_stages * 2);
+ predict_stage_func(task, Array<State>(states.begin(), states.end()),
+ static_cast<void*>(flatten_scores.data()));
+
+ // Unpack flatten scores.
+ state_scores->clear();
+ stage_scores->clear();
+
+ // Score of each states.
+ for (size_t i = 0; i < n_states; ++i) {
+ state_scores->push_back(flatten_scores[i]);
+ }
+
+ // Score of each stage in each states.
+ size_t idx = n_states;
+ for (size_t i = 0; i < n_states; ++i) {
+ CHECK_LE(idx, flatten_scores.size());
+
+ // Number of scored stages of this state.
+ int s_length = static_cast<int>(flatten_scores[idx++]);
+
+ if (s_length > 0) {
+ std::vector<float> scores;
+ int offset = 0;
+
+ if ((*state_scores)[i] > -INFINITY) {
+ // If the score is valid. Copy scored stages and assign 0 to placeholder
+ // and inlined stages. If the score is 0, meaning this state failed to
+ // be lowered. Just bypass to update offset.
+ for (const Stage& stage : states[i]->stages) {
+ if (stage->op_type == StageKind::kPlaceholder) {
+ scores.push_back(0);
+ continue;
+ }
+ if (stage->compute_at == ComputeAtKind::kInlined) {
+ scores.push_back(0);
+ continue;
+ }
+ scores.push_back(flatten_scores[idx + offset]);
+ offset++;
+ }
+ CHECK_EQ(offset, s_length);
+ stage_scores->push_back(std::move(scores));
+ }
+ idx += s_length;
+ } else {
+ // Cost model does not provide any stage score details.
+ stage_scores->push_back({});
+ }
+ }
+}
+
+TVM_REGISTER_GLOBAL("auto_scheduler.RandomModel").set_body_typed([]() { return RandomModel(); });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.PythonBasedModel")
+ .set_body_typed([](PackedFunc update_func, PackedFunc predict_func,
+ PackedFunc predict_stage_func) {
+ return PythonBasedModel(update_func, predict_func, predict_stage_func);
+ });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.CostModelUpdate")
+ .set_body_typed([](CostModel model, Array<MeasureInput> inputs, Array<MeasureResult> results) {
+ model->Update(inputs, results);
+ });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.CostModelPredict")
+ .set_body_typed([](CostModel model, SearchTask task, Array<State> states) {
+ std::vector<float> scores;
+ model->Predict(task, std::vector<State>(states.begin(), states.end()), &scores);
+ Array<FloatImm> ret;
+ for (auto x : scores) {
+ ret.push_back(FloatImm(DataType::Float(32), x));
+ }
+ return ret;
+ });
+
+} // namespace auto_scheduler
+} // namespace tvm
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Test cost models"""
+
+import tvm
+from tvm import auto_scheduler
+
+from test_auto_scheduler_common import matmul_auto_scheduler_test
+
+
+def test_random_model():
+ if not tvm.runtime.enabled("llvm"):
+ return
+ N = 128
+ workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test, (N, N, N))
+ dag = auto_scheduler.ComputeDAG(workload_key)
+ target = tvm.target.create('llvm')
+ task = auto_scheduler.SearchTask(dag, workload_key, target)
+
+ model = auto_scheduler.RandomModel()
+ model.update([], [])
+ scores = model.predict(task, [dag.init_state, dag.init_state])
+ assert len(scores) == 2
+
+
+if __name__ == "__main__":
+ test_random_model()