!===============================================================================
!
! Routines:
!
! (1) input()   Originally By ?         Last Modified 10/5/2009 (gsm)
!
! Sets up various data structures. Reads in and distributes wavefunctions.
! Stores wavefunctions on disk in temporary INT_CWFN_* and INT_WFNK_* files
! (comm_disk) or in structures distributed in memory (comm_mpi).
!
!===============================================================================

#include "f_defs.h"

subroutine input(crys,gvec,syms,kp,wpg,sig,wfnk,itpc,itpk,fnc,fnk,wfnkqmpi,wfnkmpi)

  use global_m
  use eqpcor_m
  use fftw_m
  use input_utils_m
  use misc_m
  use sort_m
  use wfn_rho_vxc_io_m
  implicit none

  type (crystal), intent(out) :: crys
  type (gspace), intent(out) :: gvec
  type (symmetry), intent(out) :: syms
  type (kpoints), intent(out) :: kp
  type (wpgen), intent(out) :: wpg
  type (siginfo), intent(out) :: sig
  type (wfnkstates), intent(out) :: wfnk
  integer, intent(in) :: itpc, itpk
  character*20, intent(in) :: fnc, fnk
  type (wfnkqmpiinfo), intent(out) :: wfnkqmpi
  type (wfnkmpiinfo), intent(out) :: wfnkmpi

  character :: fncor*32
  character :: tmpstr*100,tmpstr1*16,tmpstr2*16
  integer :: i,ierr,ii,j,jj,k,kk
  integer :: ikn,irk
  integer :: istore,nstore,inum,g1,g2,iknstore
  integer :: nq,ndvmax,ndvmax_l
  integer :: ncoul,ncoulb,ncouls,nmtx,nmtx_l,nmtx_col
  integer, allocatable :: isort(:),index(:),revindex(:)
  integer :: ipe
  integer :: Nrod,Nplane,Nfft(3),dNfft(3),dkmax(3),nmpinode
  integer :: ntband,npools,ndiag_max,noffdiag_max,ntband_max
  real(DP) :: scalarsize,qk(3),vcell,tsec(2),qtot,omega_plasma
  real(DP) :: mem,fmem,rmem,rmem2,scale,dscale
  SCALAR, allocatable :: zc(:,:)
  integer, allocatable :: gvecktmp(:,:)
  logical :: do_ch_sum, dont_read

  character(len=3) :: sheader
  integer :: iflavor
  type(gspace) :: gvec_dummy, gvec_kpt
  type(crystal) :: crys_dummy
  type(symmetry) :: syms_dummy
  type(kpoints) :: kp_dummy

  PUSH_SUB(input)

!---------------
! Read sigma.inp

  call inread(sig)

  if(sig%nkn .lt. 1) then
    if(peinf%inode.eq.0) write(0,*) 'sig%nkn < 1'
    call die('sig%nkn')
  endif
  
!---------------------------------
! Determine the available memory

  call procmem(mem,nmpinode)

  if(peinf%inode.eq.0) then
    write(6,998) mem/1024.0d0**2
    write(7,998) mem/1024.0d0**2
  endif
998 format(1x,'Memory available:',f10.1,1x,'MB per PE')
  
  fmem=mem/8.0d0

!-------------------------------
! (gsm) Estimate the required memory

  if(peinf%inode.eq.0) then

! determine number of frequencies and number of q-points
    if (sig%freq_dep.eq.-1) then
      nq=sig%nq
      sig%nFreq=1
    endif
    if (sig%freq_dep.eq.0.or.sig%freq_dep.eq.1.or.sig%freq_dep.eq.2) then
      call open_file(10,file='eps0mat',form='unformatted',status='old')
      read(10)
      read(10)i,sig%nFreq
      call close_file(10)
      call open_file(11,file='epsmat',form='unformatted',status='old',iostat=ierr)
      if (ierr.eq.0) then
        read(11)
        read(11)
        read(11)
        read(11)
        read(11)
        read(11)
        read(11)
        read(11) nq
        nq=nq+1
        call close_file(11)
      else
        nq=1
      endif
    endif
  endif

! read wavefunction parameters
  if(peinf%inode == 0) call open_file(25,file='WFN_inner',form='unformatted',status='old')

  sheader = 'WFN'
  iflavor = 0
  call read_binary_header_type(25, sheader, iflavor, kp, gvec, syms, crys)

  if(sig%ntband > kp%mnband) then
    call die("number_bands is larger than are available in WFN_inner", only_root_writes = .true.)
  endif

  do k=1,kp%nspin
    do i=1,kp%nrk
      do j=1,sig%ntband
        kp%el(sig%band_index(j), i, k) = kp%el(sig%band_index(j), i, k) - sig%avgpot / ryd
      enddo
    enddo
  enddo

  if(peinf%inode == 0) call logit('Reading WFN_inner -- gvec info')
  SAFE_ALLOCATE(gvec%k, (3, gvec%ng))
  call read_binary_gvectors(25, gvec%ng, gvec%ng, gvec%k)

!-----------------------
! sort G-vectors with respect to their kinetic energy
  SAFE_ALLOCATE(index, (gvec%ng))
  SAFE_ALLOCATE(revindex, (gvec%ng))
  SAFE_ALLOCATE(gvec%ekin, (gvec%ng))

  do i=1,gvec%ng
    gvec%ekin(i)=DOT_PRODUCT(dble(gvec%k(:, i)), MATMUL(crys%bdot, dble(gvec%k(:, i))))
  enddo
  call sortrx_D(gvec%ng, gvec%ekin, index, gvec = gvec%k)
  ncouls = gcutoff(gvec, index, sig%ecuts)
  ncoulb = gcutoff(gvec, index, sig%ecutb)
  SAFE_DEALLOCATE_P(gvec%ekin)
    
  SAFE_ALLOCATE(gvecktmp, (3,gvec%ng))
  gvecktmp(:,:)=gvec%k(:,:)
  do i=1,gvec%ng
    gvec%k(:,i)=gvecktmp(:,index(i))
    revindex(index(i)) = i
  enddo
  SAFE_DEALLOCATE(gvecktmp)
    
  SAFE_ALLOCATE(isort, (gvec%ng))

  call gvec_index(gvec)
    
