fcimc_iter_utilities.F90 Source File


Source Code

#include "macros.h"

module fcimc_iter_utils

    use SystemData, only: nel, tHPHF, tNoBrillouin, tRef_Not_HF
    use CalcData, only: tSemiStochastic, tChangeProjEDet, tTrialWavefunction, &
                        tCheckHighestPopOnce, tRestartHighPop, StepsSft, &
                        tTruncInitiator, tJumpShift, TargetGrowRate, &
                        tLetInitialPopDie, InitWalkers, tCheckHighestPop, &
                        HFPopThresh, DiagSft, tShiftOnHFPop, iRestartWalkNum, &
                        FracLargerDet, tKP_FCIQMC, MaxNoatHF, SftDamp, SftDamp2, &
                        nShiftEquilSteps, TargetGrowRateWalk, tContTimeFCIMC, &
                        tContTimeFull, pop_change_min, tPositiveHFSign, &
                        qmc_trial_wf, nEquilSteps, &
                        tSkipRef, N0_Target, tSpinProject, &
                        tFixedN0, tEN2, tTrialShift, tFixTrial, TrialTarget, &
                        tDynamicAvMCEx, AvMCExcits, tTargetShiftdamp

    use tau_main, only: t_scale_tau_to_death, scale_tau_to_death_triggered, &
        tau_search_method, input_tau_search_method, possible_tau_search_methods, &
        tau_stop_method, end_of_search_reached, scale_tau_to_death, tau, &
    use tau_search_hist, only: t_fill_frequency_hists

    use cont_time_rates, only: cont_spawn_success, cont_spawn_attempts
    use LoggingData, only: tPrintDataTables, tLogEXLEVELStats, t_spin_measurements
    use semi_stoch_procs, only: recalc_core_hamil_diag
    use bit_rep_data, only: NIfD, NIfTot, test_flag, test_flag_multi
    use hphf_integrals, only: hphf_diag_helement
    use Determinants, only: get_helement
    use LoggingData, only: tFCIMCStats2, t_calc_double_occ, t_calc_double_occ_av, &
                           AllInitsPerExLvl, initsPerExLvl, tCoupleCycleOutput, &
    use tau_search_conventional, only: update_tau
    use rdm_data, only: en_pert_main, InstRDMCorrectionFactor
    use Parallel_neci
    use fcimc_initialisation
    use fcimc_output
    use fcimc_helper
    use FciMCData
    use constants
    use util_mod
    use real_time_procs, only: normalize_gf_overlap
    use real_time_data, only: current_overlap, overlap_real, overlap_imag, &
    use real_time_data, only: allPopSnapshot, popSnapshot
    use double_occ_mod, only: inst_double_occ, all_inst_double_occ, sum_double_occ, &
                              sum_norm_psi_squared, all_inst_spatial_doub_occ, &
                              rezero_double_occ_stats, rezero_spin_diff, &
                              all_inst_spin_diff, inst_spin_diff, inst_spatial_doub_occ

    use tau_search_hist, only: update_tau_hist

    use local_spin, only: all_local_spin, inst_local_spin, rezero_local_spin_stats

    use PopsfileMod, only: ChangeRefDet

    implicit none


    subroutine output_diagnostics()
        ! Updates Time measures and gets the acceptance rate printed in output

        ! Update the total imaginary time passed
        TotImagTime = TotImagTime + StepsPrint * Tau

        ! Set Iter time to equal the average time per iteration in the
        ! previous update cycle.
        IterTime = IterTime / real(StepsPrint, sp)

        ! Do the same averaging for allNValidExcits and allNInvalidExcits
        allNValidExcits = nint(real(allNValidExcits, dp) / real(StepsPrint, dp), int64)
        allNInvalidExcits = nint(real(allNInvalidExcits, dp) / real(StepsPrint, dp), int64)

        ! Calculate the acceptance ratio
        if (tContTimeFCIMC .and. .not. tContTimeFull) then
            if (.not. near_zero(real(cont_spawn_attempts))) then
                AccRat = real(cont_spawn_success) / real(cont_spawn_attempts)
                AccRat = 0.0_dp
            end if
            if (.not. any(near_zero(AllSumWalkersOut))) then
                AccRat = real(AllAcceptances, dp) / AllSumWalkersOut
                AccRat = 0.0_dp
            end if
        end if
    end subroutine output_diagnostics

    subroutine iter_diagnostics()

        character(*), parameter :: this_routine = 'iter_diagnostics'
        character(*), parameter :: t_r = this_routine
        integer :: run, part_type

#ifndef CMPLX_
        if (tPositiveHFSign) then
            do part_type = 1, lenof_sign
                if ((.not. tFillingStochRDMonFly) .or. (inum_runs == 1)) then
                    if (AllNoAtHF(part_type) < 0.0_dp) then
                        root_print 'No. at HF < 0 - flipping sign of entire ensemble &
                                   &of particles in simulation: ', part_type
                        root_print AllNoAtHF(part_type)

                        ! And do the flipping
                        call FlipSign(part_type)
                        AllNoatHF(part_type) = -AllNoatHF(part_type)
                        NoatHF(part_type) = -NoatHF(part_type)

                        if (tFillingStochRDMonFly) then
                            ! Want to flip all the averaged signs.
                            AvNoatHF = -AVNoatHF
                            InstNoatHF(part_type) = -InstNoatHF(part_type)
                        end if
                    end if
                end if
            end do
        end if
        if (iProcIndex == Root) then
            ! Have all of the particles died?
