matrix_util.F90 Source File


Source Code

#include "macros.h"


module matrix_util
    use constants, only: sp, dp, EPS
    use util_mod, only: near_zero, get_free_unit, stop_all, isclose
    use sort_mod, only: sort
    use blas_interface_mod, only: dgetrf, dgetri, dgemm, dsyev, zheev, dgeev, zgemm
    use util_mod, only: custom_findloc, operator(.isclose.)
    use display_matrices, only: write_matrix
    implicit none
    private
    public :: eig, print_matrix, matrix_exponential, det, blas_matmul, linspace, &
        calc_eigenvalues, check_symmetric, find_degeneracies, eig_sym, norm, &
        store_hf_coeff, matrix_inverse, print_vec, gram_schmidt, my_minval, my_minloc, &
        canonicalize, is_orthonormal


    interface linspace
            module procedure linspace_sp
            module procedure linspace_dp
    end interface

    interface eig
        module procedure eig_cmplx
        module procedure eig_real
    end interface

    interface calc_eigenvalues
        module procedure calc_eigenvalues_real
        module procedure calc_eigenvalues_cmplx
    end interface

    interface norm
        module procedure norm_real_sp
        module procedure norm_real_dp
        module procedure norm_complex_sp
        module procedure norm_complex_dp
    end interface

    interface is_orthonormal
        module procedure is_orthonormal_real_sp
        module procedure is_orthonormal_real_dp
        module procedure is_orthonormal_complex_sp
        module procedure is_orthonormal_complex_dp
    end interface

