Skip to content

Commit

Permalink
[MetaSchedule] Add an API to dump a pruned database (apache#14783)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored May 7, 2023
1 parent f989033 commit 01324ef
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 9 deletions.
8 changes: 7 additions & 1 deletion include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ class TuningRecord : public runtime::ObjectRef {
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TuningRecord, runtime::ObjectRef, TuningRecordNode);
};

class Database;

/* \brief The abstract interface of database. */
class DatabaseNode : public runtime::Object {
public:
Expand Down Expand Up @@ -258,7 +260,11 @@ class DatabaseNode : public runtime::Object {
*/
virtual Optional<IRModule> QueryIRModule(const IRModule& mod, const Target& target,
const String& workload_name);

/*!
* \brief Prune the database and dump it a given database.
* \param destination The destination database to be dumped to.
*/
void DumpPruned(Database destination);
/*! \brief Return a reference to the owned module equality method instance. */
const ModuleEquality& GetModuleEquality() const {
ICHECK(mod_eq_);
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/meta_schedule/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,18 @@ def query_ir_module(
"""
return _ffi_api.DatabaseQueryIRModule(self, mod, target, workload_name) # type: ignore # pylint: disable=no-member

def dump_pruned(self, destination: "Database") -> None:
"""Dump the pruned database to files of JSONDatabase format.
Parameters
----------
destination : Database
The destination database to be dumped to.
"""
return _ffi_api.DatabaseDumpPruned( # type: ignore # pylint: disable=no-member
self, destination
)

def query(
self,
mod: IRModule,
Expand Down
25 changes: 25 additions & 0 deletions src/meta_schedule/database/database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,29 @@ Optional<IRModule> DatabaseNode::QueryIRModule(const IRModule& mod, const Target
}
}

void DatabaseNode::DumpPruned(Database destination) {
std::unordered_map<Workload, TuningRecord, ObjectPtrHash, ObjectPtrEqual> workload2record;
for (const TuningRecord& record : this->GetAllTuningRecords()) {
if (record->IsValid()) {
auto it = workload2record.find(record->workload);
if (it == workload2record.end()) {
workload2record.insert({record->workload, record});
} else if (SortTuningRecordByMeanRunSecs()(record, it->second)) {
it->second = record;
}
}
}
for (auto& kv : workload2record) {
Workload workload = kv.first;
TuningRecord record = kv.second;
workload = destination->CommitWorkload(workload->mod);
destination->CommitTuningRecord(TuningRecord(/*trace=*/record->trace, /*workload=*/workload,
/*run_secs=*/record->run_secs,
/*target=*/record->target,
/*args_info=*/record->args_info));
}
}

std::vector<Database>* ThreadLocalDatabases() {
static thread_local std::vector<Database> tls;
return &tls;
Expand Down Expand Up @@ -297,6 +320,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.DatabaseQuerySchedule")
.set_body_method<Database>(&DatabaseNode::QuerySchedule);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseQueryIRModule")
.set_body_method<Database>(&DatabaseNode::QueryIRModule);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseDumpPruned")
.set_body_method<Database>(&DatabaseNode::DumpPruned);
TVM_REGISTER_GLOBAL("meta_schedule.DatabasePyDatabase").set_body_typed(Database::PyDatabase);

} // namespace meta_schedule
Expand Down
4 changes: 0 additions & 4 deletions src/meta_schedule/database/json_database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,6 @@ class JSONDatabaseNode : public DatabaseNode {
}
}
}
if (results.size() < static_cast<size_t>(top_k)) {
LOG(WARNING) << "Returned tuning records less than requested(" << results.size() << " of "
<< top_k << " asked).";
}
return results;
}

Expand Down
4 changes: 0 additions & 4 deletions src/meta_schedule/database/memory_database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,6 @@ class MemoryDatabaseNode : public DatabaseNode {
if (results.size() > static_cast<size_t>(top_k)) {
return {results.begin(), results.begin() + top_k};
} else {
if (results.size() < static_cast<size_t>(top_k)) {
LOG(WARNING) << "Returned tuning records less than requested(" << results.size() << " of "
<< top_k << " asked).";
}
return results;
}
}
Expand Down
3 changes: 3 additions & 0 deletions src/meta_schedule/module_equality.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class ModuleEqualityStructural : public ModuleEquality {
public:
size_t Hash(IRModule mod) const { return tvm::StructuralHash()(mod); }
bool Equal(IRModule lhs, IRModule rhs) const { return tvm::StructuralEqual()(lhs, rhs); }
String GetName() const { return "structural"; }
};

class SEqualHandlerIgnoreNDArray : public SEqualHandlerDefault {
Expand Down Expand Up @@ -72,6 +73,7 @@ class ModuleEqualityIgnoreNDArray : public ModuleEquality {
bool Equal(IRModule lhs, IRModule rhs) const {
return SEqualHandlerIgnoreNDArray().Equal(lhs, rhs, false);
}
String GetName() const { return "ignore-ndarray"; }
};

// The NDArray-ignoring variant of structural equal / hash is used for the module equality
Expand All @@ -93,6 +95,7 @@ class ModuleEqualityAnchorBlock : public ModuleEquality {
}
return ModuleEqualityIgnoreNDArray().Equal(lhs, rhs);
}
String GetName() const { return "anchor-block"; }
};

std::unique_ptr<ModuleEquality> ModuleEquality::Create(const std::string& mod_eq_name) {
Expand Down
1 change: 1 addition & 0 deletions src/meta_schedule/module_equality.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class ModuleEquality {

virtual size_t Hash(IRModule mod) const = 0;
virtual bool Equal(IRModule lhs, IRModule rhs) const = 0;
virtual String GetName() const = 0;

/*!
* \brief Create a ModuleEquality instance
Expand Down

0 comments on commit 01324ef

Please sign in to comment.