!==========================================================================
!
! Routines:
!
! (1) input_co()        Originally By ?         Last Modified 4/19/2009 (gsm)
!
!     input: crys, gvec,  syms, xct, flagbz types
!
!     output: kg, distgwfco types
!
!     Reads in the coarse-grid wavefunctions from file WFN_co and
!     distributes them between processors (if xct%iwriteint=1)
!     or writes them in temporary files (if xct%iwriteint=0).
!     The k-point grid is stored in kg.
!
!==========================================================================

#include "f_defs.h"

subroutine input_co(kp,crys,gvec,kg,syms,xct,flagbz,distgwfco,eqp)

  use global_m
  use eqpcor_m
  use fullbz_m
  use input_utils_m
  use misc_m
  use wfn_rho_vxc_io_m
  implicit none

  type (kpoints), intent(inout) :: kp
  type (crystal), intent(in) :: crys
  type (gspace), intent(in) :: gvec
  type (grid), intent(out) :: kg
  type (symmetry), intent(in) :: syms
  type (xctinfo), intent(inout) :: xct
  integer, intent(in) :: flagbz
  type (tdistgwf), intent(out) :: distgwfco
  type (eqpinfo), intent(inout) :: eqp

  type (crystal) :: crys_co
  type (symmetry) :: syms_co
  type (wavefunction) :: wfnv,wfnc
  type (kpoints) :: kp_co
  character :: filenamev*20,filenamec*20
  character :: tmpfn*16
  character :: fncor*32
  integer :: itpv,itpc
  integer :: irk,irks
  integer :: ii,jj,kk,is
  real(DP) :: kt(3),div,tol
  integer, allocatable :: dist(:)
  integer, allocatable :: indxk(:)
  SCALAR, allocatable :: cg(:,:), cgarray(:)

  character(len=3) :: sheader
  integer :: iflavor
  type(gspace) :: gvec_co, gvec_kpt
  logical :: skip_checkbz, broken_degeneracy

! FHJ: used to check if we should update the number of dimensions
! for the interpolation routines
  integer :: ikb, new_idimensions
  integer :: new_iperiodic(3)

  PUSH_SUB(input_co)

!-------------------------
! Print to stdout

  if (peinf%inode.eq.0) write(6,900)
900 format(/,1x,'Started reading coarse-grid wavefunctions from WFN_co',/)

  if (peinf%inode.eq.0) call open_file(unit=25,file='WFN_co',form='unformatted',status='old')
  sheader = 'WFN'
  iflavor = 0
  call read_binary_header_type(25, sheader, iflavor, kp_co, gvec_co, syms_co, crys_co, warn = .false.)

  call check_header('WFN_fi', kp, gvec, syms, crys, 'WFN_co', kp_co, gvec_co, syms_co, crys_co, is_wfn = .true.)

  if(xct%iskipinterp == 1) then
    ! we check this first to be sure the second comparison will not segfault
    if(kp%nrk /= kp_co%nrk .or. any(kp%kgrid(1:3) /= kp_co%kgrid(1:3))) then
      call die("Cannot skip interpolation if coarse and fine grids differ.")
    endif
    if(any(abs(kp%rk(1:3, 1:kp%nrk) - kp_co%rk(1:3, 1:kp%nrk)) > TOL_Zero)) then
      call die("Cannot skip interpolation if coarse and fine k-points differ.")
    endif
    ! now we do not need the k-points anymore. This is the only place we change kp.
    SAFE_DEALLOCATE_P(kp%rk)
  endif

  SAFE_ALLOCATE(gvec_co%k, (3, gvec%ng))
  call read_binary_gvectors(25, gvec%ng, gvec%ng, gvec_co%k)
  SAFE_DEALLOCATE_P(gvec_co%k)