! estimate for nmtx

  if(peinf%inode == 0) then
    ncoul=max(ncouls,ncoulb)
    nmtx=ncouls
    nmtx_l=int(sqrt(dble(nmtx)**2/dble(peinf%npes)))
    nmtx_col=int(dble(nmtx)/dble(peinf%npes))+1

! divide bands over processors (this is repeated below)
    ntband=min(sig%ntband,kp%mnband)
    if (peinf%npools .le. 0 .or. peinf%npools .gt. peinf%npes) then
      call createpools(sig%ndiag,ntband,peinf%npes,npools,ndiag_max,ntband_max)
    else
      npools = peinf%npools
      if (mod(sig%ndiag,npools).eq.0) then
        ndiag_max=sig%ndiag/npools
      else
        ndiag_max=sig%ndiag/npools+1
      endif
      if (mod(ntband,peinf%npes/npools).eq.0) then
        ntband_max=ntband/(peinf%npes/npools)
      else
        ntband_max=ntband/(peinf%npes/npools)+1
      endif
    endif
    if (sig%noffdiag.gt.0) then
      if (mod(sig%noffdiag,npools).eq.0) then
        noffdiag_max=sig%noffdiag/npools
      else
        noffdiag_max=sig%noffdiag/npools+1
      endif
    endif

    scalarsize = sizeof_scalar()

! required memory
    rmem=0.0d0
! arrays eps and epstemp in program main
    if (sig%freq_dep.eq.0.or.sig%freq_dep.eq.1) then
      rmem=rmem+(dble(nmtx_l)**2+dble(nmtx))*scalarsize
    endif
! arrays epsR, epsRtemp, epsA, and epsAtemp in program main
    if (sig%freq_dep.eq.2) then
      rmem=rmem+(dble(nmtx_l)**2+dble(nmtx)) &
        *dble(sig%nFreq)*2.0d0*scalarsize
    endif
    if (sig%iwriteint .eq. 1) then
! array epsmpi%eps in subroutine epscopy_mpi
      if (sig%freq_dep.eq.0.or.sig%freq_dep.eq.1) then
        rmem=rmem+(dble(nmtx_col)*nmtx)*dble(nq)*scalarsize
      endif
! arrays epsmpi%epsR and epsmpi%epsA in subroutine epscopy_mpi
      if (sig%freq_dep.eq.2) then
        rmem=rmem+(nmtx_col*nmtx)*dble(nq) &
          *dble(sig%nFreq)*2.0d0*scalarsize
      endif
    endif
! array aqs in program main
    rmem=rmem+dble(ntband_max)*dble(ncoul)*scalarsize
! array aqsaug in program main
    if (sig%noffdiag.gt.0) then
      rmem=rmem+dble(ntband_max)*dble(ncoul) &
        *dble(sig%ndiag)*dble(sig%nspin)*scalarsize
    endif
! array aqsch in program main
    if (sig%exact_ch.eq.1) then
      rmem=rmem+dble(ncoul)*scalarsize
    endif
! arrays aqsaugchd and aqsaugcho in program main
    if (sig%exact_ch.eq.1.and.nq.gt.1) then
      rmem=rmem+dble(ncoul)*dble(ndiag_max) &
        *dble(sig%nspin)*scalarsize
      if (sig%noffdiag.gt.0) then
        rmem=rmem+dble(ncoul)*dble(noffdiag_max)* &
          dble(sig%nspin)*scalarsize
      endif
    endif
! array wfnk%zk in subroutine input
    rmem=rmem+dble(ndiag_max)*dble(kp%ngkmax)* &
      dble(sig%nspin)*scalarsize
! array wfnkq%zkq in subroutine genwf
    rmem=rmem+dble(ntband_max)*dble(kp%ngkmax) &
      *dble(sig%nspin)*scalarsize
! arrays zc in subroutine input and zin in subroutine genwf
    rmem=rmem+dble(kp%ngkmax)*dble(kp%nspin)*scalarsize
! array wfnkoff%zk in program main and subroutines mtxel_vxc and mtxel_sxch
    if ((sig%exact_ch.eq.1.and.sig%noffdiag.gt.0).or. &
      (.not.sig%use_vxcdat)) then
      rmem=rmem+2.0d0*dble(kp%ngkmax)*scalarsize
    endif
! array gvec%indv in input
    rmem=rmem+dble(gvec%nktot)*4.0d0
! arrays fftbox1 and fftbox2 in subroutines mtxel and mtxel_ch
    call setup_FFT_sizes(gvec%kmax,Nfft,scale)
    rmem=rmem+dble(Nfft(1)*Nfft(2)*Nfft(3))*32.0d0
    if (sig%iwriteint .eq. 1) then
! arrays wfnkqmpi%isort and wfnkqmpi%cg in subroutine input
      rmem=rmem+dble(kp%ngkmax)*dble(kp%nrk)*4.0d0+ &
        dble(kp%ngkmax)*dble(ntband_max)*dble(sig%nspin)* &
        dble(kp%nrk)*dble(scalarsize)
! arrays wfnkmpi%isort and wfnkmpi%cg in subroutine input
      rmem=rmem+dble(kp%ngkmax)*dble(sig%nkn)*4.0d0+ &
        dble((ndiag_max*kp%ngkmax)/(peinf%npes/npools))* &
        dble(sig%nspin)*dble(sig%nkn)* &
        dble(scalarsize)
    endif

    write(6,989) rmem/1024.0d0**2
    write(7,989) rmem/1024.0d0**2

