From 6104903f6debfc8e11018dac7b490b87fd9ccb67 Mon Sep 17 00:00:00 2001 From: Lior Dikstein <78903511+lior-dikstein@users.noreply.github.com> Date: Tue, 28 Jan 2025 12:46:13 +0200 Subject: [PATCH] Remove opencv dependency and replace with torch logic (#1348) Co-authored-by: liord --- .../pytorch/optimization_functions/image_initilization.py | 8 +++----- requirements.txt | 1 - 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/model_compression_toolkit/data_generation/pytorch/optimization_functions/image_initilization.py b/model_compression_toolkit/data_generation/pytorch/optimization_functions/image_initilization.py index f9d09a65d..fd5ee8a92 100644 --- a/model_compression_toolkit/data_generation/pytorch/optimization_functions/image_initilization.py +++ b/model_compression_toolkit/data_generation/pytorch/optimization_functions/image_initilization.py @@ -15,11 +15,10 @@ from functools import partial from typing import Tuple, Union, List, Callable, Dict -import cv2 from torch import Tensor from torchvision.transforms.transforms import _setup_size import torch -import numpy as np +import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from model_compression_toolkit.data_generation.common.enums import DataInitType @@ -97,9 +96,8 @@ def diverse_sample(size: Tuple[int, ...]) -> Tensor: sample = random_std * torch.randn(size) + random_mean # filtering to make the image a bit smoother - kernel = np.ones((5, 5), np.float32) / 16 - if sample.shape[1] < 500 and sample.shape[2] < 500: - sample = torch.from_numpy(cv2.filter2D(sample.float().detach().cpu().numpy(), -1, kernel)) + kernel = torch.ones(NUM_INPUT_CHANNELS, NUM_INPUT_CHANNELS, 5, 5) / 16 + sample = F.conv2d(sample, kernel, padding=1) return sample.float() def default_data_init_fn( diff --git a/requirements.txt b/requirements.txt index 4c68dd252..31bb5db02 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,6 @@ networkx!=2.8.1 tqdm Pillow numpy<2.0 -opencv-python scikit-image scikit-learn tensorboard