update_shift Subroutine

public subroutine update_shift(iter_data, replica_pairs)

Uses

Arguments

Type IntentOptional Attributes Name
type(fcimc_iter_data), intent(in) :: iter_data
logical, intent(in) :: replica_pairs

Contents

Source Code


Source Code

    subroutine update_shift(iter_data, replica_pairs)

        use CalcData, only: tInstGrowthRate, tL2GrowRate

        type(fcimc_iter_data), intent(in) :: iter_data
        logical, intent(in) :: replica_pairs
        integer(int64) :: tot_walkers
        logical, dimension(inum_runs) :: tReZeroShift
        real(dp), dimension(inum_runs) :: AllGrowRateRe, AllGrowRateIm
        real(dp), dimension(inum_runs)  :: AllHFGrowRate, AllWalkers
        real(dp), dimension(lenof_sign) :: denominator, all_denominator
        real(dp), dimension(inum_runs) :: rel_tot_trial_numerator
        integer :: error, i, proc, pos, run, lb, ub
        logical, dimension(inum_runs) :: defer_update
        logical :: start_varying_shift
        character(*), parameter :: this_routine = 'update_shift'

        ! Normally we allow the shift to vary depending on the conditions
        ! tested. Sometimes we want to defer this to the next cycle...
        defer_update(:) = .false.

        ! collate_iter_data --> The values used are only valid on Root
        i_am_root : if (iProcIndex == Root) then

            if (tL2GrowRate) then
                ! use the L2 norm to determine the growrate
                AllGrowRate(:) = norm_psi(:) / old_norm_psi(:)
                AllWalkers(:) = norm_psi(:)

            else if (tInstGrowthRate) then

                ! Calculate the growth rate simply using the two points at
                ! the beginning and the end of the update cycle.
                do run = 1, inum_runs
                    lb = min_part_type(run)
                    ub = max_part_type(run)
                    AllGrowRate(run) = (sum(iter_data%update_growth_tot(lb:ub) &
                                            + iter_data%tot_parts_old(lb:ub))) &
                                       / real(sum(iter_data%tot_parts_old(lb:ub)), dp)
                    AllWalkers(run) = (sum(iter_data%update_growth_tot(lb:ub) &
                                            + iter_data%tot_parts_old(lb:ub)))
                end do

            else

                ! Instead attempt to calculate the average growth over every
                ! iteration over the update cycle
                if (all(.not. near_zero(OldAllAvWalkersCyc))) then
                    AllGrowRate(:) = AllSumWalkersCyc(:) / real(StepsSft, dp) &
                                       / OldAllAvWalkersCyc(:)
                end if
                AllWalkers(:) = AllSumWalkersCyc(:) / real(StepsSft, dp)

            end if
            ! For complex case, obtain both Re and Im parts
#ifdef CMPLX_
            do run = 1, inum_runs
                lb = min_part_type(run)
                ub = max_part_type(run)
                if (.not. near_zero(iter_data%tot_parts_old(lb))) then
                    AllGrowRateRe(run) = &
                        (iter_data%update_growth_tot(lb) + iter_data%tot_parts_old(lb)) &
                        / iter_data%tot_parts_old(lb)
                end if
                if (.not. near_zero(iter_data%tot_parts_old(ub))) then
                    AllGrowRateIm(run) = &
                        (iter_data%update_growth_tot(ub) + iter_data%tot_parts_old(ub)) &
                        / iter_data%tot_parts_old(ub)
                end if
            end do
