@@ -17,227 +17,10 @@ def __getitem__(self, key):
1717 def __setitem__ (self , key , item ):
1818 setattr (self , key , item )
1919
20-
21- def cubic_kernel (x , a : float = - 0.75 ):
22- absx = x .abs ()
23- absx2 = absx ** 2
24- absx3 = absx ** 3
25-
26- w = (a + 2 ) * absx3 - (a + 3 ) * absx2 + 1
27- w2 = a * absx3 - 5 * a * absx2 + 8 * a * absx - 4 * a
28-
29- return torch .where (absx <= 1 , w , torch .where (absx < 2 , w2 , torch .zeros_like (x )))
30-
31- def get_indices_weights (in_size , out_size , scale ):
32- # OpenCV-style half-pixel mapping
33- x = torch .arange (out_size , dtype = torch .float32 )
34- x = (x + 0.5 ) / scale - 0.5
35-
36- x0 = x .floor ().long ()
37- dx = x .unsqueeze (1 ) - (x0 .unsqueeze (1 ) + torch .arange (- 1 , 3 ))
38-
39- weights = cubic_kernel (dx )
40- weights = weights / weights .sum (dim = 1 , keepdim = True )
41-
42- indices = x0 .unsqueeze (1 ) + torch .arange (- 1 , 3 )
43- indices = indices .clamp (0 , in_size - 1 )
44-
45- return indices , weights
46-
47- def resize_cubic_1d (x , out_size , dim ):
48- b , c , h , w = x .shape
49- in_size = h if dim == 2 else w
50- scale = out_size / in_size
51-
52- indices , weights = get_indices_weights (in_size , out_size , scale )
53-
54- if dim == 2 :
55- x = x .permute (0 , 1 , 3 , 2 )
56- x = x .reshape (- 1 , h )
57- else :
58- x = x .reshape (- 1 , w )
59-
60- gathered = x [:, indices ]
61- out = (gathered * weights .unsqueeze (0 )).sum (dim = 2 )
62-
63- if dim == 2 :
64- out = out .reshape (b , c , w , out_size ).permute (0 , 1 , 3 , 2 )
65- else :
66- out = out .reshape (b , c , h , out_size )
67-
68- return out
69-
70- def resize_cubic (img : torch .Tensor , size : tuple ) -> torch .Tensor :
71- """
72- Resize image using OpenCV-equivalent INTER_CUBIC interpolation.
73- Implemented in pure PyTorch
74- """
75-
76- if img .ndim == 3 :
77- img = img .unsqueeze (0 )
78-
79- img = img .permute (0 , 3 , 1 , 2 )
80-
81- out_h , out_w = size
82- img = resize_cubic_1d (img , out_h , dim = 2 )
83- img = resize_cubic_1d (img , out_w , dim = 3 )
84- return img
85-
86- def resize_area (img : torch .Tensor , size : tuple ) -> torch .Tensor :
87- # vectorized implementation for OpenCV's INTER_AREA using pure PyTorch
88- original_shape = img .shape
89- is_hwc = False
90-
91- if img .ndim == 3 :
92- if img .shape [0 ] <= 4 :
93- img = img .unsqueeze (0 )
94- else :
95- is_hwc = True
96- img = img .permute (2 , 0 , 1 ).unsqueeze (0 )
97- elif img .ndim == 4 :
98- pass
99- else :
100- raise ValueError ("Expected image with 3 or 4 dims." )
101-
102- B , C , H , W = img .shape
103- out_h , out_w = size
104- scale_y = H / out_h
105- scale_x = W / out_w
106-
107- device = img .device
108-
109- # compute the grid boundries
110- y_start = torch .arange (out_h , device = device ).float () * scale_y
111- y_end = y_start + scale_y
112- x_start = torch .arange (out_w , device = device ).float () * scale_x
113- x_end = x_start + scale_x
114-
115- # for each output pixel, we will compute the range for it
116- y_start_int = torch .floor (y_start ).long ()
117- y_end_int = torch .ceil (y_end ).long ()
118- x_start_int = torch .floor (x_start ).long ()
119- x_end_int = torch .ceil (x_end ).long ()
120-
121- # We will build the weighted sums by iterating over contributing input pixels once
122- output = torch .zeros ((B , C , out_h , out_w ), dtype = torch .float32 , device = device )
123- area = torch .zeros ((out_h , out_w ), dtype = torch .float32 , device = device )
124-
125- max_kernel_h = int (torch .max (y_end_int - y_start_int ).item ())
126- max_kernel_w = int (torch .max (x_end_int - x_start_int ).item ())
127-
128- for dy in range (max_kernel_h ):
129- for dx in range (max_kernel_w ):
130- # compute the weights for this offset for all output pixels
131-
132- y_idx = y_start_int .unsqueeze (1 ) + dy
133- x_idx = x_start_int .unsqueeze (0 ) + dx
134-
135- # clamp indices to image boundaries
136- y_idx_clamped = torch .clamp (y_idx , 0 , H - 1 )
137- x_idx_clamped = torch .clamp (x_idx , 0 , W - 1 )
138-
139- # compute weights by broadcasting
140- y_weight = (torch .min (y_end .unsqueeze (1 ), y_idx_clamped .float () + 1.0 ) - torch .max (y_start .unsqueeze (1 ), y_idx_clamped .float ())).clamp (min = 0 )
141- x_weight = (torch .min (x_end .unsqueeze (0 ), x_idx_clamped .float () + 1.0 ) - torch .max (x_start .unsqueeze (0 ), x_idx_clamped .float ())).clamp (min = 0 )
142-
143- weight = (y_weight * x_weight )
144-
145- y_expand = y_idx_clamped .expand (out_h , out_w )
146- x_expand = x_idx_clamped .expand (out_h , out_w )
147-
148-
149- pixels = img [:, :, y_expand , x_expand ]
150-
151- # unsqueeze to broadcast
152- w = weight .unsqueeze (0 ).unsqueeze (0 )
153-
154- output += pixels * w
155- area += weight
156-
157- # Normalize by area
158- output /= area .unsqueeze (0 ).unsqueeze (0 )
159-
160- if is_hwc :
161- return output [0 ].permute (1 , 2 , 0 )
162- elif img .shape [0 ] == 1 and original_shape [0 ] <= 4 :
163- return output [0 ]
164- else :
165- return output
166-
167- def recenter (image , border_ratio : float = 0.2 ):
168-
169- if image .shape [- 1 ] == 4 :
170- mask = image [..., 3 ]
171- else :
172- mask = torch .ones_like (image [..., 0 :1 ]) * 255
173- image = torch .concatenate ([image , mask ], axis = - 1 )
174- mask = mask [..., 0 ]
175-
176- H , W , C = image .shape
177-
178- size = max (H , W )
179- result = torch .zeros ((size , size , C ), dtype = torch .uint8 )
180-
181- # as_tuple to match numpy behaviour
182- x_coords , y_coords = torch .nonzero (mask , as_tuple = True )
183-
184- y_min , y_max = y_coords .min (), y_coords .max ()
185- x_min , x_max = x_coords .min (), x_coords .max ()
186-
187- h = x_max - x_min
188- w = y_max - y_min
189-
190- if h == 0 or w == 0 :
191- raise ValueError ('input image is empty' )
192-
193- desired_size = int (size * (1 - border_ratio ))
194- scale = desired_size / max (h , w )
195-
196- h2 = int (h * scale )
197- w2 = int (w * scale )
198-
199- x2_min = (size - h2 ) // 2
200- x2_max = x2_min + h2
201-
202- y2_min = (size - w2 ) // 2
203- y2_max = y2_min + w2
204-
205- # note: opencv takes columns first (opposite to pytorch and numpy that take the row first)
206- result [x2_min :x2_max , y2_min :y2_max ] = resize_area (image [x_min :x_max , y_min :y_max ], (h2 , w2 ))
207-
208- bg = torch .ones ((result .shape [0 ], result .shape [1 ], 3 ), dtype = torch .uint8 ) * 255
209-
210- mask = result [..., 3 :].to (torch .float32 ) / 255
211- result = result [..., :3 ] * mask + bg * (1 - mask )
212-
213- mask = mask * 255
214- result = result .clip (0 , 255 ).to (torch .uint8 )
215- mask = mask .clip (0 , 255 ).to (torch .uint8 )
216-
217- return result
218-
219- def clip_preprocess (image , size = 224 , mean = [0.48145466 , 0.4578275 , 0.40821073 ], std = [0.26862954 , 0.26130258 , 0.27577711 ],
220- crop = True , value_range = (- 1 , 1 ), border_ratio : float = None , recenter_size : int = 512 ):
221-
222- if border_ratio is not None :
223-
224- image = (image * 255 ).clamp (0 , 255 ).to (torch .uint8 )
225- image = [recenter (img , border_ratio = border_ratio ) for img in image ]
226-
227- image = torch .stack (image , dim = 0 )
228- image = resize_cubic (image , size = (recenter_size , recenter_size ))
229-
230- image = image / 255 * 2 - 1
231- low , high = value_range
232-
233- image = (image - low ) / (high - low )
234- image = image .permute (0 , 2 , 3 , 1 )
235-
20+ def clip_preprocess (image , size = 224 , mean = [0.48145466 , 0.4578275 , 0.40821073 ], std = [0.26862954 , 0.26130258 , 0.27577711 ], crop = True ):
23621 image = image [:, :, :, :3 ] if image .shape [3 ] > 3 else image
237-
23822 mean = torch .tensor (mean , device = image .device , dtype = image .dtype )
23923 std = torch .tensor (std , device = image .device , dtype = image .dtype )
240-
24124 image = image .movedim (- 1 , 1 )
24225 if not (image .shape [2 ] == size and image .shape [3 ] == size ):
24326 if crop :
@@ -246,7 +29,7 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
24629 else :
24730 scale_size = (size , size )
24831
249- image = torch .nn .functional .interpolate (image , size = scale_size , mode = "bilinear" if border_ratio is not None else " bicubic" , antialias = True )
32+ image = torch .nn .functional .interpolate (image , size = scale_size , mode = "bicubic" , antialias = True )
25033 h = (image .shape [2 ] - size )// 2
25134 w = (image .shape [3 ] - size )// 2
25235 image = image [:,:,h :h + size ,w :w + size ]
@@ -288,9 +71,9 @@ def load_sd(self, sd):
28871 def get_sd (self ):
28972 return self .model .state_dict ()
29073
291- def encode_image (self , image , crop = True , border_ratio : float = None ):
74+ def encode_image (self , image , crop = True ):
29275 comfy .model_management .load_model_gpu (self .patcher )
293- pixel_values = clip_preprocess (image .to (self .load_device ), size = self .image_size , mean = self .image_mean , std = self .image_std , crop = crop , border_ratio = border_ratio ).float ()
76+ pixel_values = clip_preprocess (image .to (self .load_device ), size = self .image_size , mean = self .image_mean , std = self .image_std , crop = crop ).float ()
29477 out = self .model (pixel_values = pixel_values , intermediate_output = 'all' if self .return_all_hidden_states else - 2 )
29578
29679 outputs = Output ()
0 commit comments