-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #12 from openml/simple_dataloaders
New version
- Loading branch information
Showing
48 changed files
with
15,937 additions
and
1,070 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
name: ci | ||
on: | ||
push: | ||
branches: | ||
- master | ||
- main | ||
- simple_dataloaders | ||
permissions: | ||
contents: write | ||
jobs: | ||
deploy: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- name: Configure Git Credentials | ||
run: | | ||
git config user.name github-actions[bot] | ||
git config user.email 41898282+github-actions[bot]@users.noreply.github.com | ||
- uses: actions/setup-python@v5 | ||
with: | ||
python-version: 3.x | ||
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV | ||
- uses: actions/cache@v4 | ||
with: | ||
key: mkdocs-material-${{ env.cache_id }} | ||
path: .cache | ||
restore-keys: | | ||
mkdocs-material- | ||
- run: pip install mkdocs-material | ||
- run: pip install mkdocs-material-extensions | ||
- run: pip install mkdocs-jupyter | ||
- run: pip install mkdocs-redirects | ||
- run: pip install mkdocs-autorefs | ||
- run: pip install mkdocs-awesome-pages-plugin | ||
- run: pip install mkdocstrings | ||
- run: pip install mkdocstrings-python | ||
- run: pip install mknotebooks | ||
- run: mkdocs gh-deploy --force |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,4 +7,10 @@ openml_pytorch/layers/__pycache__/ | |
|
||
venv | ||
|
||
model.onnx | ||
model.onnx | ||
|
||
build/ | ||
dist/ | ||
openml_pytorch.egg-info/ | ||
|
||
datasets/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from .extension import PytorchExtension | ||
from . import config | ||
from . import layers | ||
from . import trainer | ||
from . import data | ||
from openml.extensions import register_extension | ||
import torch | ||
import io | ||
import onnx | ||
|
||
|
||
__all__ = ['PytorchExtension', 'config', 'layers','add_onnx_to_run', 'trainer'] | ||
|
||
register_extension(PytorchExtension) | ||
|
||
def add_onnx_to_run(run): | ||
|
||
run._old_get_file_elements = run._get_file_elements | ||
|
||
def modified_get_file_elements(): | ||
elements = run._old_get_file_elements() | ||
elements["onnx_model"] = ("model.onnx", extension.last_models) | ||
return elements | ||
|
||
run._get_file_elements = modified_get_file_elements | ||
return run | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
import logging | ||
|
||
import torch.nn | ||
import torch.nn.functional | ||
import torch.optim | ||
|
||
from openml import OpenMLTask, OpenMLClassificationTask, OpenMLRegressionTask | ||
|
||
from typing import Any, Callable | ||
|
||
# logger is the default logger for the PyTorch extension | ||
logger = logging.getLogger(__name__) # type: logging.Logger | ||
|
||
|
||
# _default_criterion_gen returns a criterion based on the task type - regressions use | ||
# torch.nn.SmoothL1Loss while classifications use torch.nn.CrossEntropyLoss | ||
def _default_criterion_gen(task: OpenMLTask) -> torch.nn.Module: | ||
if isinstance(task, OpenMLRegressionTask): | ||
return torch.nn.SmoothL1Loss() | ||
elif isinstance(task, OpenMLClassificationTask): | ||
return torch.nn.CrossEntropyLoss() | ||
else: | ||
raise ValueError(task) | ||
|
||
|
||
# criterion_gen returns the criterion based on the task type | ||
criterion_gen = _default_criterion_gen # type: Callable[[OpenMLTask], torch.nn.Module] | ||
|
||
|
||
# _default_optimizer_gen returns the torch.optim.Adam optimizer for the given model | ||
def _default_optimizer_gen(model: torch.nn.Module, _: OpenMLTask) -> torch.optim.Optimizer: | ||
return torch.optim.Adam(params=model.parameters()) | ||
|
||
|
||
# optimizer_gen returns the optimizer to be used for a given torch.nn.Module | ||
optimizer_gen = _default_optimizer_gen \ | ||
# type: Callable[[torch.nn.Module, OpenMLTask], torch.optim.Optimizer] | ||
|
||
|
||
# _default_scheduler_gen returns the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler | ||
# for the given optimizer | ||
def _default_scheduler_gen(optim: torch.optim.Optimizer, _: OpenMLTask) -> Any: | ||
return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optim) | ||
|
||
|
||
# scheduler_gen the scheduler to be used for a given torch.optim.Optimizer | ||
scheduler_gen = _default_scheduler_gen # type: Callable[[torch.optim.Optimizer, OpenMLTask], Any] | ||
|
||
# batch_size represents the processing batch size for training | ||
batch_size = 64 # type: int | ||
|
||
# epoch_count represents the number of epochs the model should be trained for | ||
epoch_count = 1 # type: int | ||
|
||
# filename_col is the name of the column in the dataset that contains the filenames | ||
filename_col = "Filename" | ||
|
||
# file_dir is the absolute path of the directory where the image files are stored | ||
file_dir = "images" | ||
|
||
# target_mode is the mode of the target column, either "categorical" or "numerical" | ||
target_mode = "categorical" | ||
|
||
# image_size is the size of the images that are fed into the model | ||
image_size = 128 | ||
|
||
# _default_predict turns the outputs into predictions by returning the argmax of the output tensor | ||
# for classification tasks, and by flattening the prediction in case of the regression | ||
def _default_predict(output: torch.Tensor, task: OpenMLTask) -> torch.Tensor: | ||
output_axis = output.dim() - 1 | ||
if isinstance(task, OpenMLClassificationTask): | ||
output = torch.argmax(output, dim=output_axis) | ||
elif isinstance(task, OpenMLRegressionTask): | ||
output = output.view(-1) | ||
else: | ||
raise ValueError(task) | ||
return output | ||
|
||
|
||
# predict turns the outputs of the model into actual predictions | ||
predict = _default_predict # type: Callable[[torch.Tensor, OpenMLTask], torch.Tensor] | ||
|
||
|
||
# _default_predict_proba turns the outputs into probabilities using softmax | ||
def _default_predict_proba(output: torch.Tensor) -> torch.Tensor: | ||
output_axis = output.dim() - 1 | ||
output = output.softmax(dim=output_axis) | ||
return output | ||
|
||
|
||
# predict_proba turns the outputs of the model into probabilities for each class | ||
predict_proba = _default_predict_proba # type: Callable[[torch.Tensor], torch.Tensor] | ||
|
||
|
||
# _default sanitizer replaces NaNs with 1e-6 | ||
def _default_sanitize(tensor: torch.Tensor) -> torch.Tensor: | ||
tensor = torch.where(torch.isnan(tensor), torch.ones_like(tensor) * torch.tensor(1e-6), tensor) | ||
return tensor | ||
|
||
|
||
# sanitize sanitizes the input data in order to ensure that models can be | ||
# trained safely | ||
sanitize = _default_sanitize # type: Callable[[torch.Tensor], torch.Tensor] | ||
|
||
|
||
# _default_retype_labels turns the labels into torch.(cuda)LongTensor if the task is classification | ||
# or torch.(cuda)FloatTensor if the task is regression | ||
def _default_retype_labels(tensor: torch.Tensor, task: OpenMLTask) -> torch.Tensor: | ||
if isinstance(task, OpenMLClassificationTask): | ||
return tensor.long() | ||
elif isinstance(task, OpenMLRegressionTask): | ||
return tensor.float() | ||
else: | ||
raise ValueError(task) | ||
|
||
|
||
# retype_labels changes the types of the labels in order to ensure type compatibility | ||
retype_labels = _default_retype_labels # type: Callable[[torch.Tensor, OpenMLTask], torch.Tensor] | ||
|
||
|
||
# _default_progress_callback reports the current fold, rep, epoch, step and loss for every | ||
# training iteration to the default logger | ||
def _default_progress_callback(fold: int, rep: int, epoch: int, | ||
step: int, loss: float, accuracy: float): | ||
logger.info('[%d, %d, %d, %d] loss: %.4f, accuracy: %.4f' % | ||
(fold, rep, epoch, step, loss, accuracy)) | ||
|
||
|
||
# progress_callback is called when a training step is finished, in order to | ||
# report the current progress | ||
progress_callback = _default_progress_callback \ | ||
# type: Callable[[int, int, int, int, float, float], None] | ||
|
||
data_augmentation = None | ||
|
||
perform_validation = False | ||
validation_split = 0.1 | ||
|
||
def get_device(): | ||
if torch.cuda.is_available(): | ||
device = torch.device("cuda") | ||
elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): | ||
device = torch.device("mps") | ||
else: | ||
device = torch.device("cpu") | ||
|
||
return device | ||
|
||
device = get_device() | ||
|
||
def _setup(): | ||
global logger | ||
global criterion_gen | ||
global optimizer_gen | ||
global scheduler_gen | ||
global batch_size | ||
global epoch_count | ||
global predict | ||
global predict_proba | ||
global sanitize | ||
global retype_labels | ||
global progress_callback | ||
global file_dir | ||
global filename_col | ||
global target_mode | ||
global data_augmentation | ||
global perform_validation | ||
global validation_split | ||
|
||
_setup() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import os | ||
from typing import Any | ||
from sklearn import preprocessing | ||
import torch | ||
from torchvision.io import read_image | ||
from torch.utils.data import Dataset | ||
from torchvision.transforms import Compose, Resize, ToPILImage, ToTensor, Lambda | ||
|
||
class OpenMLImageDataset(Dataset): | ||
def __init__(self,image_size, annotations_df, img_dir, transform=1, target_transform=None): | ||
self.img_labels = annotations_df | ||
self.img_dir = img_dir | ||
self.transform = transform | ||
self.target_transform = target_transform | ||
self.image_size = image_size | ||
self.has_labels = 'encoded_labels' in annotations_df.columns | ||
|
||
def __len__(self): | ||
return len(self.img_labels) | ||
|
||
def __getitem__(self, idx): | ||
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) | ||
|
||
try: | ||
image = read_image(img_path) | ||
except RuntimeError as error: | ||
# print(f"Error loading image {img_path}: {error}") | ||
# Use a default image | ||
# from .config import image_size | ||
image = torch.zeros((3, self.image_size, self.image_size), dtype=torch.uint8) | ||
|
||
# label = self.img_labels.iloc[idx, 1] | ||
if self.transform: | ||
if not self.transform == 1: | ||
image = self.transform(image) | ||
image = image.float() | ||
|
||
|
||
if self.has_labels: | ||
label = self.img_labels.iloc[idx, 1] | ||
return image, label | ||
else: | ||
return image |
Oops, something went wrong.