#endif
            ! If any run uses the fixtrial option, we need to add the offset to the
            ! trial numerator
            if (tTrialWavefunction .and. tTrialShift) &
                rel_tot_trial_numerator = real(relative_trial_numerator( &
                          tot_trial_numerator, tot_trial_denom, replica_pairs), dp)

            ! Exit the single particle phase if the number of walkers exceeds
            ! the value in the input file. If particle no has fallen, re-enter
            ! it.
            tReZeroShift = .false.
            do run = 1, inum_runs
                lb = min_part_type(run)
                ub = max_part_type(run)

                if (tTrialShift .and. .not. tFixTrial(run) .and. tTrialWavefunction .and. abs(tot_trial_denom(run)) >= TrialTarget) then
                    !When reaching target overlap with trial wavefunction, set flag to keep it fixed.
                    tFixTrial(run) = .True.

                    write(stdout, '(a,i13,a,i1)') 'Exiting the varaible shift phase on iteration: ' &
                        , iter + PreviousCycles, ' - overlap with trial wavefunction of the following run is now fixed: ', run
                end if

                if (tFixedN0) then
                    if (.not. tSkipRef(run) .and. abs(AllHFCyc(run)) >= N0_Target) then
                        !When reaching target N0, set flag to keep the population of reference det fixed.
                        tSkipRef(run) = .True.

                        write(stdout, '(a,i13,a,i1)') 'Exiting the fixed shift phase on iteration: ' &
                            , iter + PreviousCycles, ' - reference population of the following run is now fixed: ', run
                        !Set these parameters because other parts of the code depends on them
                        VaryShiftIter(run) = Iter
                        iBlockingIter(run) = Iter + PreviousCycles
                        tSinglePartPhase(run) = .false.
                    end if

                    if (tSkipRef(run)) then
                        !Use the projected energy as the shift to fix the
                        !population of the reference det and thus reduce the
                        !fluctuations of the projected energy.

                        !ToDo: Make DiafSft complex
                        DiagSft(run) = real((AllENumCyc(run)) &
                            / (AllHFCyc(run)) + proje_ref_energy_offsets(run), dp)

                        ! Update the shift averages
                        if ((iter - VaryShiftIter(run)) >= nShiftEquilSteps) then
                            if ((iter - VaryShiftIter(run) - nShiftEquilSteps) < StepsSft) &
                                write(stdout, '(a,i14)') 'Beginning to average shift value on iteration: ', iter + PreviousCycles
                            VaryShiftCycles(run) = VaryShiftCycles(run) + 1
                            SumDiagSft(run) = SumDiagSft(run) + DiagSft(run)
                            AvDiagSft(run) = SumDiagSft(run) / real(VaryShiftCycles(run), dp)
                        end if
                    else
                        !Keep shift equal to input till target reference population is reached.
                        DiagSft(run) = InputDiagSft(run)
                    end if

                else if (tFixTrial(run)) then
                    !Use the trial energy as the shift to fix the
                    !overlap with trial wavefunction and thus reduce the
                    !fluctuations of the trial energy.

                    !ToDo: Make DiafSft complex
                    DiagSft(run) = real((rel_tot_trial_numerator(run) / tot_trial_denom(run)) - Hii, dp)

                    ! Update the shift averages
                    if ((iter - VaryShiftIter(run)) >= nShiftEquilSteps) then
                        if ((iter - VaryShiftIter(run) - nShiftEquilSteps) < StepsSft) &
                            write(stdout, '(a,i14)') 'Beginning to average shift value on iteration: ', iter + PreviousCycles
                        VaryShiftCycles(run) = VaryShiftCycles(run) + 1
                        SumDiagSft(run) = SumDiagSft(run) + DiagSft(run)
                        AvDiagSft(run) = SumDiagSft(run) / real(VaryShiftCycles(run), dp)
                    end if

                else !not Fixed-N0 and not Trial-Shift
                    tot_walkers = int(InitWalkers, int64) * int(nNodes, int64)
                    single_part_phase : if (TSinglePartPhase(run)) then

#ifdef CMPLX_
                        if ((sum(AllTotParts(lb:ub)) > tot_walkers) .or. &
                            (abs_sign(AllNoatHF(lb:ub)) > MaxNoatHF)) then
                            write(stdout, '(a,i13,a)') 'Exiting the single particle growth phase on iteration: ', iter + PreviousCycles, &
                                ' - Shift can now change'
                            VaryShiftIter(run) = Iter
                            iBlockingIter(run) = Iter + PreviousCycles
                            tSinglePartPhase(run) = .false.
                            if (abs(TargetGrowRate(run)) > EPS) then
                                write(stdout, "(A)") "Setting target growth rate to 1."
                                TargetGrowRate = 0.0_dp
                            end if

                            ! If enabled, jump the shift to the value preducted by the
                            ! projected energy!
                            if (tJumpShift) then
                                if (tJumpShift .and. &
                                    (.not. (isnan(real(proje_iter(run), dp))) .or. &
                                     .not. (is_inf(real(proje_iter(run), dp))))) then
                                    DiagSft(run) = real(proje_iter(run), dp)
                                    defer_update(run) = .true.
                                end if
                            end if
                        end if
#else
                        start_varying_shift = .false.
                        if (tLetInitialPopDie) then
                            if (AllTotParts(run) < tot_walkers) start_varying_shift = .true.
                        else if (tTargetShiftdamp) then
                            start_varying_shift = .true.
                        else
                            if ((AllTotParts(run) > tot_walkers) .or. &
                                (abs(AllNoatHF(run)) > MaxNoatHF)) start_varying_shift = .true.
                        end if

                        if (start_varying_shift) then
                            write(stdout, '(a,i13,a,i1)') 'Exiting the single particle growth phase on iteration: ' &
                                , iter + PreviousCycles, ' - Shift can now change for population', run
                            VaryShiftIter(run) = Iter
                            iBlockingIter(run) = Iter + PreviousCycles
                            tSinglePartPhase(run) = .false.
                            ! [W.D. 15.5.2017]
                            ! change equal 0 comps
                            if (abs(TargetGrowRate(run)) > EPS) then
                                write(stdout, "(A)") "Setting target growth rate to 1."
                                TargetGrowRate(run) = 0.0_dp
                            end if

                            ! If enabled, jump the shift to the value preducted by the
                            ! projected energy!
                            if (tJumpShift) then
                                DiagSft(run) = real(proje_iter(run), dp)
                                defer_update(run) = .true.
                            end if
                        end if
