Skip to content
Snippets Groups Projects
diahsb_gpu.h90 32.7 KiB
Newer Older
Guillaume Samson's avatar
Guillaume Samson committed
MODULE diahsb
   !!======================================================================
   !!                       ***  MODULE  diahsb  ***
   !! Ocean diagnostics: Heat, salt and volume budgets
   !!======================================================================
   !! History :  3.3  ! 2010-09  (M. Leclair)  Original code
   !!                 ! 2012-10  (C. Rousset)  add iom_put
   !!----------------------------------------------------------------------

   !!----------------------------------------------------------------------
   !!   dia_hsb       : Diagnose the conservation of ocean heat and salt contents, and volume
   !!   dia_hsb_rst   : Read or write DIA file in restart file
   !!   dia_hsb_init  : Initialization of the conservation diagnostic
   !!----------------------------------------------------------------------
   USE oce            ! ocean dynamics and tracers
   USE dom_oce        ! ocean space and time domain
   USE phycst         ! physical constants
   USE sbc_oce        ! surface thermohaline fluxes
   USE isf_oce        ! ice shelf fluxes
   USE sbcrnf         ! river runoff
   USE domvvl         ! vertical scale factors
   USE traqsr         ! penetrative solar radiation
   USE trabbc         ! bottom boundary condition
   USE trabbc         ! bottom boundary condition
   USE restart        ! ocean restart
   USE bdy_oce , ONLY : ln_bdy
   !
   USE iom            ! I/O manager
   USE in_out_manager ! I/O manager
   USE gpu_manager    ! GPU manager
   USE cudafor        ! CUDA toolkit libs
   USE cuda_fortran   ! CUDA routines
   !USE nvtx          ! CUDA profiling/DEGUG tools
   USE lib_fortran    ! glob_sum
   USE lib_mpp        ! distributed memory computing library
   USE timing         ! preformance summary

   IMPLICIT NONE
   PRIVATE

   PUBLIC   dia_hsb        ! routine called by step.F90
   PUBLIC   dia_hsb_init   ! routine called by nemogcm.F90

   LOGICAL, PUBLIC ::   ln_diahsb   !: check the heat and salt budgets

   REAL(wp)                      :: surf_tot ! ocean surface
   REAL(wp) , DIMENSION(2), SAVE :: frc_t, frc_s, frc_v ! global forcing trends
   REAL(wp) , DIMENSION(2), SAVE :: frc_wn_t, frc_wn_s ! global forcing trends
   !
   REAL(wp), DIMENSION(:,:)  , ALLOCATABLE ::   surf
   REAL(wp), DIMENSION(:,:)  , ALLOCATABLE ::   surf_ini      , ssh_ini          !
   REAL(wp), DIMENSION(:,:)  , ALLOCATABLE ::   ssh_hc_loc_ini, ssh_sc_loc_ini   !
   REAL(wp), DIMENSION(:,:,:), ALLOCATABLE, PINNED :: hc_loc_ini, sc_loc_ini !
   REAL(wp), DIMENSION(:,:,:), ALLOCATABLE :: e3t_ini !
   REAL(wp), DIMENSION(:) , ALLOCATABLE, PINNED, SAVE :: h_ztmpv, h_ztmph, h_ztmps, h_ztmp !
   REAL(wp), DIMENSION(:,:,:), ALLOCATABLE :: tmask_ini


   !Device data associate to PUBLIC arrays
   REAL(8), DIMENSION(:,:,:,:) , ALLOCATABLE, DEVICE :: d_e3t !
   REAL(8), DIMENSION(:,:,:) , ALLOCATABLE, DEVICE :: d_tmask !
   REAL(8), DIMENSION(:,:) , ALLOCATABLE, DEVICE :: d_tmask_i !
Guillaume Samson's avatar
Guillaume Samson committed
   REAL(8), DIMENSION(:,:,:) , ALLOCATABLE, DEVICE :: d_tmask_ini !
   REAL(8), DIMENSION(:,:,:,:,:), ALLOCATABLE, DEVICE :: d_ts !
   !Device data associate to LOCAL/DEVICE arrays
   REAL(8), DEVICE , DIMENSION(:,:) , ALLOCATABLE :: d_surf !
   REAL(8), DEVICE , DIMENSION(:,:) , ALLOCATABLE :: d_surf_ini !
   REAL(8), DEVICE , DIMENSION(:,:,:) , ALLOCATABLE :: d_hc_loc_ini !
   REAL(8), DEVICE , DIMENSION(:,:,:) , ALLOCATABLE :: d_sc_loc_ini !
   REAL(8), DEVICE , DIMENSION(:,:,:) , ALLOCATABLE :: d_e3t_ini !
   REAL(8), DEVICE , DIMENSION(:,:,:) , ALLOCATABLE :: d_zwrkv, d_zwrkh, d_zwrks, d_zwrk ! 3D GPU workspace
   REAL(8), DEVICE :: ztmpv, ztmph, ztmps, ztmp ! Device Reduction
   !
   INTEGER :: globsize ! 3D workspace size
   type(dim3) :: dimGrid, dimBlock ! cuda parameters
   INTEGER, parameter :: nstreams = 3 ! Streams Number
   INTEGER(kind=cuda_stream_kind) :: stream(nstreams), str ! Stream ID
   !DEBUG
   !REAL(8) , save , DIMENSION(:,:,:) , ALLOCATABLE :: prev_3d
   !REAL(8) :: accum




   !! * Substitutions
#  include "domzgr_substitute.h90"
   !!----------------------------------------------------------------------
   !! NEMO/OCE 4.0 , NEMO Consortium (2018)
   !! $Id$
   !! Software governed by the CeCILL license (see ./LICENSE)
   !!----------------------------------------------------------------------