! random numbers
    rmem=0.0D0
    if (sig%icutv.ne.5) then
! arrays ran, qran, and qran2
! (ran is deallocated before qran2 is allocated)
      rmem=rmem+6.0D0*dble(nmc)*8.0D0
    endif
! various truncation schemes
    rmem2=0.0d0
! cell wire truncation
    if (sig%icutv.eq.4) then
      dkmax(1) = gvec%kmax(1) * n_in_wire
      dkmax(2) = gvec%kmax(2) * n_in_wire
      dkmax(3) = 1
      call setup_FFT_sizes(dkmax,dNfft,dscale)
! array fftbox_2D
      rmem2=rmem2+dble(dNfft(1))*dble(dNfft(2))*16.0d0
! array inv_indx
      rmem2=rmem2+dble(Nfft(1))*dble(Nfft(2))*dble(Nfft(3))* &
        4.0d0
! array qran
      rmem2=rmem2+3.0D0*dble(nmc)*8.0D0
    endif
! cell box truncation (parallel version only)
    if (sig%icutv.eq.5) then
      dkmax(1) = gvec%kmax(1) * n_in_box
      dkmax(2) = gvec%kmax(2) * n_in_box
      dkmax(3) = gvec%kmax(3) * n_in_box
      call setup_FFT_sizes(dkmax,dNfft,dscale)
      if (mod(dNfft(3),peinf%npes) == 0) then
        Nplane = dNfft(3)/peinf%npes
      else
        Nplane = dNfft(3)/peinf%npes+1
      endif
      if (mod(dNfft(1)*dNfft(2),peinf%npes) == 0) then
        Nrod = (dNfft(1)*dNfft(2))/peinf%npes
      else
        Nrod = (dNfft(1)*dNfft(2))/peinf%npes+1
      endif
! array fftbox_2D
      rmem2=rmem2+dble(dNfft(1))*dble(dNfft(2))*dble(Nplane)* &
        16.0d0
! array fftbox_1D
      rmem2=rmem2+dble(dNfft(3))*dble(Nrod)*16.0d0
! array dummy
!          rmem2=rmem2+dble(dNfft(1))*dble(dNfft(2))*16.0d0
! arrays dummy1 and dummy2
      rmem2=rmem2+dble(Nrod)*dble(peinf%npes+1)*16.0d0
! array inv_indx
      rmem2=rmem2+dble(Nfft(1))*dble(Nfft(2))*dble(Nfft(3))* &
        4.0d0
    endif
    if (rmem2 .gt. rmem) rmem = rmem2
    write(6,988) rmem/1024.0d0**2
    write(7,988) rmem/1024.0d0**2
    write(6,*)
    write(7,*)
  endif
989 format(1x,'Memory required for execution:',f7.1,1x,'MB per PE')
988 format(1x,'Memory required for vcoul:',f7.1,1x,'MB per PE')

  if(peinf%inode == 0) then
    write(7,'(1x,"Cell Volume =",e16.9,/)') crys%celvol

!----------------------------------------------
! Compute cell volume from wave number metric

    call get_volume(vcell,crys%bdot)
    if (abs(crys%celvol-vcell).gt.TOL_Small) then
      call die('volume mismatch')
    endif

! (gsm) check consistency of spin indices

    do k=1,sig%nspin
      if (sig%spin_index(k).lt.1.or.sig%spin_index(k).gt.kp%nspin) &
        call die('inconsistent spin indices')
    enddo

    if(sig%ntband.gt.kp%mnband) then
      write(tmpstr1,660) sig%ntband
      write(tmpstr2,660) kp%mnband
      write(0,666) TRUNC(tmpstr1), TRUNC(tmpstr2)
      call die('More bands specified in sigma.inp than available in WFN_inner.')
660   format(i16)
666   format(1x,'The total number of bands (',a,') specified in sigma.inp',/, &
        3x,'is larger than the number of bands (',a,') available in WFN_inner.',/)
    endif
    do_ch_sum = .not.((sig%freq_dep == -1) .or. (sig%freq_dep == 0 .and. sig%exact_ch == 1))
    if(sig%ntband.eq.kp%mnband) then
      call die("You must provide one more band in WFN_inner than used in sigma.inp number_bands in order to assess degeneracy.")
    endif

!--------------------------------------------------------------------
! SIB:  Find the k-points in sig%kpt in the list kp%rk (die if not found).
! sig%indkn holds the index of the k-points in sig%kpt in kp%rk, i.e.
! kp%rk(sig%indkn(ikn))=sig%kpt(ikn)

    SAFE_ALLOCATE(sig%indkn, (sig%nkn))
    do ikn=1,sig%nkn
      sig%indkn(ikn)=0
      qk(:)=sig%kpt(:,ikn)
      do irk=1,kp%nrk
        if(all(abs(kp%rk(1:3,irk)-qk(1:3)) .lt. TOL_Small)) sig%indkn(ikn)=irk
      enddo
      if(sig%indkn(ikn) .eq. 0) then
        write(0,'(a,3f10.5,a)') 'Could not find the k-point ', (qk(i),i=1,3), ' among those read from WFN_inner :'
        write(0,'(3f10.5)') ((kp%rk(i,irk),i=1,3),irk=1,kp%nrk)
        call die('k-point in sigma.inp k_points block not available.')
      endif
    enddo

    if(do_ch_sum) then
      if(any(abs(kp%el(sig%ntband, 1:kp%nrk, 1:kp%nspin) - kp%el(sig%ntband + 1, 1:kp%nrk, 1:kp%nspin)) .lt. TOL_Degeneracy)) then
        if(sig%degeneracy_check_override) then
          write(0,'(a)') &
            "WARNING: Selected number of bands for CH sum (number_bands) breaks degenerate subspace. " // &
            "Run degeneracy_check.x for allowable numbers."
          write(0,*)
        else
          write(0,'(a)') &
            "Run degeneracy_check.x for allowable numbers, or use keyword " // &
            "degeneracy_check_override to run anyway (at your peril!)."
          call die("Selected number of bands for CH sum (number_bands) breaks degenerate subspace.")
        endif
      endif
    endif
  endif