contains

    subroutine print_vec(vec, filename, t_index, t_zero)
        class(*), intent(in) :: vec(:)
        character(*), intent(in), optional :: filename
        logical, intent(in), optional :: t_index, t_zero

        logical :: t_index_, t_zero_
        integer :: iunit, i
        def_default(t_index_, t_index, .false.)
        def_default(t_zero_, t_zero, .false.)

        select type(vec)
        type is (integer)
            if (present(filename)) then
                iunit = get_free_unit()
                open(iunit, file = filename, status = 'replace', action = 'write')

                if (t_zero_) then
                    if (t_index_) then
                        write(iunit, *) 0, 0.0_dp
                    else
                        write(iunit, *) 0.0_dp
                    end if
                end if


                if (t_index_) then
                    do i = 1, size(vec,1)
                        write(iunit, *) i, vec(i)
                    end do
                else
                    do i = 1, size(vec,1)
                        write(iunit, *) vec(i)
                    end do
                end if

                close(iunit)
            else
                do i = 1, size(vec,1)
                    print *, vec(i)
                end do
            end if
        type is (real(dp))
            if (present(filename)) then
                iunit = get_free_unit()
                open(iunit, file = filename, status = 'replace', action = 'write')

                if (t_zero_) then
                    if (t_index_) then
                        write(iunit, *) 0, 0.0_dp
                    else
                        write(iunit, *) 0.0_dp
                    end if
                end if


                if (t_index_) then
                    do i = 1, size(vec,1)
                        write(iunit, *) i, vec(i)
                    end do
                else
                    do i = 1, size(vec,1)
                        write(iunit, *) vec(i)
                    end do
                end if

                close(iunit)
            else
                do i = 1, size(vec,1)
                    print *, vec(i)
                end do
            end if

        end select

    end subroutine print_vec

    subroutine eig_real(matrix, e_values, e_vectors, t_left_ev)
        ! for very restricted matrices do a diag routine here!
        real(dp), intent(in) :: matrix(:,:)
        real(dp), intent(out) :: e_values(size(matrix,1))
        real(dp), intent(out), optional :: e_vectors(size(matrix,1),size(matrix,1))
        logical, intent(in), optional :: t_left_ev

        ! get the specifics for the eigenvectors still..
        ! i think i need a bigger work, and maybe also a flag for how many
        ! eigenvectors i want.. maybe also the number of eigenvalues..
        integer :: n, info
        real(dp) :: work(4*size(matrix,1)), tmp_matrix(size(matrix,1),size(matrix,2))
        real(dp) :: left_ev(size(matrix,1),size(matrix,1)), dummy_eval(size(matrix,1))
        real(dp) :: right_ev(size(matrix,1),size(matrix,1))
        integer :: sort_ind(size(matrix,1))
        character :: left, right


        ! and convention is: we only want the right eigenvectors!!
        ! and always assume real-only eigenvalues
        if (present(e_vectors)) then

            if (present(t_left_ev)) then
                if (t_left_ev) then
                    left = 'V'
                    right = 'N'
                else
                    left = 'N'
                    right = 'V'
                end if
            else
                left = 'N'
                right = 'V'
            end if

            n = size(matrix,1)

            tmp_matrix = matrix

            left = 'V'
            right = 'V'

            call dgeev(&
                left, &
                right, &
                n, &
                tmp_matrix, &
                n, &
                e_values, &
                dummy_eval, &
                left_ev, &
                n, &
                right_ev, &
                n, &
                work, &
                4*n, &
                info)

            sort_ind = [(n, n = 1, size(matrix,1))]

            call sort(e_values, sort_ind)

            if (present(t_left_ev)) then
                if (t_left_ev) then
                    e_vectors = left_ev(:,sort_ind)
                else
                    e_vectors = right_ev(:,sort_ind)
                end if
            else
                e_vectors = right_ev(:,sort_ind)
            end if

        else
            e_values = calc_eigenvalues(matrix)
        end if

    end subroutine eig_real


    subroutine eig_cmplx(matrix, e_values, e_vectors, t_left_ev)
        ! for very restricted matrices do a diag routine here!
        complex(dp), intent(in) :: matrix(:,:)
        real(dp), intent(out) :: e_values(size(matrix,1))
        complex(dp), intent(out), optional :: e_vectors(size(matrix,1),size(matrix,1))
        logical, intent(in), optional :: t_left_ev
        character(*), parameter :: this_routine = 'eig_cmplx'

        ! get the specifics for the eigenvectors still..
        ! i think i need a bigger work, and maybe also a flag for how many
        ! eigenvectors i want.. maybe also the number of eigenvalues..
        integer :: n, info
        complex(dp) :: work(4*size(matrix,1)), tmp_matrix(size(matrix,1),size(matrix,2))
        complex(dp) :: left_ev(size(matrix,1),size(matrix,1))
        real(dp), allocatable :: rwork(:)
        real(dp) :: right_ev(size(matrix,1),size(matrix,1))
        integer :: sort_ind(size(matrix,1))
        character(len=1) :: left, right


        ! and convention is: we only want the right eigenvectors!!
        ! and always assume real-only eigenvalues
        if (present(e_vectors)) then

            if (present(t_left_ev)) then
                if (t_left_ev) then
                    left = 'V'
                    right = 'N'
                else
                    left = 'N'
                    right = 'V'
                end if
            else
                left = 'N'
                right = 'V'
            end if

            n = size(matrix,1)

            tmp_matrix = matrix

            left = 'V'
            right = 'V'


            allocate(rwork(max(1,3*n-2)))
            call zheev(&
                 left, &
                 right, &
                 n, &
                 tmp_matrix, &
                 n, &
                 e_values, &
                 work, &
                 4*n, &
                 rwork, &
                 info)
            if (info /= 0) call stop_all(this_routine, 'Failed in BLAS call.')
            deallocate(rwork)

            sort_ind = [(n, n = 1, size(matrix,1))]

            call sort(e_values, sort_ind)

            if (present(t_left_ev)) then
                if (t_left_ev) then
                    e_vectors = left_ev(:,sort_ind)
                else
                    e_vectors = right_ev(:,sort_ind)
                end if
            else
                e_vectors = right_ev(:,sort_ind)
            end if

        else
            e_values = calc_eigenvalues(matrix)
        end if

    end subroutine eig_cmplx

    function calc_eigenvalues_real(matrix) result(e_values)
        real(dp), intent(in) :: matrix(:,:)
        real(dp) :: e_values(size(matrix,1))
        character(*), parameter :: this_routine = 'calc_eigenvalues_real'

        integer :: n, info
        real(dp) :: work(3*size(matrix,1))
        real(dp) :: tmp_matrix(size(matrix,1),size(matrix,2)),dummy_val(size(matrix,1))
        real(dp) :: dummy_vec_1(1,size(matrix,1)), dummy_vec_2(1,size(matrix,1))

        n = size(matrix,1)

        tmp_matrix = matrix
        call dgeev('N','N', n, tmp_matrix, n, e_values, &
            dummy_val, dummy_vec_1,1,dummy_vec_2,1,work,3*n,info)
        if (info /= 0) call stop_all(this_routine, 'Failed in BLAS call.')
        call sort(e_values)

    end function calc_eigenvalues_real

    function calc_eigenvalues_cmplx(matrix) result(e_values)
        complex(dp), intent(in) :: matrix(:,:)
        real(dp) :: e_values(size(matrix,1))
        character(*), parameter :: this_routine = 'calc_eigenvalues_cmplx'

        integer :: n, info
        complex(dp) :: work(3*size(matrix,1))
        complex(dp) :: tmp_matrix(size(matrix,1),size(matrix,2))
        real(dp), allocatable :: rwork(:)

        n = size(matrix,1)

        tmp_matrix = matrix
        allocate(rwork(max(1, 3*n - 2)))
        call zheev('N','N', n, tmp_matrix, n, e_values, work, 3*n, rwork, info)
        if (info /= 0) call stop_all(this_routine, 'Failed in BLAS call.')
        deallocate(rwork)
        call sort(e_values)

    end function calc_eigenvalues_cmplx

    subroutine eig_sym(matrix, e_values, e_vectors)
        real(dp), intent(in) :: matrix(:,:)
        real(dp), intent(out) :: e_values(size(matrix,1))
        real(dp), intent(out), optional :: e_vectors(size(matrix,1),size(matrix,2))
        character(*), parameter :: this_routine = 'eig_sym'
        integer :: n, info, lwork
        character(1) :: jobz
        real(dp) :: tmp_matrix(size(matrix,1),size(matrix,2))
        real(dp) :: work(3*size(matrix,1))

        n = size(matrix,1)
        lwork = 3*n

        tmp_matrix = matrix

        if (present(e_vectors)) then
            jobz = 'V'
        else
            jobz = 'N'
        end if

        call dsyev(&
            jobz, &
            'U', &
            n, &
            tmp_matrix, &
            n, &
            e_values, &
            work, &
            lwork, &
            info)
        if (info /= 0) call stop_all(this_routine, 'Failed in BLAS call.')

        if (present(e_vectors)) then
            e_vectors = tmp_matrix
        end if

    end subroutine eig_sym

    pure logical function check_symmetric(matrix)
        ! function to check if a given matrix is symmetric
        ! for a square matrix!
        real(dp), intent(in) :: matrix(:, :)
        debug_function_name("check_symmetric")
        ASSERT(size(matrix, 1) == size(matrix, 2))
        check_symmetric = near_zero(sum(abs(matrix - transpose(matrix))))
    end function check_symmetric

    subroutine print_matrix(matrix, iunit)
        ! print a 2-D real matrix
        class(*), intent(in) :: matrix(:,:)
        integer, intent(in), optional :: iunit

        integer :: i, j, tmp_unit

        select type (matrix)
        type is (integer)
            if (present(iunit)) then
                do i = lbound(matrix,1), ubound(matrix,1)
                    write(iunit,*) matrix(i,:)
                end do
            else
                do i = lbound(matrix,1), ubound(matrix,1)
                    print *, matrix(i,:)
                end do
            end if
        type is (real(dp))
            if (present(iunit)) then
                do i = lbound(matrix,1),ubound(matrix,1)
                    do j = lbound(matrix,2), ubound(matrix,2) - 1
                        write(iunit,'(G25.17)', advance = 'no') matrix(i,j)
                    end do
                    write(iunit,'(G25.17)', advance = 'yes') matrix(i,j)
                end do
            else
                do i = lbound(matrix,1),ubound(matrix,1)
                    print *, matrix(i,:)
                end do
            end if
        type is (complex(dp))
            if (present(iunit)) then
                tmp_unit = iunit
            else
                tmp_unit = 6
            end if
            do i = lbound(matrix,1),ubound(matrix,1)
                do j = lbound(matrix,2), ubound(matrix,2)
                    if (j < ubound(matrix,2)) then
                        write(tmp_unit,fmt = '(F10.8,SP,F10.8,"i",1x)', advance = 'no') matrix(i,j)
                    else
                        write(tmp_unit,fmt = '(F10.8,SP,F10.8,"i")', advance = 'yes') matrix(i,j)
                    end if
                end do
            end do
        end select


    end subroutine print_matrix


    real(dp) function det(matrix)
        real(dp), intent(in) :: matrix(:,:)

        integer :: n, i, info
        integer, allocatable :: ipiv(:)
        real(dp) :: sgn
        real(dp), allocatable :: tmp_matrix(:,:)
        debug_function_name("det")

        ASSERT(size(matrix, 1) == size(matrix, 2))

        n = size(matrix,1)
        allocate(tmp_matrix(n,n), source = matrix)
        allocate(ipiv(n))

        ipiv = 0

        call dgetrf(n,n,tmp_matrix,n,ipiv,info)

        det = 1.0_dp

        do i = 1, N
            det = det * tmp_matrix(i,i)
        end do

        sgn = 1.0_dp
        do i = 1, n
            if (ipiv(i) /= i) then
                sgn = -sgn
            end if
        end do

        det = sgn * det

    end function det


    function blas_matmul(A, B) result(C)
        ! a basic wrapper to the most fundamental matrix mult with blas
        HElement_t(dp), intent(in) :: A(:,:), B(:,:)
        HElement_t(dp) :: C(size(A,1),size(A,2))

        integer :: n

        n = size(A,1)
