|
47 | 47 | parser.add_argument("--num-prompts", "-n", type=int, default=8) |
48 | 48 | parser.add_argument("--compare-with-cpu", |
49 | 49 | action=argparse.BooleanOptionalAction) |
| 50 | +parser.add_argument("--trunc_print_len", |
| 51 | + "--trunc-print-len", |
| 52 | + type=int, |
| 53 | + required=False) |
50 | 54 | args = parser.parse_args() |
51 | 55 |
|
| 56 | +trunc = args.trunc_print_len |
| 57 | + |
52 | 58 | max_num_seqs = args.max_num_seqs # defines the max batch size |
53 | 59 | assert args.max_prompt_len < args.max_model_len |
54 | 60 |
|
@@ -144,8 +150,8 @@ def round_up(t): |
144 | 150 | print("Time elapsed for all prompts is %.2f sec" % (time.time() - t0)) |
145 | 151 | print("===============") |
146 | 152 | for output, prompt in zip(outputs, prompts): |
147 | | - generated_text = output.outputs[0].text[:100] |
148 | | - prompt = prompt[:100] |
| 153 | + generated_text = output.outputs[0].text[:trunc] |
| 154 | + prompt = prompt[:trunc] |
149 | 155 | print(f"\nPrompt:\n {prompt!r}") |
150 | 156 | print(f"\nGenerated text (truncated):\n {generated_text!r}\n") |
151 | 157 | print("-----------------------------------") |
@@ -177,9 +183,9 @@ def round_up(t): |
177 | 183 | any_differ = True |
178 | 184 | spyre_output = outputs[i].outputs[0].text |
179 | 185 | print(f"Results for prompt {i} differ on cpu") |
180 | | - print(f"\nPrompt:\n {prompt[:100]!r}") |
181 | | - print(f"\nSpyre generated text:\n {spyre_output[:100]!r}\n") |
182 | | - print(f"\nCPU generated text:\n {hf_generated_text[:100]!r}\n") |
| 186 | + print(f"\nPrompt:\n {prompt[:trunc]!r}") |
| 187 | + print(f"\nSpyre generated text:\n {spyre_output[:trunc]!r}\n") |
| 188 | + print(f"\nCPU generated text:\n {hf_generated_text[:trunc]!r}\n") |
183 | 189 | print("-----------------------------------") |
184 | 190 |
|
185 | 191 | if not any_differ: |
|
0 commit comments