!----------------------------------------------------------------
! Check consistency

  if (kp%mnband < maxval(sig%diag) .and. peinf%inode == 0) then
    write(0,*) 'The highest requested band is ', maxval(sig%diag),' but WFN_inner contains only ', kp%mnband,' bands.'
    call die('too few bands')
  endif
  
  SAFE_ALLOCATE(kp%elda, (sig%ntband,kp%nrk,kp%nspin))
  kp%elda(1:sig%ntband, 1:kp%nrk, 1:kp%nspin)=ryd*kp%el(1:sig%ntband, 1:kp%nrk, 1:kp%nspin)
  call scissor_shift(kp, sig%ntband, sig%evs, sig%evdel, sig%ev0, sig%ecs, sig%ecdel, sig%ec0, sig%spl_tck)

!-----------------------------------------------------------------
! If quasi-particle corrections requested, read the corrected
! quasiparticle energies from file (in eV)

  if(sig%eqp_corrections) then
    fncor='eqp.dat'
    call eqpcor(fncor,peinf%inode,peinf%npes, &
      kp,kp%mnband,1,sig%ntband, &
      kp%mnband,0,kp%mnband,0,kp%el,kp%el,kp%el,1,0)
    ! note: want in Ry since conversion occurs below
  endif

  if(peinf%inode == 0) then
    call find_efermi(sig%rfermi, sig%efermi, sig%efermi_input, kp, sig%ntband, &
      "unshifted grid", should_search = .true., should_update = .true., write7 = .true.)

    call assess_degeneracies(kp, kp%el(sig%ntband + 1, :, :), sig%ntband, sig%efermi, sig%tol, sig = sig)

    call calc_qtot(kp, crys%celvol, sig%efermi, qtot, omega_plasma, write7 = .true.)
  endif

  ! For discussion of how q-symmetry may (and may not) be used with degenerate states,
  ! see Hybertsen and Louie, Phys. Rev. B 35, 5585 (1987) Appendix A 
  if(sig%qgridsym .and. sig%noffdiag > 0) then
    if(peinf%inode == 0) then
      write(0,'(a)') "WARNING: Cannot calculate offdiagonal elements unless no_symmetries_q_grid is set."
      write(0,'(a)') "This flag is being reset to enable the calculation."
    endif
    sig%qgridsym = .false.
  endif

  kp%el(:,:,:)=kp%el(:,:,:)*ryd

 !----------------------------------------------------------------
 ! Distribute data
 
 #ifdef MPI
  call MPI_Bcast(sig%efermi, 1, MPI_REAL_DP,0, MPI_COMM_WORLD,mpierr)
  if(sig%qgridsym .or. sig%noffdiag > 0 .or. sig%ncrit > 0) then
    if(peinf%inode.ne.0) then
      SAFE_ALLOCATE(kp%degeneracy, (sig%ntband, kp%nrk, kp%nspin))
    endif
    call MPI_Bcast(kp%degeneracy(1,1,1), sig%ntband * kp%nrk * kp%nspin, MPI_INTEGER, 0, MPI_COMM_WORLD, mpierr)
  endif
   if(peinf%inode.ne.0) then
     SAFE_ALLOCATE(sig%indkn, (sig%nkn))
   endif
   call MPI_Bcast(sig%indkn, sig%nkn, MPI_INTEGER, 0, MPI_COMM_WORLD, mpierr)
#endif

!----------------------------------------------------------------
! CHP: Set the init_freq_eval relative to the Fermi energy.
!      This part should NOT be put before the Fermi energy is set.

  sig%freqevalmin = sig%freqevalmin + sig%efermi

!--------------------------------------------------------------------
! Read in the exchange-correlation potential and store in array vxc

  if(.not. sig%use_vxcdat) then
    call logit('Reading VXC')
    SAFE_ALLOCATE(sig%vxc, (gvec%ng,kp%nspin))

    if(peinf%inode == 0) call open_file(96,file='VXC',form='unformatted',status='old')

    sheader = 'VXC'
    iflavor = 0
    call read_binary_header_type(96, sheader, iflavor, kp_dummy, gvec_dummy, syms_dummy, crys_dummy, warn = .false.)

    call check_header('WFN_inner', kp, gvec, syms, crys, 'VXC', kp_dummy, gvec_dummy, syms_dummy, crys_dummy, is_wfn = .false.)

    SAFE_ALLOCATE(gvec_dummy%k, (3, gvec_dummy%ng))
    call read_binary_gvectors(96, gvec_dummy%ng, gvec_dummy%ng, gvec_dummy%k)
    do i = 1, gvec%ng
      if(any(gvec_dummy%k(:,index(i)) .ne. gvec%k(:,i))) call die("gvec mismatch in VXC")
    enddo
    SAFE_DEALLOCATE_P(gvec_dummy%k)

    call read_binary_data(96, gvec_dummy%ng, gvec_dummy%ng, kp%nspin, sig%vxc, gindex = revindex)

    if(peinf%inode == 0) call close_file(96)
  endif ! not using vxc.dat