#ifdef CMPLX_
        call zgemm('N','N', n, n, n, cmplx(1.0_dp,0.0_dp,kind=dp), A, n, B, n, &
            cmplx(1.0_dp,0.0_dp,kind=dp), C, n)
#else
        call dgemm('N', 'N', n, n, n, 1.0_dp, A, n, B, n, 0.0_dp, C, n)
#endif
    end function blas_matmul

        pure function linspace_sp(start_val, end_val, n_opt) result(vec)
            real(sp), intent(in) :: start_val, end_val
            integer, intent(in) :: n_opt
            real(sp), allocatable :: vec(:)

            integer :: i
            real(sp) :: dist

            dist = (end_val - start_val) / real(n_opt - 1, sp)
            vec = [ ( start_val + i * dist, i = 0, n_opt - 1)]
        end function linspace_sp
        pure function linspace_dp(start_val, end_val, n_opt) result(vec)
            real(dp), intent(in) :: start_val, end_val
            integer, intent(in) :: n_opt
            real(dp), allocatable :: vec(:)

            integer :: i
            real(dp) :: dist

            dist = (end_val - start_val) / real(n_opt - 1, dp)
            vec = [ ( start_val + i * dist, i = 0, n_opt - 1)]
        end function linspace_dp


    function matrix_exponential(matrix) result(exp_matrix)
        ! calculate the matrix exponential of a real, symmetric 2-D matrix with lapack
        ! routines
        ! i need A = UDU^-1
        ! e^A = Ue^DU^-1
        HElement_t(dp), intent(in) :: matrix(:,:)
        HElement_t(dp) :: exp_matrix(size(matrix,1),size(matrix,2))

        ! maybe i need to allocate this stuff:
        HElement_t(dp) :: vectors(size(matrix,1),size(matrix,2))
        real(dp) :: values(size(matrix,1))
        HElement_t(dp) :: work(3*size(matrix,1)-1)
        HElement_t(dp) :: inverse(size(matrix,1),size(matrix,2))
        HElement_t(dp) :: exp_diag(size(matrix,1),size(matrix,2))
        integer :: info, n

        n = size(matrix,1)

        ! first i need to diagonalise the matrix and calculate the
        ! eigenvectors
        vectors = matrix
