Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 232 additions & 0 deletions config_src/drivers/timing_tests/time_MOM_ANN.F90
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
program time_MOM_ANN

! This file is part of MOM6. See LICENSE.md for the license.

use MOM_ANN, only : ANN_CS
use MOM_ANN, only : ANN_allocate, ANN_apply, ANN_end
use MOM_ANN, only : ANN_apply_vector_v1, ANN_apply_vector_v2, ANN_apply_vector_v3
use MOM_ANN, only : ANN_apply_vector_v4
use MOM_ANN, only : ANN_apply_array_v1, ANN_apply_array_v2, ANN_apply_array_v3
use MOM_ANN, only : ANN_random

implicit none

! Command line options
integer :: nargs ! Number of command line arguments
character(len=12) :: cmd_ln_arg !< Command line argument (if any)

! ANN parameters
integer :: nlayers ! Number of layers
integer :: nin ! Number of inputs
integer :: layer_width ! Width of hidden layers
integer :: nout ! Number of outputs
! Timing parameters
integer :: nsamp ! Number of measurements
integer :: nits ! Number of calls to time
integer :: nxy ! Spatial dimension

nlayers = 7; nin = 4; layer_width = 16; nout = 1 ! Deep network
!nlayers = 4; nin = 4; layer_width = 48; nout = 1 ! Shallow-wide network
!nlayers = 3; nin = 4; layer_width = 20; nout = 1 ! Small network

nsamp = 100
nits = 20000
!nits = 300000 ! Needed for robust measurements on small networks
nxy = 100 ! larger array
!nxy = 10 ! small array

! Optionally grab ANN and timing parameters from the command line
nargs = command_argument_count()
if (nargs==7) then
call get_command_argument(1, cmd_ln_arg)
read(cmd_ln_arg,*) nlayers
call get_command_argument(2, cmd_ln_arg)
read(cmd_ln_arg,*) nin
call get_command_argument(3, cmd_ln_arg)
read(cmd_ln_arg,*) layer_width
call get_command_argument(4, cmd_ln_arg)
read(cmd_ln_arg,*) nout
call get_command_argument(5, cmd_ln_arg)
read(cmd_ln_arg,*) nsamp
call get_command_argument(6, cmd_ln_arg)
read(cmd_ln_arg,*) nits
call get_command_argument(7, cmd_ln_arg)
read(cmd_ln_arg,*) nxy
endif

! Fastest variants on Intel Xeon W-2223 CPU @ 3.60GHz (gfortran-13.2 -O3)
! | vector(nxy=1) | nxy = 10 | nxy = 100
! ----------------------------------------------------------------------------
! Small ANN | vector_v2 | array_v1 | array_v2
! Shallow-wide ANN | vector_v2 | array_v3 | array_v2
! Deep ANN | vector_v2 | array_v3 | array_v2

write(*,'(a)') "{"

call time_ANN(nlayers, nin, layer_width, nout, nsamp, nits, nxy, &
0, "MOM_ANN:ANN_apply(vector)")
write(*,"(',')")
call time_ANN(nlayers, nin, layer_width, nout, nsamp, nits, nxy, &
1, "MOM_ANN:ANN_apply_vector_v1(array)")
write(*,"(',')")
call time_ANN(nlayers, nin, layer_width, nout, nsamp, nits, nxy, &
2, "MOM_ANN:ANN_apply_vector_v2(array)")
write(*,"(',')")
call time_ANN(nlayers, nin, layer_width, nout, nsamp, nits, nxy, &
3, "MOM_ANN:ANN_apply_vector_v3(array)")
write(*,"(',')")
call time_ANN(nlayers, nin, layer_width, nout, nsamp, nits, nxy, &
4, "MOM_ANN:ANN_apply_vector_v4(array)")
write(*,"(',')")
call time_ANN(nlayers, nin, layer_width, nout, nsamp, nits, nxy, &
11, "MOM_ANN:ANN_apply_array_v1(array)")
write(*,"(',')")
call time_ANN(nlayers, nin, layer_width, nout, nsamp, nits, nxy, &
12, "MOM_ANN:ANN_apply_array_v2(array)")
write(*,"(',')")
call time_ANN(nlayers, nin, layer_width, nout, nsamp, nits, nxy, &
13, "MOM_ANN:ANN_apply_array_v3(array)")
write(*,"()")

write(*,'(a)') "}"

contains