!--------------------------------------------------------------------
! Read in the charge density and store in array rho (formerly known as CD95)
! CD95 Ref: http://www.nature.com/nature/journal/v471/n7337/full/nature09897.html

  if(sig%freq_dep.eq.1) then
    call logit('Reading RHO')
    SAFE_ALLOCATE(wpg%rho, (gvec%ng,kp%nspin))

    if(peinf%inode == 0) call open_file(95,file='RHO',form='unformatted',status='old')

    sheader = 'RHO'
    iflavor = 0
    call read_binary_header_type(95, sheader, iflavor, kp_dummy, gvec_dummy, syms_dummy, crys_dummy, warn = .false.)

    call check_header('WFN_inner', kp, gvec, syms, crys, 'RHO', kp_dummy, gvec_dummy, syms_dummy, crys_dummy, is_wfn = .false.)

    SAFE_ALLOCATE(gvec_dummy%k, (3, gvec_dummy%ng))
    call read_binary_gvectors(95, gvec_dummy%ng, gvec_dummy%ng, gvec_dummy%k)
    do i = 1, gvec%ng
      if(any(gvec_dummy%k(:,index(i)) .ne. gvec%k(:,i))) call die("gvec mismatch in RHO")
    enddo
    SAFE_DEALLOCATE_P(gvec_dummy%k)

    call read_binary_data(95, gvec_dummy%ng, gvec_dummy%ng, kp%nspin, wpg%rho, gindex = revindex)

    if(peinf%inode == 0) call close_file(95)

    ! otherwise if nspin == 1, the 2 component may be uninitialized to NaN
    wpg%wp(1:2) = 0d0
    wpg%rhoave(1:2) = 0d0

    ! since they are sorted, if G = 0 is present, it is the first one
    if(any(gvec%k(1:3, 1) /= 0)) call die("gvectors for RHO must include G = 0")
    ! otherwise, the code below will not do what we think it does

    do k=1,kp%nspin
      wpg%rhoave(k)=dble(wpg%rho(1,k))
      wpg%wp(k)=ryd*ryd*16.0d0*PI_D*wpg%rhoave(k)/crys%celvol
    enddo
  endif ! sig%freq_dep

!------------------------------------------------------------------------
! Divide bands over processors
!
! sig%ntband           number of total (valence and conduction) bands
! peinf%npools         number of parallel sigma calculations
! peinf%ndiag_max      maximum number of diag sigma calculations
! peinf%noffdiag_max   maximum number of offdiag sigma calculations
! peinf%ntband_max     maximum number of total bands per node
!
! ntband_node          number of total bands per current node
! nvband_node          number of valence bands per current node
! peinf%indext(itb)    indices of total bands belonging to current node
!
! peinf%index_diag     band index for diag sigma calculation
! peinf%flag_diag      flag for storing diag sigma calculation
! peinf%index_offdiag  band index for offdiag sigma calculation
! peinf%flag_offdiag   flag for storing offdiag sigma calculation
!
! flags are needed in case of uneven distribution, nodes still do the
! calculation because they share epsilon and wavefunctions and need to
! participate in global communications, but the result is not stored

  if (peinf%npools .le. 0 .or. peinf%npools .gt. peinf%npes) then
    call createpools(sig%ndiag,sig%ntband,peinf%npes,npools,ndiag_max,ntband_max)
    peinf%npools = npools
    peinf%ndiag_max = ndiag_max
    peinf%ntband_max = ntband_max
  else
    if (mod(sig%ndiag,peinf%npools).eq.0) then
      peinf%ndiag_max=sig%ndiag/peinf%npools
    else
      peinf%ndiag_max=sig%ndiag/peinf%npools+1
    endif
    
    if (mod(sig%ntband,peinf%npes/peinf%npools).eq.0) then
      peinf%ntband_max=sig%ntband/(peinf%npes/peinf%npools)
    else
      peinf%ntband_max=sig%ntband/(peinf%npes/peinf%npools)+1
    endif
  endif
  
  if (sig%noffdiag.gt.0) then
    if (mod(sig%noffdiag,peinf%npools).eq.0) then
      peinf%noffdiag_max=sig%noffdiag/peinf%npools
    else
      peinf%noffdiag_max=sig%noffdiag/peinf%npools+1
    endif
  endif
  
  SAFE_ALLOCATE(peinf%index_diag, (peinf%ndiag_max))
  SAFE_ALLOCATE(peinf%flag_diag, (peinf%ndiag_max))
  peinf%index_diag=1
  peinf%flag_diag=.false.
  do ii=1,peinf%ndiag_max*peinf%npools
    jj=(ii-1)/peinf%npools+1
    kk=mod(ii-1,peinf%npools)
    if (peinf%inode/(peinf%npes/peinf%npools).eq.kk) then
      if (ii.le.sig%ndiag) then
        peinf%index_diag(jj)=ii
        peinf%flag_diag(jj)=.true.
      else
        peinf%index_diag(jj)=1
        peinf%flag_diag(jj)=.false.
      endif
    endif
  enddo

  if (sig%noffdiag.gt.0) then
    SAFE_ALLOCATE(peinf%index_offdiag, (peinf%noffdiag_max))
    SAFE_ALLOCATE(peinf%flag_offdiag, (peinf%noffdiag_max))
    peinf%index_offdiag=1
    peinf%flag_offdiag=.false.
    do ii=1,peinf%noffdiag_max*peinf%npools
      jj=(ii-1)/peinf%npools+1
      kk=mod(ii-1,peinf%npools)
      if (peinf%inode/(peinf%npes/peinf%npools).eq.kk) then
        if (ii.le.sig%noffdiag) then
          peinf%index_offdiag(jj)=ii
          peinf%flag_offdiag(jj)=.true.
        else
          peinf%index_offdiag(jj)=1
          peinf%flag_offdiag(jj)=.false.
        endif
      endif
    enddo
  endif
  
  SAFE_ALLOCATE(peinf%indext, (peinf%ntband_max))
  
  peinf%ntband_node=0
  peinf%nvband_node=0
  peinf%indext(:)=0
  
  do jj=1,peinf%npools
    do ii=1,sig%ntband
      ipe=mod(sig%band_index(ii)-1,peinf%npes/peinf%npools) &
        +(jj-1)*(peinf%npes/peinf%npools)
      if (peinf%inode.eq.ipe) then
        peinf%ntband_node=peinf%ntband_node+1
        if(sig%band_index(ii).le.sig%nvband+sig%ncrit) then
          peinf%nvband_node=peinf%nvband_node+1
        endif
        peinf%indext(peinf%ntband_node)=sig%band_index(ii)
      endif
    enddo
  enddo
  
