#include "macros.h" module orb_idx_mod use constants, only: n_int, stdout use fortran_strings, only: str use bit_rep_data, only: nIfTot, nIfD use bit_reps, only: decode_bit_det use DetBitOps, only: EncodeBitDet use util_mod, only: ilex_leq => lex_leq, ilex_geq => lex_geq, stop_all, & operator(.div.) implicit none private public :: OrbIdx_t, SpinOrbIdx_t, SpatOrbIdx_t, size, & SpinProj_t, calc_spin, calc_spin_raw, alpha, beta, & operator(==), operator(/=), & sum, to_ilut, lex_leq, lex_geq, write_det, get_spat type, abstract :: OrbIdx_t integer, allocatable :: idx(:) end type !> We assume order [beta_1, alpha_1, beta_2, alpha_2, ...] type, extends(OrbIdx_t) :: SpinOrbIdx_t contains procedure, nopass :: from_ilut => from_ilut_SpinOrbIdx_t end type interface SpinOrbIdx_t module procedure SpinOrbIdx_t_from_SpatOrbIdx_t module procedure construction_from_array_SpinOrbIdx_t end interface type, extends(OrbIdx_t) :: SpatOrbIdx_t end type type :: SpinProj_t integer :: val !! Twice the spin projection as integer. \( S_z = 2 \cdot \text{val} \) contains private procedure :: eq_SpinProj_t_SpinProj_t generic, public :: operator(==) => eq_SpinProj_t_SpinProj_t procedure :: neq_SpinProj_t_SpinProj_t generic, public :: operator(/=) => neq_SpinProj_t_SpinProj_t procedure :: add_SpinProj_t_SpinProj_t generic, public :: operator(+) => add_SpinProj_t_SpinProj_t procedure :: mult_SpinProj_t_int generic, public :: operator(*) => mult_SpinProj_t_int procedure :: sub_SpinProj_t_SpinProj_t procedure :: neg_SpinProj_t generic, public :: operator(-) => sub_SpinProj_t_SpinProj_t, neg_SpinProj_t end type type(SpinProj_t), parameter :: beta = SpinProj_t(-1), alpha = SpinProj_t(1) interface size module procedure size_SpinOrbIdx_t module procedure size_SpatOrbIdx_t end interface interface write_det module procedure write_det_SpinOrbIdx_t module procedure write_det_SpatOrbIdx_t end interface interface operator(==) module procedure eq_SpinOrbIdx_t module procedure eq_SpatOrbIdx_t end interface interface operator(/=) module procedure neq_SpinOrbIdx_t module procedure neq_SpatOrbIdx_t end interface interface lex_leq module procedure lex_leq_SpinOrbIdx_t end interface interface lex_geq module procedure lex_geq_SpinOrbIdx_t end interface interface lex_leq module procedure lex_leq_SpatOrbIdx_t end interface interface lex_geq module procedure lex_geq_SpatOrbIdx_t end interface interface sum module procedure sum_SpinProj_t end interface contains pure function construction_from_array_SpinOrbIdx_t(idx, m_s) result(res) integer, intent(in) :: idx(:) type(SpinProj_t), intent(in), optional :: m_s type(SpinOrbIdx_t) :: res type(SpinProj_t) :: spins(size(idx)) character(*), parameter :: this_routine = 'construction_from_array_SpinOrbIdx_t' if (present(m_s)) then #ifdef DEBUG_ block use util_mod, only: stop_all if (.not. (any(m_s == [alpha, beta]))) then call stop_all (this_routine, "Assert fail: any(m_s == [alpha, beta])") end if end block #endif spins = calc_spin_raw(idx) res%idx = pack(idx, spins == m_s) else res%idx = idx end if end function pure function size_SpinOrbIdx_t (orbs) result(res) type(SpinOrbIdx_t), intent(in) :: orbs integer :: res res = size(orbs%idx) end function pure function size_SpatOrbIdx_t (orbs) result(res) type(SpatOrbIdx_t), intent(in) :: orbs integer :: res res = size(orbs%idx) end function pure function eq_SpinOrbIdx_t (lhs, rhs) result(res) type(SpinOrbIdx_t), intent(in) :: lhs, rhs logical :: res(size(lhs)) res = lhs%idx == rhs%idx end function pure function eq_SpatOrbIdx_t (lhs, rhs) result(res) type(SpatOrbIdx_t), intent(in) :: lhs, rhs logical :: res(size(lhs)) res = lhs%idx == rhs%idx end function pure function neq_SpinOrbIdx_t (lhs, rhs) result(res) type(SpinOrbIdx_t), intent(in) :: lhs, rhs logical :: res(size(lhs)) res = lhs%idx /= rhs%idx end function pure function neq_SpatOrbIdx_t (lhs, rhs) result(res) type(SpatOrbIdx_t), intent(in) :: lhs, rhs logical :: res(size(lhs)) res = lhs%idx /= rhs%idx end function pure function SpinOrbIdx_t_from_SpatOrbIdx_t(spat_orbs, m_s) result(res) type(SpatOrbIdx_t), intent(in) :: spat_orbs type(SpinProj_t), intent(in), optional :: m_s type(SpinOrbIdx_t) :: res character(*), parameter :: this_routine = 'SpinOrbIdx_t_from_SpatOrbIdx_t' if (present(m_s)) then #ifdef DEBUG_ block use util_mod, only: stop_all if (.not. (any(m_s == [alpha, beta]))) then call stop_all (this_routine, "Assert fail: any(m_s == [alpha, beta])") end if end block #endif res%idx = f(spat_orbs%idx(:), m_s) else allocate(res%idx(2 * size(spat_orbs))) res%idx(1::2) = f(spat_orbs%idx(:), beta) res%idx(2::2) = f(spat_orbs%idx(:), alpha) end if contains elemental function f(spat_orb_idx, m_s) result(spin_orb_idx) integer, intent(in) :: spat_orb_idx type(SpinProj_t), intent(in) :: m_s integer :: spin_orb_idx if (m_s == alpha) then spin_orb_idx = 2 * spat_orb_idx else spin_orb_idx = 2 * spat_orb_idx - 1 end if end function end function pure function to_ilut(det_I) result(ilut) type(SpinOrbIdx_t), intent(in) :: det_I integer(kind=n_int) :: iLut(0:nIfTot) call EncodeBitDet(det_I%idx, ilut) end function pure function calc_spin(orbs) result(res) type(SpinOrbIdx_t), intent(in) :: orbs type(SpinProj_t) :: res(size(orbs)) res = calc_spin_raw(orbs%idx) end function elemental function add_SpinProj_t_SpinProj_t(lhs, rhs) result(res) class(SpinProj_t), intent(in) :: lhs, rhs type(SpinProj_t) :: res res%val = lhs%val + rhs%val end function elemental function mult_SpinProj_t_int(lhs, rhs) result(res) class(SpinProj_t), intent(in) :: lhs integer, intent(in) :: rhs type(SpinProj_t) :: res res%val = lhs%val * rhs end function pure function sum_SpinProj_t(V) result(res) type(SpinProj_t), intent(in) :: V(:) type(SpinProj_t) :: res integer :: i res = SpinProj_t(0) do i = 1, size(V) res = res + V(i) end do end function elemental function sub_SpinProj_t_SpinProj_t(lhs, rhs) result(res) class(SpinProj_t), intent(in) :: lhs, rhs type(SpinProj_t) :: res res%val = lhs%val - rhs%val end function elemental function neg_SpinProj_t(m_s) result(res) class(SpinProj_t), intent(in) :: m_s type(SpinProj_t) :: res res%val = -m_s%val end function elemental function eq_SpinProj_t_SpinProj_t(lhs, rhs) result(res) class(SpinProj_t), intent(in) :: lhs, rhs logical :: res res = lhs%val == rhs%val end function elemental function neq_SpinProj_t_SpinProj_t(lhs, rhs) result(res) class(SpinProj_t), intent(in) :: lhs, rhs logical :: res res = lhs%val /= rhs%val end function elemental function calc_spin_raw(orb_idx) result(res) integer, intent(in) :: orb_idx type(SpinProj_t) :: res res = merge(alpha, beta, mod(orb_idx, 2) == 0) end function pure function lex_leq_SpinOrbIdx_t (lhs, rhs) result(res) type(SpinOrbIdx_t), intent(in) :: lhs, rhs logical :: res character(*), parameter :: this_routine = 'lex_lt_SpinOrbIdx_t' #ifdef DEBUG_ block use util_mod, only: stop_all if (.not. (size(lhs) == size(rhs))) then call stop_all (this_routine, "Assert fail: size(lhs) == size(rhs)") end if end block #endif res = ilex_leq(lhs%idx, rhs%idx) end function pure function lex_geq_SpinOrbIdx_t (lhs, rhs) result(res) type(SpinOrbIdx_t), intent(in) :: lhs, rhs logical :: res character(*), parameter :: this_routine = 'lex_gt_SpinOrbIdx_t' #ifdef DEBUG_ block use util_mod, only: stop_all if (.not. (size(lhs) == size(rhs))) then call stop_all (this_routine, "Assert fail: size(lhs) == size(rhs)") end if end block #endif res = ilex_geq(lhs%idx, rhs%idx) end function pure function lex_leq_SpatOrbIdx_t (lhs, rhs) result(res) type(SpatOrbIdx_t), intent(in) :: lhs, rhs logical :: res character(*), parameter :: this_routine = 'lex_lt_SpatOrbIdx_t' #ifdef DEBUG_ block use util_mod, only: stop_all if (.not. (size(lhs) == size(rhs))) then call stop_all (this_routine, "Assert fail: size(lhs) == size(rhs)") end if end block #endif res = ilex_leq(lhs%idx, rhs%idx) end function pure function lex_geq_SpatOrbIdx_t (lhs, rhs) result(res) type(SpatOrbIdx_t), intent(in) :: lhs, rhs logical :: res character(*), parameter :: this_routine = 'lex_gt_SpatOrbIdx_t' #ifdef DEBUG_ block use util_mod, only: stop_all if (.not. (size(lhs) == size(rhs))) then call stop_all (this_routine, "Assert fail: size(lhs) == size(rhs)") end if end block #endif res = ilex_geq(lhs%idx, rhs%idx) end function subroutine write_det_SpinOrbIdx_t (det_I, i_unit, advance) type(SpinOrbIdx_t), intent(in) :: det_I integer, intent(in), optional :: i_unit logical, intent(in), optional :: advance integer :: i, i_unit_ character(:), allocatable :: advance_str, format if(present(i_unit)) then i_unit_ = i_unit else i_unit_ = stdout endif if (present(advance)) then if (advance) then advance_str = 'yes' else advance_str = 'no' end if else advance_str = 'yes' end if write(i_unit_, "(a)", advance='no') 'SpinOrbIdx_t([' if (size(det_I) == 0) then write(i_unit_, "(a)", advance=advance_str) '])' else format = "(I"//str(int(log10(real(maxval(det_I%idx)))) + 2)//", a)" do i = 1, size(det_I) - 1 write(i_unit_, format, advance='no') det_I%idx(i), ',' end do write(i_unit_, format, advance=advance_str) det_I%idx(size(det_I)), '])' end if end subroutine subroutine write_det_SpatOrbIdx_t (det_I, i_unit, advance) type(SpatOrbIdx_t), intent(in) :: det_I integer, intent(in), optional :: i_unit logical, intent(in), optional :: advance integer :: i, i_unit_ character(:), allocatable :: advance_str, format if(present(i_unit)) then i_unit_ = i_unit else i_unit_ = stdout endif if (present(advance)) then if (advance) then advance_str = 'yes' else advance_str = 'no' end if else advance_str = 'yes' end if write(i_unit_, "(a)", advance='no') 'SpatOrbIdx_t([' if (size(det_I) == 0) then write(i_unit_, "(a)", advance=advance_str) '])' else format = "(I"//str(int(log10(real(maxval(det_I%idx)))) + 2)//", a)" do i = 1, size(det_I) - 1 write(i_unit_, format, advance='no') det_I%idx(i), ',' end do write(i_unit_, format, advance=advance_str) det_I%idx(size(det_I)), '])' end if end subroutine pure function from_ilut_SpinOrbIdx_t(ilut) result(res) integer(n_int), intent(in) :: ilut(0:nIfD) type(SpinOrbIdx_t) :: res integer :: n_el n_el = sum(popCnt(ilut)) allocate(res%idx(n_el)) call decode_bit_det(res%idx, ilut) end function elemental function get_spat(iorb) result(res) !! Return the spatial orbital of iorb integer, intent(in) :: iorb integer :: res res = (iorb + 1) .div. 2 end function end module