-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpre-processing.py
85 lines (71 loc) · 2.4 KB
/
pre-processing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torchvision.transforms.autoaugment import RandAugment
import albumentations as A
import random
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
class WeatherAugmentation:
def __init__(self):
self.transforms = [
self.solar_illumination(),
self.rain_effect(),
self.shadow_effect(),
self.fog_effect()
]
def solar_illumination(self):
return A.RandomSunFlare(
flare_roi=(0.9, 0, 1, 0.5), # upper right corner
flare_roi_2=(0.0, 0.0, 1.0, 0.1), # upper left corner
flare_radius=300,
)
def rain_effect(self):
return A.RandomRain(
rain_drop_size=1.0,
rain_type='drizzle',
brightness_coefficient=0.6,
)
def shadow_effect(self):
return A.RandomShadow(
num_shadows_lower=1,
num_shadows_upper=5,
shadow_dimension=6,
)
def fog_effect(self):
return A.RandomFog(
fog_coef=(0.25, 0.8),
alpha_coef_range=(0.25, 0.8),
p=0.3
)
def apply_transforms(self, image):
# Randomly select which transforms to apply
active_transforms = [
transform for transform in self.transforms
if random.random() < 0.6
]
# Compose the selected transforms
composed_transform = A.Compose(active_transforms, p=1.0)
# Apply the composed transform to the image
transformed = composed_transform(image=image)['image']
return transformed
if __name__ == "__main__":
# Load an example image
image = Image.open('path_to_your_image.jpg').convert('RGB')
image = np.array(image)
# Create an instance of the WeatherAugmentation class
weather_aug = WeatherAugmentation()
# Apply the augmentations
augmented_image = weather_aug.apply_transforms(image)
# Display the original and augmented images
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title('Original Image')
plt.imshow(image)
plt.axis('off')
plt.subplot(1, 2, 2)
plt.title('Augmented Image')
plt.imshow(augmented_image)
plt.axis('off')
plt.show()