!------------------------------------------------------------------------
! Report distribution of bands over processors

  if (peinf%inode.eq.0) then
    write(6,701) peinf%npes, peinf%npools, &
      peinf%npes/peinf%npools
    if (mod(peinf%npes,peinf%npools).ne.0) &
      write(0,702) peinf%npes-(peinf%npes/peinf%npools)* &
      peinf%npools
    if (mod(sig%ndiag,peinf%npools).eq.0) then
      write(6,703) peinf%ndiag_max
    else
      write(6,704) peinf%ndiag_max-1, peinf%ndiag_max
      write(0,705) sig%ndiag, peinf%npools
    endif
    if (mod(sig%noffdiag,peinf%npools).eq.0) then
      write(6,706) peinf%noffdiag_max
    else
      write(6,707) peinf%noffdiag_max-1, peinf%noffdiag_max
      write(0,708) sig%noffdiag, peinf%npools
    endif
    if (mod(sig%ntband,peinf%npes/peinf%npools).eq.0) then
      write(6,709) peinf%ntband_max
    else
      write(6,710) peinf%ntband_max-1, peinf%ntband_max
      write(0,711) sig%ntband, peinf%npes/peinf%npools
    endif
    write(6,*)
  endif
701 format(1x,i4,1x,'processor(s),',1x,i4,1x,'pool(s),',1x,i4,1x,'processor(s) per pool')
702 format(1x,'WARNING: distribution is not ideal,',1x,i4,1x,'processor(s) is/are idle')
703 format(1x,'each pool is computing',1x,i4,1x,'diagonal sigma matrix element(s)')
704 format(1x,'each pool is computing',1x,i4,1x,'to',1x,i4,1x,'diagonal sigma matrix element(s)')
705 format(1x,'WARNING: distribution is not ideal, number of diagonal sigma',/,&
      1x,'matrix elements',1x,i4,1x,'should be divisible by number of pools',1x,i4)
706 format(1x,'each pool is computing',1x,i4,1x,'off-diagonal sigma matrix element(s)')
707 format(1x,'each pool is computing',1x,i4,1x,'to',1x,i4,1x,'off-diagonal sigma matrix element(s)')
708 format(1x,'WARNING: distribution is not ideal, number of off-diagonal sigma',/,&
      1x,'matrix elements',1x,i4,1x,'should be divisible by number of pools',1x,i4)
709 format(1x,'each processor is holding',1x,i4,1x,'valence and conduction band(s)')
710 format(1x,'each pool is holding',1x,i4,1x,'to',1x,i4,1x,'valence and conduction band(s)')
711 format(1x,'WARNING: distribution is not ideal, number of valence and conduction bands',/,&
      1x,i4,1x,'should be divisible by number of processors per pool',1x,i4)

!-----------------------------------------------------------
!
!     LOOP OVER K-POINT GRID AND READ IN WAVEFUNCTIONS
!
!-----------------------------------------------------------

!----------------------------------------
! Open temporary wavefunction files itpc and itpk

  if (sig%iwriteint .eq. 0) then
    call open_file(itpc,file=fnc,form='unformatted',status='replace')
    if (mod(peinf%inode,peinf%npes/peinf%npools).eq.0) then
      call open_file(itpk,file=fnk,form='unformatted',status='replace')
    endif
  else
    SAFE_ALLOCATE(wfnkqmpi%nkptotal, (kp%nrk))
    SAFE_ALLOCATE(wfnkqmpi%isort, (kp%ngkmax,kp%nrk))
    SAFE_ALLOCATE(wfnkqmpi%band_index, (peinf%ntband_max,kp%nrk))
    SAFE_ALLOCATE(wfnkqmpi%qk, (3,kp%nrk))
    SAFE_ALLOCATE(wfnkqmpi%el, (sig%ntband,sig%nspin,kp%nrk))
    SAFE_ALLOCATE(wfnkqmpi%cg, (kp%ngkmax,peinf%ntband_max,sig%nspin,kp%nrk))
    if (sig%nkn.gt.1) then
      ndvmax=peinf%ndiag_max*kp%ngkmax
      if (mod(ndvmax,peinf%npes/peinf%npools).eq.0) then
        ndvmax_l=ndvmax/(peinf%npes/peinf%npools)
      else
        ndvmax_l=ndvmax/(peinf%npes/peinf%npools)+1
      endif
      SAFE_ALLOCATE(wfnkmpi%nkptotal, (sig%nkn))
      SAFE_ALLOCATE(wfnkmpi%isort, (kp%ngkmax,sig%nkn))
      SAFE_ALLOCATE(wfnkmpi%qk, (3,sig%nkn))
      SAFE_ALLOCATE(wfnkmpi%el, (sig%ntband,sig%nspin,sig%nkn))
      SAFE_ALLOCATE(wfnkmpi%elda, (sig%ntband,sig%nspin,sig%nkn))
      SAFE_ALLOCATE(wfnkmpi%cg, (ndvmax_l,sig%nspin,sig%nkn))
    endif
  endif

!-----------------------------------
! Read in wavefunction information

  SAFE_ALLOCATE(wfnk%isrtk, (gvec%ng))
  SAFE_ALLOCATE(wfnk%ek, (sig%ntband,sig%nspin))
  SAFE_ALLOCATE(wfnk%elda, (sig%ntband,sig%nspin))
  
  do irk=1,kp%nrk
    
    write(tmpstr,*) 'Reading WFN_inner -> cond/val wfns irk=',irk
    call logit(tmpstr)
    qk(:)=kp%rk(:,irk)