CONTAINS

   SUBROUTINE dia_hsb( kt, Kbb, Kmm )
      !!---------------------------------------------------------------------------
      !!                  ***  ROUTINE dia_hsb  ***
      !!
      !! ** Purpose: Compute the ocean global heat content, salt content and volume conservation
      !!
      !! ** Method : - Compute the deviation of heat content, salt content and volume
      !!	            at the current time step from their values at nit000
      !!	            - Compute the contribution of forcing and remove it from these deviations
      !!
      !!---------------------------------------------------------------------------
      INTEGER, INTENT(in) ::   kt         ! ocean time-step index
      INTEGER, INTENT(in) ::   Kbb, Kmm   ! ocean time level indices
      !
      INTEGER, VALUE                 :: ji, jj, jk, kts ! dummy loop indice
      INTEGER, VALUE                 :: localsize ! jpi * jpj * jpk
      INTEGER                        :: istat ! CUDA error check
      COMPLEX                        :: ctmp ! dummy complex number
      INTEGER(kind=cuda_stream_kind) :: str ! dummy kernel stream
      INTEGER                        :: tile_n, tile_b ! tile indexe. _n now, _b before
      REAL(wp) , DIMENSION(2), SAVE  :: zdiff_hc1, zdiff_sc1 ! heat and salt content variations
      REAL(wp) , DIMENSION(2), SAVE  :: zdiff_hc, zdiff_sc ! - - - -
      REAL(wp) , DIMENSION(2), SAVE  :: zdiff_v2 ! volume variation
      REAL(wp) , DIMENSION(2), SAVE  :: zdiff_v1 ! volume variation
      REAL(wp) , DIMENSION(2), SAVE  :: zerr_hc1, zerr_sc1 ! heat and salt content misfit
      REAL(wp) , DIMENSION(2), SAVE  :: zvol_tot ! volume
      REAL(wp) , DIMENSION(2), SAVE  :: z_frc_trd_t, z_frc_trd_s ! - -
      REAL(wp) , DIMENSION(2), SAVE  :: z_frc_trd_v ! - -
      REAL(wp) , DIMENSION(2), SAVE  :: z_wn_trd_t, z_wn_trd_s ! - -
      REAL(wp) , DIMENSION(2), SAVE  :: z_ssh_hc, z_ssh_sc ! - -
