wrf svn trunk commit r4103
[wrffire.git] / wrfv2_fire / var / da / da_minimisation / da_minimise_cg.inc
blob43bda9c46b1175a69a9ed16641515bdd790d14fa
1 subroutine da_minimise_cg(grid, config_flags,            &
2                            it, cv_size, xbx, be, iv, &
3                            j_grad_norm_target, xhat, cv, &
4                            re, y, j_cost)
6    !-------------------------------------------------------------------------
7    ! Purpose:         Main Conjugate Gradient minimisation routine 
8    !
9    ! Here 
10    !    cv   is updated in outer-loop.
11    !    xhat is the control variable in inner-loop.
12    !
13    ! Called from da_solve
14    !
15    ! History: 12/12/08 - Split J and GradJ calculations (Tom Auligne)
16    !          12/12/08 - Re-orthonormalization option   (Tom Auligne)
17    !
18    !-------------------------------------------------------------------------
20    implicit none
22    integer, intent(in)               :: it    ! external iteration.
23    integer, intent(in)               :: cv_size          ! Total cv size
24    type (xbx_type),intent(in)        :: xbx   ! Header & non-gridded vars.
25    type (be_type), intent(in)        :: be    ! background error structure.
26    type (iv_type), intent(inout)     :: iv    ! ob. increment vector.
27    real, intent(inout)               :: j_grad_norm_target ! Target norm.
28    real, intent(inout)               :: xhat(1:cv_size)  ! control variable (local).
29    real, intent(inout)               :: cv(1:cv_size)    ! control variable (local).
30    type (y_type), intent(inout)      :: re    ! residual (o-a) structure.
31    type (y_type), intent(inout)      :: y     ! y = H(x_inc) structure.
33    type (j_type), intent(out)        :: j_cost                 ! cost function
35    type(domain), intent(inout)       :: grid
36    type(grid_config_rec_type), intent(inout) :: config_flags
38    integer                           :: iter            
39    integer                           :: je_start, je_end       ! Start/end indices of Je.
40    integer                           :: cv_size_jb             ! end indices of Jb.
41    integer                           :: mz(7)
42    real                              :: fhat(1:cv_size)        ! cv copy.
43    real                              :: ghat(1:cv_size)        ! cv copy.
44    real                              :: ghat0(1:cv_size)       ! cv copy.
45    real                              :: phat(1:cv_size)        ! cv copy.
46    real, allocatable                 :: qhat(:,:)              ! cv copy.
47    real                              :: apdotp,step,rrmold,rrmnew,ratio 
48    real                              :: ob_grad, rrmnew_norm, gdot
49    real                              :: j_total    
51    ! Variables for Conjugate Gradient preconditioning
52    real                              :: precon(1:cv_size)      ! cv copy.
53    real                              :: g_total, g_partial, jo_partial                          
54    integer                           :: i, ii, nv, nn, istart, iend, sz(5)
55       
56    if (trace_use) call da_trace_entry("da_minimise_cg")
58    write(unit=stdout,fmt='(A)') 'Minimize cost function using CG method'
59    if (calculate_cg_cost_fn) then
60       write(unit=stdout,fmt='(A)') &
61          'For this run cost function diagnostics will be written'
62    else
63       write(unit=stdout,fmt='(A)') &
64          'For this run cost function diagnostics will not be written'
65    end if
66    write(unit=stdout,fmt=*) ' '
68    !-------------------------------------------------------------------------
69    ! [1.0] Initialization:
70    !-------------------------------------------------------------------------
71    mz = (/ be%v1%mz, be%v2%mz, be%v3%mz, be%v4%mz, be%v5%mz, be%alpha%mz, be % ne /)
72    sz = (/ be%cv%size1, be%cv%size2, be%cv%size3, be%cv%size4, be%cv%size5 /)
73    
74    call da_calculate_j(it, 0, cv_size, be % cv % size_jb, be % cv % size_je, &
75                         be % cv % size_jp, xbx, be, iv, xhat, cv, &
76                         re, y, j_cost, grid, config_flags)
78    call da_calculate_gradj(-it, 0, cv_size, be % cv % size_jb, be % cv % size_je, &
79                         be % cv % size_jp, xbx, be, iv,  xhat, cv, re, y, ghat, grid, config_flags)
80    ghat0 = ghat
81    
82    ! [1.1] Preconditioning:
83    !-----------------------
84    precon  = 1.0
85    
86    if (precondition_cg) then
87       g_total = da_dot(cv_size,ghat,ghat)
88       
89       iend    = 0
90       do nv = 1, 5
91          nn = sz(nv) / mz(nv)
92          do ii = 1, mz(nv)
93             istart     = iend + 1
94             iend       = istart + nn - 1
95             g_partial  = da_dot(nn, ghat(istart:iend), ghat(istart:iend))
96             jo_partial = j_cost%total / SUM(mz(1:5))
98             precon(istart:iend)=  1 / &
99                (1 + precondition_factor*(g_partial/g_total)/(jo_partial/j_cost%total)) 
100          end do
101       end do
102    end if
103    
104    phat  = - precon * ghat
106    rrmold = da_dot_cv(cv_size, cv_size_domain, -phat, ghat, grid, mz, use_varbc)
108    if (j_cost%total == 0.0) return
110    !if (it == 1) j_grad_norm_target = sqrt (rrmold)
111    j_grad_norm_target = sqrt (rrmold)
113    if (orthonorm_gradient) then
114       allocate(qhat(1:cv_size, 0:ntmax))
115       qhat(:,0) = ghat / rrmold
116    end if
118    write(unit=stdout,fmt='("Starting outer iteration : ",i3)') it
119    write(unit=stdout,fmt=11) j_cost%total, sqrt(rrmold), eps(it)*j_grad_norm_target
120 11 format('Starting cost function: ' ,1PD15.8,', Gradient= ',1PD15.8,/,&
121           'For this outer iteration gradient target is:       ',1PD15.8)
122    write(unit=stdout,fmt='(A)') &
123       '----------------------------------------------------------'
124    if (calculate_cg_cost_fn) then
125       write(unit=stdout,fmt='(A)') &
126          'Iter    Cost Function         Gradient             Step'
127    else
128       write(unit=stdout,fmt='(A)')'Iter      Gradient             Step'
129    end if
131    !-------------------------------------------------------------------------
132    ! [2.0] iteratively solve for minimum of cost function:
133    !-------------------------------------------------------------------------
135    do iter=1, ntmax
136       if (rrmold == 0.0) exit
138       call da_calculate_gradj(it, iter, cv_size, be%cv%size_jb, be%cv%size_je, be%cv%size_jp, &
139                               xbx, be, iv, phat, cv, re, y, fhat, grid, config_flags  )                          
140       
141       apdotp = da_dot_cv(cv_size, cv_size_domain, fhat, phat, grid, mz, use_varbc)
143       step = 0.0
144       if (apdotp .gt. 0.0) step = rrmold/apdotp
145       
146       ghat = ghat + step * fhat
147       xhat = xhat + step * phat
148       
149     ! Orthonormalize new gradient (using modified Gramm-Schmidt algorithm)
150       if (orthonorm_gradient) then
151          do i = iter-1, 0, -1
152             gdot = da_dot_cv(cv_size, cv_size_domain, ghat, qhat(:,i), grid, mz, use_varbc)
153             ghat = ghat - gdot * qhat(:,i)
154          end do
155       end if
156       
157       rrmnew = da_dot_cv (cv_size, cv_size_domain, precon*ghat, ghat, grid, &
158                           mz, use_varbc)
159                           
160       rrmnew_norm = sqrt(rrmnew)
162       if (rrmnew_norm  < eps(it) * j_grad_norm_target) exit
163       ratio = 0.0
164       if (rrmold .gt. 0.0) ratio = rrmnew/rrmold
166       if (orthonorm_gradient) qhat(:,iter) = ghat / rrmnew_norm
167       phat         = - precon * ghat       + ratio * phat
169       rrmold=rrmnew
171     ! Print Gradient (and Cost Function)
172     !-----------------------------------
173       if (print_detail_grad) then
174          call da_calculate_j(it, iter, cv_size, be % cv % size_jb, be % cv % size_je, &
175                              be % cv % size_jp, xbx, be, iv, xhat, cv, &
176                              re, y, j_cost, grid, config_flags)
177          call da_calculate_gradj(-it, iter, cv_size, be%cv%size_jb, be%cv%size_je, be%cv%size_jp, &
178                                  xbx, be, iv, xhat, cv, re, y, fhat, grid, config_flags  )                      
179          write(unit=stdout,fmt=12)iter, j_cost%total, rrmnew_norm, step
180       elseif (calculate_cg_cost_fn) then                 
181          j_total = j_cost%total + 0.5 * da_dot_cv(cv_size,cv_size_domain,ghat0,xhat,grid,mz,use_varbc)
182          write(unit=stdout,fmt=12)iter, j_total, rrmnew_norm, step               
183       else
184          write(unit=stdout,fmt=14)iter, rrmnew_norm , step
185       end if
187 12    format(i3,5x,1PD15.8,5x,1PD15.8,5x,1PD15.8)
188 14    format(i3,5x,1PD15.8,5x,1PD15.8)
189    end do
191    !-------------------------------------------------------------------------
192    ! End of the minimization of cost function
193    !-------------------------------------------------------------------------
194    iter = MIN(iter, ntmax)
195    if (orthonorm_gradient) deallocate(qhat)
196    
197    write(unit=stdout,fmt='(A)') &
198       '----------------------------------------------------------'
199    write(unit=stdout,fmt='(A)') " "
200    write(unit=stdout, &
201       fmt='("Inner iteration stopped after ",i4," iterations")') iter
202    write(unit=stdout,fmt='(A)') " "
204    call da_calculate_j(it, iter, cv_size, be % cv % size_jb, &
205          be % cv % size_je, be % cv % size_jp, xbx, be, iv, xhat, cv, &
206          re, y, j_cost,grid, config_flags)
208    call da_calculate_gradj(-it, iter, cv_size, be % cv % size_jb, be % cv % size_je, &
209                            be % cv % size_jp, xbx, be, iv,  xhat, cv, &
210                            re, y, ghat, grid, config_flags)
211         
212    rrmnew_norm = SQRT(da_dot_cv(cv_size,cv_size_domain,ghat,ghat,grid,mz,use_varbc))
214     write(unit=stdout,fmt=15) iter, j_cost%total , rrmnew_norm
215 15  format('Final: ',I3,' iter, J=',1PD15.8,', g=',1PD15.8)
216     write(unit=stdout,fmt='(A)') &
217       '----------------------------------------------------------'
219    cv = cv + xhat
221    if (trace_use) call da_trace_exit("da_minimise_cg")
223 end subroutine da_minimise_cg