|
| 1 | +program time_MOM_ANN |
| 2 | + |
| 3 | +! This file is part of MOM6. See LICENSE.md for the license. |
| 4 | + |
| 5 | +use MOM_ANN, only : ANN_CS |
| 6 | +use MOM_ANN, only : ANN_allocate, ANN_apply, ANN_end |
| 7 | +use MOM_ANN, only : ANN_apply_vector_orig, ANN_apply_vector_oi |
| 8 | +use MOM_ANN, only : ANN_apply_array_sio |
| 9 | +use MOM_ANN, only : ANN_random |
| 10 | + |
| 11 | +implicit none |
| 12 | + |
| 13 | +! Command line options |
| 14 | +integer :: nargs ! Number of command line arguments |
| 15 | +character(len=12) :: cmd_ln_arg !< Command line argument (if any) |
| 16 | + |
| 17 | +! ANN parameters |
| 18 | +integer :: nlayers ! Number of layers |
| 19 | +integer :: nin ! Number of inputs |
| 20 | +integer :: layer_width ! Width of hidden layers |
| 21 | +integer :: nout ! Number of outputs |
| 22 | +! Timing parameters |
| 23 | +integer :: nsamp ! Number of measurements |
| 24 | +integer :: nits ! Number of calls to time |
| 25 | +integer :: nxy ! Spatial dimension |
| 26 | + |
| 27 | +nlayers = 7; nin = 4; layer_width = 16; nout = 1 ! Deep network |
| 28 | +!nlayers = 4; nin = 4; layer_width = 48; nout = 1 ! Shallow-wide network |
| 29 | +!nlayers = 3; nin = 4; layer_width = 20; nout = 1 ! Small network |
| 30 | + |
| 31 | +nsamp = 100 |
| 32 | +nits = 20000 |
| 33 | +!nits = 300000 ! Needed for robust measurements on small networks |
| 34 | +nxy = 100 ! larger array |
| 35 | +!nxy = 10 ! small array |
| 36 | + |
| 37 | +! Optionally grab ANN and timing parameters from the command line |
| 38 | +nargs = command_argument_count() |
| 39 | +if (nargs==7) then |
| 40 | + call get_command_argument(1, cmd_ln_arg) |
| 41 | + read(cmd_ln_arg,*) nlayers |
| 42 | + call get_command_argument(2, cmd_ln_arg) |
| 43 | + read(cmd_ln_arg,*) nin |
| 44 | + call get_command_argument(3, cmd_ln_arg) |
| 45 | + read(cmd_ln_arg,*) layer_width |
| 46 | + call get_command_argument(4, cmd_ln_arg) |
| 47 | + read(cmd_ln_arg,*) nout |
| 48 | + call get_command_argument(5, cmd_ln_arg) |
| 49 | + read(cmd_ln_arg,*) nsamp |
| 50 | + call get_command_argument(6, cmd_ln_arg) |
| 51 | + read(cmd_ln_arg,*) nits |
| 52 | + call get_command_argument(7, cmd_ln_arg) |
| 53 | + read(cmd_ln_arg,*) nxy |
| 54 | +endif |
| 55 | + |
| 56 | +! Fastest variants on Intel Xeon W-2223 CPU @ 3.60GHz (gfortran-13.2 -O3) |
| 57 | +! | vector(nxy=1) | nxy = 10 | nxy = 100 |
| 58 | +! ---------------------------------------------------------------------------- |
| 59 | +! Small ANN | vector_oi | array_soi | array_sio |
| 60 | +! Shallow-wide ANN | vector_oi | array_ois | array_sio |
| 61 | +! Deep ANN | vector_oi | array_ois | array_sio |
| 62 | + |
| 63 | +write(*,'(a)') "{" |
| 64 | + |
| 65 | +call time_ANN(nlayers, nin, layer_width, nout, nsamp, nits, nxy, & |
| 66 | + 0, "MOM_ANN:ANN_apply(vector)") |
| 67 | +write(*,"(',')") |
| 68 | +call time_ANN(nlayers, nin, layer_width, nout, nsamp, nits, nxy, & |
| 69 | + 1, "MOM_ANN:ANN_apply_vector_orig(array)") |
| 70 | +write(*,"(',')") |
| 71 | +call time_ANN(nlayers, nin, layer_width, nout, nsamp, nits, nxy, & |
| 72 | + 2, "MOM_ANN:ANN_apply_vector_oi(array)") |
| 73 | +write(*,"(',')") |
| 74 | +call time_ANN(nlayers, nin, layer_width, nout, nsamp, nits, nxy, & |
| 75 | + 12, "MOM_ANN:ANN_apply_array_sio(array)") |
| 76 | +write(*,"()") |
| 77 | + |
| 78 | +write(*,'(a)') "}" |
| 79 | + |
| 80 | +contains |
| 81 | + |
| 82 | +!> Time ANN inference. |
| 83 | +!! |
| 84 | +!! Times are measured over the "nits effective calls" and appropriately scaled to the |
| 85 | +!! time per call per single vector of input features. For array inputs, the number of |
| 86 | +!! actual calls is reduced by the size of the array. The timing measurement is repeated |
| 87 | +!! "nsamp" times, to check the statistics of the timing measurement. |
| 88 | +subroutine time_ANN(nlayers, nin, width, nout, nsamp, nits, nxy, impl, label) |
| 89 | + integer, intent(in) :: nlayers !< Number of layers |
| 90 | + integer, intent(in) :: nin !< Number of inputs |
| 91 | + integer, intent(in) :: width !< Width of hidden layers |
| 92 | + integer, intent(in) :: nout !< Number of outputs |
| 93 | + integer, intent(in) :: nsamp !< Number of measurements |
| 94 | + integer, intent(in) :: nits !< Number of calls to time |
| 95 | + integer, intent(in) :: nxy !< Spatial dimension |
| 96 | + integer, intent(in) :: impl !< Implementation to time |
| 97 | + character(len=*), intent(in) :: label !< Label for YAML output |
| 98 | + ! Local variables |
| 99 | + type(ANN_CS) :: ANN ! ANN |
| 100 | + integer :: widths(nlayers) ! Width of each layer |
| 101 | + real :: x_s(nin) ! Inputs (just features) [nondim] |
| 102 | + real :: y_s(nin) ! Outputs (just features) [nondim] |
| 103 | + real :: x_fs(nin,nxy) ! Inputs (feature, space) [nondim] |
| 104 | + real :: y_fs(nin,nxy) ! Outputs (feature, space) [nondim] |
| 105 | + real :: x_sf(nin,nxy) ! Inputs (space, feature) [nondim] |
| 106 | + real :: y_sf(nin,nxy) ! Outputs (space, feature) [nondim] |
| 107 | + integer :: iter, samp ! Loop counters |
| 108 | + integer :: ij ! Horizontal loop index |
| 109 | + real :: start, finish, timing ! CPU times [s] |
| 110 | + real :: tmin, tmax, tmean, tstd ! Min, max, mean, and standard deviation, of CPU times [s] |
| 111 | + integer :: asamp ! Actual samples of timings |
| 112 | + integer :: aits ! Actual iterations |
| 113 | + real :: words_per_sec ! Operations per sec estimated from parameters [# s-1] |
| 114 | + |
| 115 | + widths(:) = width |
| 116 | + widths(1) = nin |
| 117 | + widths(nlayers) = nout |
| 118 | + |
| 119 | + call ANN_random(ANN, nlayers, widths) |
| 120 | + call random_number(x_fs) |
| 121 | + call random_number(x_sf) |
| 122 | + |
| 123 | + |
| 124 | + tmin = 1e9 |
| 125 | + tmax = 0. |
| 126 | + tmean = 0. |
| 127 | + tstd = 0. |
| 128 | + asamp = nits ! Most cases below use this |
| 129 | + aits = nits / nxy ! Most cases below use this |
| 130 | + |
| 131 | + do samp = 1, nsamp |
| 132 | + select case (impl) |
| 133 | + case (0) |
| 134 | + aits = nits |
| 135 | + call cpu_time(start) |
| 136 | + do iter = 1, nits ! Make many passes to reduce sampling error |
| 137 | + call ANN_apply(x_s, y_s, ANN) |
| 138 | + enddo |
| 139 | + call cpu_time(finish) |
| 140 | + case (1) |
| 141 | + call cpu_time(start) |
| 142 | + do iter = 1, aits ! Make many passes to reduce sampling error |
| 143 | + do ij = 1, nxy |
| 144 | + call ANN_apply_vector_orig(x_fs(:,ij), y_fs(:,ij), ANN) |
| 145 | + enddo |
| 146 | + enddo |
| 147 | + call cpu_time(finish) |
| 148 | + case (2) |
| 149 | + call cpu_time(start) |
| 150 | + do iter = 1, aits ! Make many passes to reduce sampling error |
| 151 | + do ij = 1, nxy |
| 152 | + call ANN_apply_vector_oi(x_fs(:,ij), y_fs(:,ij), ANN) |
| 153 | + enddo |
| 154 | + enddo |
| 155 | + call cpu_time(finish) |
| 156 | + case (12) |
| 157 | + call cpu_time(start) |
| 158 | + do iter = 1, aits ! Make many passes to reduce sampling error |
| 159 | + call ANN_apply_array_sio(nxy, x_sf(:,:), y_sf(:,:), ANN) |
| 160 | + enddo |
| 161 | + call cpu_time(finish) |
| 162 | + asamp = nsamp * aits ! Account for working on whole arrays |
| 163 | + end select |
| 164 | + |
| 165 | + timing = ( finish - start ) / real(nits) ! Average time per call |
| 166 | + |
| 167 | + tmin = min( tmin, timing ) |
| 168 | + tmax = max( tmax, timing ) |
| 169 | + tmean = tmean + timing |
| 170 | + tstd = tstd + timing**2 |
| 171 | + enddo |
| 172 | + |
| 173 | + tmean = tmean / real(nsamp) |
| 174 | + tstd = tstd / real(nsamp) ! convert to mean of squares |
| 175 | + tstd = tstd - tmean**2 ! convert to variance |
| 176 | + tstd = sqrt( tstd * real(nsamp) / real(nsamp-1) ) ! convert to standard deviation |
| 177 | + words_per_sec = ANN%parameters / ( tmean * 1024 * 1024 ) |
| 178 | + |
| 179 | + write(*,"(2x,3a)") '"', trim(label), '": {' |
| 180 | + write(*,"(4x,a,1pe11.4,',')") '"min": ', tmin |
| 181 | + write(*,"(4x,a,1pe11.4,',')") '"mean":', tmean |
| 182 | + write(*,"(4x,a,1pe11.4,',')") '"std": ', tstd |
| 183 | + write(*,"(4x,a,i0,',')") '"n_samples": ', asamp |
| 184 | + write(*,"(4x,a,1pe11.4,',')") '"max": ', tmax |
| 185 | + write(*,"(4x,a,1pe11.4,'}')", advance="no") '"MBps": ', words_per_sec |
| 186 | + |
| 187 | +end subroutine time_ANN |
| 188 | + |
| 189 | +end program time_MOM_ANN |
0 commit comments