subroutine fix_trial_overlap(iter_data)
use util_mod, only: binary_search_first_ge
type(fcimc_iter_data), intent(inout) :: iter_data
HElement_t(dp), dimension(inum_runs) :: new_trial_denom, new_tot_trial_denom
real(dp), dimension(lenof_sign) :: trial_delta, SignCurr, newSignCurr
integer :: j, rand, det_idx, proc_idx, run, part_type, lbnd, ubnd, err
integer :: trial_count, trial_indices(tot_trial_space_size)
real(dp) :: amps(tot_trial_space_size), total_amp, total_amps(nProcessors)
logical :: tIsStateDeterm
#ifdef CMPLX_
unused_var(iter_data)
call stop_all("fix_trial_overlap", "Complex wavefunction is not supported yet!")
#else
!Calculate the new overlap
new_trial_denom = 0.0
new_tot_trial_denom = 0.0
trial_count = 0
total_amp = 0.0
do j = 1, int(TotWalkers)
call extract_sign(CurrentDets(:, j), SignCurr)
if (.not. IsUnoccDet(SignCurr) .and. test_flag(CurrentDets(:, j), flag_trial)) then
trial_count = trial_count + 1
trial_indices(trial_count) = j
amps(trial_count) = abs(current_trial_amps(1, j))
total_amp = total_amp + amps(trial_count)
!Update the overlap
if (ntrial_excits == 1) then
new_trial_denom = new_trial_denom + current_trial_amps(1, j) * SignCurr
else if (tReplicaReferencesDiffer .and. tPairedReplicas) then
do run = 2, inum_runs, 2
new_trial_denom(run - 1:run) = new_trial_denom(run - 1:run) + current_trial_amps(run / 2, j) * SignCurr(run - 1:run)
end do
else if (ntrial_excits == lenof_sign) then
new_trial_denom = new_trial_denom + current_trial_amps(:, j) * SignCurr
end if
end if
end do
!Collecte overlaps from call processors
call MPIAllReduce(new_trial_denom, MPI_SUM, new_tot_trial_denom)
!Choose a random processor propotioanl to the sum of amplitudes of its trial space
call MPIGather(total_amp, total_amps, err)
if (iProcIndex == root) then
!write(stdout,*) "total_amps: ", total_amps
do j = 2, nProcessors
total_amps(j) = total_amps(j) + total_amps(j - 1)
end do
proc_idx = binary_search_first_ge(total_amps, genrand_real2_dSFMT() * total_amps(nProcessors)) - 1
end if
call MPIBCast(proc_idx)
!write(stdout,*) "proc_idx", proc_idx
!write(stdout,*) "total_count: ", trial_count
!write(stdout,*) "amps: ", amps(1:trial_count)
!Enforcing an update of the random determinant of the random processor
if (iProcIndex == proc_idx) then
!Choose a random determinant
do j = 2, trial_count
amps(j) = amps(j) + amps(j - 1)
end do
det_idx = trial_indices(binary_search_first_ge(amps(1:trial_count), genrand_real2_dSFMT() * amps(trial_count)))
do part_type = 1, lenof_sign
run = part_type_to_run(part_type)
if (tFixTrial(run)) then
trial_delta(part_type) = (tot_trial_denom(run) - new_tot_trial_denom(run)) / current_trial_amps(part_type, det_idx)
else
trial_delta(part_type) = 0.0
end if
end do
call extract_sign(CurrentDets(:, det_idx), SignCurr)
newSignCurr = SignCurr + trial_delta
call encode_sign(CurrentDets(:, det_idx), newSignCurr)
!Correct statistics filled by CalcHashTableStats
iter_data%ndied = iter_data%ndied + abs(SignCurr)
iter_data%nborn = iter_data%nborn + abs(newSignCurr)
TotParts = TotParts + abs(newSignCurr) - abs(SignCurr)
tIsStateDeterm = .False.
if (tSemiStochastic) tIsStateDeterm = test_flag_multi(CurrentDets(:, det_idx), flag_deterministic)
norm_psi_squared = norm_psi_squared + (newSignCurr)**2 - SignCurr**2
if (tIsStateDeterm) norm_semistoch_squared = norm_semistoch_squared + (newSignCurr)**2 - SignCurr**2
if (tCheckHighestPop) then
do run = 1, inum_runs
lbnd = min_part_type(run)
ubnd = max_part_type(run)
if (abs_sign(newSignCurr(lbnd:ubnd)) > iHighestPop(run)) then
iHighestPop(run) = int(abs_sign(newSignCurr(lbnd:ubnd)))
HighestPopDet(:, run) = CurrentDets(:, det_idx)
end if
end do
end if
if (tFillingStochRDMonFly) then
if (IsUnoccDet(newSignCurr) .and. (.not. tIsStateDeterm)) then
if (DetBitEQ(CurrentDets(:, det_idx), iLutHF_True, nifd)) then
AvNoAtHF = 0.0_dp
IterRDM_HF = Iter + 1
end if
end if
end if
if (DetBitEQ(CurrentDets(:, det_idx), iLutHF_True, nifd)) then
InstNoAtHF = newSignCurr
end if
end if
#endif
end subroutine fix_trial_overlap