#ifdef CMPLX_
        block
            real(dp), allocatable :: rwork(:)
            allocate(rwork(max(1, 3*n - 2)))
            call zheev('V', 'U', n, vectors, n, values, work, 3*n-1, rwork, info)
            deallocate(rwork)
        end block
#else
        call dsyev('V', 'U', n, vectors, n, values, work, 3*n-1,info)
#endif
        ! now i have the eigenvectors, which i need the inverse of
        ! it is rotation only or? so i would just need a transpose or?
        inverse = transpose(vectors)

        ! i need to construct exp(eigenvalues) as a diagonal matrix!
        exp_diag = matrix_diag(exp(values))

        exp_matrix = blas_matmul(blas_matmul(vectors,exp_diag), inverse)

    end function matrix_exponential

    function matrix_diag(vector) result(diag)
        ! constructs a diagonal matrix with the vector on the diagonal
        real(dp), intent(in) :: vector(:)
        HElement_t(dp) :: diag(size(vector),size(vector))

        integer :: i

        diag = 0.0_dp

        do i = 1, size(vector)
            diag(i,i) = vector(i)
        end do

    end function matrix_diag

    function matrix_inverse(matrix) result(inverse)
        ! from fortran-wiki! search there for "matrix+inversion"
        real(dp), intent(in) :: matrix(:,:)
        real(dp) :: inverse(size(matrix,1),size(matrix,2))
        character(*), parameter :: this_routine = "matrix_inverse"

        real(dp) :: work(size(matrix,1))
        integer :: ipiv(size(matrix,1))
        integer :: n, info

        inverse = matrix
        n = size(matrix,1)

        call dgetrf(n,n,inverse,n,ipiv,info)

        if (info /= 0) call stop_all(this_routine, "matrix singular!")

        call dgetri(n, inverse, n, ipiv, work, n, info)

        if (info /= 0) call stop_all(this_routine, "matrix inversion failed!")

    end function matrix_inverse

    subroutine store_hf_coeff(e_values, e_vecs, target_state, hf_coeff, hf_ind, gs_ind)
        real(dp), intent(in) :: e_values(:), e_vecs(:,:)
        integer, intent(in), optional :: target_state
        real(dp), intent(out) :: hf_coeff
        integer, intent(out) :: hf_ind, gs_ind

        real(dp) :: gs_vec(size(e_values))
        integer :: target_state_
        def_default(target_state_,target_state,1)

        gs_ind = my_minloc(e_values, target_state)

        gs_vec = abs(e_vecs(:,gs_ind))

        hf_ind = maxloc(gs_vec,1)
        hf_coeff = gs_vec(hf_ind)

    end subroutine store_hf_coeff

    pure real(dp) function my_minval(vec, target_state)
        real(dp), intent(in) :: vec(:)
        integer, intent(in), optional :: target_state

        if (present(target_state)) then
            my_minval = vec(my_minloc(vec,target_state))
        else
            my_minval = minval(vec)
        end if

    end function my_minval


    pure integer function my_minloc(vec, target_state)
        real(dp), intent(in) :: vec(:)
        integer, intent(in), optional :: target_state

        logical :: flag(size(vec))
        integer :: i


        if (present(target_state)) then
            flag = .true.
            do i = 1, target_state
                my_minloc = minloc(vec, dim = 1, mask = flag)
                flag(my_minloc) = .false.
            end do
        else
            my_minloc = minloc(vec, 1)
        end if

    end function my_minloc

        pure function norm_real_sp(vec, p) result(res)
            ! function to calculate the Lp norm of a given vector
            ! if p_in = -1 this indicates the p_inf norm
            real(sp), intent(in) :: vec(:)
            integer, intent(in), optional :: p
            real(dp) :: res
            routine_name("norm_real_sp")
            integer :: p_

            def_default(p_, p, 2)

            select case(p_)
            case(-1)
                res = maxval(abs(vec))
            case(1)
                res = sum(abs(vec))
            case(2)
                res = sqrt(sum(abs(vec)**2))
            case(3:)
                res = sum(abs(vec)**p_)**(1.0_dp / real(p_, dp))
            case default
                call stop_all(this_routine, 'invalid p')
            end select
        end function
        pure function norm_real_dp(vec, p) result(res)
            ! function to calculate the Lp norm of a given vector
            ! if p_in = -1 this indicates the p_inf norm
            real(dp), intent(in) :: vec(:)
            integer, intent(in), optional :: p
            real(dp) :: res
            routine_name("norm_real_dp")
            integer :: p_

            def_default(p_, p, 2)

            select case(p_)
            case(-1)
                res = maxval(abs(vec))
            case(1)
                res = sum(abs(vec))
            case(2)
                res = sqrt(sum(abs(vec)**2))
            case(3:)
                res = sum(abs(vec)**p_)**(1.0_dp / real(p_, dp))
            case default
                call stop_all(this_routine, 'invalid p')
            end select
        end function
        pure function norm_complex_sp(vec, p) result(res)
            ! function to calculate the Lp norm of a given vector
            ! if p_in = -1 this indicates the p_inf norm
            complex(sp), intent(in) :: vec(:)
            integer, intent(in), optional :: p
            real(dp) :: res
            routine_name("norm_complex_sp")
            integer :: p_

            def_default(p_, p, 2)

            select case(p_)
            case(-1)
                res = maxval(abs(vec))
            case(1)
                res = sum(abs(vec))
            case(2)
                res = sqrt(sum(abs(vec)**2))
            case(3:)
                res = sum(abs(vec)**p_)**(1.0_dp / real(p_, dp))
            case default
                call stop_all(this_routine, 'invalid p')
            end select
        end function
        pure function norm_complex_dp(vec, p) result(res)
            ! function to calculate the Lp norm of a given vector
            ! if p_in = -1 this indicates the p_inf norm
            complex(dp), intent(in) :: vec(:)
            integer, intent(in), optional :: p
            real(dp) :: res
            routine_name("norm_complex_dp")
            integer :: p_

            def_default(p_, p, 2)

            select case(p_)
            case(-1)
                res = maxval(abs(vec))
            case(1)
                res = sum(abs(vec))
            case(2)
                res = sqrt(sum(abs(vec)**2))
            case(3:)
                res = sum(abs(vec)**p_)**(1.0_dp / real(p_, dp))
            case default
                call stop_all(this_routine, 'invalid p')
            end select
        end function


    subroutine find_degeneracies(e_values, ind, pairs)
        ! find the indices of degenerate eigenvalues
        ! ind will have as many rows as degenerate eigenvalues exist
        ! and the columns will be the maximum number of degeneracy + 1
        ! since in the first column the number of degenerate eigenvalues are
        ! stored!
        ! it assumes the eigenvalues are sorted!!
        ! in pairs the paired indices of the degenerate eigenvalue are stored!
        real(dp), intent(in) :: e_values(:)
        integer, intent(out), allocatable :: ind(:,:), pairs(:,:)

        integer :: i, j, tmp_ind(size(e_values), size(e_values) + 1), e_ind
        integer :: max_val

        tmp_ind = 0
        e_ind = 1
        i = 1

        do while (i < size(e_values) .and. e_ind < size(e_values))
            j = 0
            do while(e_ind + j <= size(e_values))
                if (abs(e_values(e_ind) - e_values(e_ind+j)) < 10.e-8) then
                    tmp_ind(i,j+2) = e_ind+j
                    j = j + 1
                else
                    exit
                end if
            end do
            tmp_ind(i,1) = j
            i = i + 1
            e_ind = e_ind + j
        end do

        ! deal with end-value specifically
        if (e_ind == size(e_values)) then
            tmp_ind(i,1) = 1
            tmp_ind(i,2) = e_ind
        end if

        max_val = maxval(tmp_ind(:,1))+1
        allocate(ind(i-1,max_val), source = tmp_ind(1:i-1,1:max_val))

        if (max_val == 2) then
            ! if no degeneracies
            allocate(pairs(size(e_values),1))
            pairs = 0
            return
        end if
        allocate(pairs(size(e_values),max_val-2))
        pairs = 0
        ! do it in a stupid way and reuse the created array ind
        do i = 1, size(ind,1)
            if (ind(i,1) > 1) then
                do j = 2, ind(i,1) + 1
                    pairs(ind(i,j),:) = pack(ind(i,2:ind(i,1)+1), &
                        ind(i,2:ind(i,1)+1) /= ind(i,j))
                end do
            end if
        end do

    end subroutine find_degeneracies

    pure subroutine canonicalize(V, lambda)
        !! This routine canonicalizes the eigenvector matrix V and the
        !! corresponding eigenvalues lambda
        !!
        !! This procedure is particularly useful for tests,
        !! where degenerate Eigenspaces might result in different
        !! Eigenvectors.
        real(kind=dp), intent(inout) :: V(:, :), lambda(:)
        real(kind=dp), allocatable :: norms_projections(:)
            !! The norms of the projections
        integer :: low, i, j, d
        integer, allocatable :: idx(:), dimensions(:)
        real(kind=dp), allocatable :: projections(:, :)
        debug_function_name('canonicalize')

        ASSERT(size(V, 1) == size(V, 2) .and. size(V, 1) == size(lambda))

        allocate(norms_projections(size(V, 2)))

        idx = [(i, i = 1, size(lambda, 1))]
        call sort(lambda, idx)
        V(:, :) = V(:, idx)

        call determine_eigenspaces(lambda, dimensions)

        low = 1
        do j = 1, size(dimensions)
            d = dimensions(j)
            if (d == 1) then
                V(:, low) = V(:, low) / norm(V(:, low))
                low = low + d
                cycle
            end if
            ! Ensure that the basis of each subspace is orthonormal
            V(:, low : low + d - 1) = gram_schmidt(V(:, low : low + d - 1))

            projections = project_canonical_unit_vectors(V(:, low : low + d - 1))

            do i = 1, size(norms_projections)
                norms_projections(i) = norm(projections(:, i))
            end do

            idx = [(i, i = 1, size(norms_projections))]
    

        ! merge_sort
        block
            integer, dimension(:), allocatable :: tmp
            integer :: current_size, left, mid, right
            integer, parameter :: along = 1, &
                bubblesort_size = 20

            associate(X => idx)
                ! Determine good number of starting splits
                block
                    integer :: n_splits
                    n_splits = 1
                    do while (size(X, along) / n_splits + merge(0, 1, mod(size(X, along), n_splits) == 0) > bubblesort_size)
                        n_splits = n_splits + 1
                    end do
                    current_size = size(X, along) / n_splits + merge(0, 1, mod(size(X, along), n_splits) == 0)
                end block

                ! Reuse this variable as tmp for swap in bubble_sort
                ! and for merge in merge_sort.
                block
                    integer :: max_buffer_size, n_merges
                    n_merges = ceiling(log(real(size(X, along)) / real(current_size)) / log(2.0))
                    max_buffer_size = current_size * merge(2**(n_merges - 1), 1, n_merges >= 1)
                    allocate(tmp(max_buffer_size))
                end block

                ! Bubble sort bins of size `bubblesort_size`.
                do left = lbound(X, along), ubound(X, along), current_size
                    right = min(left + bubblesort_size - 1, ubound(X, along))
                        ! bubblesort
    block
      integer :: n, i
      associate(V => X(left : right))
        do n = ubound(V, 1), lbound(V, 1) + 1, -1
          do i = lbound(V, 1), ubound(V, 1) - 1
            if (.not.   comp_idx(V(i), V(i + 1))) then
                  ! swap
    block
        associate(tmp => tmp(1))
            tmp = V(i)
            V(i) = V(i + 1)
            V(i + 1) = tmp
        end associate
    end block
            end if
          end do
        end do
      end associate
    end block
                end do

                do while (current_size < size(X, along))
                    do left = lbound(X, along), ubound(X, along), 2 * current_size
                        right = min(left + 2 * current_size - 1, ubound(X, along))
                        mid = min(left + current_size, right) - 1
                        tmp(: mid - left + 1) = X(left : mid)
                            ! merge
    block
        integer :: i, j, k

        associate(A => tmp(: mid - left + 1), B => X(mid + 1 : right), C => X(left : right))

            if (size(A) + size(B) > size(C)) then
                error stop
            end if

            i = lbound(A, 1)
            j = lbound(B, 1)
            do k = lbound(C, 1), ubound(C, 1)
                if (i <= ubound(A, 1) .and. j <= ubound(B, 1)) then
                    if (  comp_idx(A(i), B(j))) then
                        C(k) = A(i)
                        i = i + 1
                    else
                        C(k) = B(j)
                        j = j + 1
                    end if
                else if (i <= ubound(A, 1)) then
                    C(k) = A(i)
                    i = i + 1
                else if (j <= ubound(B, 1)) then
                    C(k) = B(j)
                    j = j + 1
                end if
            end do
        end associate
    end block
                    end do
                    current_size = 2 * current_size
                end do
            end associate
        end block
    

        ! merge_sort
        block
            integer, dimension(:), allocatable :: tmp
            integer :: current_size, left, mid, right
            integer, parameter :: along = 1, &
                bubblesort_size = 20

            associate(X => idx(: d))
                ! Determine good number of starting splits
                block
                    integer :: n_splits
                    n_splits = 1
                    do while (size(X, along) / n_splits + merge(0, 1, mod(size(X, along), n_splits) == 0) > bubblesort_size)
                        n_splits = n_splits + 1
                    end do
                    current_size = size(X, along) / n_splits + merge(0, 1, mod(size(X, along), n_splits) == 0)
                end block

                ! Reuse this variable as tmp for swap in bubble_sort
                ! and for merge in merge_sort.
                block
                    integer :: max_buffer_size, n_merges
                    n_merges = ceiling(log(real(size(X, along)) / real(current_size)) / log(2.0))
                    max_buffer_size = current_size * merge(2**(n_merges - 1), 1, n_merges >= 1)
                    allocate(tmp(max_buffer_size))
                end block

                ! Bubble sort bins of size `bubblesort_size`.
                do left = lbound(X, along), ubound(X, along), current_size
                    right = min(left + bubblesort_size - 1, ubound(X, along))
                        ! bubblesort
    block
      integer :: n, i
      associate(V => X(left : right))
        do n = ubound(V, 1), lbound(V, 1) + 1, -1
          do i = lbound(V, 1), ubound(V, 1) - 1
            if (.not.   V(i) <= V(i + 1)) then
                  ! swap
    block
        associate(tmp => tmp(1))
            tmp = V(i)
            V(i) = V(i + 1)
            V(i + 1) = tmp
        end associate
    end block
            end if
          end do
        end do
      end associate
    end block
                end do

                do while (current_size < size(X, along))
                    do left = lbound(X, along), ubound(X, along), 2 * current_size
                        right = min(left + 2 * current_size - 1, ubound(X, along))
                        mid = min(left + current_size, right) - 1
                        tmp(: mid - left + 1) = X(left : mid)
                            ! merge
    block
        integer :: i, j, k

        associate(A => tmp(: mid - left + 1), B => X(mid + 1 : right), C => X(left : right))

            if (size(A) + size(B) > size(C)) then
                error stop
            end if

            i = lbound(A, 1)
            j = lbound(B, 1)
            do k = lbound(C, 1), ubound(C, 1)
                if (i <= ubound(A, 1) .and. j <= ubound(B, 1)) then
                    if (  A(i) <= B(j)) then
                        C(k) = A(i)
                        i = i + 1
                    else
                        C(k) = B(j)
                        j = j + 1
                    end if
                else if (i <= ubound(A, 1)) then
                    C(k) = A(i)
                    i = i + 1
                else if (j <= ubound(B, 1)) then
                    C(k) = B(j)
                    j = j + 1
                end if
            end do
        end associate
    end block
                    end do
                    current_size = 2 * current_size
                end do
            end associate
        end block

            ! reorthogonalize
            V(:, low : low + d - 1) = gram_schmidt(projections(:, idx(: d)))

            where (near_zero(V(:, low : low + d - 1)))
                V(:, low : low + d - 1) = 0._dp
            end where
            low = low + d
        end do

        contains

            pure logical function comp_idx(i, j)
                integer, intent(in) :: i, j
                comp_idx = norms_projections(i) >= norms_projections(j)
            end function
    end subroutine canonicalize


    pure function gram_schmidt(V) result(U)
        real(dp), intent(in) :: V(:, :)
        real(dp) :: U(size(V, 1), size(V, 2)), norm_factor
        routine_name("gram_schmidt")
        integer :: j, k

        U(:, 1) = V(:, 1) / norm(V(:, 1))

        do j = 2, size(V, 2)
            do k = 1, j - 1
                U(:, j) = V(:, j) - dot_product(V(:, j), U(:, k)) * U(:, k)
            end do
            norm_factor = norm(U(:, j))
            if (near_zero(norm_factor)) then
                call stop_all(this_routine, 'linear dependent vectors')
            end if
            U(:, j) = U(:, j) / norm_factor
        end do
    end function

    pure function project_canonical_unit_vectors(M) result(P)
        real(dp), intent(in) :: M(:, :)
        real(dp) :: P(size(M, 1), size(M, 1))
        integer :: i, j

        P(:, :) = 0._dp
        ! calculate projections of basis b_j into eigenvectors v_i:
        ! p_j = sum_i < b_j | v_i > v_i
        ! if b_j are canonical unit vectors, the dot product is the j-th
        ! component of v_i
        do j = 1, size(P, 2)
            do i = 1, size(M, 2)
                P(:, j) = P(:, j) + M(j, i)*M(:, i)
            end do
        end do
    end function

    pure subroutine determine_eigenspaces(lambda,dimensions)
        real(dp), intent(inout) :: lambda(:)
        integer, allocatable, intent(out) :: dimensions(:)
        integer, allocatable :: d_buffer(:)
        integer :: i, low
        integer :: n_spaces
        debug_function_name("determine_eigenspaces")

        ASSERT(all(lambda(: size(lambda) - 1) <= lambda(2 : )))

        allocate(d_buffer(size(lambda)), source=0)

        low = 1
        n_spaces = 1
        do i=1, size(lambda)
            d_buffer(n_spaces) = d_buffer(n_spaces)+1
            if (i + 1 <= size(lambda)) then
                if (.not. isclose(lambda(i), lambda(i + 1), &
                                   epsilon(lambda) * 1.0e3_dp, 1.0e-8_dp)) then
                    n_spaces = n_spaces + 1
                    lambda(low : i) = mean(lambda(low : i))
                    low = i+1
                end if
            else
                lambda(low:i) = mean(lambda(low:i))
            end if
        end do

        dimensions = d_buffer(:n_spaces)

        contains

            real(dp) pure function mean(X)
                real(dp), intent(in) :: X(:)
                mean = sum(X) / size(X)
            end function mean
    end subroutine determine_eigenspaces

    pure function is_orthonormal_real_sp(M) result(res)
        !! Check if all < M(:, i) | M(:, j) > = 0
        !!
        !! Note that M does not have to be square.
        real(sp), intent(in) :: M(:, :)
        logical :: res

        integer :: i, j

        res = .true.
        do i = 1, size(M, 2)
            if (.not. (norm(M(:, i)) .isclose. 1._dp)) then
                res = .false.
                return
            end if
            do j = i + 1, size(M, 2)
                if (.not. near_zero(dot_product(M(:, i), M(:, j)))) then
                    res = .false.
                    return
                end if
            end do
        end do
    end function
    pure function is_orthonormal_real_dp(M) result(res)
        !! Check if all < M(:, i) | M(:, j) > = 0
        !!
        !! Note that M does not have to be square.
        real(dp), intent(in) :: M(:, :)
        logical :: res

        integer :: i, j

        res = .true.
        do i = 1, size(M, 2)
            if (.not. (norm(M(:, i)) .isclose. 1._dp)) then
                res = .false.
                return
            end if
            do j = i + 1, size(M, 2)
                if (.not. near_zero(dot_product(M(:, i), M(:, j)))) then
                    res = .false.
                    return
                end if
            end do
        end do
    end function
    pure function is_orthonormal_complex_sp(M) result(res)
        !! Check if all < M(:, i) | M(:, j) > = 0
        !!
        !! Note that M does not have to be square.
        complex(sp), intent(in) :: M(:, :)
        logical :: res

        integer :: i, j

        res = .true.
        do i = 1, size(M, 2)
            if (.not. (norm(M(:, i)) .isclose. 1._dp)) then
                res = .false.
                return
            end if
            do j = i + 1, size(M, 2)
                if (.not. near_zero(dot_product(M(:, i), M(:, j)))) then
                    res = .false.
                    return
                end if
            end do
        end do
    end function
    pure function is_orthonormal_complex_dp(M) result(res)
        !! Check if all < M(:, i) | M(:, j) > = 0
        !!
        !! Note that M does not have to be square.
        complex(dp), intent(in) :: M(:, :)
        logical :: res

        integer :: i, j

        res = .true.
        do i = 1, size(M, 2)
            if (.not. (norm(M(:, i)) .isclose. 1._dp)) then
                res = .false.
                return
            end if
            do j = i + 1, size(M, 2)
                if (.not. near_zero(dot_product(M(:, i), M(:, j)))) then
                    res = .false.
                    return
                end if
            end do
        end do
    end function

end module