# 147 "diahsb_new.F90"
      REAL(wp), DIMENSION(jpi,jpj) :: z2d0, z2d1 ! 2D workspace
      REAL(wp), DIMENSION(jpi,jpj,jpkm1) :: zwrk ! 3D workspace
      !!---------------------------------------------------------------------------
      IF( ln_timing )   CALL timing_start('dia_hsb')

      localsize = jpi * jpj * jpk
      kts = kt
      IF (kts == 1) THEN
          tile_n = 1
          tile_b = 1
      ELSE
          IF( MOD(kts,2) == 0) THEN
              tile_n = 2
              tile_b = 1
          ELSE IF( MOD(kts,2) == 1 ) THEN
              tile_n = 1
              tile_b = 2
          END IF
      END IF

      !
      ts(:,:,:,1,Kmm) = ts(:,:,:,1,Kmm) * tmask(:,:,:) ; ts(:,:,:,1,Kbb) = ts(:,:,:,1,Kbb) * tmask(:,:,:) ;
      ts(:,:,:,2,Kmm) = ts(:,:,:,2,Kmm) * tmask(:,:,:) ; ts(:,:,:,2,Kbb) = ts(:,:,:,2,Kbb) * tmask(:,:,:) ;
      ! ------------------------- !
      ! 1 - Trends due to forcing !
      ! ------------------------- !
      z_frc_trd_v(tile_n) = r1_rho0 * glob_sum( 'diahsb', - ( emp(:,:) - rnf(:,:) + fwfisf_cav(:,:) + fwfisf_par(:,:) ) * surf(:,:) )! volume fluxes
      z_frc_trd_t(tile_n) = glob_sum( 'diahsb', sbc_tsc(:,:,jp_tem) * surf(:,:) ) ! heat fluxes
      z_frc_trd_s(tile_n) = glob_sum( 'diahsb', sbc_tsc(:,:,jp_sal) * surf(:,:) ) ! salt fluxes
      !                    !  Add runoff    heat & salt input
      IF( ln_rnf    )   z_frc_trd_t(tile_n) = z_frc_trd_t(tile_n) + glob_sum( 'diahsb', rnf_tsc(:,:,jp_tem) * surf(:,:) )
      IF( ln_rnf_sal)   z_frc_trd_s(tile_n) = z_frc_trd_s(tile_n) + glob_sum( 'diahsb', rnf_tsc(:,:,jp_sal) * surf(:,:) ) ! Add ice shelf heat & salt input
      !                    ! Add ice shelf heat & salt input
      IF( ln_isf    )   z_frc_trd_t(tile_n) = z_frc_trd_t(tile_n) &
         &                          + glob_sum( 'diahsb', ( risf_cav_tsc(:,:,jp_tem) + risf_par_tsc(:,:,jp_tem) ) * surf(:,:) ) ! Add penetrative solar radiation
      !                    ! Add penetrative solar radiation
      IF( ln_traqsr )   z_frc_trd_t(tile_n) = z_frc_trd_t(tile_n) + r1_rho0_rcp * glob_sum( 'diahsb', qsr (:,:) * surf(:,:) ) ! Add geothermal heat flux
      !                    ! Add geothermal heat flux
      IF( ln_trabbc )   z_frc_trd_t(tile_n) = z_frc_trd_t(tile_n) + glob_sum( 'diahsb', qgh_trd0(:,:) * surf(:,:) )
      !
      IF( ln_linssh ) THEN
         IF( ln_isfcav ) THEN
            DO ji=1,jpi
               DO jj=1,jpj
                  z2d0(ji,jj) = surf(ji,jj) * ww(ji,jj,mikt(ji,jj)) * ts(ji,jj,mikt(ji,jj),jp_tem,Kbb)
                  z2d1(ji,jj) = surf(ji,jj) * ww(ji,jj,mikt(ji,jj)) * ts(ji,jj,mikt(ji,jj),jp_sal,Kbb)
               END DO
            END DO
         ELSE
            z2d0(:,:) = surf(:,:) * ww(:,:,1) * ts(:,:,1,jp_tem,Kbb)
            z2d1(:,:) = surf(:,:) * ww(:,:,1) * ts(:,:,1,jp_sal,Kbb)
         END IF
         z_wn_trd_t(tile_n) = - glob_sum( 'diahsb', z2d0 )
         z_wn_trd_s(tile_n) = - glob_sum( 'diahsb', z2d1 )
      ENDIF

      IF (kts>1) THEN
         frc_v(tile_n) = frc_v(tile_b)
         frc_t(tile_n) = frc_t(tile_b)
         frc_s(tile_n) = frc_s(tile_b)
         frc_wn_t(tile_n) = frc_wn_t(tile_b)
         frc_wn_s(tile_n) = frc_wn_s(tile_b)
      END IF
      frc_v(tile_n) = frc_v(tile_n) + z_frc_trd_v(tile_n) * rn_Dt
      frc_t(tile_n) = frc_t(tile_n) + z_frc_trd_t(tile_n) * rn_Dt
      frc_s(tile_n) = frc_s(tile_n) + z_frc_trd_s(tile_n) * rn_Dt
      ! ! Advection flux through fixed surface (z=0)
      IF( ln_linssh ) THEN
         frc_wn_t(tile_n) = frc_wn_t(tile_n) + z_wn_trd_t(tile_n) * rn_Dt
         frc_wn_s(tile_n) = frc_wn_s(tile_n) + z_wn_trd_s(tile_n) * rn_Dt
      ENDIF

      ! ------------------------ !
      ! 2 -  Content variations  !
      ! ------------------------ !
      ! glob_sum_full is needed because you keep the full interior domain to compute the sum (iscpl)

      !                    ! volume variation (calculated with ssh)

      zdiff_v1(tile_n) = glob_sum_full( 'diahsb', surf(:,:)*ssh(:,:,Kmm) - surf_ini(:,:)*ssh_ini(:,:) )

      !                    ! heat & salt content variation (associated with ssh)
      IF( ln_linssh ) THEN       ! linear free surface case
         IF( ln_isfcav ) THEN          ! ISF case
            DO ji = 1, jpi
               DO jj = 1, jpj
                  z2d0(ji,jj) = surf(ji,jj) * ( ts(ji,jj,mikt(ji,jj),jp_tem,Kmm) * ssh(ji,jj,Kmm) - ssh_hc_loc_ini(ji,jj) )
                  z2d1(ji,jj) = surf(ji,jj) * ( ts(ji,jj,mikt(ji,jj),jp_sal,Kmm) * ssh(ji,jj,Kmm) - ssh_sc_loc_ini(ji,jj) )
               END DO
            END DO
         ELSE                          ! no under ice-shelf seas
            z2d0(:,:) = surf(:,:) * ( ts(:,:,1,jp_tem,Kmm) * ssh(:,:,Kmm) - ssh_hc_loc_ini(:,:) )
            z2d1(:,:) = surf(:,:) * ( ts(:,:,1,jp_sal,Kmm) * ssh(:,:,Kmm) - ssh_sc_loc_ini(:,:) )
         END IF
         z_ssh_hc(tile_n) = glob_sum_full( 'diahsb', z2d0 )
         z_ssh_sc(tile_n) = glob_sum_full( 'diahsb', z2d1 )
      ENDIF

      str = stream(tile_n)
      istat = 0
      istat = cudaMemcpyAsync( d_e3t, e3t , jpi*jpj*jpk*jpt , str ) + istat
      istat = cudaMemcpyAsync(d_hc_loc_ini, hc_loc_ini, jpi*jpj*jpk , str ) + istat
      istat = cudaMemcpyAsync(d_sc_loc_ini, sc_loc_ini, jpi*jpj*jpk , str ) + istat
      istat = cudaMemcpyAsync( d_ts, ts , jpi*jpj*jpk*2*jpt, str ) + istat
      IF( istat /= 0 ) THEN
         CALL ctl_stop( 'dia_hsb: unable to async GPU copy H2D' ) ; RETURN
      ENDIF
      dimBlock = dim3(4,4,4)
      dimGrid = dim3( ceiling( real( jpi ) / dimBlock%x ) , ceiling( real( jpj ) / dimBlock%y ) , &
                                                             ceiling( real( jpkm1 ) / dimBlock%z ) )
      !
      CALL dia_hsb_kernel<<<dimGrid, dimBlock, 0, str>>> (d_surf , d_e3t, d_surf_ini, d_e3t_ini, &
      & d_ts, d_hc_loc_ini, d_sc_loc_ini, d_tmask, d_tmask_ini, d_zwrkv, d_zwrkh, d_zwrks, d_zwrk, jpi, jpj, jpk, jpt, Kmm)


      CALL filter_cuda<<<dimGrid, dimBlock, 0, str>>>(d_zwrkv , d_tmask_i , jpi, jpj, jpk)
      CALL filter_cuda<<<dimGrid, dimBlock, 0, str>>>(d_zwrkh , d_tmask_i , jpi, jpj, jpk)
      CALL filter_cuda<<<dimGrid, dimBlock, 0, str>>>(d_zwrks , d_tmask_i , jpi, jpj, jpk)
      CALL filter_cuda<<<dimGrid, dimBlock, 0, str>>>(d_zwrk , d_tmask_i , jpi, jpj, jpk)