!> Time ANN inference.
!!
!! Times are measured over the "nits effective calls" and appropriately scaled to the
!! time per call per single vector of input features. For array inputs, the number of
!! actual calls is reduced by the size of the array. The timing measurement is repeated
!! "nsamp" times, to check the statistics of the timing measurement.
subroutine time_ANN(nlayers, nin, width, nout, nsamp, nits, nxy, impl, label)
integer, intent(in) :: nlayers !< Number of layers
integer, intent(in) :: nin !< Number of inputs
integer, intent(in) :: width !< Width of hidden layers
integer, intent(in) :: nout !< Number of outputs
integer, intent(in) :: nsamp !< Number of measurements
integer, intent(in) :: nits !< Number of calls to time
integer, intent(in) :: nxy !< Spatial dimension
integer, intent(in) :: impl !< Implementation to time
character(len=*), intent(in) :: label !< Label for YAML output
! Local variables
type(ANN_CS) :: ANN ! ANN
integer :: widths(nlayers) ! Width of each layer
real :: x_s(nin) ! Inputs (just features) [nondim]
real :: y_s(nin) ! Outputs (just features) [nondim]
real :: x_fs(nin,nxy) ! Inputs (feature, space) [nondim]
real :: y_fs(nin,nxy) ! Outputs (feature, space) [nondim]
real :: x_sf(nin,nxy) ! Inputs (space, feature) [nondim]
real :: y_sf(nin,nxy) ! Outputs (space, feature) [nondim]
integer :: iter, samp ! Loop counters
integer :: ij ! Horizontal loop index
real :: start, finish, timing ! CPU times [s]
real :: tmin, tmax, tmean, tstd ! Min, max, mean, and standard deviation, of CPU times [s]
integer :: asamp ! Actual samples of timings
integer :: aits ! Actual iterations
real :: flops ! Operations per sec estimated from parameters [# s-1]

widths(:) = width
widths(1) = nin
widths(nlayers) = nout

call ANN_random(ANN, nlayers, widths)
call random_number(x_fs)
call random_number(x_sf)


tmin = 1e9
tmax = 0.
tmean = 0.
tstd = 0.
asamp = nits ! Most cases below use this
aits = nits / nxy ! Most cases below use this

do samp = 1, nsamp
select case (impl)
case (0)
aits = nits
call cpu_time(start)
do iter = 1, nits ! Make many passes to reduce sampling error
call ANN_apply(x_s, y_s, ANN)
enddo
call cpu_time(finish)
case (1)
call cpu_time(start)
do iter = 1, aits ! Make many passes to reduce sampling error
do ij = 1, nxy
call ANN_apply_vector_v1(x_fs(:,ij), y_fs(:,ij), ANN)
enddo
enddo
call cpu_time(finish)
case (2)
call cpu_time(start)
do iter = 1, aits ! Make many passes to reduce sampling error
do ij = 1, nxy
call ANN_apply_vector_v2(x_fs(:,ij), y_fs(:,ij), ANN)
enddo
enddo
call cpu_time(finish)
case (3)
call cpu_time(start)
do iter = 1, aits ! Make many passes to reduce sampling error
do ij = 1, nxy
call ANN_apply_vector_v3(x_fs(:,ij), y_fs(:,ij), ANN)
enddo
enddo
call cpu_time(finish)
case (4)
call cpu_time(start)
do iter = 1, aits ! Make many passes to reduce sampling error
do ij = 1, nxy
call ANN_apply_vector_v4(x_fs(:,ij), y_fs(:,ij), ANN)
enddo
enddo
call cpu_time(finish)
case (11)
call cpu_time(start)
do iter = 1, aits ! Make many passes to reduce sampling error
call ANN_apply_array_v1(nxy, x_sf(:,:), y_sf(:,:), ANN)
enddo
call cpu_time(finish)
asamp = nsamp * aits ! Account for working on whole arrays
case (12)
call cpu_time(start)
do iter = 1, aits ! Make many passes to reduce sampling error
call ANN_apply_array_v2(nxy, x_sf(:,:), y_sf(:,:), ANN)
enddo
call cpu_time(finish)
asamp = nsamp * aits ! Account for working on whole arrays
case (13)
call cpu_time(start)
do iter = 1, aits ! Make many passes to reduce sampling error
call ANN_apply_array_v3(nxy, x_fs(:,:), y_fs(:,:), ANN)
enddo
call cpu_time(finish)
asamp = nsamp * aits ! Account for working on whole arrays
end select

timing = ( finish - start ) / real(nits) ! Average time per call

tmin = min( tmin, timing )
tmax = max( tmax, timing )
tmean = tmean + timing
tstd = tstd + timing**2
enddo

tmean = tmean / real(nsamp)
tstd = tstd / real(nsamp) ! convert to mean of squares
tstd = tstd - tmean**2 ! convert to variance
tstd = sqrt( tstd * real(nsamp) / real(nsamp-1) ) ! convert to standard deviation
flops = ANN%parameters / tmean

write(*,"(2x,3a)") '"', trim(label), '": {'
write(*,"(4x,a,1pe11.4,',')") '"min": ', tmin
write(*,"(4x,a,1pe11.4,',')") '"mean":', tmean
write(*,"(4x,a,1pe11.4,',')") '"std": ', tstd
write(*,"(4x,a,i0,',')") '"n_samples": ', asamp
write(*,"(4x,a,1pe11.4,',')") '"max": ', tmax
write(*,"(4x,a,1pe11.4,'}')", advance="no") '"flops": ', flops

end subroutine time_ANN

end program time_MOM_ANN
10 changes: 10 additions & 0 deletions config_src/drivers/unit_tests/test_MOM_ANN.F90
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
program test_MOM_ANN

use MOM_ANN, only : ANN_unit_tests
use MOM_error_handler, only : set_skip_mpi

call set_skip_mpi(.true.) ! This unit tests is not expecting MPI to be used

if ( ANN_unit_tests(.true.) ) stop 1

end program test_MOM_ANN
Loading