!-----------------------------------------------------------------------
! Read eqp_co.dat for possible interpolation

  SAFE_ALLOCATE(eqp%evshift_co, (xct%nvb_co,kp_co%nrk,kp_co%nspin))
  SAFE_ALLOCATE(eqp%ecshift_co, (xct%ncb_co,kp_co%nrk,kp_co%nspin))
  eqp%evshift_co=0D0
  eqp%ecshift_co=0D0
  
  fncor = ''

  if(xct%eqp_corrections .and. xct%iskipinterp == 1) fncor = 'eqp.dat'
  ! not interpolating, coarse and fine grids are identical.
  ! must correct this grid or wreck occupations. we need kp_co%el for efermi below.
  ! DAS: this is a hack to make things work. Better is never to read anything of the coarse grid if not interpolating!

  ! Note: eqp_co_corrections and iskipinterp == 1 are incompatible and blocked in inread
  if(xct%eqp_co_corrections) then
    fncor = 'eqp_co.dat'
    xct%inteqp = 1
  endif

  SAFE_ALLOCATE(kp_co%elda, (kp_co%mnband, kp_co%nrk, kp_co%nspin))
  kp_co%elda(:,:,:) = kp_co%el(:,:,:)

  if(trim(fncor) /= '') then
    ! FIXME: for metals this is asking for a few more bands than actually needed on some k-points
    call eqpcor(fncor,peinf%inode,peinf%npes,kp_co,kp%mnband,&
      minval(kp_co%ifmax(:,:)-xct%nvb_co+1),maxval(kp_co%ifmax(:,:)+xct%ncb_co), &
      xct%nvb_co,0,xct%ncb_co,0,kp_co%el,eqp%evshift_co,eqp%ecshift_co,1,0)
  endif

  if(xct%eqp_co_corrections .and. xct%eqp_corrections) xct%inteqp = 0
  ! if we have the fine-grid QP energies, we do not need to interpolate from the fine grid

  ! scissor shift is only needed for consistency with Fermi level shift here
  ! since 'interpolating' to fine grid is the same as just applying scissor shift directly
  call scissor_shift(kp_co, kp_co%mnband, eqp%evs, eqp%evdel, eqp%ev0, eqp%ecs, eqp%ecdel, eqp%ec0, eqp%spl_tck)

  ! JRD: If we are using eqp_co.dat, then this is our best estimate of the true fermi qp fermi
  ! energy, so we update it. If we don`t update it, we may also mess up the occupations of the
  ! coarse grid because qp efermi of coarse grid may not be lda efermi of fine grids.
  if(xct%eqp_co_corrections) then
    call find_efermi(xct%rfermi, xct%efermi, xct%efermi_input, kp_co, kp%mnband, &
      "coarse grid", should_search = .true., should_update = .true., write7 = .false.)
  else
    call find_efermi(xct%rfermi, xct%efermi, xct%efermi_input, kp_co, kp%mnband, &
      "coarse grid", should_search = .true., should_update = .false., write7 = .false.)
  endif

  ! now we call again to initialize the eqp arrays
  if(xct%eqp_co_corrections) then
    call eqpcor(fncor,peinf%inode,peinf%npes,kp_co,1,0,0, &
      xct%nvb_co,xct%nvb_co,xct%ncb_co,xct%ncb_co,kp_co%el,eqp%evshift_co,eqp%ecshift_co,1,2,dont_write=.true.)
  endif

  if(any(kp_co%ifmax(:,:) == 0)) & 
    call die("BSE codes cannot handle a system where some k-points have no occupied bands.", only_root_writes = .true.) 

  kp_co%nvband=minval(kp_co%ifmax(:,:)-kp_co%ifmin(:,:))+1
  kp_co%ncband=kp_co%mnband-maxval(kp_co%ifmax(:,:))

!----------------------------------------------------------------
! (gsm) check whether the requested number of bands
!       is available in the wavefunction file

  if(xct%nvb_co .gt. kp_co%nvband) then
    call die("The requested number of valence bands is not available in WFN_co.", only_root_writes = .true.)
  endif
  if(xct%ncb_co .gt. kp_co%ncband) then
    call die("The requested number of conduction bands is not available in WFN_co.", only_root_writes = .true.)
  endif

! DAS: degenerate subspace check

  if (peinf%inode.eq.0) then
    if(xct%ncb_co .eq. kp_co%ncband) then
      call die("You must provide one more conduction band in WFN_co in order to assess degeneracy.")
    endif
    broken_degeneracy = .false.
    do jj = 1, kp_co%nspin
      do ii = 1, kp_co%nrk
        if(kp_co%ifmax(ii, jj) - xct%nvb_co > 0) then
          ! no need to compare against band 0 if all valence are included
          if(abs(kp_co%elda(kp_co%ifmax(ii, jj) - xct%nvb_co + 1, ii, jj) &
            - kp_co%elda(kp_co%ifmax(ii, jj) - xct%nvb_co, ii, jj)) .lt. TOL_Degeneracy) then
            broken_degeneracy = .true.
          endif
        endif
      enddo
    enddo

    if(broken_degeneracy) then
      if(xct%degeneracy_check_override) then
        write(0,'(a)') &
          "WARNING: Selected number of valence bands breaks degenerate subspace in WFN_co. " // &
          "Run degeneracy_check.x for allowable numbers."
        write(0,*)
      else
        write(0,'(a)') &
          "Run degeneracy_check.x for allowable numbers, or use keyword " // &
          "degeneracy_check_override to run anyway (at your peril!)."
        call die("Selected number of valence bands breaks degenerate subspace in WFN_co.")
      endif
    endif
    
    broken_degeneracy = .false.
    do jj = 1, kp_co%nspin
      do ii = 1, kp_co%nrk
        if(abs(kp_co%elda(kp_co%ifmax(ii, jj) + xct%ncb_co, ii, jj) &
          - kp_co%elda(kp_co%ifmax(ii, jj) + xct%ncb_co + 1, ii, jj)) .lt. TOL_Degeneracy) then
          broken_degeneracy = .true.
        endif
      enddo
    enddo

    if(broken_degeneracy) then
      if(xct%degeneracy_check_override) then
        write(0,'(a)') &
          "WARNING: Selected number of conduction bands breaks degenerate subspace in WFN_co. " // &
          "Run degeneracy_check.x for allowable numbers."
        write(0,*)
      else
        write(0,'(a)') &
          "Run degeneracy_check.x for allowable numbers, or use keyword " // &
          "degeneracy_check_override to run anyway (at your peril!)."
        call die("Selected number of conduction bands breaks degenerate subspace in WFN_co.")
      endif
    endif
  endif

  SAFE_DEALLOCATE_P(kp_co%elda)

!-----------------------------------------------------------------------
!     Read k-points from file kpoints_co (if it exists) or from WFN_co
!     Array indxk has the same meaning as in input

  if (xct%skpt.eq.1) then
    if (peinf%inode.eq.0) then
      call open_file(9,file='kpoints_co',form='formatted',status='old')
      read(9,*) kg%nr
      SAFE_ALLOCATE(kg%r, (3,kg%nr))
      do ii=1,kg%nr
        read(9,*) (kg%r(jj,ii),jj=1,3),div
        kg%r(:,ii) = kg%r(:,ii)/div
      enddo
      call close_file(9)
    endif ! node 0
#ifdef MPI
    call MPI_BCAST(kg%nr,1,MPI_INTEGER,0,MPI_COMM_WORLD,mpierr)
    if (peinf%inode.ne.0) then
      SAFE_ALLOCATE(kg%r, (3,kg%nr))
    endif
    call MPI_BCAST(kg%r,3*kg%nr,MPI_REAL_DP,0,MPI_COMM_WORLD,mpierr)
#endif
    tol = TOL_Small
    SAFE_ALLOCATE(indxk, (kg%nr))
    indxk=0
    do jj=1,kg%nr
      do ii=1,kp_co%nrk
        kt(:) = kg%r(:,jj) - kp_co%rk(:,ii)
        if (all(abs(kt(1:3)).lt.tol)) then
          if (indxk(jj).ne.0) then
            if (peinf%inode.eq.0) write(6,996) jj,indxk(jj),kg%r(:,jj)
          endif
          indxk(jj)=ii
        endif
      enddo
      if (indxk(jj).eq.0) then
        if (peinf%inode.eq.0) write(0,995) kg%r(:,jj)
      endif
    enddo
  else
    kg%nr=kp_co%nrk
    SAFE_ALLOCATE(kg%r, (3,kg%nr))
    kg%r(1:3,1:kg%nr)=kp_co%rk(1:3,1:kp_co%nrk)
    SAFE_ALLOCATE(indxk, (kg%nr))
    do ii=1,kg%nr
      indxk(ii)=ii
    enddo
  endif
#ifdef MPI
  call MPI_BCAST(indxk,kg%nr,MPI_INTEGER,0,MPI_COMM_WORLD,mpierr)
#endif
996 format(1x,'WARNING: Multiple definition of k-point',2i4,3f10.6)
995 format(1x,'WARNING: Could not find k-point',3f10.6,1x,'in WFN_co')
  
!-----------------------------------------------------------------------
! Initialization of distributed wavefunctions

  if (xct%iwriteint.eq.1) then
    
    distgwfco%ngm=kp_co%ngkmax
    distgwfco%nk=kg%nr
    distgwfco%ns=kp_co%nspin
    distgwfco%nv=xct%nvb_co
    distgwfco%nc=xct%ncb_co
    
    SAFE_ALLOCATE(dist, (peinf%npes))
    dist=0
    jj=peinf%npes
    do ii=1,kp_co%ngkmax
      dist(jj)=dist(jj)+1
      jj=jj-1
      if (jj.eq.0) jj=peinf%npes
    enddo
    distgwfco%ngl=dist(peinf%inode+1)
    jj=0
    do ii=1,peinf%inode
      jj=jj+dist(ii)
    enddo
    distgwfco%tgl=jj
    SAFE_DEALLOCATE(dist)
    
    SAFE_ALLOCATE(distgwfco%ng, (distgwfco%nk))
    SAFE_ALLOCATE(distgwfco%isort, (distgwfco%ngl,distgwfco%nk))
    SAFE_ALLOCATE(distgwfco%zv, (distgwfco%ngl,distgwfco%nv,distgwfco%ns,distgwfco%nk))
    SAFE_ALLOCATE(distgwfco%zc, (distgwfco%ngl,distgwfco%nc,distgwfco%ns,distgwfco%nk))
    
    distgwfco%ng(:)=0
    distgwfco%isort(:,:)=0
    distgwfco%zv(:,:,:,:)=ZERO
    distgwfco%zc(:,:,:,:)=ZERO
    
  endif ! xct%iwriteint.eq.1
  
!-----------------------------------------------------------------------
!     Generate full Brillouin zone from irreducible wedge, rk -> fk

  if (flagbz.eq.1) then
    call fullbz(crys,syms,kg,1,skip_checkbz,wigner_seitz=.true.,paranoid=.true.)
  else
    call fullbz(crys,syms,kg,syms%ntran,skip_checkbz,wigner_seitz=.true.,paranoid=.true.)
  endif
  tmpfn='WFN_co'
  if (.not. skip_checkbz) then
    call checkbz(kg%nf,kg%f,kp_co%kgrid,kp_co%shift,crys%bdot, &
      tmpfn,'k',.true.,xct%freplacebz,xct%fwritebz)
  endif
  
  if (flagbz.eq.0.and.peinf%inode.eq.0) write(6,801)
  if (flagbz.eq.1.and.peinf%inode.eq.0) write(6,802)
801 format(1x,'Using symmetries to expand the coarse-grid sampling',/)
802 format(1x,'No symmetries used in the coarse-grid sampling',/)

  if (xct%nkpt_co.ne.kg%nf) then
   if(peinf%inode == 0) write(0,994) xct%nkpt_co,kg%nf
994 format('The given number of points in the coarse grid (',i4, &
      ') does not match the number of points in file WFN_co after unfolding (',i4,').')
    call die('If you are sure WFN_co is correct, please change the .inp file and try again.', only_root_writes = .true.)
  endif

!-----------------------------------------------------------------------
! FHJ: update number of dimensions if the coarse grid has more k-pts than the
! fine one.

  new_idimensions = 0
  if (all(kp_co%kgrid(:)/=0)) then
    do ikb = 1, 3
      new_iperiodic(ikb) = 0
      if (kp_co%kgrid(ikb) > 1) then
        new_idimensions = new_idimensions + 1
        new_iperiodic(ikb) = 1
      end if
    enddo
  endif
  if (new_idimensions > xct%idimensions) then
    xct%idimensions = new_idimensions
    xct%iperiodic(:) = new_iperiodic(:)
    if (peinf%inode .eq. 0) then
      write(6,'(1x,a)') 'Updating number of dimensions using the coarse grid.'
      write(6,'(1x,a,i1,a)') 'A ',xct%idimensions,'-D interpolation algorithm will be employed.'
    endif
  endif

!-----------------------------------------------------------------------
! Read the wavefunctions and distribute or write to temp files

  if (xct%iwriteint.eq.0) then
    
    write(filenamec,'(a)') 'INT_CWFN_CO'
    itpc=126
    write(filenamev,'(a)') 'INT_VWFN_CO'
    itpv=127
    
    if (peinf%inode.eq.0) then
      call open_file(itpc,file=filenamec,form='unformatted',status='replace')
      call open_file(itpv,file=filenamev,form='unformatted',status='replace')
    endif ! node 0
    
  endif ! xct%iwriteint.eq.0
  
  SAFE_ALLOCATE(wfnv%isort, (gvec%ng))
  wfnv%nband=xct%nvb_co
  wfnv%nspin=kp_co%nspin
  wfnc%nband=xct%ncb_co
  wfnc%nspin=kp_co%nspin
  
  do irk=1,kp_co%nrk
    irks = 0
    do ii=1,kg%nr
      if (indxk(ii) == irk) then
        irks=ii
        exit
      endif
    enddo

    SAFE_ALLOCATE(gvec_kpt%k, (3, kp_co%ngk(irk)))
    call read_binary_gvectors(25, kp_co%ngk(irk), kp_co%ngk(irk), gvec_kpt%k)

    SAFE_ALLOCATE(cg, (kp_co%ngk(irk),kp_co%nspin))
    if(irks > 0) then
      do ii = 1, kp_co%ngk(irk)
        call findvector(wfnv%isort(ii), gvec_kpt%k(1, ii), gvec_kpt%k(2, ii), gvec_kpt%k(3, ii), gvec)
        if (wfnv%isort(ii) == 0) call die('Could not find g-vector.')
      enddo
      
      wfnv%ng=kp_co%ngk(irk)
      wfnc%ng=kp_co%ngk(irk)
      if(peinf%inode == 0) then
        SAFE_ALLOCATE(wfnv%cg, (wfnv%ng,wfnv%nband,wfnv%nspin))
        SAFE_ALLOCATE(wfnc%cg, (wfnc%ng,wfnc%nband,wfnc%nspin))
        SAFE_ALLOCATE(cgarray, (kp_co%ngk(irk)))
      endif
    endif

! Loop over the bands

    do ii=1,kp_co%mnband

! Read planewave coefficients for band ii
      call read_binary_data(25, kp_co%ngk(irk), kp_co%ngk(irk), kp_co%nspin, cg)

      if(irks == 0) cycle
        
      if(peinf%inode == 0) then
        do is=1, kp_co%nspin
          if (ii .gt. (kp_co%ifmax(irk,is)-xct%nvb_co) .and. ii .le. (kp_co%ifmax(irk,is)+xct%ncb_co)) then
            
            do kk=1, kp_co%ngk(irk)
              cgarray(kk)=cg(kk, is)
            end do
#ifdef VERBOSE
            write(*,'(a, 3i7, 2(f18.13))') 'input_co', irks, ii, is, cgarray(1)
#endif
            call checknorm('WFN_co',ii,irks,is,kp_co%ngk(irk),cgarray)

            if ((ii.le.kp_co%ifmax(irk,is)).and. &
              (ii.gt.kp_co%ifmax(irk,is)-xct%nvb_co)) &
              wfnv%cg(1:wfnv%ng,kp_co%ifmax(irk,is)-ii+1,is)=cgarray
            if ((ii.gt.kp_co%ifmax(irk,is)).and. &
              (ii.le.kp_co%ifmax(irk,is)+xct%ncb_co)) &
              wfnc%cg(1:wfnc%ng,ii-kp_co%ifmax(irk,is),is)=cgarray
          end if
        end do
      endif
        
    enddo ! ii (loop over bands)
    
    SAFE_DEALLOCATE(cg)
    if(peinf%inode == 0) then
      SAFE_DEALLOCATE(cgarray)
    endif

    if (xct%iwriteint.eq.1) then

#ifdef MPI
      if (peinf%inode.ne.0) then
        SAFE_ALLOCATE(wfnv%cg, (wfnv%ng,wfnv%nband,wfnv%nspin))
        SAFE_ALLOCATE(wfnc%cg, (wfnc%ng,wfnc%nband,wfnc%nspin))
      endif
      call MPI_BCAST(wfnv%cg(1,1,1),wfnv%ng*wfnv%nband*wfnv%nspin, &
        MPI_SCALAR,0,MPI_COMM_WORLD,mpierr)
      call MPI_BCAST(wfnc%cg(1,1,1),wfnc%ng*wfnc%nband*wfnc%nspin, &
        MPI_SCALAR,0,MPI_COMM_WORLD,mpierr)
#endif

      distgwfco%ng(irks)=wfnv%ng
      do ii=1,distgwfco%ngl
        if (ii+distgwfco%tgl.le.wfnv%ng) &
          distgwfco%isort(ii,irks)=wfnv%isort(ii+distgwfco%tgl)
      enddo
      do kk=1,distgwfco%ns
        do jj=1,distgwfco%nv
          do ii=1,distgwfco%ngl
            if (ii+distgwfco%tgl.le.wfnv%ng) &
              distgwfco%zv(ii,jj,kk,irks)=wfnv%cg(ii+distgwfco%tgl,jj,kk)
          enddo
        enddo
      enddo
      do kk=1,distgwfco%ns
        do jj=1,distgwfco%nc
          do ii=1,distgwfco%ngl
            if (ii+distgwfco%tgl.le.wfnv%ng) &
              distgwfco%zc(ii,jj,kk,irks)=wfnc%cg(ii+distgwfco%tgl,jj,kk)
          enddo
        enddo
      enddo

    endif ! xct%iwriteint.eq.1

    if (xct%iwriteint.eq.0) then
      
      if (peinf%inode.eq.0) then
        write(itpv) irks,wfnv%ng,wfnv%nband,wfnv%nspin
        write(itpv) (wfnv%isort(ii),ii=1,gvec%ng), &
          (((wfnv%cg(ii,jj,kk),ii=1,wfnv%ng),jj=1,wfnv%nband),kk=1,wfnv%nspin)
        write(itpc) irks,wfnc%ng,wfnc%nband,wfnc%nspin
        write(itpc) (wfnv%isort(ii),ii=1,gvec%ng), &
          (((wfnc%cg(ii,jj,kk),ii=1,wfnc%ng),jj=1,wfnc%nband),kk=1,wfnc%nspin)
      endif ! node 0

    endif ! xct%iwriteint.eq.0

    if (peinf%inode.eq.0 .or. xct%iwriteint == 1) then
      SAFE_DEALLOCATE_P(wfnv%cg)
      SAFE_DEALLOCATE_P(wfnc%cg)
    endif ! node 0

  enddo ! irk (loop over k-points)

  SAFE_DEALLOCATE_P(wfnv%isort)    
  SAFE_DEALLOCATE(indxk)
  
  if (peinf%inode.eq.0) then
    write(6,301)
    write(6,302) kg%nr
    write(6,303) ((kg%r(ii,jj),ii=1,3),jj=1,kg%nr)
    write(6,304) kg%nf,kg%sz
301 format(/,1x,'Finished reading coarse-grid wavefunctions from tape WFN_co',/)
302 format(6x,'nrk =',i4,/)
303 format(6x,3f10.6)
304 format(/,6x,'nfk =',i6,1x,'ksz =',f10.6,/)
    if (xct%iwriteint.eq.0) then
      call close_file(itpv)
      call close_file(itpc)
    endif ! xct%iwriteint.eq.0
    call close_file(25)
  endif ! node 0

  SAFE_DEALLOCATE_P(kp_co%rk)
  SAFE_DEALLOCATE_P(kp_co%ifmin)
  SAFE_DEALLOCATE_P(kp_co%ifmax)
  SAFE_DEALLOCATE_P(kp_co%el)
  
  if(xct%iwriteint == 0) then
#ifdef MPI
    call MPI_Barrier(MPI_COMM_WORLD, mpierr)
#endif
  endif

  POP_SUB(input_co)

  return
end subroutine input_co