Guillaume Samson's avatar
Guillaume Samson committed

      ztmpv = 0.e0
      ztmph = 0.e0
      ztmps = 0.e0
      ztmp = 0.e0

      istat = 0
      !$cuf kernel do <<< *, *, stream=str >>>
      do ji = 1, localsize
        ztmpv = ztmpv + d_zwrkv(ji)
      end do
      istat = cudaMemcpyAsync( h_ztmpv(tile_n) , ztmpv , 1 , str ) + istat
      !$cuf kernel do <<< *, *, stream=str >>>
      do ji = 1, localsize
        ztmph = ztmph + d_zwrkh(ji)
      end do
      istat = cudaMemcpyAsync( h_ztmph(tile_n) , ztmph , 1 , str ) + istat
      !$cuf kernel do <<< *, *, stream=str >>>
      do ji = 1, localsize
        ztmps = ztmps + d_zwrks(ji)
      end do
      istat = cudaMemcpyAsync( h_ztmps(tile_n) , ztmps , 1 , str ) + istat
      !$cuf kernel do <<< *, *, stream=str >>>
      do ji = 1, localsize
        ztmp = ztmp + d_zwrk(ji)
      end do
      istat = cudaMemcpyAsync( h_ztmp (tile_n) , ztmp , 1 , str ) + istat
      !
      IF( istat /= 0 ) THEN
         CALL ctl_stop( 'dia_hsb: unable to async GPU copy D2H' ) ; RETURN
      ENDIF
      !
      istat = cudaStreamSynchronize(stream(tile_b))
      !
      IF( istat /= 0 ) THEN
         CALL ctl_stop( 'dia_hsb: unable to stream synchronize' ) ; RETURN
      ENDIF
      !
      ctmp = CMPLX( h_ztmpv(tile_b) , 0.e0, 8 )
      CALL mpp_sum('diahsb', ctmp )
      zdiff_v2(tile_b) = REAL( ctmp, 8 )

      ctmp = CMPLX( h_ztmph(tile_b) , 0.e0, 8 )
      CALL mpp_sum('diahsb', ctmp )
      zdiff_hc(tile_b) = REAL( ctmp, 8 )

      ctmp = CMPLX( h_ztmps(tile_b) , 0.e0, 8 )
      CALL mpp_sum('diahsb', ctmp )
      zdiff_sc(tile_b) = REAL( ctmp, 8 )

      ctmp = CMPLX( h_ztmp(tile_b) , 0.e0, 8 )
      CALL mpp_sum('diahsb', ctmp )
      zvol_tot(tile_b) = REAL( ctmp, 8 )

      IF ( kt == nitend ) THEN
         !
         istat = cudaStreamSynchronize(stream(tile_n))
         IF( istat /= 0 ) THEN
         CALL ctl_stop( 'dia_hsb: unable to stream synchronize' ) ; RETURN
         ENDIF
         !
         ctmp = CMPLX( h_ztmpv(tile_n) , 0.e0, 8 )
         CALL mpp_sum('diahsb', ctmp )
         zdiff_v2(tile_n) = REAL( ctmp, 8 )

         ctmp = CMPLX( h_ztmph(tile_n) , 0.e0, 8 )
         CALL mpp_sum('diahsb', ctmp )
         zdiff_hc(tile_n) = REAL( ctmp, 8 )

         ctmp = CMPLX( h_ztmps(tile_n) , 0.e0, 8 )
         CALL mpp_sum('diahsb', ctmp )
         zdiff_sc(tile_n) = REAL( ctmp, 8 )

         ctmp = CMPLX( h_ztmp(tile_n) , 0.e0, 8 )
         CALL mpp_sum('diahsb', ctmp )
         zvol_tot(tile_n) = REAL( ctmp, 8 )
      ENDIF
      ! ------------------------ !
      ! 3 - Drifts !
      ! ------------------------ !

      IF ( kt > 1 ) THEN
            kts = kts - 1
            zdiff_v1(tile_b) = zdiff_v1(tile_b) - frc_v(tile_b)
            IF( .NOT.ln_linssh ) zdiff_v2(tile_b) = zdiff_v2(tile_b) - frc_v(tile_b)
            zdiff_hc(tile_b) = zdiff_hc(tile_b) - frc_t(tile_b)
            zdiff_sc(tile_b) = zdiff_sc(tile_b) - frc_s(tile_b)
            IF( ln_linssh ) THEN
                zdiff_hc1(tile_b) = zdiff_hc (tile_b) + z_ssh_hc(tile_b)
                zdiff_sc1(tile_b) = zdiff_sc (tile_b) + z_ssh_sc(tile_b)
                zerr_hc1 (tile_b) = z_ssh_hc(tile_b) - frc_wn_t(tile_b)
                zerr_sc1 (tile_b) = z_ssh_sc(tile_b) - frc_wn_s(tile_b)
            ENDIF
            !!gm to be added ?
            ! IF( ln_linssh ) THEN ! fixed volume, add the ssh contribution
            ! zvol_tot = zvol_tot + glob_sum( 'diahsb', surf(:,:) * sshn(:,:) )
            ! ENDIF
            !!gm end

            CALL iom_put( 'bgfrcvol' , frc_v(tile_b) * 1.e-9 ) ! vol - surface forcing (km3)
            CALL iom_put( 'bgfrctem' , frc_t(tile_b) * rho0 * rcp * 1.e-20 ) ! hc - surface forcing (1.e20 J)
            CALL iom_put( 'bgfrchfx' , frc_t(tile_b) * rho0 * rcp / & ! hc - surface forcing (W/m2)
                & ( surf_tot * kts * rn_Dt ) )
            CALL iom_put( 'bgfrcsal' , frc_s(tile_b) * 1.e-9 ) ! sc - surface forcing (psu*km3)
            IF( .NOT. ln_linssh ) THEN
                CALL iom_put( 'bgtemper' , zdiff_hc(tile_b) / zvol_tot(tile_b) ) ! Temperature drift (C)
                CALL iom_put( 'bgsaline' , zdiff_sc(tile_b) / zvol_tot(tile_b) ) ! Salinity drift (PSU)
                CALL iom_put( 'bgheatco' , zdiff_hc(tile_b) * 1.e-20 * rho0 * rcp ) ! Heat content drift (1.e20 J)
                CALL iom_put( 'bgheatfx' , zdiff_hc(tile_b) * rho0 * rcp / & ! Heat flux drift (W/m2)
                    & ( surf_tot * kts * rn_Dt ) )
                CALL iom_put( 'bgsaltco' , zdiff_sc(tile_b) * 1.e-9 ) ! Salt content drift (psu*km3)
                CALL iom_put( 'bgvolssh' , zdiff_v1(tile_b) * 1.e-9 ) ! volume ssh drift (km3)
                CALL iom_put( 'bgvole3t' , zdiff_v2(tile_b) * 1.e-9 ) ! volume e3t drift (km3)
                !
