From a4da548e830e8ad75fbdf4fd3b29033251c770ad Mon Sep 17 00:00:00 2001
From: Sibylle TECHENE <techenes@irene193.c-irene.tgcc.ccc.cea.fr>
Date: Wed, 27 Jul 2022 12:35:02 +0200
Subject: [PATCH] reduce local memory usage #80

---
 src/OCE/TRA/trazdf.F90 | 262 ++++++++++++++++++++++++-----------------
 1 file changed, 156 insertions(+), 106 deletions(-)

diff --git a/src/OCE/TRA/trazdf.F90 b/src/OCE/TRA/trazdf.F90
index 8280d3181..7123aeba3 100644
--- a/src/OCE/TRA/trazdf.F90
+++ b/src/OCE/TRA/trazdf.F90
@@ -6,6 +6,7 @@ MODULE trazdf
    !! History :  1.0  !  2005-11  (G. Madec)  Original code
    !!            3.0  !  2008-01  (C. Ethe, G. Madec)  merge TRC-TRA
    !!            4.0  !  2017-06  (G. Madec)  remove explict time-stepping option
+   !!            4.5  !  2022-06  (G. Madec)  refactoring to reduce memory usage (j-k-i loops)
    !!----------------------------------------------------------------------
 
    !!----------------------------------------------------------------------
@@ -22,7 +23,7 @@ MODULE trazdf
    USE ldfslp         ! lateral diffusion: iso-neutral slope
    USE trd_oce        ! trends: ocean variables
    USE trdtra         ! trends: tracer trend manager
-   USE eosbn2, ONLY: ln_SEOS, rn_b0
+   USE eosbn2   , ONLY: ln_SEOS, rn_b0
    !
    USE in_out_manager ! I/O manager
    USE prtctl         ! Print control
@@ -77,7 +78,7 @@ CONTAINS
       ENDIF
       !
       !                                      !* compute lateral mixing trend and add it to the general trend
