-
Notifications
You must be signed in to change notification settings - Fork 60
Description
Describe the bug
The return value of model.do classification does not match
acc, unknown, (cost, prompt_tokens, cont_tokens), cache = model.do_classification(dataset, task_message=task_dic[args.fairness.dataset], example_prefix=False, dry_run=args.dry_run)
except Exception as e:
print(e)
if len(cache) == 0:
return None, None, 0, [], []
else:
**return acc / len(cache), unknown, cost, cache, predictions
return acc / len(dataset), unknown, cost, cache, predictions**
To Reproduce
config.yaml as fellow:
model_config:
model: "/home/nfs03/dongjc/model/Llama-2-7b-chat-hf"
type: CHAT
conv_template: "llama-2"
model_loader: HF
torch_dtype: FLOAT16
quant_file: null
tokenizer_name: "/home/nfs03/dongjc/model/Llama-2-7b-chat-hf"
trust_remote_code: true
use_auth_token: true
key: null
dry_run: false
hydra:
job:
chdir: false
callbacks:
save_job_info:
target: hydra.experimental.callbacks.PickleJobInfoCallback
dt-run ++dry_run=True +fairness=zero_shot_br_0.0.yaml
Expected behavior
DecodingTrust/src/dt/perspectives/fairness/fairness_evaluation.py", line 29, in main
acc, unknown, (cost, prompt_tokens, cont_tokens), cache = model.do_classification(dataset, task_message=task_dic[args.fairness.dataset], example_prefix=False, dry_run=args.dry_run)
TypeError: cannot unpack non-iterable int object