!----------------------------
! Read in and sort gvectors

    SAFE_ALLOCATE(gvec_kpt%k, (3, kp%ngk(irk)))
    call read_binary_gvectors(25, kp%ngk(irk), kp%ngk(irk), gvec_kpt%k)

    do i = 1, kp%ngk(irk)
      call findvector(isort(i), gvec_kpt%k(1, i), gvec_kpt%k(2, i), gvec_kpt%k(3, i), gvec)
      if (isort(i) .eq. 0) then
        write(0,*) 'could not find gvec', kp%ngk(irk), i, gvec_kpt%k(1:3, i)
        call die('findvector')
      endif
    enddo
    SAFE_DEALLOCATE_P(gvec_kpt%k)
    
!--------------------------------------------------------
! Determine if Sigma must be computed for this k-point.
! If so, store the bands and wavefunctions on file itpk.
! If there is only one k-point, store directly in wfnk.

    istore=0
    do ikn=1,sig%nkn
      if(sig%indkn(ikn).eq.irk) then
        istore=1
        iknstore=ikn
        wfnk%nkpt=kp%ngk(irk)
        wfnk%ndv=peinf%ndiag_max*kp%ngk(irk)
        wfnk%k(:)=qk(:)
        wfnk%isrtk(:)=isort(:)
        do k=1,sig%nspin
          wfnk%ek(1:sig%ntband,k)= &
            kp%el(1:sig%ntband,irk,sig%spin_index(k))
          wfnk%elda(1:sig%ntband,k)= &
            kp%elda(1:sig%ntband,irk,sig%spin_index(k))
        enddo
        SAFE_ALLOCATE(wfnk%zk, (wfnk%ndv,sig%nspin))
        wfnk%zk=ZERO
      endif
    enddo ! ikn

    if (sig%iwriteint .eq. 0) then
      write(itpc) kp%ngk(irk),sig%ntband,(isort(j),j=1,gvec%ng), &
        ((kp%el(j,irk,sig%spin_index(k)),j=1,sig%ntband), &
        k=1,sig%nspin),(qk(i),i=1,3)
    else
      wfnkqmpi%nkptotal(irk) = kp%ngk(irk)
      wfnkqmpi%isort(1:kp%ngk(irk),irk) = isort(1:kp%ngk(irk))
      do k = 1, sig%nspin
        wfnkqmpi%el(1:sig%ntband,k,irk) = &
          kp%el(1:sig%ntband,irk,sig%spin_index(k))
      enddo
      wfnkqmpi%qk(1:3,irk) = qk(1:3)
    endif

!-------------------------------------------------------------------------
! SIB:  Read wave functions from file WFN_inner (unit=25) and have
! the appropriate processor write it to itpc after checking norm.
! The wavefunctions for bands where Sigma matrix elements are
! requested are stored in wfnk%zk for later writing to unit itpk.
! If band index is greater than sig%ntband, we will actually
! not do anything with this band (see code below).
!  We still have to read it though, in order
! to advance the file to get to the data for the next k-point.

    SAFE_ALLOCATE(zc, (kp%ngk(irk), kp%nspin))
    
    inum=0
    do i=1,kp%mnband

      if (peinf%inode.eq.0) call timacc(21,1,tsec)

      dont_read = (i > sig%ntband)
      call read_binary_data(25, kp%ngk(irk), kp%ngk(irk), kp%nspin, zc, dont_read = dont_read)

      if (peinf%inode.eq.0) call timacc(21,2,tsec)

      if(i <= sig%ntband) then
        if(peinf%inode.eq.0) then
          do k=1,kp%nspin
            call checknorm('WFN_inner',i,irk,k,kp%ngk(irk),zc(:,k))
          enddo
        endif

        nstore=0
        do j=1,peinf%ndiag_max
          if (.not.peinf%flag_diag(j)) cycle
          if (sig%band_index(i).eq. &
            sig%diag(peinf%index_diag(j))) nstore=j
        enddo
        if((istore.eq.1).and.(nstore.ne.0)) then
          do k=1,sig%nspin
            do j=1,kp%ngk(irk)
              wfnk%zk((nstore-1)*kp%ngk(irk)+j,k) = zc(j,sig%spin_index(k))
            enddo
          enddo
        endif
        j=0
        do ii=1,peinf%ntband_node
          if (peinf%indext(ii).eq.sig%band_index(i)) j=1
        enddo

        if (peinf%inode.eq.0) call timacc(22,1,tsec)

        if (j.eq.1) then
          if (sig%iwriteint .eq. 0) then
            write(itpc) sig%band_index(i),kp%ngk(irk), &
              ((zc(j,sig%spin_index(k)),j=1,kp%ngk(irk)), &
              k=1,sig%nspin)
          else
            inum=inum+1
            wfnkqmpi%band_index(inum,irk)=sig%band_index(i)
            do k=1,sig%nspin
              wfnkqmpi%cg(1:kp%ngk(irk),inum,k,irk)= &
                zc(1:kp%ngk(irk),sig%spin_index(k))
            enddo
          endif
        endif

        if (peinf%inode.eq.0) call timacc(22,2,tsec)

      endif
      
    enddo ! i (loop over bands)
    SAFE_DEALLOCATE(zc)
    
    if((istore.eq.1).and.(sig%nkn.gt.1)) then

      if (peinf%inode.eq.0) call timacc(22,1,tsec)

      if (sig%iwriteint .eq. 0) then
        if (mod(peinf%inode,peinf%npes/peinf%npools).eq.0) then
          write(itpk) kp%ngk(irk),wfnk%ndv,sig%ntband, &
            (wfnk%isrtk(j),j=1,gvec%ng),(qk(i),i=1,3), &
            ((wfnk%ek(j,k),j=1,sig%ntband),k=1,sig%nspin), &
            ((wfnk%elda(j,k),j=1,sig%ntband),k=1,sig%nspin), &
            ((wfnk%zk(j,k),j=1,wfnk%ndv),k=1,sig%nspin)
        endif
      else
        ikn=iknstore
        wfnkmpi%nkptotal(ikn)=kp%ngk(irk)
        wfnkmpi%isort(1:kp%ngk(irk),ikn)=wfnk%isrtk(1:kp%ngk(irk))
        wfnkmpi%qk(1:3,ikn)=qk(1:3)
        wfnkmpi%el(1:sig%ntband,1:sig%nspin,ikn)= &
          wfnk%ek(1:sig%ntband,1:sig%nspin)
        wfnkmpi%elda(1:sig%ntband,1:sig%nspin,ikn)= &
          wfnk%elda(1:sig%ntband,1:sig%nspin)
        do k=1,sig%nspin