-      CALL tra_zdf_imp( kt, nit000, 'TRA', rDt, Kbb, Kmm, Krhs, pts, Kaa, jpts )
+      CALL tra_zdf_imp( 'TRA', Kbb, Kmm, Krhs, pts, Kaa, jpts )
 
 !!gm WHY here !   and I don't like that !
       ! DRAKKAR SSS control {
@@ -116,7 +117,7 @@ CONTAINS
    END SUBROUTINE tra_zdf
 
 
-   SUBROUTINE tra_zdf_imp( kt, kit000, cdtype, p2dt, Kbb, Kmm, Krhs, pt, Kaa, kjpt )
+   SUBROUTINE tra_zdf_imp( cdtype, Kbb, Kmm, Krhs, pt, Kaa, kjpt )
       !!----------------------------------------------------------------------
       !!                  ***  ROUTINE tra_zdf_imp  ***
       !!
@@ -136,128 +137,177 @@ CONTAINS
       !!
       !! ** Action  : - pt(:,:,:,:,Kaa)  becomes the after tracer
       !!---------------------------------------------------------------------
-      INTEGER                                  , INTENT(in   ) ::   kt       ! ocean time-step index
       INTEGER                                  , INTENT(in   ) ::   Kbb, Kmm, Krhs, Kaa  ! ocean time level indices
-      INTEGER                                  , INTENT(in   ) ::   kit000   ! first time step index
       CHARACTER(len=3)                         , INTENT(in   ) ::   cdtype   ! =TRA or TRC (tracer indicator)
       INTEGER                                  , INTENT(in   ) ::   kjpt     ! number of tracers
-      REAL(wp)                                 , INTENT(in   ) ::   p2dt     ! tracer time-step
       REAL(wp), DIMENSION(jpi,jpj,jpk,kjpt,jpt), INTENT(inout) ::   pt       ! tracers and RHS of tracer equation
       !
       INTEGER  ::  ji, jj, jk, jn   ! dummy loop indices
       REAL(wp) ::  zrhs, zzwi, zzws ! local scalars
-      REAL(wp), DIMENSION(A2D(nn_hls),jpk) ::  zwi, zwt, zwd, zws
+      REAL(wp), DIMENSION(A1Di(0),jpk) ::  zwi, zwt, zwd, zws
       !!---------------------------------------------------------------------
       !
-      !                                               ! ============= !
-      DO jn = 1, kjpt                                 !  tracer loop  !
-         !                                            ! ============= !
-         !  Matrix construction
-         ! --------------------
-         ! Build matrix if temperature or salinity (only in double diffusion case) or first passive tracer
-         !
-         IF(  ( cdtype == 'TRA' .AND. ( jn == jp_tem .OR. ( jn == jp_sal .AND. ln_zdfddm ) ) ) .OR.   &
-            & ( cdtype == 'TRC' .AND. jn == 1 )  )  THEN
+      !                                               ! ================= !
+      DO_1Dj( 0, 0 )                                  !  i-k slices loop  !
+         !                                            ! ================= !
+         DO jn = 1, kjpt                              !    tracer loop    !
+            !                                         ! ================= !
             !
-            ! vertical mixing coef.: avt for temperature, avs for salinity and passive tracers
-            IF( cdtype == 'TRA' .AND. jn == jp_tem ) THEN
-               DO_3D( 1, 1, 1, 1, 2, jpk )
-                  zwt(ji,jj,jk) = avt(ji,jj,jk)
-               END_3D
-            ELSE
-               DO_3D( 1, 1, 1, 1, 2, jpk )
-                  zwt(ji,jj,jk) = avs(ji,jj,jk)
-               END_3D
-            ENDIF
-            zwt(:,:,1) = 0._wp
+            !  Matrix construction
+            ! --------------------
+            ! Build matrix if temperature or salinity (only in double diffusion case) or first passive tracer
             !
-            IF( l_ldfslp ) THEN            ! isoneutral diffusion: add the contribution
-               IF( ln_traldf_msc  ) THEN     ! MSC iso-neutral operator
-                  DO_3D( 0, 0, 0, 0, 2, jpkm1 )
-                     zwt(ji,jj,jk) = zwt(ji,jj,jk) + akz(ji,jj,jk)
-                  END_3D
-               ELSE                          ! standard or triad iso-neutral operator
-                  DO_3D( 0, 0, 0, 0, 2, jpkm1 )
-                     zwt(ji,jj,jk) = zwt(ji,jj,jk) + ah_wslp2(ji,jj,jk)
-                  END_3D
+            IF(  ( cdtype == 'TRA' .AND. ( jn == jp_tem .OR. ( jn == jp_sal .AND. ln_zdfddm ) ) ) .OR.   &
+               & ( cdtype == 'TRC' .AND. jn == 1 )  )  THEN
+               !
+               ! vertical mixing coef.: avt for temperature, avs for salinity and passive tracers
+               !
+               IF( cdtype == 'TRA' .AND. jn == jp_tem ) THEN     ! use avt  for temperature
+                  !
+                  IF( l_ldfslp ) THEN            ! use avt + isoneutral diffusion contribution
+                     IF( ln_traldf_msc  ) THEN        ! MSC iso-neutral operator
+                        DO_2Dik( 0, 0,   2, jpk, 1 )
+                           zwt(ji,jk) = avt(ji,jj,jk) + akz(ji,jj,jk)
+                        END_2D
+                     ELSE                             ! standard or triad iso-neutral operator
+                        DO_2Dik( 0, 0,   2, jpk, 1 )
+                           zwt(ji,jk) = avt(ji,jj,jk) + ah_wslp2(ji,jj,jk)
+                        END_2D
+                     ENDIF
+                  ELSE                          ! use avt only
+                     DO_2Dik( 0, 0,   2, jpk, 1 )
+                        zwt(ji,jk) = avt(ji,jj,jk)
+                     END_2D
+                  ENDIF
+                  !
+               ELSE                                               ! use avs for salinty or passive tracers
+                  !
+                  IF( l_ldfslp ) THEN            ! use avs + isoneutral diffusion contribution
+                     IF( ln_traldf_msc  ) THEN        ! MSC iso-neutral operator
+                        DO_2Dik( 0, 0,   2, jpk, 1 )
+                           zwt(ji,jk) = avs(ji,jj,jk) + akz(ji,jj,jk)
+                        END_2D
+                     ELSE                             ! standard or triad iso-neutral operator
+                        DO_2Dik( 0, 0,   2, jpk, 1 )
+                           zwt(ji,jk) = avs(ji,jj,jk) + ah_wslp2(ji,jj,jk)
+                        END_2D
+                     ENDIF
+                  ELSE                          ! 
+                     DO_2Dik( 0, 0,   2, jpk, 1 )
+                        zwt(ji,jk) = avs(ji,jj,jk)
+                     END_2D
+                  ENDIF
                ENDIF
+               zwt(:,1) = 0._wp
+               !
+               ! Diagonal, lower (i), upper (s)  (including the bottom boundary condition since avt is masked)
+               IF( ln_zad_Aimp ) THEN         ! Adaptive implicit vertical advection
+                  DO_2Dik( 0, 0,   1, jpkm1, 1 )
+                     zzwi = - rDt * zwt(ji,jk  ) / e3w(ji,jj,jk  ,Kmm)
+                     zzws = - rDt * zwt(ji,jk+1) / e3w(ji,jj,jk+1,Kmm)
+                     zwd(ji,jk) = e3t(ji,jj,jk,Kaa) - zzwi - zzws   &
+                        &              + rDt * ( MAX( wi(ji,jj,jk  ) , 0._wp ) &
+                        &                      - MIN( wi(ji,jj,jk+1) , 0._wp ) )
+                     zwi(ji,jk) = zzwi + rDt *   MIN( wi(ji,jj,jk  ) , 0._wp )
+                     zws(ji,jk) = zzws - rDt *   MAX( wi(ji,jj,jk+1) , 0._wp )
+                  END_2D
+               ELSE
+                  DO_2Dik( 0, 0,   1, jpkm1, 1 )
+                     zwi(ji,jk) = - rDt * zwt(ji,jk  ) / e3w(ji,jj,jk,Kmm)
+                     zws(ji,jk) = - rDt * zwt(ji,jk+1) / e3w(ji,jj,jk+1,Kmm)
+                     zwd(ji,jk) = e3t(ji,jj,jk,Kaa) - zwi(ji,jk) - zws(ji,jk)
+                  END_2D
+               ENDIF
+               !
+!!gm  BUG?? : if edmfm is equivalent to a w  ==>>>   just add +/-  rDt * edmfm(ji,jj,jk+1/jk  )
+!!            but edmfm is at t-point !!!!   crazy???  why not keep it at w-point????
+               !
+               IF( ln_zdfmfc ) THEN    ! add upward Mass Flux in the matrix
+                  DO_2Dik( 0, 0,   1, jpkm1, 1 )
+                     zws(ji,jk) = zws(ji,jk) + e3t(ji,jj,jk,Kaa) * rDt * edmfm(ji,jj,jk+1) / e3w(ji,jj,jk+1,Kmm)
+                     zwd(ji,jk) = zwd(ji,jk) - e3t(ji,jj,jk,Kaa) * rDt * edmfm(ji,jj,jk  ) / e3w(ji,jj,jk+1,Kmm)
+                  END_2D
+               ENDIF
+!       DO_3D( 0, 0, 0, 0, 1, jpkm1 )
+!          edmfa(ji,jj,jk) =  0._wp
+!          edmfb(ji,jj,jk) = -edmfm(ji,jj,jk  ) / e3w(ji,jj,jk+1,Kmm)
+!          edmfc(ji,jj,jk) =  edmfm(ji,jj,jk+1) / e3w(ji,jj,jk+1,Kmm)
+!       END_3D
+!!gm    BUG :  level jpk never used in the inversion
+!       DO_2D( 0, 0, 0, 0 )
+!          edmfa(ji,jj,jpk)   = -edmfm(ji,jj,jpk-1) / e3w(ji,jj,jpk,Kmm)
+!          edmfb(ji,jj,jpk)   =  edmfm(ji,jj,jpk  ) / e3w(ji,jj,jpk,Kmm)
+!          edmfc(ji,jj,jpk)   =  0._wp
+!       END_2D
+!!
+!!gm   BUG ???   below  e3t_Kmm  should be used ?  
+!!               or even no multiplication by e3t unless there is a bug in wi calculation
+!!
+!                   DO_3D( 0, 0, 0, 0, 1, jpkm1 )
+!!gm edmfa = 0._wp except at jpk which is not used  ==>>  zdiagi update is useless !
+!                      zdiagi(ji,jj,jk) = zdiagi(ji,jj,jk) + e3t(ji,jj,jk,Kaa) * p2dt *edmfa(ji,jj,jk)
+!                      zdiags(ji,jj,jk) = zdiags(ji,jj,jk) + e3t(ji,jj,jk,Kaa) * p2dt *edmfc(ji,jj,jk)
+!                      zdiagd(ji,jj,jk) = zdiagd(ji,jj,jk) + e3t(ji,jj,jk,Kaa) * p2dt *edmfb(ji,jj,jk)
+!                   END_3D
+!!gm                  CALL diag_mfc( zwi, zwd, zws, rDt, Kaa )
+!!gm   SUBROUTINE diag_mfc( zdiagi, zdiagd, zdiags, p2dt, Kaa )
+               !
+               !! Matrix inversion from the first level
+               !!----------------------------------------------------------------------
+               !   solve m.x = y  where m is a tri diagonal matrix ( jpk*jpk )
+               !
+               !        ( zwd1 zws1   0    0    0  )( zwx1 ) ( zwy1 )
+               !        ( zwi2 zwd2 zws2   0    0  )( zwx2 ) ( zwy2 )
+               !        (  0   zwi3 zwd3 zws3   0  )( zwx3 )=( zwy3 )
+               !        (        ...               )( ...  ) ( ...  )
+               !        (  0    0    0   zwik zwdk )( zwxk ) ( zwyk )
+               !
+               !   m is decomposed in the product of an upper and lower triangular matrix.
+               !   The 3 diagonal terms are in 3d arrays: zwd, zws, zwi.
+               !   Suffices i,s and d indicate "inferior" (below diagonal), diagonal
+               !   and "superior" (above diagonal) components of the tridiagonal system.
+               !   The solution will be in the 4d array pta.
+               !   The 3d array zwt is used as a work space array.
+               !   En route to the solution pt(:,:,:,:,Kaa) is used a to evaluate the rhs and then
+               !   used as a work space array: its value is modified.
+               !
+               DO_1Di( 0, 0 )          !* 1st recurrence:   Tk = Dk - Ik Sk-1 / Tk-1   (increasing k) ! done one for all passive tracers (so included in the IF instruction)
+                  zwt(ji,1) = zwd(ji,1)
+               END_1D
+               DO_2Dik( 0, 0,   2, jpkm1, 1 )
+                  zwt(ji,jk) = zwd(ji,jk) - zwi(ji,jk) * zws(ji,jk-1) / zwt(ji,jk-1)
+               END_2D
+               !
             ENDIF
             !
-            ! Diagonal, lower (i), upper (s)  (including the bottom boundary condition since avt is masked)
-            IF( ln_zad_Aimp ) THEN         ! Adaptive implicit vertical advection
-               DO_3D( 0, 0, 0, 0, 1, jpkm1 )
-                  zzwi = - p2dt * zwt(ji,jj,jk  ) / e3w(ji,jj,jk  ,Kmm)
-                  zzws = - p2dt * zwt(ji,jj,jk+1) / e3w(ji,jj,jk+1,Kmm)
-                  zwd(ji,jj,jk) = e3t(ji,jj,jk,Kaa) - zzwi - zzws   &
-                     &                 + p2dt * ( MAX( wi(ji,jj,jk  ) , 0._wp ) - MIN( wi(ji,jj,jk+1) , 0._wp ) )
-                  zwi(ji,jj,jk) = zzwi + p2dt *   MIN( wi(ji,jj,jk  ) , 0._wp )
-                  zws(ji,jj,jk) = zzws - p2dt *   MAX( wi(ji,jj,jk+1) , 0._wp )
-               END_3D
-            ELSE
-               DO_3D( 0, 0, 0, 0, 1, jpkm1 )
-                  zwi(ji,jj,jk) = - p2dt * zwt(ji,jj,jk  ) / e3w(ji,jj,jk,Kmm)
-                  zws(ji,jj,jk) = - p2dt * zwt(ji,jj,jk+1) / e3w(ji,jj,jk+1,Kmm)
-                  zwd(ji,jj,jk) = e3t(ji,jj,jk,Kaa) - zwi(ji,jj,jk) - zws(ji,jj,jk)
-               END_3D
+            IF( ln_zdfmfc ) THEN    ! add Mass Flux to the RHS 
+               DO_2Dik( 0, 0,   1, jpkm1, 1 )
+                  pt(ji,jj,jk,jn,Krhs) = pt(ji,jj,jk,jn,Krhs) + edmftra(ji,jj,jk,jn)
+               END_2D
+!!gm               CALL rhs_mfc( pt(:,:,:,jn,Krhs), jn )
             ENDIF
             !
-            ! Modification of diagonal to add MF scheme
-            IF ( ln_zdfmfc ) THEN
-               CALL diag_mfc( zwi, zwd, zws, p2dt, Kaa )
-            END IF
-            !
-            !! Matrix inversion from the first level
-            !!----------------------------------------------------------------------
-            !   solve m.x = y  where m is a tri diagonal matrix ( jpk*jpk )
-            !
-            !        ( zwd1 zws1   0    0    0  )( zwx1 ) ( zwy1 )
-            !        ( zwi2 zwd2 zws2   0    0  )( zwx2 ) ( zwy2 )
-            !        (  0   zwi3 zwd3 zws3   0  )( zwx3 )=( zwy3 )
-            !        (        ...               )( ...  ) ( ...  )
-            !        (  0    0    0   zwik zwdk )( zwxk ) ( zwyk )
-            !
-            !   m is decomposed in the product of an upper and lower triangular matrix.
-            !   The 3 diagonal terms are in 3d arrays: zwd, zws, zwi.
-            !   Suffices i,s and d indicate "inferior" (below diagonal), diagonal
-            !   and "superior" (above diagonal) components of the tridiagonal system.
-            !   The solution will be in the 4d array pta.
-            !   The 3d array zwt is used as a work space array.
-            !   En route to the solution pt(:,:,:,:,Kaa) is used a to evaluate the rhs and then
-            !   used as a work space array: its value is modified.
-            !
-            DO_2D( 0, 0, 0, 0 )      !* 1st recurrence:   Tk = Dk - Ik Sk-1 / Tk-1   (increasing k) ! done one for all passive tracers (so included in the IF instruction)
-               zwt(ji,jj,1) = zwd(ji,jj,1)
+            DO_1Di( 0, 0 )             !* 2nd recurrence:    Zk = Yk - Ik / Tk-1  Zk-1
+               pt(ji,jj,1,jn,Kaa) =       e3t(ji,jj,1,Kbb) * pt(ji,jj,1,jn,Kbb )    &
+                  &               + rDt * e3t(ji,jj,1,Kmm) * pt(ji,jj,1,jn,Krhs)
+            END_1D
+            DO_2Dik( 0, 0,   2, jpkm1, 1 )
+               zrhs =       e3t(ji,jj,jk,Kbb) * pt(ji,jj,jk,jn,Kbb )   &
+                  & + rDt * e3t(ji,jj,jk,Kmm) * pt(ji,jj,jk,jn,Krhs)   ! zrhs=right hand side
+               pt(ji,jj,jk,jn,Kaa) = zrhs - zwi(ji,jk) / zwt(ji,jk-1) * pt(ji,jj,jk-1,jn,Kaa)
             END_2D
-            DO_3D( 0, 0, 0, 0, 2, jpkm1 )
-               zwt(ji,jj,jk) = zwd(ji,jj,jk) - zwi(ji,jj,jk) * zws(ji,jj,jk-1) / zwt(ji,jj,jk-1)
-            END_3D
             !
-         ENDIF
-         !
-         ! Modification of rhs to add MF scheme
-         IF ( ln_zdfmfc ) THEN
-            CALL rhs_mfc( pt(:,:,:,jn,Krhs), jn )
-         END IF
-         !
-         DO_2D( 0, 0, 0, 0 )         !* 2nd recurrence:    Zk = Yk - Ik / Tk-1  Zk-1
-            pt(ji,jj,1,jn,Kaa) =        e3t(ji,jj,1,Kbb) * pt(ji,jj,1,jn,Kbb)    &
-               &               + p2dt * e3t(ji,jj,1,Kmm) * pt(ji,jj,1,jn,Krhs)
-         END_2D
-         DO_3D( 0, 0, 0, 0, 2, jpkm1 )
-            zrhs =        e3t(ji,jj,jk,Kbb) * pt(ji,jj,jk,jn,Kbb)    &
-               & + p2dt * e3t(ji,jj,jk,Kmm) * pt(ji,jj,jk,jn,Krhs)   ! zrhs=right hand side
-            pt(ji,jj,jk,jn,Kaa) = zrhs - zwi(ji,jj,jk) / zwt(ji,jj,jk-1) * pt(ji,jj,jk-1,jn,Kaa)
-         END_3D
-         !
-         DO_2D( 0, 0, 0, 0 )         !* 3d recurrence:    Xk = (Zk - Sk Xk+1 ) / Tk   (result is the after tracer)
-            pt(ji,jj,jpkm1,jn,Kaa) = pt(ji,jj,jpkm1,jn,Kaa) / zwt(ji,jj,jpkm1) * tmask(ji,jj,jpkm1)
-         END_2D
-         DO_3DS( 0, 0, 0, 0, jpk-2, 1, -1 )
-            pt(ji,jj,jk,jn,Kaa) = ( pt(ji,jj,jk,jn,Kaa) - zws(ji,jj,jk) * pt(ji,jj,jk+1,jn,Kaa) )   &
-               &             / zwt(ji,jj,jk) * tmask(ji,jj,jk)
-         END_3D
+            DO_1Di( 0, 0 )             !* 3d recurrence:    Xk = (Zk - Sk Xk+1 ) / Tk   (result is the after tracer)
+               pt(ji,jj,jpkm1,jn,Kaa) = pt(ji,jj,jpkm1,jn,Kaa) / zwt(ji,jpkm1) * tmask(ji,jj,jpkm1)
+            END_1D
+            DO_2Dik( 0, 0,   jpk-2, 1, -1 )
+               pt(ji,jj,jk,jn,Kaa) = ( pt(ji,jj,jk,jn,Kaa) - zws(ji,jk) * pt(ji,jj,jk+1,jn,Kaa) )   &
+                  &             / zwt(ji,jk) * tmask(ji,jj,jk)
+            END_2D
+            !                                         ! ================= !
+         END DO                                       !    tracer loop    !
          !                                            ! ================= !
-      END DO                                          !  end tracer loop  !
+      END_1D                                          !  i-k slices loop  !      
       !                                               ! ================= !
    END SUBROUTINE tra_zdf_imp
 
-- 
GitLab