diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index bea9cc3e371b..744e4cfd54c3 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -178,6 +178,8 @@ class TuningRecord : public runtime::ObjectRef { TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TuningRecord, runtime::ObjectRef, TuningRecordNode); }; +class Database; + /* \brief The abstract interface of database. */ class DatabaseNode : public runtime::Object { public: @@ -258,7 +260,11 @@ class DatabaseNode : public runtime::Object { */ virtual Optional QueryIRModule(const IRModule& mod, const Target& target, const String& workload_name); - + /*! + * \brief Prune the database and dump it a given database. + * \param destination The destination database to be dumped to. + */ + void DumpPruned(Database destination); /*! \brief Return a reference to the owned module equality method instance. */ const ModuleEquality& GetModuleEquality() const { ICHECK(mod_eq_); diff --git a/python/tvm/meta_schedule/database/database.py b/python/tvm/meta_schedule/database/database.py index 4775a93de33f..601571089592 100644 --- a/python/tvm/meta_schedule/database/database.py +++ b/python/tvm/meta_schedule/database/database.py @@ -313,6 +313,18 @@ def query_ir_module( """ return _ffi_api.DatabaseQueryIRModule(self, mod, target, workload_name) # type: ignore # pylint: disable=no-member + def dump_pruned(self, destination: "Database") -> None: + """Dump the pruned database to files of JSONDatabase format. + + Parameters + ---------- + destination : Database + The destination database to be dumped to. + """ + return _ffi_api.DatabaseDumpPruned( # type: ignore # pylint: disable=no-member + self, destination + ) + def query( self, mod: IRModule, diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc index 649429f9bc13..f549b850067a 100644 --- a/src/meta_schedule/database/database.cc +++ b/src/meta_schedule/database/database.cc @@ -210,6 +210,29 @@ Optional DatabaseNode::QueryIRModule(const IRModule& mod, const Target } } +void DatabaseNode::DumpPruned(Database destination) { + std::unordered_map workload2record; + for (const TuningRecord& record : this->GetAllTuningRecords()) { + if (record->IsValid()) { + auto it = workload2record.find(record->workload); + if (it == workload2record.end()) { + workload2record.insert({record->workload, record}); + } else if (SortTuningRecordByMeanRunSecs()(record, it->second)) { + it->second = record; + } + } + } + for (auto& kv : workload2record) { + Workload workload = kv.first; + TuningRecord record = kv.second; + workload = destination->CommitWorkload(workload->mod); + destination->CommitTuningRecord(TuningRecord(/*trace=*/record->trace, /*workload=*/workload, + /*run_secs=*/record->run_secs, + /*target=*/record->target, + /*args_info=*/record->args_info)); + } +} + std::vector* ThreadLocalDatabases() { static thread_local std::vector tls; return &tls; @@ -297,6 +320,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.DatabaseQuerySchedule") .set_body_method(&DatabaseNode::QuerySchedule); TVM_REGISTER_GLOBAL("meta_schedule.DatabaseQueryIRModule") .set_body_method(&DatabaseNode::QueryIRModule); +TVM_REGISTER_GLOBAL("meta_schedule.DatabaseDumpPruned") + .set_body_method(&DatabaseNode::DumpPruned); TVM_REGISTER_GLOBAL("meta_schedule.DatabasePyDatabase").set_body_typed(Database::PyDatabase); } // namespace meta_schedule diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index 3e08cec95de3..53f680f0a666 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -139,10 +139,6 @@ class JSONDatabaseNode : public DatabaseNode { } } } - if (results.size() < static_cast(top_k)) { - LOG(WARNING) << "Returned tuning records less than requested(" << results.size() << " of " - << top_k << " asked)."; - } return results; } diff --git a/src/meta_schedule/database/memory_database.cc b/src/meta_schedule/database/memory_database.cc index b003606c9cc0..3d418206b031 100644 --- a/src/meta_schedule/database/memory_database.cc +++ b/src/meta_schedule/database/memory_database.cc @@ -80,10 +80,6 @@ class MemoryDatabaseNode : public DatabaseNode { if (results.size() > static_cast(top_k)) { return {results.begin(), results.begin() + top_k}; } else { - if (results.size() < static_cast(top_k)) { - LOG(WARNING) << "Returned tuning records less than requested(" << results.size() << " of " - << top_k << " asked)."; - } return results; } } diff --git a/src/meta_schedule/module_equality.cc b/src/meta_schedule/module_equality.cc index 0997aab9b6a6..76f7f8941897 100644 --- a/src/meta_schedule/module_equality.cc +++ b/src/meta_schedule/module_equality.cc @@ -34,6 +34,7 @@ class ModuleEqualityStructural : public ModuleEquality { public: size_t Hash(IRModule mod) const { return tvm::StructuralHash()(mod); } bool Equal(IRModule lhs, IRModule rhs) const { return tvm::StructuralEqual()(lhs, rhs); } + String GetName() const { return "structural"; } }; class SEqualHandlerIgnoreNDArray : public SEqualHandlerDefault { @@ -72,6 +73,7 @@ class ModuleEqualityIgnoreNDArray : public ModuleEquality { bool Equal(IRModule lhs, IRModule rhs) const { return SEqualHandlerIgnoreNDArray().Equal(lhs, rhs, false); } + String GetName() const { return "ignore-ndarray"; } }; // The NDArray-ignoring variant of structural equal / hash is used for the module equality @@ -93,6 +95,7 @@ class ModuleEqualityAnchorBlock : public ModuleEquality { } return ModuleEqualityIgnoreNDArray().Equal(lhs, rhs); } + String GetName() const { return "anchor-block"; } }; std::unique_ptr ModuleEquality::Create(const std::string& mod_eq_name) { diff --git a/src/meta_schedule/module_equality.h b/src/meta_schedule/module_equality.h index ba5877471e2c..7aa3944a4048 100644 --- a/src/meta_schedule/module_equality.h +++ b/src/meta_schedule/module_equality.h @@ -34,6 +34,7 @@ class ModuleEquality { virtual size_t Hash(IRModule mod) const = 0; virtual bool Equal(IRModule lhs, IRModule rhs) const = 0; + virtual String GetName() const = 0; /*! * \brief Create a ModuleEquality instance