-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathSelectedEnergyRatio.py
executable file
·78 lines (71 loc) · 2.84 KB
/
SelectedEnergyRatio.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# -*- coding: utf-8 -*-
"""
Created on Tue Oct 27 11:57:23 2020
@author: Rui LIN
"""
import torch
import math
def torch_fftshift(real, imag):
'''
Input:
- real: a matrix of size [h, w], which is the real number part of the feature map slice in frequency domain.
- imag: a matrix of size [h, w], which is the imaginary number part of the feature map slice in frequency domain.
Output:
- real: a matrix of size [h, w], which is the real number part of the feature map slice in frequency domain
after shift operation.
- imag: a matrix of size [h, w], which is the imaginary number part of the feature map slice in frequency domain
after shift operation.
'''
for dim in range(0, len(real.size())):
real = torch.roll(real, dims=dim, shifts=real.size(dim)//2)
imag = torch.roll(imag, dims=dim, shifts=imag.size(dim)//2)
return real, imag
def StepDecision(h, w, alpha):
'''
Input:
- h: a scalar, which is the height of the given feature map slice.
- w: a scalar, which is the width of the given feature map slice.
- alpha: a scalar between 0 and 1, which determines the size of selcted area.
Output:
- step: a scalar, which decides the selected area of the given feature map in frequency domain.
'''
if h % 2 == 0 and w % 2 == 0:
xc = h / 2
yc = w / 2
else:
xc = (h - 1) / 2
yc = (w - 1) / 2
max_h = h - (xc + 1)
max_w = w - (yc + 1)
if xc - 1 == 0 or yc - 1 == 0:
step = 0
else:
step = min(int(math.ceil(max_h * alpha)),int(math.ceil(max_w * alpha)))
return step
def EnergyRatio(fm_slice, alpha=1/4):
'''
Input:
- fm_slice: a matrix of size [h, w], which is a slice of a given feature map in spatial domain.
Output:
- ratio: a scalar, which is the ratio of the energy of the unselected area of the feature map
and the total energy of the feature map (both in frequency domain).
'''
FFT_fm_slice = torch.rfft(fm_slice, signal_ndim=2, onesided=False)
shift_real, shift_imag = torch_fftshift(FFT_fm_slice[:,:,0], FFT_fm_slice[:,:,1])
FFTshift_fm_slice = (shift_real**2 + shift_imag**2)**(1/2)
FFTshift_fm_slice = torch.log(FFTshift_fm_slice+1)
h, w = FFTshift_fm_slice.shape
step = StepDecision(h, w, alpha)
if h % 2 == 0 and w % 2 == 0:
xc = h / 2
yc = w / 2
else:
xc = (h - 1) / 2
yc = (w - 1) / 2
E = sum(sum(FFTshift_fm_slice))
select_FFTshift_fm_slice = FFTshift_fm_slice[int(xc-step):int(xc+step+1), int(yc-step):int(yc+step+1)]
select_E = sum(sum(select_FFTshift_fm_slice))
ratio = 1 - select_E / E
if ratio != ratio:
ratio = torch.zeros(1)
return ratio