Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,38 @@ bool common_params_speculative::has_stage_type(common_speculative_type stage_typ
});
}

void common_params_speculative::remove_stage_type(common_speculative_type stage_type) {
stages.erase(std::remove_if(stages.begin(), stages.end(), [stage_type](const common_speculative_stage_params & stage) {
return stage.type == stage_type;
}), stages.end());

if (type == stage_type) {
const auto resolved = get_resolved_stages();
type = resolved.empty() ? COMMON_SPECULATIVE_TYPE_NONE : resolved.front().type;
}
}

bool common_params_speculative::has_composite_stage_chain() const {
return get_resolved_stages().size() > 1;
}

bool common_params_speculative::needs_dft_model() const {
return has_stage_type(COMMON_SPECULATIVE_TYPE_DRAFT) ||
(has_stage_type(COMMON_SPECULATIVE_TYPE_MTP) && has_dft());
}

void common_params_speculative::clear_dft() {
if (model_dft != nullptr) {
llama_free_model(model_dft);
model_dft = nullptr;
}

model.clear();
params.clear();
mparams_dft.path.clear();
cparams_dft = llama_context_default_params();
}

int32_t common_params_speculative::get_max_stage_n_max() const {
const auto resolved = get_resolved_stages();
if (resolved.empty()) {
Expand Down
3 changes: 3 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,10 @@ struct common_params_speculative {
common_params_speculative with_stage_overrides(const common_speculative_stage_params & stage) const;
bool has_stage_chain() const;
bool has_stage_type(common_speculative_type stage_type) const;
void remove_stage_type(common_speculative_type stage_type);
bool has_composite_stage_chain() const;
bool needs_dft_model() const;
void clear_dft();
int32_t get_max_stage_n_max() const;
int32_t get_min_usable_stage_n_min() const;

Expand Down
Loading