#ifdef CMPLX_
            tRestart = .false.
            do run = 1, inum_runs
                if (near_zero(sum(AllTotParts(min_part_type(run):max_part_type(run))))) then
                    call stop_all(t_r, "All particles have died. Aborting.")
                end if
            end do
            if (near_zero(AllTotParts(1)) .or. near_zero(AllTotParts(inum_runs))) then
                call stop_all(t_r, "All particles have died. Aborting.")
                tRestart = .false.
            end if
            !TODO CMO: Work out how to wipe the walkers on the second population if double run
        end if
        call MPIBCast(tRestart)
        if (tRestart) then
            ! a restart not wanted in the real-time fciqmc..
            !Initialise variables for calculation on each node
            CALL DeallocFCIMCMemPar()
            IF (iProcIndex == Root) THEN
                if (inum_runs == 2) close(fcimcstats_unit2)
                IF (tTruncInitiator) close(initiatorstats_unit)
                IF (tLogComplexPops) close(complexstats_unit)
                if (tLogEXLEVELStats) close(EXLEVELStats_unit)
            end if
            IF (TDebug) close(11)
            CALL SetupParameters()
            CALL InitFCIMCCalcPar()
            if (tFCIMCStats2) then
                call write_fcimcstats2(iter_data_fciqmc, initial=.true.)
                call write_fcimcstats2(iter_data_fciqmc)
                call WriteFciMCStatsHeader()
                ! Prepend a # to the initial status line so analysis doesn't pick up
                ! repetitions in the FCIMCStats or INITIATORStats files from restarts.
                if (iProcIndex == root) then
                    write(fcimcstats_unit, '("#")', advance='no')
                    if (inum_runs == 2) &
                        write(fcimcstats_unit2, '("#")', advance='no')
                    write(initiatorstats_unit, '("#")', advance='no')
                end if
                call WriteFCIMCStats()
            end if
            Iter = 1
            if (iProcIndex == root .and. tLogEXLEVELStats) &
                write(EXLEVELStats_unit, '("#")', advance='no')
        end if

        ! update the number of spawning attempts per walker
        if (tDynamicAvMCEx) then
            if (allNValidExcits /= 0) then
                ! we try to have approx. one valid excitation generated per walker
                AvMCExcits = (allNValidExcits + allNInvalidExcits) / (allNValidExcits)
                write(stdout, *) "Now spawning ", AvMCExcits, " times per walker"
            end if
        end if

    end subroutine iter_diagnostics

    subroutine population_check()

        use HPHFRandExcitMod, only: ReturnAlphaOpenDet

        integer(int32) :: pop_highest(inum_runs), proc_highest(inum_runs)
        real(dp) :: pop_change, old_Hii
        integer :: det(nel), i, error, ierr, run
        integer(int32) :: int_tmp(2)
        logical :: tSwapped, allocate_temp_parts, changed_any
        HElement_t(dp) :: h_tmp
        character(*), parameter :: this_routine = 'population_check'
        character(*), parameter :: t_r = this_routine

        ! If we aren't doing this, then bail out...
        if (.not. tCheckHighestPop) return

        ! If we are accumulating RDMs, then a temporary spawning array is
        ! required of <~ the size of the largest occupied det.
        ! This memory holds walkers spawned from one determinant. This
        ! allows us to test if we are spawning onto the same Dj multiple
        ! times. If only using connections to the HF (tHF_Ref_Explicit)
        ! no stochastic RDM construction is done, and this is not
        ! necessary.
        if (tRDMOnFly .and. .not. tExplicitAllRDM) then

            ! Test if we need to allocate or re-allocate the temporary
            ! spawned parts array
            allocate_temp_parts = .false.
            if (.not. allocated(TempSpawnedParts)) then
                allocate_temp_parts = .true.
                TempSpawnedPartsSize = 1000
            end if
            if (1.5 * maxval(iHighestPop) > TempSpawnedPartsSize) then
                ! This testing routine is only called once every update
                ! cycle. The 1.5 gives us a buffer to cope with particle
                ! growth
                TempSpawnedPartsSize = int(maxval(iHighestPop) * 1.5)
                allocate_temp_parts = .true.
                !write(stdout,*) 1.5 * maxval(iHighestPop), TempSpawnedPartsSize
            end if

            ! If we need to allocate this array, then do so.
            if (allocate_temp_parts) then
                if (allocated(TempSpawnedParts)) then
                end if
                allocate(TempSpawnedParts(0:nifd, TempSpawnedPartsSize), &
                          stat=ierr, source=0_n_int)
                call LogMemAlloc('TempSpawnedParts', size(TempSpawnedParts, kind=int64), size_per_element(TempSpawnedParts), &
                                 this_routine, TempSpawnedPartsTag, ierr)
                write (stdout, "(' Allocating temporary array for walkers spawned &
                           &from a particular Di.')")
                write(stdout, "(a,f14.6,a)") " This requires ", &
                    real(((nifd + 1) * TempSpawnedPartsSize * size_n_int), dp) &
                    / 1048576.0_dp, " Mb/Processor"
            end if

        end if ! Allocating memory for RDMs

        ! Obtain the determinant (and its processor) with the highest pop
        ! in each of the runs.
        ! n.b. the use of int(iHighestPop) obviously introduces a small amount
        !      of error here, by ignoring the fractional part...
        ! [Werner Dobrautz 15.5.2017:]
        ! maybe this samll error here is the cause of the failed test_suite
        ! runs..
        if (tReplicaReferencesDiffer) then

            do run = 1, inum_runs
                call MPIAllReduceDatatype( &
                    (/int(iHighestPop(run), int32), int(iProcIndex, int32)/), 1, &
                    MPI_MAXLOC, MPI_2INTEGER, int_tmp)
                pop_highest(run) = int_tmp(1)
                proc_highest(run) = int_tmp(2)
            end do

            call MPIAllReduceDatatype( &
                (/int(iHighestPop(1), int32), int(iProcIndex, int32)/), 1, &
                MPI_MAXLOC, MPI_2INTEGER, int_tmp)
            pop_highest = int_tmp(1)
            proc_highest = int_tmp(2)

        end if

        changed_any = .false.
        do run = 1, inum_runs

            ! If using the same reference for all, then we don't consider the
            ! populations seperately...
            if (run /= 1 .and. .not. tReplicaReferencesDiffer) &

            ! What are the change conditions?
#ifdef CMPLX_
            if (tReplicaReferencesDiffer) then
                pop_change = FracLargerDet * abs_sign(AllNoAtHF(min_part_type(run):max_part_type(run)))
                pop_change = FracLargerDet * abs_sign(AllNoAtHF(1:(lenof_sign / inum_runs)))
            end if
            if (tReplicaReferencesDiffer) then
                pop_change = FracLargerDet * abs(AllNoAtHF(run))
                pop_change = FracLargerDet * abs(AllNoAtHF(1))
            end if
!            write(stdout,*) "***",AllNoAtHF,FracLargerDet,pop_change, pop_highest,proc_highest
            ! Do we need to do a change?
            ! is this a valid comparison?? we ware comparing a real(dp) pop_change
            ! with a (now) 32 bit integer..
            if (pop_change < real(pop_highest(run), dp) .and. &
                real(pop_highest(run), dp) > pop_change_min) then

                if (tChangeProjEDet) then

                    ! Write out info!
                    changed_any = .true.
                    root_print 'Highest weighted determinant on run', run, &
                        'not reference det: ', pop_highest, abs_sign(AllNoAtHF( &

                    ! Here we are changing the reference det on the fly.
                    ! --> else block for restarting simulation.

                    ! Communicate the change to all dets and print out.
                    ! [W.D. 15.5.2017:]
                    ! we are typecasting here too..
                    ! we are casting a 32 bit int to a 64 bit ...
                    ! that could cause troubles!
!                     call MPIBcast (HighestPopDet(0:NIfTot, run), NIfTot+1, &
!                                    int(proc_highest(run),n_int))
                    call MPIBcast(HighestPopDet(0:NIfTot, run), NIfTot + 1, &

                    call update_run_reference(HighestPopDet(:, run), run)

                    ! Reset averages
                    SumENum = 0.0_dp
                    sum_proje_denominator = 0.0_dp
                    cyc_proje_denominator = 0.0_dp
                    SumNoatHF = 0.0_dp
                    VaryShiftCycles = 0
                    SumDiagSft = 0.0_dp
                    root_print 'Zeroing all energy estimators.'

                    !Since we have a new reference, we must block only from after this point
                    iBlockingIter = Iter + PreviousCycles

                    ! Reset values introduced in soft_exit (CHANGEVARS)
                    if (tCHeckHighestPopOnce) then
                        tChangeProjEDet = .false.
                        tCheckHighestPop = .false.
                        tCheckHighestPopOnce = .false.
                    end if

                    ! Or are we restarting the calculation with the reference
                    ! det switched?
#ifdef CMPLX_
                else if (tRestartHighPop .and. &
                        iRestartWalkNum < sum(AllTotParts(1:2))) then
                else if (tRestartHighPop .and. &
                        iRestartWalkNum < AllTotParts(1)) then
                    ! Here we are restarting the simulation with a new
                    ! reference. See above block for doing it on the fly.

                    ! Broadcast the changed det to all processors
!                     call MPIBcast (HighestPopDet(:,run), NIfTot+1, &
!                                    int(proc_highest(run),n_int))
                    call MPIBcast(HighestPopDet(:, run), NIfTot + 1, &

                    call update_run_reference(HighestPopDet(:, run), run)

                    ! Only update the global reference energies if they
                    ! correspond to run 1 (which is used for those)
                    if (run == 1) then
                        call ChangeRefDet(ProjEDet(:, 1))
                    end if

                    ! Reset values introduced in soft_exit (CHANGEVARS)
                    if (tCHeckHighestPopOnce) then
                        tChangeProjEDet = .false.
                        tCheckHighestPop = .false.
                        tCheckHighestPopOnce = .false.
                    end if

                end if

            end if
        end do

    end subroutine population_check

    subroutine communicate_estimates(iter_data, tot_parts_new, tot_parts_new_all, t_output)

        ! This routine sums all estimators and stats over all processes.

        ! We want this to be done in as few MPI calls as possible. Therefore, all
        ! quantities are first placed into one of two arrays. There is one array
        ! for HElement_t(dp) estimates, and another array (real(dp)) for all other
        ! kinds and types. A single MPISumAll is then performed for both of these
        ! combined arrays. After communication, the summed results are copied to
        ! the appropriate final arrays, with the type and kind corrected when
        ! necessary.

        ! There are also a few separate MPI calls which reduce using MPI_MAX and
        ! MPI_MIN at the end.

        ! -- *IMPORTANT FOR DEVELOPERS* ---------------------------------------
        ! To add a new quantity to be communicated, you must give it a new entry
        ! in send_arr or send_arr_helem array. It is hopefully clear how to do
        ! this by analogy. You should also update the indices in the appropriate
        ! stop_all, so that it can be checked if enough memory has been assigned.
        use adi_data, only: nCoherentDoubles, nIncoherentDets, nConnection, &
                            AllCoherentDoubles, AllIncoherentDets, AllConnection
        type(fcimc_iter_data) :: iter_data
        real(dp), intent(in) :: tot_parts_new(lenof_sign)
        real(dp), intent(out) :: tot_parts_new_all(lenof_sign)

        ! RT_M_Merge: Added real-time statistics for the newer communication scheme
        integer, parameter :: real_arr_size = 2000
        integer, parameter :: hel_arr_size = 200
        ! RT_M_Merge: Doubled all array sizes since there are now two
        ! copies of most of the variables (necessary?)

        ! Allow room to send up to 1000 (2000 for rt) elements.
        real(dp) :: send_arr(real_arr_size)
        ! Allow room to receive up to 2000 elements.
        real(dp) :: recv_arr(real_arr_size)
        logical, intent(in) :: t_output
        logical :: t_comm_trial
        ! Equivalent arrays for HElement_t variables.
        integer, parameter :: arr_helem_size = 300
        HElement_t(dp) :: send_arr_helem(arr_helem_size)
        HElement_t(dp) :: recv_arr_helem(arr_helem_size)
        ! Equivalent arrays for EXLEVELStats (of exactly required size).
        real(dp) :: send_arr_WNorm(3 * (NEl + 1) * inum_runs), &
                    recv_arr_WNorm(3 * (NEl + 1) * inum_runs)
        ! Allow room for 100 different arrays to be communicated.
        integer, parameter :: size_arr_size = 100
        integer :: sizes(size_arr_size)
        integer :: low, upp, run

        integer(int64) :: TotWalkersTemp
        ! [W.D.12.12.2017]
        ! allow for triples now:
        ! Todo: make that more flexible in the future!
        real(dp) :: bloom_sz_tmp(0:3)
        real(dp) :: RealAllHFCyc(max(lenof_sign, inum_runs))
        real(dp) :: RealAllHFOut(max(lenof_sign, inum_runs))
        real(dp) :: all_norm_semistoch_squared(inum_runs)
        integer :: NoArrs
        character(len=*), parameter :: t_r = 'communicate_estimates'
        integer :: cnt

        ! Remove the holes in the main list when wanting the number of uniquely
        ! occupied determinants.
        TotWalkersTemp = TotWalkers - HolesInList

        ! The trial wavefunction is communicated before output and only if the option is on
        t_comm_trial = t_output .and. tTrialWavefunction

        sizes = 0

        ! low will represent the lower bound of an array slice.
        low = 0
        ! upp will represent the upper bound of an array slice.
        upp = 0

        sizes(1) = size(SpawnFromSing)
        sizes(2) = size(iter_data%update_growth)
        sizes(3) = size(NoBorn)
        sizes(4) = size(NoDied)
        sizes(5) = size(HFCyc)
        sizes(6) = size(NoAtDoubs)
        sizes(7) = size(Annihilated)
        if (tTruncInitiator) then
            sizes(8) = size(NoAddedInitiators)
            sizes(9) = size(NoInitDets)
            sizes(10) = size(NoNonInitDets)
            sizes(11) = size(NoExtraInitDoubs)
            sizes(12) = size(InitRemoved)
            sizes(13) = size(NoAborted)
            sizes(14) = size(NoRemoved)
            sizes(15) = size(NoNonInitWalk)
            sizes(16) = size(NoInitWalk)
        end if
        sizes(17) = 1 ! TotWalkersTemp (single int, not an array)
        sizes(18) = size(norm_psi_squared)
        sizes(19) = size(norm_semistoch_squared)
        sizes(20) = size(TotParts)
        sizes(21) = size(tot_parts_new)
        sizes(22) = size(SumNoAtHF)
        sizes(23) = size(bloom_count)
        sizes(24) = size(NoAtHF)
        sizes(25) = size(SumWalkersCyc)
        sizes(26) = 1 ! nspawned (single int, not an array)

        sizes(27) = 1 ! inst_double_occ
        if (tTruncInitiator) sizes(28) = 1 ! doubleSpawns
        ! communicate the coherence numbers for SI
        sizes(29) = 1

        sizes(30) = 1
        ! Perturbation correction
        sizes(31) = 1

        ! communicate the instant spin diff.. although i am not sure if this
        ! gets too big..
        if (t_spin_measurements) then
            sizes(32) = nBasis / 2
            sizes(33) = nBasis / 2
        end if
        ! truncated weight
        sizes(34) = 1
        ! inits per ex lvl
        sizes(35) = size(initsPerExLvl)
        ! number of successful/invalid excits
        sizes(36) = 1
        sizes(37) = 1

        if (t_measure_local_spin) then
            sizes(38) = nBasis / 2
        end if
        ! en pert space size
        if (tEN2) sizes(39) = 1
        ! Output variable
        if (t_output) then
            sizes(40) = size(HFOut)
            sizes(41) = size(Acceptances)
            sizes(42) = size(SumWalkersOut)
            sizes(43) = 1
        end if
        if (t_real_time_fciqmc) then
            sizes(44) = size(popSnapShot)
            NoArrs = 44
            NoArrs = 43
        end if

        send_arr = 0.0_dp

        if (sum(sizes(1:NoArrs)) > real_arr_size) call stop_all(t_r, &
             "No space left in arrays for communication of estimates. Please increase &
             & the size of the send_arr and recv_arr arrays in the source code.")

        low = upp + 1; upp = low + sizes(1) - 1; send_arr(low:upp) = SpawnFromSing;
        low = upp + 1; upp = low + sizes(2) - 1; send_arr(low:upp) = iter_data%update_growth;
        low = upp + 1; upp = low + sizes(3) - 1; send_arr(low:upp) = NoBorn;
        low = upp + 1; upp = low + sizes(4) - 1; send_arr(low:upp) = NoDied;
        low = upp + 1; upp = low + sizes(5) - 1; send_arr(low:upp) = HFCyc;
        low = upp + 1; upp = low + sizes(6) - 1; send_arr(low:upp) = NoAtDoubs;
        low = upp + 1; upp = low + sizes(7) - 1; send_arr(low:upp) = Annihilated;
        if (tTruncInitiator) then
            low = upp + 1; upp = low + sizes(8) - 1; send_arr(low:upp) = NoAddedInitiators;
            low = upp + 1; upp = low + sizes(9) - 1; send_arr(low:upp) = NoInitDets;
            low = upp + 1; upp = low + sizes(10) - 1; send_arr(low:upp) = NoNonInitDets;
            low = upp + 1; upp = low + sizes(11) - 1; send_arr(low:upp) = NoExtraInitDoubs;
            low = upp + 1; upp = low + sizes(12) - 1; send_arr(low:upp) = InitRemoved;
            low = upp + 1; upp = low + sizes(13) - 1; send_arr(low:upp) = NoAborted;
            low = upp + 1; upp = low + sizes(14) - 1; send_arr(low:upp) = NoRemoved;
            low = upp + 1; upp = low + sizes(15) - 1; send_arr(low:upp) = NoNonInitWalk;
            low = upp + 1; upp = low + sizes(16) - 1; send_arr(low:upp) = NoInitWalk;
        end if

        low = upp + 1; upp = low + sizes(17) - 1; send_arr(low:upp) = TotWalkersTemp;
        low = upp + 1; upp = low + sizes(18) - 1; send_arr(low:upp) = norm_psi_squared;
        low = upp + 1; upp = low + sizes(19) - 1; send_arr(low:upp) = norm_semistoch_squared;
        low = upp + 1; upp = low + sizes(20) - 1; send_arr(low:upp) = TotParts;
        low = upp + 1; upp = low + sizes(21) - 1; send_arr(low:upp) = tot_parts_new;
        low = upp + 1; upp = low + sizes(22) - 1; send_arr(low:upp) = SumNoAtHf;
        low = upp + 1; upp = low + sizes(23) - 1; send_arr(low:upp) = bloom_count;
        low = upp + 1; upp = low + sizes(24) - 1; send_arr(low:upp) = NoAtHF;
        low = upp + 1; upp = low + sizes(25) - 1; send_arr(low:upp) = SumWalkersCyc;
        low = upp + 1; upp = low + sizes(26) - 1; send_arr(low:upp) = nspawned;
        ! double occ change:
        low = upp + 1; upp = low + sizes(27) - 1; send_arr(low:upp) = inst_double_occ

        if (tTruncInitiator) then
            low = upp + 1; upp = low + sizes(28) - 1; send_arr(low:upp) = doubleSpawns;
        end if
        low = upp + 1; upp = low + sizes(29) - 1; send_arr(low:upp) = nCoherentDoubles
        low = upp + 1; upp = low + sizes(30) - 1; send_arr(low:upp) = nIncoherentDets
        low = upp + 1; upp = low + sizes(31) - 1; send_arr(low:upp) = nConnection

        if (t_spin_measurements) then
            low = upp + 1; upp = low + sizes(32) - 1; send_arr(low:upp) = inst_spin_diff
            low = upp + 1; upp = low + sizes(33) - 1; send_arr(low:upp) = inst_spatial_doub_occ
        end if

        ! truncated weight
        low = upp + 1; upp = low + sizes(34) - 1; send_arr(low:upp) = truncatedWeight;
        ! initiators per excitation level
        low = upp + 1; upp = low + sizes(35) - 1; send_arr(low:upp) = initsPerExLvl;
        ! excitation number trackers

        low = upp + 1; upp = low + sizes(36) - 1; send_arr(low:upp) = nInvalidExcits;
        low = upp + 1; upp = low + sizes(37) - 1; send_arr(low:upp) = nValidExcits;
        ! local spin
        if (t_measure_local_spin) then
            low = upp + 1; upp = low + sizes(38) - 1; send_arr(low:upp) = inst_local_spin;
        end if

        ! en pert space size
        if (tEN2) then
            low = upp + 1; upp = low + sizes(39) - 1; send_arr(low:upp) = en_pert_main%ndets;
        end if

        if (t_output) then
            low = upp + 1; upp = low + sizes(40) - 1; send_arr(low:upp) = HFOut
            low = upp + 1; upp = low + sizes(41) - 1; send_arr(low:upp) = Acceptances
            low = upp + 1; upp = low + sizes(42) - 1; send_arr(low:upp) = SumWalkersOut
            low = upp + 1; upp = low + sizes(43) - 1; send_arr(low:upp) = n_core_non_init
            if (t_real_time_fciqmc) then
                low = upp + 1; upp = low + sizes(44) - 1; send_arr(low:upp) = popSnapShot;
            end if
        end if

        ! Perform the communication.
        call MPISumAll(send_arr(1:upp), recv_arr(1:upp))

        ! Now we just need each result to be extracted to the correct array, with
        ! the correct type.

        low = 0; upp = 0

        low = upp + 1; upp = low + sizes(1) - 1; AllSpawnFromSing = recv_arr(low:upp);
        low = upp + 1; upp = low + sizes(2) - 1; iter_data%update_growth_tot = recv_arr(low:upp);
        low = upp + 1; upp = low + sizes(3) - 1; AllNoBorn = recv_arr(low:upp);
        low = upp + 1; upp = low + sizes(4) - 1; AllNoDied = recv_arr(low:upp);
        low = upp + 1; upp = low + sizes(5) - 1; RealAllHFCyc = recv_arr(low:upp);
        low = upp + 1; upp = low + sizes(6) - 1; AllNoAtDoubs = recv_arr(low:upp);
        low = upp + 1; upp = low + sizes(7) - 1; AllAnnihilated = recv_arr(low:upp);
        if (tTruncInitiator) then
            low = upp + 1; upp = low + sizes(8) - 1; AllNoAddedInitiators = nint(recv_arr(low:upp), int64);
            low = upp + 1; upp = low + sizes(9) - 1; AllNoInitDets = nint(recv_arr(low:upp), int64);
            low = upp + 1; upp = low + sizes(10) - 1; AllNoNonInitDets = nint(recv_arr(low:upp), int64);
            low = upp + 1; upp = low + sizes(11) - 1; AllNoExtraInitDoubs = nint(recv_arr(low:upp), int64);
            low = upp + 1; upp = low + sizes(12) - 1; AllInitRemoved = nint(recv_arr(low:upp), int64);
            low = upp + 1; upp = low + sizes(13) - 1; AllNoAborted = recv_arr(low:upp);
            low = upp + 1; upp = low + sizes(14) - 1; AllNoRemoved = recv_arr(low:upp);
            low = upp + 1; upp = low + sizes(15) - 1; AllNoNonInitWalk = recv_arr(low:upp);
            low = upp + 1; upp = low + sizes(16) - 1; AllNoInitWalk = recv_arr(low:upp);
        end if
        low = upp + 1; upp = low + sizes(17) - 1; AllTotWalkers = nint(recv_arr(low), int64);
        low = upp + 1; upp = low + sizes(18) - 1; all_norm_psi_squared = recv_arr(low:upp);
        low = upp + 1; upp = low + sizes(19) - 1; all_norm_semistoch_squared = recv_arr(low:upp);
        low = upp + 1; upp = low + sizes(20) - 1; AllTotParts = recv_arr(low:upp);
        low = upp + 1; upp = low + sizes(21) - 1; tot_parts_new_all = recv_arr(low:upp);
        low = upp + 1; upp = low + sizes(22) - 1; AllSumNoAtHF = recv_arr(low:upp);
        low = upp + 1; upp = low + sizes(23) - 1; all_bloom_count = nint(recv_arr(low:upp));
        low = upp + 1; upp = low + sizes(24) - 1; AllNoAtHf = recv_arr(low:upp);
        low = upp + 1; upp = low + sizes(25) - 1; AllSumWalkersCyc = recv_arr(low:upp);
        low = upp + 1; upp = low + sizes(26) - 1; nspawned_tot = nint(recv_arr(low), int64);
        ! double occ:
        low = upp + 1; upp = low + sizes(27) - 1; all_inst_double_occ = recv_arr(low);
        if (tTruncInitiator) then
            low = upp + 1; upp = low + sizes(28) - 1; allDoubleSpawns = nint(recv_arr(low));
            doubleSpawns = 0
        end if
        low = upp + 1; upp = low + sizes(29) - 1; AllCoherentDoubles = nint(recv_arr(low));
        low = upp + 1; upp = low + sizes(30) - 1; AllIncoherentDets = nint(recv_arr(low));
        low = upp + 1; upp = low + sizes(31) - 1; AllConnection = nint(recv_arr(low));
        if (t_spin_measurements) then
            low = upp + 1; upp = low + sizes(32) - 1; all_inst_spin_diff = recv_arr(low:upp)
            low = upp + 1; upp = low + sizes(33) - 1; all_inst_spatial_doub_occ = recv_arr(low:upp)
        end if

        ! truncated weight
        low = upp + 1; upp = low + sizes(34) - 1; AllTruncatedWeight = recv_arr(low);
        ! initiators per excitation level
        low = upp + 1; upp = low + sizes(35) - 1; AllInitsPerExLvl = nint(recv_arr(low:upp));
        ! excitation number trackers
        low = upp + 1; upp = low + sizes(36) - 1; allNInvalidExcits = nint(recv_arr(low), int64);
        low = upp + 1; upp = low + sizes(37) - 1; allNValidExcits = nint(recv_arr(low), int64);

        ! local spin
        if (t_measure_local_spin) then
            low = upp + 1; upp = low + sizes(38) - 1; all_local_spin = recv_arr(low:upp);
        end if

        ! en_pert space size
        if (tEN2) then
            low = upp + 1; upp = low + sizes(39) - 1; en_pert_main%ndets_all = nint(recv_arr(low));
        end if
        ! Output variables
        if (t_output) then
            low = upp + 1; upp = low + sizes(40) - 1; RealAllHFOut = recv_arr(low:upp)
            low = upp + 1; upp = low + sizes(41) - 1; AllAcceptances = recv_arr(low:upp)
            low = upp + 1; upp = low + sizes(42) - 1; AllSumWalkersOut = recv_arr(low:upp)
            low = upp + 1; upp = low + sizes(43) - 1; all_n_core_non_init = nint(recv_arr(low))
            if (t_real_time_fciqmc) then
                low = upp + 1; upp = low + sizes(44) - 1; allPopSnapShot = recv_arr(low:upp);
            end if
        end if
        ! Communicate HElement_t variables:

        low = 0; upp = 0;
        sizes(1) = size(ENumCyc)
        sizes(2) = size(SumENum)
        sizes(3) = size(ENumCycAbs)
        sizes(4) = size(cyc_proje_denominator)
        sizes(5) = size(sum_proje_denominator)
        if (t_comm_trial) then
            sizes(6) = size(trial_numerator)
            sizes(7) = size(trial_denom)
            sizes(8) = size(trial_num_inst)
            sizes(9) = size(trial_denom_inst)
            sizes(10) = size(init_trial_numerator)
            sizes(11) = size(init_trial_denom)
        end if
        sizes(12) = size(InitsEnumCyc)
        sizes(13) = size(ENumOut)

        if (sum(sizes(1:14)) > arr_helem_size) call stop_all(t_r, "No space left in arrays for communication of estimates. Please &
                                                        & increase the size of the send_arr_helem and recv_arr_helem &
                                                        & arrays in the source code.")

        low = upp + 1; upp = low + sizes(1) - 1; send_arr_helem(low:upp) = ENumCyc;
        low = upp + 1; upp = low + sizes(2) - 1; send_arr_helem(low:upp) = SumENum;
        low = upp + 1; upp = low + sizes(3) - 1; send_arr_helem(low:upp) = ENumCycAbs;
        low = upp + 1; upp = low + sizes(4) - 1; send_arr_helem(low:upp) = cyc_proje_denominator;
        low = upp + 1; upp = low + sizes(5) - 1; send_arr_helem(low:upp) = sum_proje_denominator;
        if (t_comm_trial) then
            low = upp + 1; upp = low + sizes(6) - 1; send_arr_helem(low:upp) = trial_numerator;
            low = upp + 1; upp = low + sizes(7) - 1; send_arr_helem(low:upp) = trial_denom;
            low = upp + 1; upp = low + sizes(8) - 1; send_arr_helem(low:upp) = trial_num_inst;
            low = upp + 1; upp = low + sizes(9) - 1; send_arr_helem(low:upp) = trial_denom_inst;
            low = upp + 1; upp = low + sizes(10) - 1; send_arr_helem(low:upp) = init_trial_numerator;
            low = upp + 1; upp = low + sizes(11) - 1; send_arr_helem(low:upp) = init_trial_denom;
        end if
        low = upp + 1; upp = low + sizes(12) - 1; send_arr_helem(low:upp) = InitsENumCyc;
        if (t_output) then
            low = upp + 1; upp = low + sizes(13) - 1; send_arr_helem(low:upp) = ENumOut;
        end if

        call MPISumAll(send_arr_helem(1:upp), recv_arr_helem(1:upp))

        low = 0; upp = 0;
        low = upp + 1; upp = low + sizes(1) - 1; AllENumCyc = recv_arr_helem(low:upp);
        low = upp + 1; upp = low + sizes(2) - 1; AllSumENum = recv_arr_helem(low:upp);
        low = upp + 1; upp = low + sizes(3) - 1; AllENumCycAbs = recv_arr_helem(low:upp);
        low = upp + 1; upp = low + sizes(4) - 1; all_cyc_proje_denominator = recv_arr_helem(low:upp);
        low = upp + 1; upp = low + sizes(5) - 1; all_sum_proje_denominator = recv_arr_helem(low:upp);
        if (t_comm_trial) then
            low = upp + 1; upp = low + sizes(6) - 1; tot_trial_numerator = recv_arr_helem(low:upp);
            low = upp + 1; upp = low + sizes(7) - 1; tot_trial_denom = recv_arr_helem(low:upp);
            low = upp + 1; upp = low + sizes(8) - 1; tot_trial_num_inst = recv_arr_helem(low:upp);
            low = upp + 1; upp = low + sizes(9) - 1; tot_trial_denom_inst = recv_arr_helem(low:upp);
            low = upp + 1; upp = low + sizes(10) - 1; tot_init_trial_numerator = recv_arr_helem(low:upp);
            low = upp + 1; upp = low + sizes(11) - 1; tot_init_trial_denom = recv_arr_helem(low:upp);
        end if
        low = upp + 1; upp = low + sizes(12) - 1; AllInitsENumCyc = recv_arr_helem(low:upp);
        if (t_output) then
            low = upp + 1; upp = low + sizes(13) - 1; AllEnumOut = recv_arr_helem(low:upp);
        end if

        ! Optionally communicate EXLEVEL_WNorm.
        if (tLogEXLEVELStats) then
            upp = size(EXLEVEL_WNorm)
            send_arr_WNorm(1:upp) = reshape(EXLEVEL_WNorm, (/upp/))
            call MPISumAll(send_arr_WNorm(1:upp), recv_arr_WNorm(1:upp))
            AllEXLEVEL_WNorm = reshape(recv_arr_WNorm(1:upp), &
            ! Apply square root for L2 norm.
            AllEXLEVEL_WNorm(2, :, :) = sqrt(AllEXLEVEL_WNorm(2, :, :))
        end if

        ! Do some processing of the received data.

        ! Convert real array into HElement_t one.
        do run = 1, inum_runs
            AllHFCyc(run) = ARR_RE_OR_CPLX(RealAllHFCyc, run)
            AllHFOut(run) = ARR_RE_OR_CPLX(RealAllHFOut, run)
        end do

#ifdef CMPLX_
        norm_psi = sqrt(sum(all_norm_psi_squared))
        norm_semistoch = sqrt(sum(all_norm_semistoch_squared))
        norm_psi = sqrt(all_norm_psi_squared)
        norm_semistoch = sqrt(all_norm_semistoch_squared)

        ! These require a different type of reduce operation, so are communicated
        ! separately to the above communication.
        call MPIAllReduce(bloom_sizes(1:2), MPI_MAX, bloom_sz_tmp(1:2))
        bloom_sizes(1:2) = bloom_sz_tmp(1:2)

        ! Arrays for checking load balancing.
        call MPIReduce(TotWalkersTemp, MPI_MAX, MaxWalkersProc)
        call MPIReduce(TotWalkersTemp, MPI_MIN, MinWalkersProc)
        call MPIReduce(max_cyc_spawn, MPI_MAX, all_max_cyc_spawn)

    end subroutine communicate_estimates

    subroutine collate_iter_data(iter_data)

        type(fcimc_iter_data) :: iter_data
        logical :: ltmp
        character(len=*), parameter :: this_routine = 'collate_iter_data'

        ! We should update tau searching if it is enabled, or if it has been
        ! enabled, and now tau is outside the range acceptable for tau
        ! searching
        if (t_scale_tau_to_death .and. tau_search_method == possible_tau_search_methods%OFF) then
            call MPIAllLORLogical(scale_tau_to_death_triggered, ltmp)
            scale_tau_to_death_triggered = ltmp
        end if

        ! for now with the new tau-search also update tau in variable shift
        ! mode..
        if (.not. tFillingStochRDMOnFly) then
            if (tau_search_method == possible_tau_search_methods%CONVENTIONAL) then
                call update_tau()
            else if (tau_search_method == possible_tau_search_methods%HISTOGRAMMING) then
                call update_tau_hist()
            else if (scale_tau_to_death_triggered) then
                ASSERT(tau_search_method == possible_tau_search_methods%OFF)
                call scale_tau_to_death()
            end if
        end if

        ! quick fix for the double occupancy:
        if (t_calc_double_occ_av) then
            ! sum up the squared norm after shift has set in TODO
            ! and use the mean value if multiple runs are used
            ! still thinking about if i only want to calc it after
            ! equilibration
            sum_norm_psi_squared = sum_norm_psi_squared + &
                                   sum(all_norm_psi_squared) / real(inum_runs, dp)

            ! and also sum up the double occupancy:
            sum_double_occ = sum_double_occ + all_inst_double_occ
        end if

#ifdef DEBUG_
        if (.not. tfirst_cycle) then
            ! realtime case is handled seperately with the check_update_growth function
            ! as each RK step has to be monitored separately

            ! Write this 'ASSERTROOT' out explicitly to avoid line lengths problems
            if ((iProcIndex == root) .and. .not. tSpinProject .and. .not. tTrialShift .and. &
                all(abs(iter_data%update_growth_tot - (AllTotParts - AllTotPartsOld)) > 1.0e-5)) then
                write(stderr, *) "update_growth: ", iter_data%update_growth_tot
                write(stderr, *) "AllTotParts: ", AllTotParts
                write(stderr, *) "AllTotPartsOld: ", AllTotPartsOld
                call stop_all(this_routine, &
                              "Assertation failed: all(iter_data%update_growth_tot.eq.AllTotParts-AllTotPartsOld)")
            end if
        end if

    end subroutine collate_iter_data

    function relative_trial_numerator(tt_numerator, tt_denom, replica_pairs) &
        implicit none
        HElement_t(dp), intent(in) :: tt_numerator(inum_runs), tt_denom(inum_runs)
        logical, intent(in) :: replica_pairs
        HElement_t(dp) :: rel_tot_trial_numerator(inum_runs)
        integer :: run

        if (.not. qmc_trial_wf) then
            ! Becuase tot_trial_numerator/tot_trial_denom is the energy
            ! relative to the the trial energy, add on this contribution to
            ! make it relative to the HF energy.
            if (ntrial_excits == 1) then
                rel_tot_trial_numerator = tt_numerator + (tt_denom * trial_energies(1))
                if (replica_pairs) then
                    do run = 2, inum_runs, 2
                        rel_tot_trial_numerator(run - 1:run) = tt_numerator(run - 1:run) + &
                                                               tt_denom(run - 1:run) * trial_energies(run / 2)
                    end do
                    rel_tot_trial_numerator = tt_numerator + (tt_denom * trial_energies)
                end if
            end if
            rel_tot_trial_numerator = tt_numerator
        end if

    end function relative_trial_numerator

    subroutine update_shift(iter_data, replica_pairs)

        use CalcData, only: tInstGrowthRate, tL2GrowRate

        type(fcimc_iter_data), intent(in) :: iter_data
        logical, intent(in) :: replica_pairs
        integer(int64) :: tot_walkers
        logical, dimension(inum_runs) :: tReZeroShift
        real(dp), dimension(inum_runs) :: AllGrowRateRe, AllGrowRateIm
        real(dp), dimension(inum_runs)  :: AllHFGrowRate, AllWalkers
        real(dp), dimension(lenof_sign) :: denominator, all_denominator
        real(dp), dimension(inum_runs) :: rel_tot_trial_numerator
        integer :: error, i, proc, pos, run, lb, ub
        logical, dimension(inum_runs) :: defer_update
        logical :: start_varying_shift
        character(*), parameter :: this_routine = 'update_shift'

        ! Normally we allow the shift to vary depending on the conditions
        ! tested. Sometimes we want to defer this to the next cycle...
        defer_update(:) = .false.

        ! collate_iter_data --> The values used are only valid on Root
        i_am_root : if (iProcIndex == Root) then

            if (tL2GrowRate) then
                ! use the L2 norm to determine the growrate
                AllGrowRate(:) = norm_psi(:) / old_norm_psi(:)
                AllWalkers(:) = norm_psi(:)

            else if (tInstGrowthRate) then

                ! Calculate the growth rate simply using the two points at
                ! the beginning and the end of the update cycle.
                do run = 1, inum_runs
                    lb = min_part_type(run)
                    ub = max_part_type(run)
                    AllGrowRate(run) = (sum(iter_data%update_growth_tot(lb:ub) &
                                            + iter_data%tot_parts_old(lb:ub))) &
                                       / real(sum(iter_data%tot_parts_old(lb:ub)), dp)
                    AllWalkers(run) = (sum(iter_data%update_growth_tot(lb:ub) &
                                            + iter_data%tot_parts_old(lb:ub)))
                end do


                ! Instead attempt to calculate the average growth over every
                ! iteration over the update cycle
                if (all(.not. near_zero(OldAllAvWalkersCyc))) then
                    AllGrowRate(:) = AllSumWalkersCyc(:) / real(StepsSft, dp) &
                                       / OldAllAvWalkersCyc(:)
                end if
                AllWalkers(:) = AllSumWalkersCyc(:) / real(StepsSft, dp)

            end if
            ! For complex case, obtain both Re and Im parts
#ifdef CMPLX_
            do run = 1, inum_runs
                lb = min_part_type(run)
                ub = max_part_type(run)
                if (.not. near_zero(iter_data%tot_parts_old(lb))) then
                    AllGrowRateRe(run) = &
                        (iter_data%update_growth_tot(lb) + iter_data%tot_parts_old(lb)) &
                        / iter_data%tot_parts_old(lb)
                end if
                if (.not. near_zero(iter_data%tot_parts_old(ub))) then
                    AllGrowRateIm(run) = &
                        (iter_data%update_growth_tot(ub) + iter_data%tot_parts_old(ub)) &
                        / iter_data%tot_parts_old(ub)
                end if
            end do
            ! If any run uses the fixtrial option, we need to add the offset to the
            ! trial numerator
            if (tTrialWavefunction .and. tTrialShift) &
                rel_tot_trial_numerator = real(relative_trial_numerator( &
                          tot_trial_numerator, tot_trial_denom, replica_pairs), dp)

            ! Exit the single particle phase if the number of walkers exceeds
            ! the value in the input file. If particle no has fallen, re-enter
            ! it.
            tReZeroShift = .false.
            do run = 1, inum_runs
                lb = min_part_type(run)
                ub = max_part_type(run)

                if (tTrialShift .and. .not. tFixTrial(run) .and. tTrialWavefunction .and. abs(tot_trial_denom(run)) >= TrialTarget) then
                    !When reaching target overlap with trial wavefunction, set flag to keep it fixed.
                    tFixTrial(run) = .True.

                    write(stdout, '(a,i13,a,i1)') 'Exiting the varaible shift phase on iteration: ' &
                        , iter + PreviousCycles, ' - overlap with trial wavefunction of the following run is now fixed: ', run
                end if

                if (tFixedN0) then
                    if (.not. tSkipRef(run) .and. abs(AllHFCyc(run)) >= N0_Target) then
                        !When reaching target N0, set flag to keep the population of reference det fixed.
                        tSkipRef(run) = .True.

                        write(stdout, '(a,i13,a,i1)') 'Exiting the fixed shift phase on iteration: ' &
                            , iter + PreviousCycles, ' - reference population of the following run is now fixed: ', run
                        !Set these parameters because other parts of the code depends on them
                        VaryShiftIter(run) = Iter
                        iBlockingIter(run) = Iter + PreviousCycles
                        tSinglePartPhase(run) = .false.
                    end if

                    if (tSkipRef(run)) then
                        !Use the projected energy as the shift to fix the
                        !population of the reference det and thus reduce the
                        !fluctuations of the projected energy.

                        !ToDo: Make DiafSft complex
                        DiagSft(run) = real((AllENumCyc(run)) &
                            / (AllHFCyc(run)) + proje_ref_energy_offsets(run), dp)

                        ! Update the shift averages
                        if ((iter - VaryShiftIter(run)) >= nShiftEquilSteps) then
                            if ((iter - VaryShiftIter(run) - nShiftEquilSteps) < StepsSft) &
                                write(stdout, '(a,i14)') 'Beginning to average shift value on iteration: ', iter + PreviousCycles
                            VaryShiftCycles(run) = VaryShiftCycles(run) + 1
                            SumDiagSft(run) = SumDiagSft(run) + DiagSft(run)
                            AvDiagSft(run) = SumDiagSft(run) / real(VaryShiftCycles(run), dp)
                        end if
                        !Keep shift equal to input till target reference population is reached.
                        DiagSft(run) = InputDiagSft(run)
                    end if

                else if (tFixTrial(run)) then
                    !Use the trial energy as the shift to fix the
                    !overlap with trial wavefunction and thus reduce the
                    !fluctuations of the trial energy.

                    !ToDo: Make DiafSft complex
                    DiagSft(run) = real((rel_tot_trial_numerator(run) / tot_trial_denom(run)) - Hii, dp)

                    ! Update the shift averages
                    if ((iter - VaryShiftIter(run)) >= nShiftEquilSteps) then
                        if ((iter - VaryShiftIter(run) - nShiftEquilSteps) < StepsSft) &
                            write(stdout, '(a,i14)') 'Beginning to average shift value on iteration: ', iter + PreviousCycles
                        VaryShiftCycles(run) = VaryShiftCycles(run) + 1
                        SumDiagSft(run) = SumDiagSft(run) + DiagSft(run)
                        AvDiagSft(run) = SumDiagSft(run) / real(VaryShiftCycles(run), dp)
                    end if

                else !not Fixed-N0 and not Trial-Shift
                    tot_walkers = int(InitWalkers, int64) * int(nNodes, int64)
                    single_part_phase : if (TSinglePartPhase(run)) then

#ifdef CMPLX_
                        if ((sum(AllTotParts(lb:ub)) > tot_walkers) .or. &
                            (abs_sign(AllNoatHF(lb:ub)) > MaxNoatHF)) then
                            write(stdout, '(a,i13,a)') 'Exiting the single particle growth phase on iteration: ', iter + PreviousCycles, &
                                ' - Shift can now change'
                            VaryShiftIter(run) = Iter
                            iBlockingIter(run) = Iter + PreviousCycles
                            tSinglePartPhase(run) = .false.
                            if (abs(TargetGrowRate(run)) > EPS) then
                                write(stdout, "(A)") "Setting target growth rate to 1."
                                TargetGrowRate = 0.0_dp
                            end if

                            ! If enabled, jump the shift to the value preducted by the
                            ! projected energy!
                            if (tJumpShift) then
                                if (tJumpShift .and. &
                                    (.not. (isnan(real(proje_iter(run), dp))) .or. &
                                     .not. (is_inf(real(proje_iter(run), dp))))) then
                                    DiagSft(run) = real(proje_iter(run), dp)
                                    defer_update(run) = .true.
                                end if
                            end if
                        end if
                        start_varying_shift = .false.
                        if (tLetInitialPopDie) then
                            if (AllTotParts(run) < tot_walkers) start_varying_shift = .true.
                        else if (tTargetShiftdamp) then
                            start_varying_shift = .true.
                            if ((AllTotParts(run) > tot_walkers) .or. &
                                (abs(AllNoatHF(run)) > MaxNoatHF)) start_varying_shift = .true.
                        end if

                        if (start_varying_shift) then
                            write(stdout, '(a,i13,a,i1)') 'Exiting the single particle growth phase on iteration: ' &
                                , iter + PreviousCycles, ' - Shift can now change for population', run
                            VaryShiftIter(run) = Iter
                            iBlockingIter(run) = Iter + PreviousCycles
                            tSinglePartPhase(run) = .false.
                            ! [W.D. 15.5.2017]
                            ! change equal 0 comps
                            if (abs(TargetGrowRate(run)) > EPS) then
                                write(stdout, "(A)") "Setting target growth rate to 1."
                                TargetGrowRate(run) = 0.0_dp
                            end if

                            ! If enabled, jump the shift to the value preducted by the
                            ! projected energy!
                            if (tJumpShift) then
                                DiagSft(run) = real(proje_iter(run), dp)
                                defer_update(run) = .true.
                            end if
                        end if
                    else ! .not.tSinglePartPhase(run)

#ifdef CMPLX_
                        if (abs_sign(AllNoatHF(lb:ub)) < MaxNoatHF - HFPopThresh) then
                        if (abs(AllNoatHF(run)) < MaxNoatHF - HFPopThresh) then
                            write(stdout, '(a,i13,a)') 'No at HF has fallen too low - reentering the &
                                         &single particle growth phase on iteration', iter + PreviousCycles, ' - particle number &
                                         &may grow again.'
                            tSinglePartPhase(run) = .true.
                            tReZeroShift(run) = .true.
                        end if

                    end if single_part_phase

                    ! How should the shift change for the entire ensemble of walkers
                    ! over all processors.

                    if (.not. (tSinglePartPhase(run) &
                               .and. near_zero(TargetGrowRate(run)) &
                               .or. defer_update(run))) then

                        !In case we want to continue growing, TargetGrowRate > 0.0_dp
                        ! New shift value
                        !                     if(TargetGrowRate(run).ne.0.0_dp) then
                        ! [W.D. 15.5.2017]
                        if (abs(TargetGrowRate(run)) > EPS) then
#ifdef CMPLX_
                            if (sum(AllTotParts(lb:ub)) > TargetGrowRateWalk(run)) then
                            if (AllTotParts(run) > TargetGrowRateWalk(run)) then
                                if (tTargetShiftdamp) then
                                    call stop_all(this_routine, &
                                        "Target-shiftdamp not compatible with targetgrowrate!")
                                end if
                                !Only allow targetgrowrate to kick in once we have > TargetGrowRateWalk walkers.
                                DiagSft(run) = DiagSft(run) - (log(AllGrowRate(run) - TargetGrowRate(run)) * SftDamp) / &
                                               (Tau * StepsSft)
                                ! Same for the info shifts for complex walkers
#ifdef CMPLX_
                                DiagSftRe(run) = DiagSftRe(run) - (log(AllGrowRateRe(run) - TargetGrowRate(run)) * SftDamp) / &
                                                 (Tau * StepsSft)
                                DiagSftIm(run) = DiagSftIm(run) - (log(AllGrowRateIm(run) - TargetGrowRate(run)) * SftDamp) / &
                                                 (Tau * StepsSft)
                            end if
                            if (tShiftonHFPop) then
                                !Calculate the shift required to keep the HF population constant

                                AllHFGrowRate(run) = abs(AllHFCyc(run) / real(StepsSft, dp)) / abs(OldAllHFCyc(run))
                                if (.not. near_zero(AllHFGrowRate(run))) then
                                    DiagSft(run) = DiagSft(run) - (log(AllHFGrowRate(run)) * SftDamp) / &
                                                   (Tau * StepsSft)
                                    call stop_all(this_routine, "Shift undefined because HF growth rate is zero. Aborting.")
                                end if
                            else if (tTargetShiftdamp) then
                                if (.not. near_zero(AllGrowRate(run)) .and. .not. near_zero(AllWalkers(run))) then
                                    DiagSft(run) = DiagSft(run) - (log(AllGrowRate(run)) * SftDamp + &
                                                   log(AllWalkers(run)/tot_walkers) * SftDamp2) / &
                                                   (Tau * StepsSft)
                                    call stop_all(this_routine, "Shift undefined because walker growth rate is zero. Aborting.")
                                end if
                                if (.not. near_zero(AllGrowRate(run))) then
                                    DiagSft(run) = DiagSft(run) - (log(AllGrowRate(run)) * SftDamp) / &
                                                   (Tau * StepsSft)
                                    call stop_all(this_routine, "Shift undefined because walker growth rate is zero. Aborting.")
                                end if
                            end if
                        end if

                        ! Update the shift averages
                        if ((iter - VaryShiftIter(run)) >= nShiftEquilSteps) then
                            if ((iter - VaryShiftIter(run) - nShiftEquilSteps) < StepsSft) &
                                write(stdout, '(a,i14)') 'Beginning to average shift value on iteration: ', iter + PreviousCycles
                            VaryShiftCycles(run) = VaryShiftCycles(run) + 1
                            SumDiagSft(run) = SumDiagSft(run) + DiagSft(run)
                            AvDiagSft(run) = SumDiagSft(run) / real(VaryShiftCycles(run), dp)
                        end if

                    end if
                end if !tFixedN0 or not
                ! only update the shift this way if possible
                if (abs_sign(AllNoatHF(lb:ub)) > EPS) then
#ifdef CMPLX_
                    ! Calculate the instantaneous 'shift' from the HF population
                    HFShift(run) = -1.0_dp / abs_sign(AllNoatHF(lb:ub)) * &
                                   (abs_sign(AllNoatHF(lb:ub)) - abs_sign(OldAllNoatHF(lb:ub)) / &
                                    (Tau * real(StepsSft, dp)))
                    InstShift(run) = -1.0_dp / sum(AllTotParts(lb:ub)) * &
                                     ((sum(AllTotParts(lb:ub)) - sum(AllTotPartsOld(lb:ub))) / &
                                      (Tau * real(StepsSft, dp)))
                    ! Calculate the instantaneous 'shift' from the HF population
                    HFShift(run) = -1.0_dp / abs(AllNoatHF(run)) * &
                                   (abs(AllNoatHF(run)) - abs(OldAllNoatHF(run)) / &
                                    (Tau * real(StepsSft, dp)))
                    InstShift(run) = -1.0_dp / AllTotParts(run) * &
                                     ((AllTotParts(run) - AllTotPartsOld(run)) / &
                                      (Tau * real(StepsSft, dp)))
                end if

                ! When using a linear combination, the denominator is summed
                ! directly.
                all_sum_proje_denominator(run) = ARR_RE_OR_CPLX(AllSumNoatHF, run)
                all_cyc_proje_denominator(run) = AllHFCyc(run)

                ! Calculate the projected energy.

                if (.not. near_zero(AllSumNoatHF(run))) then
                    ProjectionE(run) = (AllSumENum(run)) / (all_sum_proje_denominator(run)) &
                                       + proje_ref_energy_offsets(run)
                end if
                if (abs(AllHFCyc(run)) > EPS) then
                    proje_iter(run) = (AllENumCyc(run)) / (all_cyc_proje_denominator(run)) &
                                      + proje_ref_energy_offsets(run)
                    AbsProjE(run) = (AllENumCycAbs(run)) / (all_cyc_proje_denominator(run)) &
                                    + proje_ref_energy_offsets(run)
                    inits_proje_iter(run) = (AllInitsENumCyc(run)) / (all_cyc_proje_denominator(run)) &
                                            + proje_ref_energy_offsets(run)
                end if
                ! If we are re-zeroing the shift
                if (tReZeroShift(run)) then
                    DiagSft(run) = 0.0_dp
                    VaryShiftCycles(run) = 0
                    SumDiagSft(run) = 0.0_dp
                    AvDiagSft(run) = 0.0_dp
                end if
            end do

            ! Get some totalled values
            if (abs(sum(all_sum_proje_denominator(1:inum_runs))) > EPS) then
                projectionE_tot = sum(AllSumENum(1:inum_runs)) &
                                  / sum(all_sum_proje_denominator(1:inum_runs))
            end if
            if (abs(sum(all_cyc_proje_denominator(1:inum_runs))) > EPS) then
                proje_iter_tot = sum(AllENumCyc(1:inum_runs)) &
                                 / sum(all_cyc_proje_denominator(1:inum_runs))
                inits_proje_iter_tot = sum(AllInitsENumCyc(1:inum_runs)) &
                                       / sum(all_cyc_proje_denominator(1:inum_runs))
            end if

        end if i_am_root

        ! Broadcast the shift from root to all the other processors
        call MPIBcast(tSinglePartPhase)
        call MPIBcast(VaryShiftIter)
        call MPIBcast(DiagSft)
        call MPIBcast(tSkipRef)
        call MPIBcast(tFixTrial)
        call MPIBcast(VaryShiftCycles)
        call MPIBcast(SumDiagSft)
        call MPIBcast(AvDiagSft)

        do run = 1, inum_runs
            if (.not. tSinglePartPhase(run)) then
                TargetGrowRate(run) = 0.0_dp
            end if
        end do

        if (tau_search_method /= possible_tau_search_methods%off) then
            if (end_of_search_reached(tau_search_method, tau_stop_method)) then
                call stop_tau_search(tau_stop_method)
            end if
        end if
    end subroutine update_shift

    subroutine rezero_output_stats()
        ! Zero all accumulated variables that are only used in output
        IterTime = 0.0_sp
        Acceptances = 0.0_dp
        NoBorn = 0.0_dp
        NoDied = 0.0_dp
        Annihilated = 0.0_dp
        max_cyc_spawn = 0.0_dp
        trial_numerator = 0.0_dp
        trial_denom = 0.0_dp

        ! These are dedicated output variables
        ENumOut = 0.0_dp
        HFOut = 0.0_dp
        AllTotPartsLastOutput = AllTotParts
        SumWalkersOut = 0.0_dp

        ! reset the truncated weight
        truncatedWeight = 0.0_dp

        ! reset the logged number of initiators
        initsPerExLvl = 0

        ! and the number of excits
        nInvalidExcits = 0
        nValidExcits = 0

        ! the number of non-inititators in the core-space
        n_core_non_init = 0
    end subroutine rezero_output_stats

    subroutine rezero_iter_stats_update_cycle(iter_data, tot_parts_new_all)

        type(fcimc_iter_data), intent(inout) :: iter_data
        real(dp), dimension(lenof_sign), intent(in) :: tot_parts_new_all

        ! Zero all of the variables which accumulate for each iteration.
        SumWalkersCyc(:) = 0.0_dp
        SpawnFromSing = 0.0_dp
        ENumCyc = 0.0_dp
        InitsENumCyc = 0.0_dp
        ENumCycAbs = 0.0_dp
        HFCyc = 0.0_dp
        cyc_proje_denominator = 0.0_dp

        ! also reset the real-time specific quantities:
        ! and maybe have to call this routine twice to rezero also the
        ! inputted iter_data for both RK steps..
        ! Reset TotWalkersOld so that it is the number of walkers now
        TotWalkersOld = TotWalkers
        TotPartsOld = TotParts

        ! Save the number at HF to use in the HFShift
        OldAllNoatHF = AllNoatHF
        !OldAllHFCyc is the average HF value for this update cycle
        OldAllHFCyc = AllHFCyc / real(StepsSft, dp)
        !OldAllAvWalkersCyc gives the average number of walkers per iteration in the last update cycle
        !TODO CMO: are these summed across real/complex?
        OldAllAvWalkersCyc = AllSumWalkersCyc / real(StepsSft, dp)

        ! Also the cumulative global variables
        AllTotWalkersOld = AllTotWalkers
        AllTotPartsOld = AllTotParts
        AllNoAbortedOld = AllNoAborted

        ! also reset the real-time specific quantities:
        ! and maybe have to call this routine twice to rezero also the
        ! inputted iter_data for both RK steps..
        iter_data_fciqmc%update_growth = 0.0_dp
        iter_data_fciqmc%update_iters = 0

        ! and the norm
        old_norm_psi = norm_psi

        ! Reset the counters
        iter_data%update_growth = 0.0_dp
        iter_data%update_iters = 0
        iter_data%tot_parts_old = tot_parts_new_all

        cont_spawn_attempts = 0
        cont_spawn_success = 0
        tfirst_cycle = .false.
        if (t_calc_double_occ) then
            call rezero_double_occ_stats()
            if (t_spin_measurements) then
                call rezero_spin_diff()
            end if
        end if

        if (t_measure_local_spin) then
            call rezero_local_spin_stats()
        end if

    end subroutine rezero_iter_stats_update_cycle

    subroutine iteration_output_wrapper(iter_data, tot_parts_new, &
                                        replica_pairs, t_comm_req)
        type(fcimc_iter_data), intent(inout) :: iter_data
        real(dp), dimension(lenof_sign), intent(in) :: tot_parts_new
        real(dp), dimension(lenof_sign) :: tot_parts_new_all
        logical, intent(in) :: replica_pairs
        logical, intent(in), optional :: t_comm_req
        logical :: t_do_comm

        ! The comm can be switched off
        if (present(t_comm_req)) then
            t_do_comm = t_comm_req
            t_do_comm = .true.
        end if

        if (t_do_comm) &
            call communicate_estimates(iter_data, tot_parts_new, tot_parts_new_all, .true.)
        if (tPrintDataTables) &
            call write_to_stats()

        subroutine write_to_stats()
            ! Write the current output cycle's stats to the FCIMCStats/fciqmc_stats output file
            ! + stdout

            ! adjust the trial numerator for output (add in the offset)
            if (tTrialWavefunction) then
                tot_trial_numerator = relative_trial_numerator( &
                                      tot_trial_numerator, tot_trial_denom, replica_pairs)
                if (tTruncInitiator) &
                    tot_init_trial_numerator = relative_trial_numerator( &
                                               tot_init_trial_numerator, tot_init_trial_denom, replica_pairs)
            end if
            call output_diagnostics()

            if (tFCIMCStats2) then
                call write_fcimcstats2(iter_data_fciqmc)
                call WriteFCIMCStats()
            end if
            ! reset accumulated output variables
            call rezero_output_stats()
        end subroutine write_to_stats
    end subroutine iteration_output_wrapper

    subroutine calculate_new_shift_wrapper(iter_data, tot_parts_new, replica_pairs, t_comm_req)

        type(fcimc_iter_data), intent(inout) :: iter_data
        real(dp), dimension(lenof_sign), intent(in) :: tot_parts_new
        real(dp), dimension(lenof_sign) :: tot_parts_new_all
        logical, intent(in) :: replica_pairs
        ! optional argument: if false is passed, do not do the shift update and diagnostics,
        !                    only produce output
        logical, intent(in), optional :: t_comm_req
        logical :: t_do_comm

        ! TODO: use def_default once available
        if (present(t_comm_req)) then
            t_do_comm = t_comm_req
            t_do_comm = .true.
        end if

        ! communication of trial wf properties is only done for output steps
        if (t_do_comm) &
            call communicate_estimates(iter_data, tot_parts_new, tot_parts_new_all, tCoupleCycleOutput)

        ! update the shift and rezero the cycle data
        call shift_update()
        call rezero_iter_stats_update_cycle(iter_data, tot_parts_new_all)


        subroutine shift_update()
            ! This is what defines the update of the shift
            call collate_iter_data(iter_data)
            call iter_diagnostics()
            if (tRestart) return
            call population_check()
            call update_shift(iter_data, replica_pairs)
        end subroutine shift_update

    end subroutine calculate_new_shift_wrapper

    subroutine update_iter_data(iter_data)

        type(fcimc_iter_data), intent(inout) :: iter_data

        iter_data%update_growth = iter_data%update_growth + iter_data%nborn &
                                  - iter_data%ndied - iter_data%nannihil &
                                  - iter_data%naborted - iter_data%nremoved
        iter_data%update_iters = iter_data%update_iters + 1

    end subroutine update_iter_data

    !Fix the overlap with trial wavefunction by enforcing the value of a random determinant of the trial space
    !As long as the shift equals the trial energy, this should still give the right dynamics.
    subroutine fix_trial_overlap(iter_data)
        use util_mod, only: binary_search_first_ge
        type(fcimc_iter_data), intent(inout) :: iter_data

        HElement_t(dp), dimension(inum_runs) :: new_trial_denom, new_tot_trial_denom
        real(dp), dimension(lenof_sign) :: trial_delta, SignCurr, newSignCurr
        integer :: j, rand, det_idx, proc_idx, run, part_type, lbnd, ubnd, err
        integer :: trial_count, trial_indices(tot_trial_space_size)
        real(dp) :: amps(tot_trial_space_size), total_amp, total_amps(nProcessors)
        logical :: tIsStateDeterm

#ifdef CMPLX_
        call stop_all("fix_trial_overlap", "Complex wavefunction is not supported yet!")

        !Calculate the new overlap
        new_trial_denom = 0.0
        new_tot_trial_denom = 0.0

        trial_count = 0
        total_amp = 0.0
        do j = 1, int(TotWalkers)
            call extract_sign(CurrentDets(:, j), SignCurr)
            if (.not. IsUnoccDet(SignCurr) .and. test_flag(CurrentDets(:, j), flag_trial)) then
                trial_count = trial_count + 1
                trial_indices(trial_count) = j
                amps(trial_count) = abs(current_trial_amps(1, j))
                total_amp = total_amp + amps(trial_count)
                !Update the overlap
                if (ntrial_excits == 1) then
                    new_trial_denom = new_trial_denom + current_trial_amps(1, j) * SignCurr
                else if (tReplicaReferencesDiffer .and. tPairedReplicas) then
                    do run = 2, inum_runs, 2
        new_trial_denom(run - 1:run) = new_trial_denom(run - 1:run) + current_trial_amps(run / 2, j) * SignCurr(run - 1:run)
                    end do
                else if (ntrial_excits == lenof_sign) then
                    new_trial_denom = new_trial_denom + current_trial_amps(:, j) * SignCurr
                end if
            end if
        end do

        !Collecte overlaps from call processors
        call MPIAllReduce(new_trial_denom, MPI_SUM, new_tot_trial_denom)

        !Choose a random processor propotioanl to the sum of amplitudes of its trial space
        call MPIGather(total_amp, total_amps, err)
        if (iProcIndex == root) then
            !write(stdout,*) "total_amps: ", total_amps
            do j = 2, nProcessors
                total_amps(j) = total_amps(j) + total_amps(j - 1)
            end do
            proc_idx = binary_search_first_ge(total_amps, genrand_real2_dSFMT() * total_amps(nProcessors)) - 1
        end if
        call MPIBCast(proc_idx)

        !write(stdout,*) "proc_idx", proc_idx
        !write(stdout,*) "total_count: ", trial_count
        !write(stdout,*) "amps: ", amps(1:trial_count)
        !Enforcing an update of the random determinant of the random processor
        if (iProcIndex == proc_idx) then
            !Choose a random determinant
            do j = 2, trial_count
                amps(j) = amps(j) + amps(j - 1)
            end do
            det_idx = trial_indices(binary_search_first_ge(amps(1:trial_count), genrand_real2_dSFMT() * amps(trial_count)))
            do part_type = 1, lenof_sign
                run = part_type_to_run(part_type)
                if (tFixTrial(run)) then
         trial_delta(part_type) = (tot_trial_denom(run) - new_tot_trial_denom(run)) / current_trial_amps(part_type, det_idx)
                    trial_delta(part_type) = 0.0
                end if
            end do

            call extract_sign(CurrentDets(:, det_idx), SignCurr)
            newSignCurr = SignCurr + trial_delta
            call encode_sign(CurrentDets(:, det_idx), newSignCurr)

            !Correct statistics filled by CalcHashTableStats
            iter_data%ndied = iter_data%ndied + abs(SignCurr)
            iter_data%nborn = iter_data%nborn + abs(newSignCurr)
            TotParts = TotParts + abs(newSignCurr) - abs(SignCurr)

            tIsStateDeterm = .False.
            if (tSemiStochastic) tIsStateDeterm = test_flag_multi(CurrentDets(:, det_idx), flag_deterministic)

            norm_psi_squared = norm_psi_squared + (newSignCurr)**2 - SignCurr**2
            if (tIsStateDeterm) norm_semistoch_squared = norm_semistoch_squared + (newSignCurr)**2 - SignCurr**2

            if (tCheckHighestPop) then
                do run = 1, inum_runs
                    lbnd = min_part_type(run)
                    ubnd = max_part_type(run)
                    if (abs_sign(newSignCurr(lbnd:ubnd)) > iHighestPop(run)) then
                        iHighestPop(run) = int(abs_sign(newSignCurr(lbnd:ubnd)))
                        HighestPopDet(:, run) = CurrentDets(:, det_idx)
                    end if
                end do
            end if
            if (tFillingStochRDMonFly) then
                if (IsUnoccDet(newSignCurr) .and. (.not. tIsStateDeterm)) then
                    if (DetBitEQ(CurrentDets(:, det_idx), iLutHF_True, nifd)) then
                        AvNoAtHF = 0.0_dp
                        IterRDM_HF = Iter + 1
                    end if
                end if
            end if

            if (DetBitEQ(CurrentDets(:, det_idx), iLutHF_True, nifd)) then
                InstNoAtHF = newSignCurr
            end if
        end if
    end subroutine fix_trial_overlap

end module fcimc_iter_utils