Skip to content

Commit 8f7adc1

Browse files
authored
Merge pull request #3 from adcroft/tmp-pavel-ann
Re-factor of MOM_ANN
2 parents bf3d44f + da17218 commit 8f7adc1

File tree

4 files changed

+933
-238
lines changed

4 files changed

+933
-238
lines changed
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
program time_MOM_ANN
2+
3+
! This file is part of MOM6. See LICENSE.md for the license.
4+
5+
use MOM_ANN, only : ANN_CS
6+
use MOM_ANN, only : ANN_allocate, ANN_apply, ANN_end
7+
use MOM_ANN, only : ANN_apply_vector_orig, ANN_apply_vector_oi
8+
use MOM_ANN, only : ANN_apply_array_sio
9+
use MOM_ANN, only : ANN_random
10+
11+
implicit none
12+
13+
! Command line options
14+
integer :: nargs ! Number of command line arguments
15+
character(len=12) :: cmd_ln_arg !< Command line argument (if any)
16+
17+
! ANN parameters
18+
integer :: nlayers ! Number of layers
19+
integer :: nin ! Number of inputs
20+
integer :: layer_width ! Width of hidden layers
21+
integer :: nout ! Number of outputs
22+
! Timing parameters
23+
integer :: nsamp ! Number of measurements
24+
integer :: nits ! Number of calls to time
25+
integer :: nxy ! Spatial dimension
26+
27+
nlayers = 7; nin = 4; layer_width = 16; nout = 1 ! Deep network
28+
!nlayers = 4; nin = 4; layer_width = 48; nout = 1 ! Shallow-wide network
29+
!nlayers = 3; nin = 4; layer_width = 20; nout = 1 ! Small network
30+
31+
nsamp = 100
32+
nits = 20000
33+
!nits = 300000 ! Needed for robust measurements on small networks
34+
nxy = 100 ! larger array
35+
!nxy = 10 ! small array
36+
37+
! Optionally grab ANN and timing parameters from the command line
38+
nargs = command_argument_count()
39+
if (nargs==7) then
40+
call get_command_argument(1, cmd_ln_arg)
41+
read(cmd_ln_arg,*) nlayers
42+
call get_command_argument(2, cmd_ln_arg)
43+
read(cmd_ln_arg,*) nin
44+
call get_command_argument(3, cmd_ln_arg)
45+
read(cmd_ln_arg,*) layer_width
46+
call get_command_argument(4, cmd_ln_arg)
47+
read(cmd_ln_arg,*) nout
48+
call get_command_argument(5, cmd_ln_arg)
49+
read(cmd_ln_arg,*) nsamp
50+
call get_command_argument(6, cmd_ln_arg)
51+
read(cmd_ln_arg,*) nits
52+
call get_command_argument(7, cmd_ln_arg)
53+
read(cmd_ln_arg,*) nxy
54+
endif
55+
56+
! Fastest variants on Intel Xeon W-2223 CPU @ 3.60GHz (gfortran-13.2 -O3)
57+
! | vector(nxy=1) | nxy = 10 | nxy = 100
58+
! ----------------------------------------------------------------------------
59+
! Small ANN | vector_oi | array_soi | array_sio
60+
! Shallow-wide ANN | vector_oi | array_ois | array_sio
61+
! Deep ANN | vector_oi | array_ois | array_sio
62+
63+
write(*,'(a)') "{"
64+
65+
call time_ANN(nlayers, nin, layer_width, nout, nsamp, nits, nxy, &
66+
0, "MOM_ANN:ANN_apply(vector)")
67+
write(*,"(',')")
68+
call time_ANN(nlayers, nin, layer_width, nout, nsamp, nits, nxy, &
69+
1, "MOM_ANN:ANN_apply_vector_orig(array)")
70+
write(*,"(',')")
71+
call time_ANN(nlayers, nin, layer_width, nout, nsamp, nits, nxy, &
72+
2, "MOM_ANN:ANN_apply_vector_oi(array)")
73+
write(*,"(',')")
74+
call time_ANN(nlayers, nin, layer_width, nout, nsamp, nits, nxy, &
75+
12, "MOM_ANN:ANN_apply_array_sio(array)")
76+
write(*,"()")
77+
78+
write(*,'(a)') "}"
79+
80+
contains
81+
82+
!> Time ANN inference.
83+
!!
84+
!! Times are measured over the "nits effective calls" and appropriately scaled to the
85+
!! time per call per single vector of input features. For array inputs, the number of
86+
!! actual calls is reduced by the size of the array. The timing measurement is repeated
87+
!! "nsamp" times, to check the statistics of the timing measurement.
88+
subroutine time_ANN(nlayers, nin, width, nout, nsamp, nits, nxy, impl, label)
89+
integer, intent(in) :: nlayers !< Number of layers
90+
integer, intent(in) :: nin !< Number of inputs
91+
integer, intent(in) :: width !< Width of hidden layers
92+
integer, intent(in) :: nout !< Number of outputs
93+
integer, intent(in) :: nsamp !< Number of measurements
94+
integer, intent(in) :: nits !< Number of calls to time
95+
integer, intent(in) :: nxy !< Spatial dimension
96+
integer, intent(in) :: impl !< Implementation to time
97+
character(len=*), intent(in) :: label !< Label for YAML output
98+
! Local variables
99+
type(ANN_CS) :: ANN ! ANN
100+
integer :: widths(nlayers) ! Width of each layer
101+
real :: x_s(nin) ! Inputs (just features) [nondim]
102+
real :: y_s(nin) ! Outputs (just features) [nondim]
103+
real :: x_fs(nin,nxy) ! Inputs (feature, space) [nondim]
104+
real :: y_fs(nin,nxy) ! Outputs (feature, space) [nondim]
105+
real :: x_sf(nin,nxy) ! Inputs (space, feature) [nondim]
106+
real :: y_sf(nin,nxy) ! Outputs (space, feature) [nondim]
107+
integer :: iter, samp ! Loop counters
108+
integer :: ij ! Horizontal loop index
109+
real :: start, finish, timing ! CPU times [s]
110+
real :: tmin, tmax, tmean, tstd ! Min, max, mean, and standard deviation, of CPU times [s]
111+
integer :: asamp ! Actual samples of timings
112+
integer :: aits ! Actual iterations
113+
real :: words_per_sec ! Operations per sec estimated from parameters [# s-1]
114+
115+
widths(:) = width
116+
widths(1) = nin
117+
widths(nlayers) = nout
118+
119+
call ANN_random(ANN, nlayers, widths)
120+
call random_number(x_fs)
121+
call random_number(x_sf)
122+
123+
124+
tmin = 1e9
125+
tmax = 0.
126+
tmean = 0.
127+
tstd = 0.
128+
asamp = nits ! Most cases below use this
129+
aits = nits / nxy ! Most cases below use this
130+
131+
do samp = 1, nsamp
132+
select case (impl)
133+
case (0)
134+
aits = nits
135+
call cpu_time(start)
136+
do iter = 1, nits ! Make many passes to reduce sampling error
137+
call ANN_apply(x_s, y_s, ANN)
138+
enddo
139+
call cpu_time(finish)
140+
case (1)
141+
call cpu_time(start)
142+
do iter = 1, aits ! Make many passes to reduce sampling error
143+
do ij = 1, nxy
144+
call ANN_apply_vector_orig(x_fs(:,ij), y_fs(:,ij), ANN)
145+
enddo
146+
enddo
147+
call cpu_time(finish)
148+
case (2)
149+
call cpu_time(start)
150+
do iter = 1, aits ! Make many passes to reduce sampling error
151+
do ij = 1, nxy
152+
call ANN_apply_vector_oi(x_fs(:,ij), y_fs(:,ij), ANN)
153+
enddo
154+
enddo
155+
call cpu_time(finish)
156+
case (12)
157+
call cpu_time(start)
158+
do iter = 1, aits ! Make many passes to reduce sampling error
159+
call ANN_apply_array_sio(nxy, x_sf(:,:), y_sf(:,:), ANN)
160+
enddo
161+
call cpu_time(finish)
162+
asamp = nsamp * aits ! Account for working on whole arrays
163+
end select
164+
165+
timing = ( finish - start ) / real(nits) ! Average time per call
166+
167+
tmin = min( tmin, timing )
168+
tmax = max( tmax, timing )
169+
tmean = tmean + timing
170+
tstd = tstd + timing**2
171+
enddo
172+
173+
tmean = tmean / real(nsamp)
174+
tstd = tstd / real(nsamp) ! convert to mean of squares
175+
tstd = tstd - tmean**2 ! convert to variance
176+
tstd = sqrt( tstd * real(nsamp) / real(nsamp-1) ) ! convert to standard deviation
177+
words_per_sec = ANN%parameters / ( tmean * 1024 * 1024 )
178+
179+
write(*,"(2x,3a)") '"', trim(label), '": {'
180+
write(*,"(4x,a,1pe11.4,',')") '"min": ', tmin
181+
write(*,"(4x,a,1pe11.4,',')") '"mean":', tmean
182+
write(*,"(4x,a,1pe11.4,',')") '"std": ', tstd
183+
write(*,"(4x,a,i0,',')") '"n_samples": ', asamp
184+
write(*,"(4x,a,1pe11.4,',')") '"max": ', tmax
185+
write(*,"(4x,a,1pe11.4,'}')", advance="no") '"MBps": ', words_per_sec
186+
187+
end subroutine time_ANN
188+
189+
end program time_MOM_ANN
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
program test_MOM_ANN
2+
3+
use MOM_ANN, only : ANN_unit_tests
4+
use MOM_error_handler, only : set_skip_mpi
5+
6+
call set_skip_mpi(.true.) ! This unit tests is not expecting MPI to be used
7+
8+
if ( ANN_unit_tests(.true.) ) stop 1
9+
10+
end program test_MOM_ANN

0 commit comments

Comments
 (0)