Skip to content

Commit

Permalink
Fix imports and type errors (#12)
Browse files Browse the repository at this point in the history
* Remove old code for mnemonic classification

* Enhance type hints and error handling 

* Update imports for consistency
  • Loading branch information
chiffonng authored Feb 12, 2025
1 parent bc00af1 commit e3b661a
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 413 deletions.
41 changes: 28 additions & 13 deletions src/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from unsloth import FastLanguageModel

DESCRIPTION = """
This is a demo for the Google Gemma 2 9B IT model. Use it to generate mnemonics for English words you want to learn and remember.
Input your instructions or start with one of the examples provided. The input supports a subset of markdown formatting such as bold, italics, code, tables. You can also use the following special tokens to customize the mnemonic:
Input your instructions or start with one of the examples provided. The input supports a subset of markdown formatting such as bold, italics, code, tables.
"""

MAX_MAX_NEW_TOKENS = 2048
Expand All @@ -23,18 +24,32 @@

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_id = "google/gemma-2-9b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
model.config.sliding_window = 4096
model.eval()

# model_id = "google/gemma-2-9b-it"
model_id = "unsloth/gemma-2-9b-it"

@spaces.GPU(duration=90)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_id,
max_seq_length=MAX_INPUT_TOKEN_LENGTH,
dtype=None,
load_in_4bit=True,
device=device,
cache_dir="models",
)
FastLanguageModel.for_inference(model)
# tokenizer = AutoTokenizer.from_pretrained(model_id)
# model = AutoModelForCausalLM.from_pretrained(
# model_id,
# device_map="auto",
# load_in_4bit=True,
# torch_dtype=torch.bfloat16,
# cache_dir="models",
# )
# model.config.sliding_window = 4096
# model.eval()


# Uncomment to use Hugging Face Spaces GPU
# @spaces.GPU(duration=90)
def generate(
message: str,
chat_history: list[dict],
Expand Down Expand Up @@ -159,4 +174,4 @@ def generate(


if __name__ == "__main__":
demo.queue(max_size=20).launch()
demo.queue(max_size=20).launch(sharer=True)
17 changes: 12 additions & 5 deletions src/data/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@

from datasets import Dataset

import utils.constants as c
from utils.aliases import ExtensionsType, PathLike
from utils.common import login_hf_hub
from utils.error_handling import check_dir_path, check_file_path
from src.utils import constants as c
from src.utils.aliases import ExtensionsType, PathLike
from src.utils.common import login_hf_hub
from src.utils.error_handling import check_dir_path, check_file_path

# Set up logging to console
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -63,7 +63,7 @@ def load_local_dataset(file_path: PathLike, **kwargs) -> "Dataset":
def load_hf_dataset(
repo_id: Optional[str] = None,
to_csv: bool = False,
file_path: PathLike = None,
file_path: "Optional[PathLike]" = None,
**kwargs,
) -> "DatasetDict":
"""Load a dataset from the Hugging Face hub.
Expand All @@ -87,6 +87,10 @@ def load_hf_dataset(

if to_csv:
file_path = check_file_path(file_path, new_ok=True, extensions=c.CSV_EXT)
if not file_path:
raise ValueError(
"Invalid file path. Must be a valid path of csv to save the dataset to."
)
dataset.to_csv(file_path)
logger.info(f"Saved dataset to {file_path}.")
else:
Expand All @@ -106,4 +110,7 @@ def load_hf_dataset(
if __name__ == "__main__":
# Load a dataset from the Hugging Face hub
mnemonic_dataset: "Dataset" = load_hf_dataset()
test_dataset: "Dataset" = load_hf_dataset(
repo_id="nbalepur/Mnemonic_Test", split="train"
)
logger.info(f"\n\n{mnemonic_dataset}")
20 changes: 12 additions & 8 deletions src/data/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
if TYPE_CHECKING:
from datasets import Dataset, DatasetDict

import utils.constants as c
from data.data_loaders import load_local_dataset
from utils.aliases import ExtensionsType, PathLike
from utils.common import login_hf_hub
from utils.error_handling import check_dir_path, check_file_path
from src.data.data_loaders import load_local_dataset
from src.utils import constants as c
from src.utils.aliases import ExtensionsType, PathLike
from src.utils.common import login_hf_hub
from src.utils.error_handling import check_dir_path, check_file_path

# Set up logging to console
logger = logging.getLogger(__name__)
Expand All @@ -45,6 +45,10 @@ def load_clean_txt_csv_data(dir_path: PathLike) -> pd.DataFrame:
"""
df = pd.DataFrame()
file_paths = check_dir_path(dir_path, extensions=[c.TXT_EXT, c.CSV_EXT])

if not file_paths or isinstance(file_paths, Path):
raise FileNotFoundError(f"No txt or csv files found in '{dir_path}'.")

logger.info(f"Loading txt/csv files from {[str(p) for p in file_paths]}.")

# Read only the first two columns, skipping the first two rows
Expand Down Expand Up @@ -126,10 +130,10 @@ def combine_datasets(
Raises:
ValueError: If the provided output format is not 'csv' or 'parquet'.
"""
input_dir = check_dir_path(input_dir)
checked_input_dir = check_dir_path(input_dir)

# Load and combine the datasets
combined_df = load_clean_txt_csv_data(input_dir)
combined_df = load_clean_txt_csv_data(checked_input_dir)

# Clean the data
combined_df.drop_duplicates(subset=[c.TERM_COL], inplace=True, keep="first")
Expand Down Expand Up @@ -176,7 +180,7 @@ def train_test_split(dataset: "Dataset", test_size: float = 0.2) -> "DatasetDict

def push_to_hf_hub(
dataset: "Dataset",
repo_id: str = c.HF_DATASET_REPO,
repo_id: str = c.HF_DATASET_NAME,
private: bool = False,
**kwargs,
):
Expand Down
Loading

0 comments on commit e3b661a

Please sign in to comment.