diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..0d74570 --- /dev/null +++ b/Makefile @@ -0,0 +1,9 @@ +# this target runs checks on all files +quality: + ruff check ochat + ruff format --check ochat + +# this target runs checks on all files and potentially modifies some of them +style: + ruff check ochat --fix + ruff format ochat \ No newline at end of file diff --git a/ochat/config/__init__.py b/ochat/config/__init__.py index d221042..f90cfa3 100644 --- a/ochat/config/__init__.py +++ b/ochat/config/__init__.py @@ -3,10 +3,9 @@ import torch import transformers -from ochat.config.model_config import ModelConfig -from ochat.config.conversation_template import Message, Conversation, ConversationTemplate import ochat.models - +from ochat.config.conversation_template import Conversation, ConversationTemplate, Message +from ochat.config.model_config import ModelConfig _V3_2_PREFIXES = { # OpenAI mapping diff --git a/ochat/config/conversation_template.py b/ochat/config/conversation_template.py index 3e0e79b..43f9068 100644 --- a/ochat/config/conversation_template.py +++ b/ochat/config/conversation_template.py @@ -1,4 +1,4 @@ -from typing import Optional, Callable, Iterable, List, Dict +from typing import Callable, Iterable, List, Optional from pydantic import BaseModel diff --git a/ochat/data/generate_dataset.py b/ochat/data/generate_dataset.py index f1ccf39..90d6928 100644 --- a/ochat/data/generate_dataset.py +++ b/ochat/data/generate_dataset.py @@ -4,17 +4,15 @@ Usage: python -m ochat.data.generate_data --in-file sharegpt_gpt4.jsonl --tokenizer-name HF_REPO_NAME --out-dir . """ -from typing import Optional import argparse import os import random -import ray import orjson import pyarrow +import ray from pyarrow import parquet - PAD_TOKEN_ID = 0 @@ -106,11 +104,11 @@ def generate_split(model_type: str, model_path: str, conversations: list, split_ pyarrow.field("total_length", pyarrow.int32()), pyarrow.field("num_seqs", pyarrow.float32()), - pyarrow.field(f"seqlens", pyarrow.list_(pyarrow.int32())), - pyarrow.field(f"nz_input_ids", pyarrow.list_(pyarrow.int32())), - pyarrow.field(f"nz_position_ids", pyarrow.list_(pyarrow.int32())), - pyarrow.field(f"nz_shifted_label_ids", pyarrow.list_(pyarrow.int32())), - pyarrow.field(f"nz_shifted_loss_weights", pyarrow.list_(pyarrow.float32())) + pyarrow.field("seqlens", pyarrow.list_(pyarrow.int32())), + pyarrow.field("nz_input_ids", pyarrow.list_(pyarrow.int32())), + pyarrow.field("nz_position_ids", pyarrow.list_(pyarrow.int32())), + pyarrow.field("nz_shifted_label_ids", pyarrow.list_(pyarrow.int32())), + pyarrow.field("nz_shifted_loss_weights", pyarrow.list_(pyarrow.float32())) ] schema = pyarrow.schema(schema, metadata={"metadata_json": orjson.dumps(metadata)}) diff --git a/ochat/evaluation/conv_eval.py b/ochat/evaluation/conv_eval.py index df436b8..8d35d56 100644 --- a/ochat/evaluation/conv_eval.py +++ b/ochat/evaluation/conv_eval.py @@ -1,14 +1,13 @@ -from typing import OrderedDict -import signal +import argparse import os -import json +import re +import signal import subprocess -import argparse import time -import requests -import re -import coolname +from typing import OrderedDict +import coolname +import requests MAX_CONTEXT = 4096 diff --git a/ochat/evaluation/convert_to_evalplus.py b/ochat/evaluation/convert_to_evalplus.py index 30eae8c..9421aa6 100644 --- a/ochat/evaluation/convert_to_evalplus.py +++ b/ochat/evaluation/convert_to_evalplus.py @@ -1,9 +1,9 @@ import argparse import os -import orjson - from glob import glob +import orjson + def convert_to_evalplus(results_path: str, output_path: str): os.makedirs(output_path, exist_ok=True) diff --git a/ochat/evaluation/grading/math_grader.py b/ochat/evaluation/grading/math_grader.py index 3d6310c..4f8d614 100644 --- a/ochat/evaluation/grading/math_grader.py +++ b/ochat/evaluation/grading/math_grader.py @@ -4,16 +4,16 @@ Call grade_answer(given_answer: str, ground_truth: str). """ import re + import sympy from pylatexenc import latex2text from sympy.parsing import sympy_parser from ochat.evaluation.grading import math_normalize - # sympy might hang -- we don't care about trying to be lenient in these cases BAD_SUBSTRINGS = ["^{", "^("] -BAD_REGEXES = ["\^[0-9]+\^", "\^[0-9][0-9]+"] +BAD_REGEXES = [r"\^[0-9]+\^", r"\^[0-9][0-9]+"] TUPLE_CHARS = "()[]" @@ -93,7 +93,7 @@ def _inject_implicit_mixed_number(step: str): def _strip_properly_formatted_commas(expr: str): # We want to be careful because we don't want to strip tuple commas - p1 = re.compile("(\d)(,)(\d\d\d)($|\D)") + p1 = re.compile(r"(\d)(,)(\d\d\d)($|\D)") while True: next_expr = p1.sub("\\1\\3\\4", expr) if next_expr == expr: @@ -108,7 +108,7 @@ def _normalize(expr: str) -> str: return None # Remove enclosing `\text{}`. - m = re.search("^\\\\text\{(?P.+?)\}$", expr) + m = re.search("^\\\\text\\{(?P.+?)\\}$", expr) if m is not None: expr = m.group("text") @@ -141,8 +141,8 @@ def _normalize(expr: str) -> str: "inch", "yard", ]: - expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr) - expr = re.sub(f"\^ *\\\\circ", "", expr) + expr = re.sub(rf"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr) + expr = re.sub("\\^ *\\\\circ", "", expr) if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}": expr = expr[1:-1] diff --git a/ochat/evaluation/grading/math_normalize.py b/ochat/evaluation/grading/math_normalize.py index 4bb753b..9f57d8e 100644 --- a/ochat/evaluation/grading/math_normalize.py +++ b/ochat/evaluation/grading/math_normalize.py @@ -11,7 +11,7 @@ def normalize_answer(answer: Optional[str]) -> Optional[str]: answer = answer.strip() try: # Remove enclosing `\text{}`. - m = re.search("^\\\\text\{(?P.+?)\}$", answer) + m = re.search("^\\\\text\\{(?P.+?)\\}$", answer) if m is not None: answer = m.group("text").strip() return _strip_string(answer) @@ -126,7 +126,7 @@ def _strip_string(string): # remove percentage string = string.replace("\\%", "") - string = string.replace("\%", "") + string = string.replace(r"\%", "") # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string string = string.replace(" .", " 0.") diff --git a/ochat/evaluation/match_answer.py b/ochat/evaluation/match_answer.py index 00ddef7..ce3b0db 100644 --- a/ochat/evaluation/match_answer.py +++ b/ochat/evaluation/match_answer.py @@ -1,5 +1,5 @@ -import re import ast +import re from ochat.evaluation.grading.math_grader import grade_answer @@ -50,7 +50,7 @@ def _last_boxed_only_string(string): break i += 1 - + if left_brace_idx is None or right_brace_idx is None: return None @@ -119,7 +119,7 @@ def fs_cothub_gsm8k_match_answer(task_data, response): # CoT hub match answer for GSM8k, match last numeric value # https://github.com/FranxYao/chain-of-thought-hub/blob/main/gsm8k/gpt3.5turbo_gsm8k_complex.ipynb - pattern = '\d*\.?\d+' + pattern = r'\d*\.?\d+' pred = re.findall(pattern, response) if len(pred) >= 1: return True, pred[-1] @@ -135,7 +135,7 @@ def fs_cothub_mmlu_match_answer(task_data, response): return False, "(C)" else: ans = ans_line[-1].strip() - + options = ['(A)', '(B)', '(C)', '(D)'] for option in options: if option in ans: @@ -174,12 +174,12 @@ def _try_match(content, prefix, entrypoint): include_prefix = humaneval_task['prompt'].split('def')[0].strip() + "\n\n" result = _try_match(response, include_prefix, humaneval_task["entry_point"]) - if result: + if result: return True, {"task_id": humaneval_task["task_id"], "completion": result} # If fail then match with function signature result = _try_match(response, humaneval_task["prompt"], humaneval_task["entry_point"]) - if result: + if result: return True, {"task_id": humaneval_task["task_id"], "completion": result} return False, {"task_id": humaneval_task["task_id"], "completion": response} diff --git a/ochat/evaluation/run_eval.py b/ochat/evaluation/run_eval.py index 633b90c..8ed4f82 100644 --- a/ochat/evaluation/run_eval.py +++ b/ochat/evaluation/run_eval.py @@ -1,20 +1,19 @@ -from typing import Optional import argparse -import os import asyncio +import os from glob import glob +from typing import Optional -import orjson import openai -from tqdm import tqdm +import orjson from openai.error import RateLimitError, ServiceUnavailableError -from tenacity import retry, stop_after_attempt, wait_random_exponential, retry_if_exception_type -from vllm import LLM, SamplingParams - +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential +from tqdm import tqdm from transformers.utils.hub import cached_file +from vllm import LLM, SamplingParams -from ochat.evaluation.match_answer import MATCH_ANSWER_FUNCTION from ochat.config import MODEL_CONFIG_MAP +from ochat.evaluation.match_answer import MATCH_ANSWER_FUNCTION @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(20), retry=retry_if_exception_type((RateLimitError, ServiceUnavailableError, ))) @@ -46,7 +45,7 @@ async def chat_completion_thread(model, progress_bar, queue): e = e._exception print(type(e), str(e)) - + # Progress progress_bar.update() diff --git a/ochat/evaluation/view_results.py b/ochat/evaluation/view_results.py index 0a8ce78..b57333c 100644 --- a/ochat/evaluation/view_results.py +++ b/ochat/evaluation/view_results.py @@ -1,10 +1,10 @@ import argparse import os +from glob import glob from pathlib import Path import orjson import pandas as pd -from glob import glob def view_results(result_path: str): diff --git a/ochat/experimental/generate_dataset_old.py b/ochat/experimental/generate_dataset_old.py index 5530f32..a2be712 100644 --- a/ochat/experimental/generate_dataset_old.py +++ b/ochat/experimental/generate_dataset_old.py @@ -4,17 +4,17 @@ Usage: python -m ochat.data.generate_data --in-file sharegpt_gpt4.json --tokenizer-name HF_REPO_NAME --out-dir . """ -from typing import Optional -from dataclasses import dataclass import argparse import json import os import random +from dataclasses import dataclass +from typing import Optional import numpy as np import transformers -from transformers.trainer_pt_utils import LabelSmoother from ray.util.multiprocessing import Pool +from transformers.trainer_pt_utils import LabelSmoother @dataclass diff --git a/ochat/models/unpadded_llama.py b/ochat/models/unpadded_llama.py index f7c236a..9636141 100644 --- a/ochat/models/unpadded_llama.py +++ b/ochat/models/unpadded_llama.py @@ -24,16 +24,15 @@ import torch import torch.utils.checkpoint from torch import nn - from transformers.activations import ACT2FN from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel -from transformers.utils import logging from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.utils import logging try: - from flash_attn.flash_attn_interface import flash_attn_varlen_func from flash_attn.bert_padding import pad_input + from flash_attn.flash_attn_interface import flash_attn_varlen_func except ImportError: print ("FlashAttention not found. Install it if you need to train models.") @@ -313,7 +312,7 @@ def forward( else: nz_hidden_states = decoder_layer( cos_sin, - + nz_hidden_states, nz_position_ids, cu_seqlens, @@ -355,7 +354,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model - + def forward( self, # Unpadded inputs diff --git a/ochat/models/unpadded_mistral.py b/ochat/models/unpadded_mistral.py index 2abb216..4b3cdd0 100644 --- a/ochat/models/unpadded_mistral.py +++ b/ochat/models/unpadded_mistral.py @@ -24,16 +24,15 @@ import torch import torch.utils.checkpoint from torch import nn - from transformers.activations import ACT2FN from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel -from transformers.utils import logging from transformers.models.mistral.configuration_mistral import MistralConfig +from transformers.utils import logging try: - from flash_attn.flash_attn_interface import flash_attn_varlen_func from flash_attn.bert_padding import pad_input + from flash_attn.flash_attn_interface import flash_attn_varlen_func except ImportError: print ("FlashAttention not found. Install it if you need to train models.") @@ -313,7 +312,7 @@ def forward( else: nz_hidden_states = decoder_layer( cos_sin, - + nz_hidden_states, nz_position_ids, cu_seqlens, @@ -352,7 +351,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model - + def forward( self, # Unpadded inputs diff --git a/ochat/scripts/hf_add_tokens.py b/ochat/scripts/hf_add_tokens.py index 3d27f98..231da46 100644 --- a/ochat/scripts/hf_add_tokens.py +++ b/ochat/scripts/hf_add_tokens.py @@ -1,7 +1,7 @@ import argparse -import transformers import torch +import transformers def add_tokens_to_embedding(added_special_tokens, embedding): diff --git a/ochat/serving/async_tokenizer.py b/ochat/serving/async_tokenizer.py index adde68b..df22ffa 100644 --- a/ochat/serving/async_tokenizer.py +++ b/ochat/serving/async_tokenizer.py @@ -1,6 +1,6 @@ import ray -from ochat.config import Message, Conversation +from ochat.config import Conversation, Message @ray.remote @@ -38,7 +38,7 @@ def tokenize(self, messages, condition, enable_sys_prompt=False): tokens, _ = self.conv_template.tokenize_conversations([Conversation(items=items, system=system_message, condition=condition)], inference=True) return tokens[0] - + def get_eot_tokens(self): assert len(self.conv_template.eot_tokens_) == 1 diff --git a/ochat/serving/openai_api_protocol.py b/ochat/serving/openai_api_protocol.py index 3be884c..65aec4c 100644 --- a/ochat/serving/openai_api_protocol.py +++ b/ochat/serving/openai_api_protocol.py @@ -4,7 +4,6 @@ from typing import Dict, List, Literal, Optional, Union from pydantic import BaseModel, Field - from vllm.utils import random_uuid diff --git a/ochat/serving/openai_api_server.py b/ochat/serving/openai_api_server.py index 413db30..4a2009b 100644 --- a/ochat/serving/openai_api_server.py +++ b/ochat/serving/openai_api_server.py @@ -3,24 +3,23 @@ import argparse import asyncio -from http import HTTPStatus import json -import time import logging +import time +from dataclasses import dataclass +from http import HTTPStatus from logging.handlers import RotatingFileHandler from typing import AsyncGenerator, Optional -from dataclasses import dataclass import fastapi +import ray +import uvicorn from fastapi import BackgroundTasks, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer - -import uvicorn -import ray - +from transformers.utils.hub import cached_file from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.outputs import RequestOutput @@ -28,10 +27,7 @@ from vllm.utils import random_uuid from ochat.config import MODEL_CONFIG_MAP -from ochat.serving import openai_api_protocol, async_tokenizer - -from transformers.utils.hub import cached_file - +from ochat.serving import async_tokenizer, openai_api_protocol TIMEOUT_KEEP_ALIVE = 5 # seconds diff --git a/ochat/training_deepspeed/multipack_sampler.py b/ochat/training_deepspeed/multipack_sampler.py index af372f1..6a3d57c 100644 --- a/ochat/training_deepspeed/multipack_sampler.py +++ b/ochat/training_deepspeed/multipack_sampler.py @@ -1,5 +1,5 @@ -import numpy as np import numba +import numpy as np @numba.njit @@ -95,7 +95,7 @@ def allocate(lengths: np.ndarray, numseqs: np.ndarray, lengths_cumsum: np.ndarra class MultipackDistributedSampler: """Unpadded data loading using Multipack. Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.""" - + def __init__( self, lengths: np.ndarray, @@ -139,7 +139,7 @@ def generate_batches(self, epoch, set_stats=False): rank=self.rank, c=self.batch_max_length, n=self.num_replicas) - + curseqs = [np.sum(numseqs[batch]) for batch in batches] batches = [indices[batch] for batch in batches] @@ -149,7 +149,7 @@ def generate_batches(self, epoch, set_stats=False): self.eff_total_slots += total_slots return batches, totseqs, curseqs - + def iter(self, epoch): all_batches, all_totseqs, all_curseqs = self.generate_batches(epoch, set_stats=True) diff --git a/ochat/training_deepspeed/openchat_dataset.py b/ochat/training_deepspeed/openchat_dataset.py index dae3e78..275d502 100644 --- a/ochat/training_deepspeed/openchat_dataset.py +++ b/ochat/training_deepspeed/openchat_dataset.py @@ -1,9 +1,8 @@ -import torch import numpy as np -from torch.utils.data import IterableDataset, get_worker_info - -import pyarrow.parquet as pq import orjson +import pyarrow.parquet as pq +import torch +from torch.utils.data import IterableDataset, get_worker_info from ochat.training_deepspeed.multipack_sampler import MultipackDistributedSampler diff --git a/ochat/training_deepspeed/train.py b/ochat/training_deepspeed/train.py index c0caa45..6e95844 100644 --- a/ochat/training_deepspeed/train.py +++ b/ochat/training_deepspeed/train.py @@ -1,16 +1,15 @@ import argparse -import os -import math import json +import math +import os from functools import partial +import numpy as np import torch import torch.distributed as dist -from torch.utils.data import DataLoader - import tqdm import wandb -import numpy as np +from torch.utils.data import DataLoader from ochat.config import MODEL_CONFIG_MAP from ochat.training_deepspeed.openchat_dataset import OpenchatDataset @@ -151,7 +150,7 @@ def save_openchat_metadata(args, epoch, save_path): def calculate_auto_lr(lr, batch_max_len, model_type, train_dataset): if lr is not None: return lr - + # Llama hyperparameters # FIXME: Only 7B/13B is supported base_lr = 3e-4 diff --git a/pyproject.toml b/pyproject.toml index b6a3119..1acd764 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,3 +48,27 @@ exclude = ["assets*", "ochat/experimental*"] exclude = ["assets*", "ochat/experimental*"] [tool.setuptools_scm] + +[tool.ruff] +# Never enforce `E501` (line length violations). +ignore = ["C901", "E501", "E741", "F402", "F823"] +select = ["C", "E", "F", "I", "W"] +line-length = 118 + +# Ignore import violations in all `__init__.py` files. +[tool.ruff.per-file-ignores] +"__init__.py" = ["E402", "F401", "F403", "F811"] +"src/diffusers/utils/dummy_*.py" = ["F401"] + +[tool.ruff.format] +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" \ No newline at end of file