! IF( lwp ) THEN
! WRITE(numout,*)
! WRITE(numout,*) 'dia_hsb : last time step hsb diagnostics: at it= ', kt,' date= ', ndastp
! WRITE(numout,*) '~~~~~~~'
! WRITE(numout,*) '   Temperature drift = ', zdiff_hc(tile_b) / zvol_tot(tile_b), ' C'
! WRITE(numout,*) '   Salinity12  drift = ', zdiff_sc(tile_b) / zvol_tot(tile_b), ' PSU'
! WRITE(numout,*) '   volume ssh  drift = ', zdiff_v1(tile_b) * 1.e-9 , ' km^3'
! WRITE(numout,*) '   volume e3t  drift = ', zdiff_v2(tile_b) * 1.e-9 , ' km^3'
! ENDIF
            ELSE
                CALL iom_put( 'bgtemper' , zdiff_hc1(tile_b) / zvol_tot(tile_b)) ! Heat content drift (C)
                CALL iom_put( 'bgsaline' , zdiff_sc1(tile_b) / zvol_tot(tile_b)) ! Salt content drift (PSU)
                CALL iom_put( 'bgheatco' , zdiff_hc1(tile_b) * 1.e-20 * rho0 * rcp ) ! Heat content drift (1.e20 J)
                CALL iom_put( 'bgheatfx' , zdiff_hc1(tile_b) * rho0 * rcp / & ! Heat flux drift (W/m2)
                    & ( surf_tot * kts * rn_Dt ) )
                CALL iom_put( 'bgsaltco' , zdiff_sc1(tile_b) * 1.e-9 ) ! Salt content drift (psu*km3)
                CALL iom_put( 'bgvolssh' , zdiff_v1(tile_b) * 1.e-9 ) ! volume ssh drift (km3)
                CALL iom_put( 'bgmistem' , zerr_hc1(tile_b) / zvol_tot(tile_b) ) ! hc - error due to free surface (C)
                CALL iom_put( 'bgmissal' , zerr_sc1(tile_b) / zvol_tot(tile_b) ) ! sc - error due to free surface (psu)
            ENDIF
            !
            IF( lrst_oce ) CALL dia_hsb_rst( kts, Kmm, tile_n, 'WRITE' )
        !
        END IF
        IF ( kt == nitend ) THEN

            zdiff_v1(tile_n) = zdiff_v1(tile_n) - frc_v(tile_n)
            IF( .NOT.ln_linssh ) zdiff_v2(tile_n) = zdiff_v2(tile_n) - frc_v(tile_n)
            zdiff_hc(tile_n) = zdiff_hc(tile_n) - frc_t(tile_n)
            zdiff_sc(tile_n) = zdiff_sc(tile_n) - frc_s(tile_n)
            IF( ln_linssh ) THEN
                zdiff_hc1(tile_n) = zdiff_hc (tile_n) + z_ssh_hc(tile_n)
                zdiff_sc1(tile_n) = zdiff_sc (tile_n) + z_ssh_sc(tile_n)
                zerr_hc1 (tile_n) = z_ssh_hc(tile_n) - frc_wn_t(tile_n)
                zerr_sc1 (tile_n) = z_ssh_sc(tile_n) - frc_wn_s(tile_n)
            ENDIF
            !!gm to be added ?
            ! IF( ln_linssh ) THEN ! fixed volume, add the ssh contribution
            ! zvol_tot = zvol_tot + glob_sum( 'diahsb', surf(:,:) * sshn(:,:) )
            ! ENDIF
            !!gm end

            CALL iom_put( 'bgfrcvol' , frc_v(tile_n) * 1.e-9 ) ! vol - surface forcing (km3)
            CALL iom_put( 'bgfrctem' , frc_t(tile_n) * rho0 * rcp * 1.e-20 ) ! hc - surface forcing (1.e20 J)
            CALL iom_put( 'bgfrchfx' , frc_t(tile_n) * rho0 * rcp / & ! hc - surface forcing (W/m2)
                & ( surf_tot * kt * rn_Dt ) )
            CALL iom_put( 'bgfrcsal' , frc_s(tile_n) * 1.e-9 ) ! sc - surface forcing (psu*km3)
            IF( .NOT. ln_linssh ) THEN
                CALL iom_put( 'bgtemper' , zdiff_hc(tile_n) / zvol_tot(tile_n) ) ! Temperature drift (C)
                CALL iom_put( 'bgsaline' , zdiff_sc(tile_n) / zvol_tot(tile_n) ) ! Salinity drift (PSU)
                CALL iom_put( 'bgheatco' , zdiff_hc(tile_n) * 1.e-20 * rho0 * rcp ) ! Heat content drift (1.e20 J)
                CALL iom_put( 'bgheatfx' , zdiff_hc(tile_n) * rho0 * rcp / & ! Heat flux drift (W/m2)
                    & ( surf_tot * kt * rn_Dt ) )
                CALL iom_put( 'bgsaltco' , zdiff_sc(tile_n) * 1.e-9 ) ! Salt content drift (psu*km3)
                CALL iom_put( 'bgvolssh' , zdiff_v1(tile_n) * 1.e-9 ) ! volume ssh drift (km3)
                CALL iom_put( 'bgvole3t' , zdiff_v2(tile_n) * 1.e-9 ) ! volume e3t drift (km3)
                !
                IF( kt == nitend .AND. lwp ) THEN
                    WRITE(numout,*)
                    WRITE(numout,*) 'dia_hsb : last time step hsb diagnostics: at it= ', kt,' date= ', ndastp
                    WRITE(numout,*) '~~~~~~~'
                    WRITE(numout,*) '   Temperature drift = ', zdiff_hc(tile_n) / zvol_tot(tile_n), ' C'
                    WRITE(numout,*) '   Salinity  drift   = ', zdiff_sc(tile_n) / zvol_tot(tile_n), ' PSU'
                    WRITE(numout,*) '   volume ssh  drift = ', zdiff_v1(tile_n) * 1.e-9 , ' km^3'
                    WRITE(numout,*) '   volume e3t  drift = ', zdiff_v2(tile_n) * 1.e-9 , ' km^3'
                !
                ENDIF
               !
            ELSE
                CALL iom_put( 'bgtemper' , zdiff_hc1(tile_n) / zvol_tot(tile_n)) ! Heat content drift (C)
                CALL iom_put( 'bgsaline' , zdiff_sc1(tile_n) / zvol_tot(tile_n)) ! Salt content drift (PSU)
                CALL iom_put( 'bgheatco' , zdiff_hc1(tile_n) * 1.e-20 * rho0 * rcp ) ! Heat content drift (1.e20 J)
                CALL iom_put( 'bgheatfx' , zdiff_hc1(tile_n) * rho0 * rcp / & ! Heat flux drift (W/m2)
                    & ( surf_tot * kt * rn_Dt ) )
                CALL iom_put( 'bgsaltco' , zdiff_sc1(tile_n) * 1.e-9 ) ! Salt content drift (psu*km3)
                CALL iom_put( 'bgvolssh' , zdiff_v1(tile_n) * 1.e-9 ) ! volume ssh drift (km3)
                CALL iom_put( 'bgmistem' , zerr_hc1(tile_n) / zvol_tot(tile_n) ) ! hc - error due to free surface (C)
                CALL iom_put( 'bgmissal' , zerr_sc1(tile_n) / zvol_tot(tile_n) ) ! sc - error due to free surface (psu)
            ENDIF
            !
            !Last step, don't need restart
            !IF( lrst_oce ) CALL dia_hsb_rst( kts, Kmm, tile_n, 'WRITE' )
        !
        END IF
      IF( ln_timing ) CALL timing_stop('dia_hsb')
      !
   END SUBROUTINE dia_hsb


   SUBROUTINE dia_hsb_rst( kt, Kmm, tile, cdrw )
      !!---------------------------------------------------------------------
      !!                   ***  ROUTINE dia_hsb_rst  ***
      !!
      !! ** Purpose : Read or write DIA file in restart file
      !!
      !! ** Method  : use of IOM library
      !!----------------------------------------------------------------------
      INTEGER         , INTENT(in) ::   kt     ! ocean time-step
      INTEGER         , INTENT(in) ::   Kmm    ! ocean time level index
      CHARACTER(len=*), INTENT(in) ::   cdrw   ! "READ"/"WRITE" flag
      INTEGER , INTENT(in)         :: tile     ! host tile
      !
      INTEGER ::   ji, jj, jk   ! dummy loop indices
      !!----------------------------------------------------------------------
      !
      IF( TRIM(cdrw) == 'READ' ) THEN        ! Read/initialise
         IF( ln_rstart ) THEN                   !* Read the restart file
            !
            IF(lwp) WRITE(numout,*)
            IF(lwp) WRITE(numout,*) '   dia_hsb_rst : read hsb restart at it= ', kt,' date= ', ndastp
            IF(lwp) WRITE(numout,*)
            CALL iom_get( numror, 'frc_v', frc_v(tile) )
            CALL iom_get( numror, 'frc_t', frc_t(tile) )
            CALL iom_get( numror, 'frc_s', frc_s(tile) )
            IF( ln_linssh ) THEN
               CALL iom_get( numror, 'frc_wn_t', frc_wn_t(tile) )
               CALL iom_get( numror, 'frc_wn_s', frc_wn_s(tile) )
            ENDIF
            CALL iom_get( numror, jpdom_auto, 'surf_ini'  , surf_ini   ) ! ice sheet coupling
            CALL iom_get( numror, jpdom_auto, 'ssh_ini'   , ssh_ini    )
            CALL iom_get( numror, jpdom_auto, 'e3t_ini'   , e3t_ini    )
            CALL iom_get( numror, jpdom_auto, 'tmask_ini' , tmask_ini  )
            CALL iom_get( numror, jpdom_auto, 'hc_loc_ini', hc_loc_ini )
            CALL iom_get( numror, jpdom_auto, 'sc_loc_ini', sc_loc_ini )
            IF( ln_linssh ) THEN
               CALL iom_get( numror, jpdom_auto, 'ssh_hc_loc_ini', ssh_hc_loc_ini )
               CALL iom_get( numror, jpdom_auto, 'ssh_sc_loc_ini', ssh_sc_loc_ini )
            ENDIF
         ELSE
            IF(lwp) WRITE(numout,*)
            IF(lwp) WRITE(numout,*) '   dia_hsb_rst : initialise hsb at initial state '
            IF(lwp) WRITE(numout,*)
            surf_ini(:,:) = e1e2t(:,:) * tmask_i(:,:)         ! initial ocean surface
            ssh_ini(:,:) = ssh(:,:,Kmm)                          ! initial ssh
            DO jk = 1, jpk
              ! if ice sheet/oceqn coupling, need to mask ini variables here (mask could change at the next NEMO instance).
               e3t_ini   (:,:,jk) = e3t(:,:,jk,Kmm)                      * tmask(:,:,jk)  ! initial vertical scale factors
               tmask_ini (:,:,jk) = tmask(:,:,jk)                                       ! initial mask
               hc_loc_ini(:,:,jk) = ts(:,:,jk,jp_tem,Kmm) * e3t(:,:,jk,Kmm) * tmask(:,:,jk)  ! initial heat content
               sc_loc_ini(:,:,jk) = ts(:,:,jk,jp_sal,Kmm) * e3t(:,:,jk,Kmm) * tmask(:,:,jk)  ! initial salt content
            END DO
            d_surf_ini = surf_ini
            d_e3t_ini = e3t_ini
            d_tmask_ini = tmask_ini
            d_hc_loc_ini = hc_loc_ini
            d_sc_loc_ini = sc_loc_ini
            frc_v(tile) = 0._wp                                           ! volume       trend due to forcing
            frc_t(tile) = 0._wp                                           ! heat content   -    -   - -
            frc_s(tile) = 0._wp                                           ! salt content   -    -   - -
            IF( ln_linssh ) THEN
               IF( ln_isfcav ) THEN
                  DO ji = 1, jpi
                     DO jj = 1, jpj
                        ssh_hc_loc_ini(ji,jj) = ts(ji,jj,mikt(ji,jj),jp_tem,Kmm) * ssh(ji,jj,Kmm) ! initial heat content in ssh
                        ssh_sc_loc_ini(ji,jj) = ts(ji,jj,mikt(ji,jj),jp_sal,Kmm) * ssh(ji,jj,Kmm) ! initial salt content in ssh
                     END DO
                   END DO
                ELSE
                  ssh_hc_loc_ini(:,:) = ts(:,:,1,jp_tem,Kmm) * ssh(:,:,Kmm) ! initial heat content in ssh
                  ssh_sc_loc_ini(:,:) = ts(:,:,1,jp_sal,Kmm) * ssh(:,:,Kmm) ! initial salt content in ssh
               END IF
               frc_wn_t(tile) = 0._wp ! initial heat content misfit due to free surface
               frc_wn_s(tile) = 0._wp ! initial salt content misfit due to free surface
            ENDIF
         ENDIF
         !
      ELSEIF( TRIM(cdrw) == 'WRITE' ) THEN   ! Create restart file
         !                                   ! -------------------
         IF(lwp) WRITE(numout,*)
         IF(lwp) WRITE(numout,*) '   dia_hsb_rst : write restart at it= ', kt,' date= ', ndastp
         IF(lwp) WRITE(numout,*)
         !
         CALL iom_rstput( kt, nitrst, numrow, 'frc_v', frc_v(tile) )
         CALL iom_rstput( kt, nitrst, numrow, 'frc_t', frc_t(tile) )
         CALL iom_rstput( kt, nitrst, numrow, 'frc_s', frc_s(tile) )
         IF( ln_linssh ) THEN
            CALL iom_rstput( kt, nitrst, numrow, 'frc_wn_t', frc_wn_t(tile) )
            CALL iom_rstput( kt, nitrst, numrow, 'frc_wn_s', frc_wn_s(tile) )
         ENDIF
         CALL iom_rstput( kt, nitrst, numrow, 'surf_ini'  , surf_ini   ) ! ice sheet coupling
         CALL iom_rstput( kt, nitrst, numrow, 'ssh_ini'   , ssh_ini    )
         CALL iom_rstput( kt, nitrst, numrow, 'e3t_ini'   , e3t_ini    )
         CALL iom_rstput( kt, nitrst, numrow, 'tmask_ini' , tmask_ini  )
         CALL iom_rstput( kt, nitrst, numrow, 'hc_loc_ini', hc_loc_ini )
         CALL iom_rstput( kt, nitrst, numrow, 'sc_loc_ini', sc_loc_ini )
         IF( ln_linssh ) THEN
            CALL iom_rstput( kt, nitrst, numrow, 'ssh_hc_loc_ini', ssh_hc_loc_ini )
            CALL iom_rstput( kt, nitrst, numrow, 'ssh_sc_loc_ini', ssh_sc_loc_ini )
         ENDIF
         !
      ENDIF
      !
   END SUBROUTINE dia_hsb_rst


   SUBROUTINE dia_hsb_init( Kmm )
      !!---------------------------------------------------------------------------
      !!                  ***  ROUTINE dia_hsb  ***
      !!
      !! ** Purpose: Initialization for the heat salt volume budgets
      !!
      !! ** Method : Compute initial heat content, salt content and volume
      !!
      !! ** Action : - Compute initial heat content, salt content and volume
      !!             - Initialize forcing trends
      !!             - Compute coefficients for conversion
      !!---------------------------------------------------------------------------
      INTEGER, INTENT(in) :: Kmm ! time level index
      !
      INTEGER ::   ierror, ios   ! local integer
      INTEGER ::   i, istat      ! local integer
      !!
      NAMELIST/namhsb/ ln_diahsb
      !!----------------------------------------------------------------------
      !
      IF(lwp) THEN
         WRITE(numout,*)
         WRITE(numout,*) 'dia_hsb_init : heat and salt budgets diagnostics'
         WRITE(numout,*) '~~~~~~~~~~~~ '
      ENDIF
      READ  ( numnam_ref, namhsb, IOSTAT = ios, ERR = 901)