#endif
                    else ! .not.tSinglePartPhase(run)

#ifdef CMPLX_
                        if (abs_sign(AllNoatHF(lb:ub)) < MaxNoatHF - HFPopThresh) then
#else
                        if (abs(AllNoatHF(run)) < MaxNoatHF - HFPopThresh) then
#endif
                            write(stdout, '(a,i13,a)') 'No at HF has fallen too low - reentering the &
                                         &single particle growth phase on iteration', iter + PreviousCycles, ' - particle number &
                                         &may grow again.'
                            tSinglePartPhase(run) = .true.
                            tReZeroShift(run) = .true.
                        end if

                    end if single_part_phase

                    ! How should the shift change for the entire ensemble of walkers
                    ! over all processors.

                    if (.not. (tSinglePartPhase(run) &
                               .and. near_zero(TargetGrowRate(run)) &
                               .or. defer_update(run))) then

                        !In case we want to continue growing, TargetGrowRate > 0.0_dp
                        ! New shift value
                        !                     if(TargetGrowRate(run).ne.0.0_dp) then
                        ! [W.D. 15.5.2017]
                        if (abs(TargetGrowRate(run)) > EPS) then
#ifdef CMPLX_
                            if (sum(AllTotParts(lb:ub)) > TargetGrowRateWalk(run)) then
#else
                            if (AllTotParts(run) > TargetGrowRateWalk(run)) then
#endif
                                if (tTargetShiftdamp) then
                                    call stop_all(this_routine, &
                                        "Target-shiftdamp not compatible with targetgrowrate!")
                                end if
                                !Only allow targetgrowrate to kick in once we have > TargetGrowRateWalk walkers.
                                DiagSft(run) = DiagSft(run) - (log(AllGrowRate(run) - TargetGrowRate(run)) * SftDamp) / &
                                               (Tau * StepsSft)
                                ! Same for the info shifts for complex walkers
#ifdef CMPLX_
                                DiagSftRe(run) = DiagSftRe(run) - (log(AllGrowRateRe(run) - TargetGrowRate(run)) * SftDamp) / &
                                                 (Tau * StepsSft)
                                DiagSftIm(run) = DiagSftIm(run) - (log(AllGrowRateIm(run) - TargetGrowRate(run)) * SftDamp) / &
                                                 (Tau * StepsSft)
#endif
                            end if
                        else
                            if (tShiftonHFPop) then
                                !Calculate the shift required to keep the HF population constant

                                AllHFGrowRate(run) = abs(AllHFCyc(run) / real(StepsSft, dp)) / abs(OldAllHFCyc(run))
                                if (.not. near_zero(AllHFGrowRate(run))) then
                                    DiagSft(run) = DiagSft(run) - (log(AllHFGrowRate(run)) * SftDamp) / &
                                                   (Tau * StepsSft)
                                else
                                    call stop_all(this_routine, "Shift undefined because HF growth rate is zero. Aborting.")
                                end if
                            else if (tTargetShiftdamp) then
                                if (.not. near_zero(AllGrowRate(run)) .and. .not. near_zero(AllWalkers(run))) then
                                    DiagSft(run) = DiagSft(run) - (log(AllGrowRate(run)) * SftDamp + &
                                                   log(AllWalkers(run)/tot_walkers) * SftDamp2) / &
                                                   (Tau * StepsSft)
                                else
                                    call stop_all(this_routine, "Shift undefined because walker growth rate is zero. Aborting.")
                                end if
                            else
                                if (.not. near_zero(AllGrowRate(run))) then
                                    DiagSft(run) = DiagSft(run) - (log(AllGrowRate(run)) * SftDamp) / &
                                                   (Tau * StepsSft)
                                else
                                    call stop_all(this_routine, "Shift undefined because walker growth rate is zero. Aborting.")
                                end if
                            end if
                        end if

                        ! Update the shift averages
                        if ((iter - VaryShiftIter(run)) >= nShiftEquilSteps) then
                            if ((iter - VaryShiftIter(run) - nShiftEquilSteps) < StepsSft) &
                                write(stdout, '(a,i14)') 'Beginning to average shift value on iteration: ', iter + PreviousCycles
                            VaryShiftCycles(run) = VaryShiftCycles(run) + 1
                            SumDiagSft(run) = SumDiagSft(run) + DiagSft(run)
                            AvDiagSft(run) = SumDiagSft(run) / real(VaryShiftCycles(run), dp)
                        end if

                    end if
                end if !tFixedN0 or not
                ! only update the shift this way if possible
                if (abs_sign(AllNoatHF(lb:ub)) > EPS) then
