Skip to content

Commit

Permalink
Remove opencv dependency and replace with torch logic
Browse files Browse the repository at this point in the history
  • Loading branch information
liord committed Jan 28, 2025
1 parent c3ec981 commit 9e1f019
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
from functools import partial
from typing import Tuple, Union, List, Callable, Dict

import cv2
from torch import Tensor
from torchvision.transforms.transforms import _setup_size
import torch
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from model_compression_toolkit.data_generation.common.enums import DataInitType
Expand Down Expand Up @@ -97,9 +96,8 @@ def diverse_sample(size: Tuple[int, ...]) -> Tensor:
sample = random_std * torch.randn(size) + random_mean

# filtering to make the image a bit smoother
kernel = np.ones((5, 5), np.float32) / 16
if sample.shape[1] < 500 and sample.shape[2] < 500:
sample = torch.from_numpy(cv2.filter2D(sample.float().detach().cpu().numpy(), -1, kernel))
kernel = torch.ones(NUM_INPUT_CHANNELS, NUM_INPUT_CHANNELS, 5, 5) / 16
sample = F.conv2d(sample, kernel, padding=1)
return sample.float()

def default_data_init_fn(
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ networkx!=2.8.1
tqdm
Pillow
numpy<2.0
opencv-python
scikit-image
scikit-learn
tensorboard
Expand Down

0 comments on commit 9e1f019

Please sign in to comment.