901   IF( ios /= 0 )   CALL ctl_nam ( ios , 'namhsb in reference namelist' )
      READ  ( numnam_cfg, namhsb, IOSTAT = ios, ERR = 902 )
902   IF( ios >  0 )   CALL ctl_nam ( ios , 'namhsb in configuration namelist' )
      IF(lwm) WRITE( numond, namhsb )

      IF(lwp) THEN
         WRITE(numout,*) '   Namelist  namhsb :'
         WRITE(numout,*) '      check the heat and salt budgets (T) or not (F)       ln_diahsb = ', ln_diahsb
      ENDIF
      !
      IF( .NOT. ln_diahsb )   RETURN

      ! ------------------- !
      ! 1 - Allocate memory !
      ! ------------------- !

      CALL setdevice()
      !Device data associate to PUBLIC arrays
      ALLOCATE(d_e3t (jpi,jpj,jpk,jpt) ) !
      ALLOCATE(d_tmask (jpi,jpj,jpk) ) !
      ALLOCATE(d_tmask_ini (jpi,jpj,jpk) ) !
      ALLOCATE(d_tmask_i (jpi,jpj) ) !
Guillaume Samson's avatar
Guillaume Samson committed
      ALLOCATE(d_ts (jpi,jpj,jpk,2,jpj) ) !
      !Device data associate to LOCAL/DEVICE arrays !
      ALLOCATE(d_surf (jpi,jpj) ) !
      ALLOCATE(d_surf_ini (jpi,jpj) ) !
      ALLOCATE(d_hc_loc_ini (jpi,jpj,jpk) ) !
      ALLOCATE(d_sc_loc_ini (jpi,jpj,jpk) ) !
      ALLOCATE(d_e3t_ini (jpi,jpj,jpk) ) !
      ALLOCATE(d_zwrkv (jpi,jpj,jpkm1) ) !
      ALLOCATE(d_zwrkh (jpi,jpj,jpkm1) ) !
      ALLOCATE(d_zwrks (jpi,jpj,jpkm1) ) !
      ALLOCATE(d_zwrk (jpi,jpj,jpkm1) ) !
      ALLOCATE(h_ztmpv(2),h_ztmph(2),h_ztmps(2),h_ztmp(2)) !

      DO i = 1, nstreams !Create Streams
          istat = cudaStreamCreate(stream(i))
         IF( istat /= 0 ) THEN
         CALL ctl_stop( 'dia_hsb_init: error in Stream creation' ) ; RETURN
         ENDIF
      END DO
            !
      !Pinned reallocation step non constant
      istat = cudaHostRegister(C_LOC(ts ), sizeof(ts ), cudaHostRegisterMapped)
      istat = cudaHostRegister(C_LOC(e3t), sizeof(e3t), cudaHostRegisterMapped)
      IF( istat /= 0 ) THEN
         CALL ctl_stop( 'dia_hsb_init: unable to pin host memory to GPU' ) ; RETURN
      ENDIF

      ALLOCATE( hc_loc_ini(jpi,jpj,jpk), sc_loc_ini(jpi,jpj,jpk), surf_ini(jpi,jpj), &
         &      e3t_ini(jpi,jpj,jpk), surf(jpi,jpj),  ssh_ini(jpi,jpj), tmask_ini(jpi,jpj,jpk),STAT=ierror  )
      IF( ierror > 0 ) THEN
         CALL ctl_stop( 'dia_hsb_init: unable to allocate hc_loc_ini' )   ;   RETURN
      ENDIF

      IF( ln_linssh )   ALLOCATE( ssh_hc_loc_ini(jpi,jpj), ssh_sc_loc_ini(jpi,jpj),STAT=ierror )
      IF( ierror > 0 ) THEN
         CALL ctl_stop( 'dia_hsb: unable to allocate ssh_hc_loc_ini' )   ;   RETURN
      ENDIF

      ! ----------------------------------------------- !
      ! 2 - Time independant variables and file opening !
      ! ----------------------------------------------- !
      surf(:,:) = e1e2t(:,:) * tmask_i(:,:)               ! masked surface grid cell area
      surf_tot  = glob_sum( 'diahsb', surf(:,:) )         ! total ocean surface area

       d_surf = surf
       d_surf_ini = surf_ini
       d_e3t_ini = e3t_ini
       d_tmask = tmask
       d_tmask_ini = tmask_ini
       d_tmask_i = tmask_i
Guillaume Samson's avatar
Guillaume Samson committed
       h_ztmp = 0.0

      IF( ln_bdy ) CALL ctl_warn( 'dia_hsb_init: heat/salt budget does not consider open boundary fluxes' )
      !
      ! ---------------------------------- !
      ! 4 - initial conservation variables !
      ! ---------------------------------- !
      CALL dia_hsb_rst( nit000, Kmm, 1, 'READ' ) !* read or initialize all required files



      !
   END SUBROUTINE dia_hsb_init

   !!======================================================================
END MODULE diahsb