fix_trial_overlap Subroutine

public subroutine fix_trial_overlap(iter_data)

Uses

Arguments

Type IntentOptional Attributes Name
type(fcimc_iter_data), intent(inout) :: iter_data

Contents

Source Code


Source Code

    subroutine fix_trial_overlap(iter_data)
        use util_mod, only: binary_search_first_ge
        type(fcimc_iter_data), intent(inout) :: iter_data

        HElement_t(dp), dimension(inum_runs) :: new_trial_denom, new_tot_trial_denom
        real(dp), dimension(lenof_sign) :: trial_delta, SignCurr, newSignCurr
        integer :: j, rand, det_idx, proc_idx, run, part_type, lbnd, ubnd, err
        integer :: trial_count, trial_indices(tot_trial_space_size)
        real(dp) :: amps(tot_trial_space_size), total_amp, total_amps(nProcessors)
        logical :: tIsStateDeterm

#ifdef CMPLX_
        unused_var(iter_data)
        call stop_all("fix_trial_overlap", "Complex wavefunction is not supported yet!")
#else

        !Calculate the new overlap
        new_trial_denom = 0.0
        new_tot_trial_denom = 0.0

        trial_count = 0
        total_amp = 0.0
        do j = 1, int(TotWalkers)
            call extract_sign(CurrentDets(:, j), SignCurr)
            if (.not. IsUnoccDet(SignCurr) .and. test_flag(CurrentDets(:, j), flag_trial)) then
                trial_count = trial_count + 1
                trial_indices(trial_count) = j
                amps(trial_count) = abs(current_trial_amps(1, j))
                total_amp = total_amp + amps(trial_count)
                !Update the overlap
                if (ntrial_excits == 1) then
                    new_trial_denom = new_trial_denom + current_trial_amps(1, j) * SignCurr
                else if (tReplicaReferencesDiffer .and. tPairedReplicas) then
                    do run = 2, inum_runs, 2
        new_trial_denom(run - 1:run) = new_trial_denom(run - 1:run) + current_trial_amps(run / 2, j) * SignCurr(run - 1:run)
                    end do
                else if (ntrial_excits == lenof_sign) then
                    new_trial_denom = new_trial_denom + current_trial_amps(:, j) * SignCurr
                end if
            end if
        end do

        !Collecte overlaps from call processors
        call MPIAllReduce(new_trial_denom, MPI_SUM, new_tot_trial_denom)

        !Choose a random processor propotioanl to the sum of amplitudes of its trial space
        call MPIGather(total_amp, total_amps, err)
        if (iProcIndex == root) then
            !write(stdout,*) "total_amps: ", total_amps
            do j = 2, nProcessors
                total_amps(j) = total_amps(j) + total_amps(j - 1)
            end do
            proc_idx = binary_search_first_ge(total_amps, genrand_real2_dSFMT() * total_amps(nProcessors)) - 1
        end if
        call MPIBCast(proc_idx)

        !write(stdout,*) "proc_idx", proc_idx
        !write(stdout,*) "total_count: ", trial_count
        !write(stdout,*) "amps: ", amps(1:trial_count)
        !Enforcing an update of the random determinant of the random processor
        if (iProcIndex == proc_idx) then
            !Choose a random determinant
            do j = 2, trial_count
                amps(j) = amps(j) + amps(j - 1)
            end do
            det_idx = trial_indices(binary_search_first_ge(amps(1:trial_count), genrand_real2_dSFMT() * amps(trial_count)))
            do part_type = 1, lenof_sign
                run = part_type_to_run(part_type)
                if (tFixTrial(run)) then
         trial_delta(part_type) = (tot_trial_denom(run) - new_tot_trial_denom(run)) / current_trial_amps(part_type, det_idx)
                else
                    trial_delta(part_type) = 0.0
                end if
            end do

            call extract_sign(CurrentDets(:, det_idx), SignCurr)
            newSignCurr = SignCurr + trial_delta
            call encode_sign(CurrentDets(:, det_idx), newSignCurr)

            !Correct statistics filled by CalcHashTableStats
            iter_data%ndied = iter_data%ndied + abs(SignCurr)
            iter_data%nborn = iter_data%nborn + abs(newSignCurr)
            TotParts = TotParts + abs(newSignCurr) - abs(SignCurr)

            tIsStateDeterm = .False.
            if (tSemiStochastic) tIsStateDeterm = test_flag_multi(CurrentDets(:, det_idx), flag_deterministic)

            norm_psi_squared = norm_psi_squared + (newSignCurr)**2 - SignCurr**2
            if (tIsStateDeterm) norm_semistoch_squared = norm_semistoch_squared + (newSignCurr)**2 - SignCurr**2

            if (tCheckHighestPop) then
                do run = 1, inum_runs
                    lbnd = min_part_type(run)
                    ubnd = max_part_type(run)
                    if (abs_sign(newSignCurr(lbnd:ubnd)) > iHighestPop(run)) then
                        iHighestPop(run) = int(abs_sign(newSignCurr(lbnd:ubnd)))
                        HighestPopDet(:, run) = CurrentDets(:, det_idx)
                    end if
                end do
            end if
            if (tFillingStochRDMonFly) then
                if (IsUnoccDet(newSignCurr) .and. (.not. tIsStateDeterm)) then
                    if (DetBitEQ(CurrentDets(:, det_idx), iLutHF_True, nifd)) then
                        AvNoAtHF = 0.0_dp
                        IterRDM_HF = Iter + 1
                    end if
                end if
            end if

            if (DetBitEQ(CurrentDets(:, det_idx), iLutHF_True, nifd)) then
                InstNoAtHF = newSignCurr
            end if
        end if
#endif
    end subroutine fix_trial_overlap