#ifdef MPI
          i=mod(peinf%inode,peinf%npes/peinf%npools)
          if (mod(wfnk%ndv,peinf%npes/peinf%npools).eq.0) then
            j=wfnk%ndv/(peinf%npes/peinf%npools)
          else
            j=wfnk%ndv/(peinf%npes/peinf%npools)+1
          endif
          g1=1+i*j
          g2=min(j+i*j,wfnk%ndv)
          if (g2.ge.g1) then
            wfnkmpi%cg(1:g2-g1+1,k,ikn)=wfnk%zk(g1:g2,k)
          endif ! g2.ge.g1
#else
          wfnkmpi%cg(1:wfnk%ndv,k,ikn)=wfnk%zk(1:wfnk%ndv,k)
#endif
        enddo
      endif

      if (peinf%inode.eq.0) call timacc(22,2,tsec)

      SAFE_DEALLOCATE_P(wfnk%zk)
    endif
    
  enddo ! irk (loop over k-points)
  
!------------------------------------------
! Close files and write out information about the crystal

  if(sig%iwriteint .eq. 0) then
    call close_file(itpc)
    if(mod(peinf%inode,peinf%npes/peinf%npools).eq.0) call close_file(itpk)
  endif

  if(peinf%inode.eq.0) then
    call close_file(25) ! WFN_inner
    write(6,111)
111 format(/,1x,'Read WFN_inner successfully.',/)
    write(7,111)    

!------------------------------------------------------
! Write out information about self-energy calculation

    write(6,120) sig%ntband,sig%nvband,sig%ncrit, &
      sig%ecutb,sig%ecuts,sig%sexcut,sig%gamma, &
      sig%fdf,sig%dw
    write(7,120) sig%ntband,sig%nvband,sig%ncrit, &
      sig%ecutb,sig%ecuts,sig%sexcut,sig%gamma, &
      sig%fdf,sig%dw
120 format(/,5x,'ntband =',i6,2x,'nvband =',i4,2x,'ncrit =',i3,/, &
      5x,'ecutb =',f8.4,2x,'ecuts =',f8.4,2x,'sexcut =',f8.4,/, &
      5x,'gamma =',f8.4,2x,'fdf =',i3,2x,'de =',f8.4,1x,'eV')
    write(6,130) sig%evs,sig%ev0,sig%evdel,sig%ecs,sig%ec0, &
      sig%ecdel,sig%evs_outer,sig%ev0_outer,sig%evdel_outer, &
      sig%ecs_outer,sig%ec0_outer,sig%ecdel_outer
    write(7,130) sig%evs,sig%ev0,sig%evdel,sig%ecs,sig%ec0, &
      sig%ecdel,sig%evs_outer,sig%ev0_outer,sig%evdel_outer, &
      sig%ecs_outer,sig%ec0_outer,sig%ecdel_outer
130 format(/,5x,'cvfit',7x,'=',6f8.4,/,5x,'cvfit_outer',1x,'=',6f8.4)
    if(sig%freq_dep.eq.1) then
      do k=1,sig%nspin
        write(6,160) sig%spin_index(k),wpg%rhoave &
          (sig%spin_index(k)),sqrt(wpg%wp(sig%spin_index(k)))
        write(7,160) sig%spin_index(k),wpg%rhoave &
          (sig%spin_index(k)),sqrt(wpg%wp(sig%spin_index(k)))
160     format(/,5x,'data for sum rule',3x,'rho(0,',i1,')=',f8.4, &
          3x,'wp =',f8.4,1x,'eV')
      enddo
      write(6,*)
      write(7,*)
    endif ! sig%freq_dep
    write(6,171) sig%ndiag
    write(6,172) sig%diag(:)
    write(6,173) sig%noffdiag
    if (sig%noffdiag>0) then
      do k=1,sig%noffdiag
        write(6,174) k, (sig%offmap(k,ii), ii = 1, 3)
      enddo
    endif
    write(6,*)
171 format(/,1x,'ndiag    =',i4)
172 format(1x,'diag(:)  =',999i4)
173 format(1x,'noffdiag =',i4)
174 format(1x,'offmap(:, ',i4,') =',3i4)
    
  endif ! node 0
  
  SAFE_DEALLOCATE(isort)
  SAFE_DEALLOCATE(index)
  SAFE_DEALLOCATE(revindex)
  SAFE_DEALLOCATE_P(kp%w)
  SAFE_DEALLOCATE_P(kp%el)
  SAFE_DEALLOCATE_P(kp%elda)

#ifdef MPI
  if(sig%iwriteint == 0) call MPI_Barrier(MPI_COMM_WORLD, mpierr)
#endif

  POP_SUB(input)
  
  return
  
end subroutine input