#ifdef CMPLX_
                    ! Calculate the instantaneous 'shift' from the HF population
                    HFShift(run) = -1.0_dp / abs_sign(AllNoatHF(lb:ub)) * &
                                   (abs_sign(AllNoatHF(lb:ub)) - abs_sign(OldAllNoatHF(lb:ub)) / &
                                    (Tau * real(StepsSft, dp)))
                    InstShift(run) = -1.0_dp / sum(AllTotParts(lb:ub)) * &
                                     ((sum(AllTotParts(lb:ub)) - sum(AllTotPartsOld(lb:ub))) / &
                                      (Tau * real(StepsSft, dp)))
#else
                    ! Calculate the instantaneous 'shift' from the HF population
                    HFShift(run) = -1.0_dp / abs(AllNoatHF(run)) * &
                                   (abs(AllNoatHF(run)) - abs(OldAllNoatHF(run)) / &
                                    (Tau * real(StepsSft, dp)))
                    InstShift(run) = -1.0_dp / AllTotParts(run) * &
                                     ((AllTotParts(run) - AllTotPartsOld(run)) / &
                                      (Tau * real(StepsSft, dp)))
#endif
                end if

                ! When using a linear combination, the denominator is summed
                ! directly.
                all_sum_proje_denominator(run) = ARR_RE_OR_CPLX(AllSumNoatHF, run)
                all_cyc_proje_denominator(run) = AllHFCyc(run)

                ! Calculate the projected energy.

                if (.not. near_zero(AllSumNoatHF(run))) then
                    ProjectionE(run) = (AllSumENum(run)) / (all_sum_proje_denominator(run)) &
                                       + proje_ref_energy_offsets(run)
                end if
                if (abs(AllHFCyc(run)) > EPS) then
                    proje_iter(run) = (AllENumCyc(run)) / (all_cyc_proje_denominator(run)) &
                                      + proje_ref_energy_offsets(run)
                    AbsProjE(run) = (AllENumCycAbs(run)) / (all_cyc_proje_denominator(run)) &
                                    + proje_ref_energy_offsets(run)
                    inits_proje_iter(run) = (AllInitsENumCyc(run)) / (all_cyc_proje_denominator(run)) &
                                            + proje_ref_energy_offsets(run)
                end if
                ! If we are re-zeroing the shift
                if (tReZeroShift(run)) then
                    DiagSft(run) = 0.0_dp
                    VaryShiftCycles(run) = 0
                    SumDiagSft(run) = 0.0_dp
                    AvDiagSft(run) = 0.0_dp
                end if
            end do

            ! Get some totalled values
            if (abs(sum(all_sum_proje_denominator(1:inum_runs))) > EPS) then
                projectionE_tot = sum(AllSumENum(1:inum_runs)) &
                                  / sum(all_sum_proje_denominator(1:inum_runs))
            end if
            if (abs(sum(all_cyc_proje_denominator(1:inum_runs))) > EPS) then
                proje_iter_tot = sum(AllENumCyc(1:inum_runs)) &
                                 / sum(all_cyc_proje_denominator(1:inum_runs))
                inits_proje_iter_tot = sum(AllInitsENumCyc(1:inum_runs)) &
                                       / sum(all_cyc_proje_denominator(1:inum_runs))
            end if

        end if i_am_root

        ! Broadcast the shift from root to all the other processors
        call MPIBcast(tSinglePartPhase)
        call MPIBcast(VaryShiftIter)
        call MPIBcast(DiagSft)
        call MPIBcast(tSkipRef)
        call MPIBcast(tFixTrial)
        call MPIBcast(VaryShiftCycles)
        call MPIBcast(SumDiagSft)
        call MPIBcast(AvDiagSft)

        do run = 1, inum_runs
            if (.not. tSinglePartPhase(run)) then
                TargetGrowRate(run) = 0.0_dp
            end if
        end do

        if (tau_search_method /= possible_tau_search_methods%off) then
            if (end_of_search_reached(tau_search_method, tau_stop_method)) then
                call stop_tau_search(tau_stop_method)
            end if
        end if
    end subroutine update_shift