#!/usr/bin/python3

r'''Facilitate interactive human-in-the-loop corresponding-feature picking

And use features to solve for extrinsics in a stereo pair. This needs a clearer,
more general purpose. Needs to be documented, and made to work reliably and
tested


talk about extrinsics. I leave cam0 where it is, and move cam1.

default geometry selectable. baseline stable or fixed at 1.0?

We always leave the extrinsics of camera0, and move camera1 instead. By default,
we move camera1 as dictated by the solve, leaving the baseline the same if the
solve is unique up-to-scale only.

'''

import sys
import argparse
import re
import os

def parse_args():

    def positive_int(string):
        try:
            value = int(string)
        except:
            raise argparse.ArgumentTypeError("argument MUST be a positive integer. Got '{}'".format(string))
        if value <= 0 or abs(value-float(string)) > 1e-6:
            raise argparse.ArgumentTypeError("argument MUST be a positive integer. Got '{}'".format(string))
        return value


    parser = \
        argparse.ArgumentParser(description = __doc__,
                                formatter_class=argparse.RawDescriptionHelpFormatter)

    parser.add_argument('--valid-intrinsics-region',
                        action='store_true',
                        help='''If given, annotate the image with its
                        valid-intrinsics region''')
    parser.add_argument('--range-estimate',
                        type=float,
                        default=10.,
                        help='''Initial guess for the range of all picked
                        points. Defaults to 10m''')
    parser.add_argument('--template-size',
                        type=positive_int,
                        nargs=2,
                        default = (13,13),
                        help='''The size of the template used for feature
                        matching, in pixel coordinates of the second image. Two
                        arguments are required: width height. This is passed
                        directly to mrcal.match_feature(). We default to
                        13x13''')
    parser.add_argument('--search-radius',
                        type=positive_int,
                        default = 20,
                        help='''How far the feature-matching routine should
                        search, in pixel coordinates of the second image. This
                        should be larger if the nominal range estimate is poor,
                        especially, at near ranges. This is passed directly to
                        mrcal.match_feature(). We default to 20 pixels''')
    parser.add_argument('--initial-correspondences', '--correspondences',
                        help='''Start with the given correspondences vnl. If
                        omitted, we don't start with any features at all. The
                        given vnl must have the columns: x0 y0 x1 y1. Exclusive
                        with --initial-features''')
    parser.add_argument('--initial-features',
                        help='''Start with the given features vnl. If omitted,
                        we don't start with any features at all. At startup we
                        try to find matches in image1 for each of these
                        features; any failed matches are discarded. The given
                        vnl must have the columns: x0 y0. Exclusive with
                        --initial-correspondences''')
    parser.add_argument('--rt01',
                        help='''We always leave the extrinsics of camera0, and
                        move camera1, as dictated by the solve; we leave the
                        baseline the same if the solve is unique up-to-scale
                        only. By default we start with the geometry given by the
                        models. If --rt01 is given, we use this relative initial
                        geometry instead. This is a comma-separated list of 6
                        numbers''')
    ######## image pre-filtering; same as in mrcal-stereo; please consolidate
    parser.add_argument('--equalization',
                        choices=('clahe', 'fieldscale','stretch'),
                        help='''The equalization method to use for the input
                        images. "fieldscale" requires mrcam to be installed, and
                        can only operate on uint16 images''')

    parser.add_argument('models',
                        type=str,
                        nargs = 2,
                        help='''Camera models representing cameras used to
                        capture the images. Intrinsics only are used. A nominal
                        stereo geometry with unit baseline is assumed. Given
                        models are the left, right cameras''')
    parser.add_argument('images',
                        type=str,
                        nargs=2,
                        help='''The images to use for the matching''')

    args = parser.parse_args()

    if args.initial_correspondences is not None and \
       args.initial_features        is not None:
        print("--initial-correspondences and --initial-features are mutually exclusive",
              file=sys.stderr)
        sys.exit(1)
    return args

args = parse_args()

# arg-parsing is done before the imports so that --help works without building
# stuff, so that I can generate the manpages and README




from fltk               import *
from Fl_Gl_Image_Widget import *

import numpy as np
import numpysane as nps
import mrcal
import pyopengv
import vnlog

if args.equalization == 'fieldscale':
    try:
        import mrcam
    except:
        print("ERROR: the 'fieldscale' equalization method requires mrcam, but it could not be imported", file=sys.stderr)
        sys.exit(1)
if args.equalization == 'clahe':
    import cv2


def get_q01_estimate(*,
                     q,
                     index):
    i_from     = index
    i_to       = 1-index


    if index == 0: Rt_to_from = mrcal.invert_Rt(context['Rt01'])
    else:          Rt_to_from = context['Rt01']

    v_from = mrcal.unproject(q, *models[i_from].intrinsics(),
                        normalize = True)
    p_from = v_from * args.range_estimate
    p_to = mrcal.transform_point_Rt(Rt_to_from, p_from)
    q_to = mrcal.project(p_to, *models[i_to].intrinsics())

    if index == 0:
        return nps.cat(q, q_to)
    else:
        return nps.cat(q_to, q)


def solve_R01__kneip(v0,v1,
                     **cookie):
    r'''Rotation from corresponding-feature observations: Kneip's eigesolver

Kneip describes a method to compute the rotation:

  L. Kneip, R. Siegwart, M. Pollefeys, "Finding the Exact Rotation Between Two
  Images Independently of the Translation", Proc. of The European Conference on
  Computer Vision (ECCV), Florence, Italy. October 2012.

  L. Kneip, S. Lynen, "Direct Optimization of Frame-to-Frame Rotation", Proc. of
  The International Conference on Computer Vision (ICCV), Sydney, Australia.
  December 2013.
'''
    if cookie:
        raise Exception("This function does not use the cookie")
    return \
        pyopengv.relative_pose_eigensolver(v0,v1,
                                           # seed
                                           mrcal.identity_R())


def solve_R01__assume_infinity(v0,v1,
                               **cookie):
    r'''Rotation from corresponding-feature observations: assume everything is infinitely-far away

For testing or maybe seeding
    '''

    # if cookie:
    #     raise Exception("This function does not use the cookie")
    return \
        mrcal.align_procrustes_vectors_R01(v0,v1)


def solve_R01__optimize(v0,v1,
                        *,
                        # cookie
                        t01_fixed = None):
    r'''Rotation from corresponding-feature observations: direct optimization

This is for experimenting. solve_R01__kneip() should work well.
'''
    import scipy.optimize

    def cost__coplanarity_normal_vectors(r01):
        n = np.cross(v0,
                     mrcal.rotate_point_r(r01,v1))
        l,v = mrcal.sorted_eig(nps.matmult(n.T,n))
        return l[0]

    def cost__triangulation_geometric(r01):

        R01 = mrcal.R_from_r(r01)
        t01,Epredicted,A,B = t01_from_R01__geometric(v0,v1,
                                                     R01 = R01)
        return Epredicted

    def cost__triangulation_have_t01(r01):

        R01 = mrcal.R_from_r(r01)
        p0 = \
            mrcal.triangulate_leecivera_wmid2(v0, v1,
                                              v_are_local = True,
                                              Rt01        = nps.glue(R01,
                                                                     t01_fixed,
                                                                     axis = -2))

        # Have triangulated points p0. Divergent or at-infinity triangulations
        # are reported as p0=(0,0,0).
        #
        # I want to minimize sum(th_err). It's analogous to maximize
        # sum(cos(th_err)) ~ sum(inner(v0, p0/mag(p0))). I minimize
        # -sum(inner(...)) instead.
        #
        # For the divergent points I assume we're at infinity, and use inner(v0,
        # R01 v1)/2. The /2 is because the convergent cost is from v0 to the
        # midpoint, and I want to keep the same scaling for the divergent cost.
        # This inner(v0,v1) logic will treat noise incorrectly (just before
        # infinity I correctly only respond to yaw noise, but past infinity I
        # would respond to all noise). This is close-enough for now.

        mask = nps.norm2(p0) > 0

        if np.any(mask):
            cost_convergent = \
                -np.sum( nps.inner(p0[mask]/nps.dummy(nps.mag(p0[mask]),-1), v0[mask]) )
        else:
            cost_convergent = 0

        if np.any(~mask):
            cost_divergent = \
                -np.sum( nps.inner(v0[~mask], mrcal.rotate_point_R(R01, v1[~mask])) ) / 2
        else:
            cost_divergent = 0

        return cost_convergent + cost_divergent

    if t01_fixed is None:
        res = \
            scipy.optimize.minimize( cost__triangulation_geometric,
                                     x0 = mrcal.identity_r() )
    else:
        res = \
            scipy.optimize.minimize( cost__triangulation_have_t01,
                                     x0 = mrcal.identity_r() )

    R01 = mrcal.R_from_r(res.x)

    return R01



    # debugging stuff
    if True:

        t01,Epredicted,A,B = t01_from_R01__geometric(v0,v1,
                                                     R01 = R01)
        # e = p - k0 v0
        # E = sum(norm2(e_i))

        Rt01    = nps.glue(R01,  t01, axis = -2)
        Rt01neg = nps.glue(R01, -t01, axis = -2)
        p = mrcal.triangulate_geometric(v0, v1,
                                        v_are_local = True,
                                        Rt01        = Rt01 )
        E = np.sum(nps.norm2(p - nps.dummy(nps.inner(p,v0),-1) * v0))
        print(f"E_relerr={(E-Epredicted) / ((np.abs(E)+np.abs(Epredicted))/2.)} {E=} {Epredicted=}")

    R01_in         = context['Rt01'][:3,:]
    R01_procrustes = mrcal.align_procrustes_vectors_R01(v0,v1)
    R01_kneip      = solve_R01__kneip(v0,v1)

    r01_in         = mrcal.r_from_R(R01_in)
    r01_procrustes = mrcal.r_from_R(R01_procrustes)
    r01_kneip      = mrcal.r_from_R(R01_kneip)

    if False:
        print(f"{r01=} {r01_in=} {r01_procrustes=} {r01_kneip=}")

    return R01


def t01_from_R01__geometric(v0,v1,R01):
    r'''Translation from corresponding-feature observations: Dima's method

Given corresponding feature observations from two cameras we compute the
transform between the cameras. This is unique up-to-scale, so the reported
translation has length = 1.

This method assumes a rotation is already available. Given this rotation this
method computes the translation. This may or may not be any better than
solve_Rt01_given_R01__epipolar_plane_normality(). That needs evaluation.

The derivation of this method follows. I use the geometric triangulation
expression derived here:

  https://github.com/dkogan/mrcal/blob/8be76fc28278f8396c0d3b07dcaada2928f1aae0/triangulation.cc#L112

I assume that I'm triangulating normalized v0,v1 both expressed in cam-0
coordinates. And I have a t01 translation that I call "t" from here on. This is
unique only up-to-scale, so I assume that norm2(t) = 1. The geometric
triangulation from the above link says that

  [k0] = 1/(v0.v0 v1.v1 -(v0.v1)**2) [ v1.v1   v0.v1][ t01.v0]
  [k1]                               [ v0.v1   v0.v0][-t01.v1]

  The midpoint p is

  p = (k0 v0 + t01 + k1 v1)/2

I assume that v are normalized and I represent k as a vector. I also define

  c = inner(v0,v1)

So

  k = 1/(1 - c^2) [1 c] [ v0t] t
                  [c 1] [-v1t]

I define

  A = 1/(1 - c^2) [1 c]     This is a 2x2 array
                  [c 1]

  B = [ v0t]                This is a 2x3 array
      [-v1t]

  V = [v0 v1]               This is a 3x2 array

Note that none of A,B,V depend on t.

So

  k = A B t

Then

  p = (k0 v0 + t01 + k1 v1)/2
    = (V k + t)/2
    = (V A B t + t)/2
    = (I + V A B) t/2

Each triangulated error is

  e = mag(p - k0 v0)

I split A into its rows

  A = [ a0t ]
      [ a1t ]

Then

  e = p - k0 v0
    = (I + V A B) t/2 - v0 a0t B t
    = I t/2 + V A B t/2 - v0 a0t B t
    = I t/2 + v0 a0t B t/2 + v1 a1t B t/2 - v0 a0t B t
    = I t/2 - v0 a0t B t/2 + v1 a1t B t/2
    = ((I - v0 a0t B + v1 a1t B) t) / 2
    = ((I + (- v0 a0t + v1 a1t) B) t) / 2
    = ((I - Bt A B) t) / 2

I define a joint error function I'm optimizing as the sum of all the individual
triangulation errors:

  E = sum(norm2(e_i))

Each component is

  norm2(e) = 1/4 tt (I - Bt A B)t (I - Bt A B) t
           = 1/4 tt (I - 2 Bt A B + Bt A B Bt A B ) t

  B Bt = [1  -c]
         [-c  1]

  B Bt A = 1/(1 - c^2) [1  -c] [1 c]
                       [-c  1] [c 1]
         = 1/(1 - c^2) [1-c^2  0     ]
                       [0      1-c^2 ]
         = I

-> norm2(e) = 1/4 tt (I - 2 Bt A B + Bt A B) t
            = 1/4 tt (I - Bt A B) t
            = 1/4 - 1/4 tt Bt A B t

So

    E = N/4 - 1/4 tt sum(Bt A B) t

I let

    M  = sum(Bt A B)
    M  = sum(Mi)
    Mi = Bt A B

So

    E = N/4 - 1/4 tt M t
      = N/4 - 1/4 lambda

So to minimize E I find t that is the eigenvector of M that corresponds to its
largest eigenvalue lambda. Furthermore, lambda depends on the rotation. If I
couldn't estimate the rotation from far-away features I can solve the
eigenvalue-optimization problem to maximize lambda.

More simplification:

    Mi = Bt A B = [ v0  -v1 ] A [ v0t]
                                [-v1t]
       = 1/(1-c^2) [ v0 - v1 c    v0 c - v1] [ v0t]
                                             [-v1t]
       = 1/(1-c^2) ((v0 - v1 c) v0t - (v0 c - v1) v1t)

    c  = v0t v1 ->
    F0 = v0 v0t
    F1 = v1 v1t

    -> Mi = 1/(1-c^2) (v0 v0t - v1 v1t v0 v0t + v1 v1t - v0 v0t v1 v1t)
          = 1/(1-c^2) (F0 + F1 - (F1 F0 + F0 F1))
          = (F0 - F1)^2 / (1 - c^2)
          = ((F0 - F1)/s)^2

    where s = mag(cross(v0,v1))


    tt M t = sum( norm2((F0i - F1i)/si t) )

    Let Di = (F0i - F1i)/si

    I want to maximize sum( norm2(Di t) )


    (F0 - F1)/s = (v0 v0t - v1 v1t) / mag(cross(v0,v1))
                ~ (v0 v0t - R v1 v1t Rt) / mag(cross(v0,R v1))

experiments:

          = 1/(1-c^2) (F0 + F1 - (F1 F0 + F0 F1))
          = 1/(1-c^2) (v0 v0t + v1 v1t - (c v0 v1t + c v1 v0t))

    '''
    v1 = mrcal.rotate_point_R(R01,v1)

    # shape (N,)
    c = nps.inner(v0,v1)

    N = len(c)

    # shape (N,2,2)
    A = np.ones((N,2,2), dtype=float)
    A[:,0,1] = c
    A[:,1,0] = c
    A /= nps.mv(1. - c*c, -1,-3)

    # shape (N,2,3)
    B = np.empty((N,2,3), dtype=float)
    B[:,0,:] =  v0
    B[:,1,:] = -v1

    # shape (3,3)
    M = np.sum( nps.matmult(nps.transpose(B), A, B), axis = 0 )

    l,v = mrcal.sorted_eig(M)

    # The answer is the eigenvector corresponding to the biggest eigenvalue
    t01 = v[:,-1]

    E = N/4 - l[-1]/4

    return t01,E,A,B


def solve_Rt01_given_R01__geometric(v0,v1,R01):
    r'''Estimate pose from rotation and feature observations

See the docstring for t01_from_R01__geometric()
'''
    t01,Epredicted,A,B = t01_from_R01__geometric(v0,v1,R01)

    # Almost done. I want either t or -t. The wrong one will produce
    # mostly triangulations behind me
    k = nps.matmult(A,B, nps.transpose(t01))[..., 0]

    mask_divergent_t    = (k[:,0] <= 0) + (k[:,1] <= 0)
    mask_divergent_negt = (k[:,0] >= 0) + (k[:,1] >= 0)
    N_divergent_t    = np.count_nonzero( mask_divergent_t )
    N_divergent_negt = np.count_nonzero( mask_divergent_negt )

    if False:
        # Epipolar plane normals. Should be planar, if R01 is right
        n = np.cross(v0,mrcal.rotate_point_R(R01,v1))
        l,vv = mrcal.sorted_eig(nps.matmult(n.T,n))
        print(f"epipolar normal t,mine: {vv[:,0]=} {t01=}")

    if N_divergent_t < N_divergent_negt:
        return ( nps.glue(R01, t01, axis=-2),
                 mask_divergent_t,
                 N_divergent_t )
    else:
        return ( nps.glue(R01, -t01, axis=-2),
                 mask_divergent_negt,
                 N_divergent_negt )


def solve_Rt01_given_R01__epipolar_plane_normality(v0,v1,R01):
    r'''Translation from corresponding-feature observations: Kneip's method

Given corresponding feature observations from two cameras we compute the
transform between the cameras. This is unique up-to-scale, so the reported
translation has length = 1.

This method assumes a rotation is already available (from Kneip's eigensolver,
for instance). Given this rotation this method computes the translation. This
may or may not be any better than solve_Rt01_given_R01__geometric(). That needs
evaluation.

Kneip describes a method to compute the ROTATION:

  L. Kneip, R. Siegwart, M. Pollefeys, "Finding the Exact Rotation Between Two
  Images Independently of the Translation", Proc. of The European Conference on
  Computer Vision (ECCV), Florence, Italy. October 2012.

  L. Kneip, S. Lynen, "Direct Optimization of Frame-to-Frame Rotation", Proc. of
  The International Conference on Computer Vision (ICCV), Sydney, Australia.
  December 2013.

The rotation is the hard part, but we do still need a translation. opengv should
do this for us, and I think its C++ API does, but its Python bindings are
lacking. So I compute t myself for now

    '''

    Rt01 = np.zeros((4,3), dtype=float)
    Rt01[:3,:] = R01

    # shape (N,3)
    c = np.cross(v0, mrcal.rotate_point_R(R01, v1))
    l,v = mrcal.sorted_eig(np.sum(nps.outer(c,c), axis=0))
    # t is the eigenvector corresponding to the smallest eigenvalue
    t01 = v[:,0]

    Rt01[3,:] = t01

    # Almost done. I want either t or -t. The wrong one will produce
    # mostly triangulations behind me
    p_t = mrcal.triangulate_geometric(v0, v1,
                                      v_are_local = True,
                                      Rt01        = Rt01 )
    mask_divergent_t = (nps.norm2(p_t) == 0)
    N_divergent_t    = np.count_nonzero( mask_divergent_t )

    Rt01_negt = Rt01 * nps.transpose(np.array((1,1,1,-1),))
    p_negt = mrcal.triangulate_geometric(v0, v1,
                                         v_are_local = True,
                                         Rt01        = Rt01_negt )
    mask_divergent_negt = (nps.norm2(p_negt) == 0)
    N_divergent_negt    = np.count_nonzero( mask_divergent_negt )

    if N_divergent_t == 0 and N_divergent_negt == 0:
        raise Exception("Cannot ambiguously determine t: neither has divergent observations")

    # We have divergences even in the best case. I pick the fewer-divergence
    # case, and report the diverging observations as outliers. All converging
    # observations are treated as inlier here, regardless of reprojection error
    if N_divergent_t < N_divergent_negt:
        return (Rt01,
                mask_divergent_t,
                N_divergent_t)
    else:
        return (Rt01_negt,
                mask_divergent_negt,
                N_divergent_negt)


def solve_Rt01__baseline_noz(v0,v1,R01,
                             *,
                             t01_fixed     = None,
                             baseline_want = None,
                             th0           = None,
                             dth           = None,
                             Nth           = None):
    r'''Translation keeping a constant z offset, sampling the xy plane
    '''
    if t01_fixed is not None:
        raise Exception("This function uses baseline_want, and NOT t01_fixed")
    if baseline_want is None:
        raise Exception("This function requires baseline_want")

    # I do this discretely
    if not (    th0 is not None \
            and dth is not None \
            and Nth is not None):
        # default: look at the full 360 span
        Nth = 1000
        thdeg = np.linspace(0, 360.,
                            num      = Nth,
                            endpoint = False)
    else:
        thdeg = np.linspace(th0 - dth,
                            th0 + dth,
                            num = Nth)
    th = thdeg * np.pi/180.


    c = np.cos(th)
    s = np.sin(th)

    Rt01       = np.zeros((Nth,4,3), dtype=float)
    Rt01[:,:3,:] += R01
    Rt01[:, 3,:]  = nps.transpose( nps.cat(c, s, np.zeros(c.shape)) ) * baseline_want

    # shape (Nth,Npoints,3)
    p = mrcal.triangulate_geometric(v0, v1,
                                    v_are_local = True,
                                    Rt01        = nps.dummy(Rt01, -3) )

    Npoints = p.shape[-2]

    # shape (Nth,Npoints)
    magp = nps.mag(p)
    # shape (Nth,Npoints)
    mask_divergent = (magp == 0)
    # shape (Nth)
    N_divergent    = np.count_nonzero( mask_divergent,
                                       axis = -1)

    mean_thsq_err                    = np.zeros( (Nth,), dtype=float)
    mean_thsq_err_accept_divergences = np.zeros( (Nth,), dtype=float)
    for ith in range(Nth):
        m = ~mask_divergent[ith]
        if not np.any(m):
            mean_thsq_err[ith] = 1e6
        else:
            # inner ~ cos ~ 1 - th^2/2
            sum_cos = np.sum(nps.inner(p[ith][m], v0[m]) / magp[ith][m])
            mean_cos = sum_cos / np.count_nonzero(m)

            sum_cos_accept_divergences = sum_cos + np.count_nonzero(~m)
            mean_cos_accept_divergences = sum_cos_accept_divergences/Npoints

            mean_thsq_err[ith]                    = 2.*(1. - mean_cos)
            mean_thsq_err_accept_divergences[ith] = 2.*(1. - mean_cos_accept_divergences)

    # diagnostic code to visualize the scores and divergence counts
    if False:
        import gnuplotlib as gp
        gp.plot( (th/np.pi*180., nps.cat(mean_thsq_err,
                                         mean_thsq_err_accept_divergences), ),
                 (th/np.pi*180., N_divergent,
                  dict(y2=True) ),
                 ymax=1e-6)
        import IPython
        IPython.embed()
        sys.exit()

    ith = np.argmin(mean_thsq_err)
    if False:
        # be truthful about divergences
        return (Rt01[ith],
                mask_divergent[ith],
                N_divergent[ith])
    else:
        # lie, and say that there were no divergences
        return (Rt01[ith],
                np.zeros( (Npoints,), dtype=bool),
                0)


def solve_Rt10(# shape (N,Nleftright=2,Nxy=2)
               q01,
               intrinsics0,
               intrinsics1,
               *,
               solve_R10           = solve_R01__kneip,
               solve_Rt10_from_R01 = solve_Rt01_given_R01__geometric,
               # cookie
               t01_fixed     = None,
               baseline_want = None):
    r'''Full-transform optimization

Given corresponding feature observations from two cameras we compute the
transform between the cameras. This is unique up-to-scale, so the reported
translation has length = 1.

The methods to compute the rotation and the translation are given in solve_R10
and solve_Rt10_from_R01.

    '''

    # shape (N,3)
    # These are in their LOCAL coord system
    v0 = mrcal.unproject(q01[...,0,:], *intrinsics0, normalize = True)
    v1 = mrcal.unproject(q01[...,1,:], *intrinsics1, normalize = True)


    # Keep all all points initially
    mask_inliers = np.ones( (q01.shape[0],), dtype=bool )

    while True:

        if np.count_nonzero(mask_inliers) == 0:
            raise Exception("All points thrown out as outliers")

        R01 = solve_R10(v0[mask_inliers],
                        v1[mask_inliers],
                        t01_fixed = t01_fixed)

        if t01_fixed is not None:
            Rt01 = nps.glue(R01, t01_fixed, axis=-2)
            return Rt01, mask_inliers

        Rt01, mask_outlier, Noutliers = \
            solve_Rt10_from_R01(v0[mask_inliers],
                                v1[mask_inliers],
                                R01 = R01,
                                t01_fixed     = t01_fixed,
                                baseline_want = baseline_want)
        if Noutliers == 0:
            break

        mask_inliers[np.nonzero(mask_inliers)[0][mask_outlier]] = False

    # This hopefully was already enforced above
    # if baseline_want is not None:
    #     baseline_have = nps.mag(Rt01[3,:])
    #     Rt01[3,:] *= baseline_want/baseline_have

    return \
        Rt01, \
        mask_inliers

def update_p0_triangulated_stored():
    global context

    q0 = context['q01_stored'][...,0,:]
    q1 = context['q01_stored'][...,1,:]
    v0 = mrcal.unproject(q0, *models[0].intrinsics(), normalize = True)
    v1 = mrcal.unproject(q1, *models[1].intrinsics(), normalize = True)
    context['p0_triangulated_stored'] = \
        mrcal.triangulate_geometric(v0, v1,
                                    v_are_local = True,
                                    Rt01        = context['Rt01'] )

def update_q01_stored(_q01_stored):
    global context

    context['q01_stored'] = _q01_stored
    update_p0_triangulated_stored()
    try:
        widget_table.rows(len(context['q01_stored']))
    except NameError:
        # widget_table might not exist yet
        pass


def line_segments_squares(# shape (..., 2)
                          q,
                          radii):

    Nradii = len(radii)

    # shape (..., Nradii, Nsegments_in_square=4, Npoints_in_line_segment=2, xy=2)
    p = np.zeros(q.shape[:-1] + (Nradii,4,2,2), dtype=np.float32)
    p += nps.dummy(q,-2,-2,-2)

    rx = np.array((1,0), dtype=np.float32)
    ry = np.array((0,1), dtype=np.float32)

    for i,radius in enumerate(radii):

        # line segment 0
        p[..., i,0,0,:] += (-rx-ry)*radius
        p[..., i,0,1,:] += (-rx+ry)*radius
        # line segment 1
        p[..., i,1,0,:] += (+rx-ry)*radius
        p[..., i,1,1,:] += (+rx+ry)*radius
        # line segment 2
        p[..., i,2,0,:] += (-rx-ry)*radius
        p[..., i,2,1,:] += (+rx-ry)*radius
        # line segment 3
        p[..., i,3,0,:] += (-rx+ry)*radius
        p[..., i,3,1,:] += (+rx+ry)*radius

    # flatten to a list of line segments
    # shape (N, 2,2)
    return nps.clump(p, n=p.ndim-2)

def line_segments_crosshairs(# shape (..., 2)
                             q,
                             radius):

    # shape (..., Nsegments_in_crosshair=2, Npoints_in_line_segment=2, xy=2)
    p = np.zeros(q.shape[:-1] + (4,2,2), dtype=np.float32)
    p += nps.dummy(q,-2,-2)

    rx = np.array((1,0), dtype=np.float32)
    ry = np.array((0,1), dtype=np.float32)

    # line segment 0
    p[..., 0,0,:] += (-rx-ry)*radius
    p[..., 0,1,:] += (+rx+ry)*radius
    # line segment 1
    p[..., 1,0,:] += (+rx-ry)*radius
    p[..., 1,1,:] += (-rx+ry)*radius

    # flatten to a list of line segments
    # shape (..., 2,2)
    return nps.clump(p, n=p.ndim-2)


def set_all_overlay_lines_and_redraw():

    radius_stored    = 10
    crosshair_radius = 1

    color_stored = np.array((0,1,0), dtype=np.float32)

    radius_selected = 11
    color_selected  = np.array((0,0,1), dtype=np.float32)

    color_searchbox = np.array((1,1,0), dtype=np.float32)



    Nstored = len(context['q01_stored'])

    Rt01 = context['Rt01']
    baseline = nps.mag(Rt01[3,:])


    # for projected rays, I start at a small ratio of the baseline, and increase
    # the range geometrically. I add another point at infinity at the end
    k = 1.1
    N = 25
    r = baseline * 1. * (k ** np.arange(N))

    iselected = tuple(i for i in range(Nstored) if widget_table.row_selected(i))

    for w in widgets_image:

        ####### The squares and crosshair for all the stored features
        lines_stored = \
            [ dict(points    = nps.glue( line_segments_squares(context['q01_stored'][:, w.index,:],
                                                               (radius_stored,)),
                                         line_segments_crosshairs(context['q01_stored'][:, w.index,:],
                                                                  crosshair_radius),
                                         axis = -3 ),
                   color_rgb = color_stored ) ]


        ####### The search box
        if context['search_center__q01_indexfrom_showfrom_searchradius'] is not None:
            q01_search_center,indexfrom,showfrom,search_radius = context['search_center__q01_indexfrom_showfrom_searchradius']
            if indexfrom == w.index and showfrom:
                lines_stored.append( dict(points    = line_segments_crosshairs(q01_search_center[w.index],
                                                                               crosshair_radius),
                                          color_rgb = color_searchbox ) )
            elif indexfrom != w.index:
                lines_stored.append( dict(points    = line_segments_squares(q01_search_center[w.index],
                                                                            (search_radius,)),
                                          color_rgb = color_searchbox ) )


        if len(iselected):

            ####### The epipolar CURVE corresponding to THIS point in the OTHER
            ####### camera

            # shape (Nselected, 2)
            q_stored       = context['q01_stored'][iselected,   w.index,:]
            q_stored_other = context['q01_stored'][iselected, 1-w.index,:]

            # shape (Nranges, Nselected, 3)
            p = mrcal.unproject(q_stored_other, *models[1-w.index].intrinsics(),
                                normalize = True) * \
                    nps.mv(r,-1,-3)
            if w.index == 0:
                pinf = mrcal.rotate_point_R(Rt01[:3,:], p[0])
                p    = mrcal.transform_point_Rt(Rt01, p)
            else:
                pinf = mrcal.rotate_point_R(mrcal.invert_R(Rt01[:3,:]), p[0])
                p    = mrcal.transform_point_Rt(mrcal.invert_Rt(Rt01), p)

            # shape (Nsegments, Nsegmentpoint=2, Nxy=2)
            qsegments = np.zeros( (0,2,2), dtype=np.float32)

            # shape (Nranges+1, Nselected, Nxyz=3)
            p = nps.glue(p,pinf, axis=-3)

            for pselected in nps.mv(p,-2,-3):
                # pselected.shape is (Nranges+1,Nxyz=3)

                # cut off points where the ray is behind the other camera
                pselected = pselected[pselected[:,2] > 0,:]

                if len(pselected)>=2:
                    # shape (N, Nxy=2)
                    q = mrcal.project(pselected,
                                      *models[w.index].intrinsics()).astype(np.float32)
                    # shape (N-1, Nsegmentpoint=2, Nxy=2)
                    qsegments = \
                        nps.glue(qsegments,
                                 nps.mv( nps.cat(q[0:-1,...],q[1:,...]),
                                         0,1 ),
                                 axis=-3)

            lines_stored.append( dict(points    = nps.glue( line_segments_squares(q_stored,
                                                                                  (radius_selected,),),
                                                            qsegments,
                                                            axis = -3),
                                      color_rgb = color_selected ) )

        w.set_lines(*lines_stored)


def clear_q_search_center():
    global context
    context['search_center__q01_indexfrom_showfrom_searchradius'] = None



def widget_table_callback(*args):
    ctx   = widget_table.callback_context()
    event = Fl.event()
    if ctx == Fl_Table.CONTEXT_CELL and \
       event == FL_RELEASE:

        # I want the selected feature to be in view. I do that for image0
        # (image1 will be automatically panned). If it's already in view or
        # close to it, then I don't touch it
        row = widget_table.callback_row()
        i_image = 0
        q = context['q01_stored'][row,i_image,:]

        viewport_w = widgets_image[i_image].w()
        viewport_h = widgets_image[i_image].h()
        qviewport = widgets_image[i_image].map_pixel_viewport_from_image(q)

        # The margin is 1/10 or 100 pixels, whichever is smaller. If we're within
        # the margin, we pan
        margin_w = min(viewport_w//10,100)
        margin_h = min(viewport_h//10,100)
        if qviewport[0] < margin_w or viewport_w-1-qviewport[0] < margin_w or \
           qviewport[1] < margin_h or viewport_h-1-qviewport[1] < margin_h:

            widgets_image[i_image].set_panzoom(*q,
                                               visible_width_pixels = np.nan, # leave the same
                                               notify_other_widgets = True)


        clear_q_search_center()
        set_all_overlay_lines_and_redraw()

def widget_solve_callback(*args):
    global context

    try:
        Rt01, mask_inliers = \
            solve_Rt10(context['q01_stored'],
                       *[m.intrinsics() for m in models],
                       solve_R10           = solve_R01__optimize,
                       t01_fixed           = context['Rt01'][3,:])
    except Exception as e:
        widget_info.value(UI_usage_message + "\n" + \
                          f"Solve failed: {e}")
        return

    # Store the new transform. The model0 extrinsics are untouched, but I move
    # model1
    context['Rt01'] = Rt01
    Rtr0 = models[0].Rt_ref_cam()
    Rtr1 = mrcal.compose_Rt(Rtr0, Rt01)
    models[1].Rt_ref_cam(Rtr1)

    update_p0_triangulated_stored()

    widget_table.redraw()
    clear_q_search_center()
    set_all_overlay_lines_and_redraw()

    N = len(context['q01_stored'])
    Noutliers = N-np.count_nonzero(mask_inliers)
    widget_info.value(UI_usage_message + "\n" + \
                      f"Extrinsics of model1 updated: {Noutliers}/{N} outliers\n" + \
                      f"Outlier points: {np.nonzero(~mask_inliers)[0]}")



    # Debugging code to write the models and correspondences to disk after each
    # solve
    if False:
        for icam in (0,1):
            filename = f"/tmp/camera-{icam}.cameramodel"
            models[icam].write(filename)
            print(f"Wrote '{filename}")

        p = context['p0_triangulated_stored']
        r = nps.mag(p)
        idx = nps.norm2(p) > 0
        r[~idx] = np.nan

        filename = '/tmp/correspondences.vnl'
        np.savetxt(filename,
                   nps.glue( nps.clump(context['q01_stored'][:,:4],
                                       n = -2),
                             nps.dummy(r,-1),
                             axis = -1),
                   fmt = '%.2f',
                   header = "x0 y0 x1 y1 range")
        print(f"Wrote '{filename}")




def widget_write_callback(*args):
    p = context['p0_triangulated_stored']
    r = nps.mag(p)
    idx = nps.norm2(p) > 0
    r[~idx] = np.nan

    np.savetxt(sys.stdout,
               nps.glue( nps.clump(context['q01_stored'][:,:4],
                                   n = -2),
                         nps.dummy(r,-1),
                         axis = -1),
               fmt = '%.2f',
               header = "x0 y0 x1 y1 range")

    for m in models: m.write(sys.stdout)


def find_corresponding_feature(q, index,
                               *,
                               qother_estimate = None,
                               search_radius):
    if not (q[0] >= -0.5 and q[0] <= W-0.5 and \
            q[1] >= -0.5 and q[1] <= H-0.5):
        return f"Out of bounds: {q=}", None, None

    if qother_estimate is None:
        # shape (2,2): (leftright, qxy)
        q01_estimate = get_q01_estimate(q     = q,
                                        index = index)
    else:
        if index == 0: q01_estimate = nps.cat(q,              qother_estimate)
        else:          q01_estimate = nps.cat(qother_estimate,q)

    index_other = 1 - index
    try:
        match_feature_out = \
            mrcal.match_feature(images[index], images[index_other],
                                q0               = q01_estimate[index],
                                q1_estimate      = q01_estimate[index_other],
                                search_radius1   = search_radius,
                                template_size1   = args.template_size)
        q_other, match_feature_diagnostics = match_feature_out[:2]
    except:
        q_other = None

    if q_other is None:
        return "Error matching feature", q01_estimate, None

    q01 = np.array(q01_estimate)
    q01[index_other] = q_other

    return "Feature match successful", q01_estimate, q01


class Fl_Gl_Image_Widget_Derived(Fl_Gl_Image_Widget):

    def __init__(self,
                 *args,
                 index,
                 **kwargs):
        self.index = index
        return super().__init__(*args, **kwargs)

    def set_panzoom(self,
                    x_centerpixel, y_centerpixel,
                    visible_width_pixels,
                    notify_other_widgets = True):
        r'''Pan/zoom the image

        This is an override of the function to do this: any request to
        pan/zoom the widget will come here first. I dispatch any
        pan/zoom commands to all the widgets, so that they all work in
        unison. visible_width_pixels < 0 means: this is the redirected
        call; just call the base class

        '''
        if not notify_other_widgets:
            return super().set_panzoom(x_centerpixel, y_centerpixel,
                                       visible_width_pixels)

        # All the widgets should pan/zoom together
        result = \
            all( w.set_panzoom(x_centerpixel, y_centerpixel,
                               visible_width_pixels,
                               notify_other_widgets = False) \
                 for w in widgets_image )

        # Switch back to THIS widget
        self.make_current()
        return result


    def handle_right_mouse_button(self):

        search_center__q01_indexfrom_showfrom_searchradius_copy = context['search_center__q01_indexfrom_showfrom_searchradius']
        clear_q_search_center()


        try:
            q = \
                np.array( self.map_pixel_image_from_viewport( (Fl.event_x(),Fl.event_y()), ),
                          dtype=float )
        except:
            q = None
        if q is None:
            message = "Error converting pixel coordinates"
            widget_info.value(UI_usage_message + "\n" + \
                              message)
            return 1


        if Fl.event_state() & FL_CTRL:
            if search_center__q01_indexfrom_showfrom_searchradius_copy is None:
                # Ctrl-right-click but we didn't just do a search. Do nothing
                return 1

            # We just did a search. It may have succeeded, with the match
            # selected, in the last row of the table. Or it may have failed,
            # with nothing selected
            q01_search_center, indexfrom, showfrom, _ = search_center__q01_indexfrom_showfrom_searchradius_copy
            if self.index  == indexfrom:
                # Ctrl-right-click but we just did a search, from the same image. Do nothing
                return 1

            # We can rerun the search with human seeding. If the search we just
            # did succeeded, then I need to clear out that feature; I'm about to
            # redo it
            iselected = tuple(i for i in range(widget_table.rows()) if widget_table.row_selected(i))
            if len(iselected) == 0:
                # Nothing selected; the previous search failed, and we don't
                # need to clear it out
                pass
            elif len(iselected) > 1:
                print("WARNING: We're trying to rerun a search, but more than one feature is selected in the UI. This shouldn't be able to happen, and it's a bug. Doing nothing")
                return 1

            else:
                # Have one feature. Delete it
                mask = np.ones((widget_table.rows(),), dtype=bool)
                mask[iselected] = 0
                update_q01_stored(context['q01_stored'][mask, ...])
                widget_table.rows(len(context['q01_stored']))
                widget_table.select_all_rows(0) # deselect all


            # Set up the new search. I want to search FROM the other widget: I'm
            # repeating the previous search
            qother_estimate = q
            q               = q01_search_center[indexfrom]
            search_radius   = 5 # tighter bounds

        else:
            # normal path
            indexfrom       = self.index
            qother_estimate = None
            search_radius   = args.search_radius


        message,q01_estimate,q01 = find_corresponding_feature(q, indexfrom,
                                                              qother_estimate = qother_estimate,
                                                              search_radius   = search_radius)

        widget_info.value(UI_usage_message + "\n" + \
                          message)

        if q01 is not None:
            update_q01_stored(nps.glue( context['q01_stored'],
                                        q01,
                                        axis=-3))

            Nstored = len(context['q01_stored'])
            widget_table.rows( Nstored )
            widget_table.select_all_rows(0) # deselect all
            widget_table.select_row(Nstored-1)
        else:
            # match failed
            widget_table.select_all_rows(0) # deselect all

        if q01_estimate is not None:
            context['search_center__q01_indexfrom_showfrom_searchradius'] = (q01_estimate, indexfrom, q01 is None, search_radius)

        return 1

    def handle(self, event):
        if event == FL_ENTER:
            return 1
        if event == FL_LEAVE:
            widget_status.value("")
            return 1
        if event == FL_MOVE:
            try:
                q = self.map_pixel_image_from_viewport( (Fl.event_x(),Fl.event_y()), )
                this = f"Image {self.index}"
                widget_status.value(f"{this}: {q[0]:.2f},{q[1]:.2f}")
            except:
                widget_status.value("")
            return 1

        if event == FL_PUSH and Fl.event_button() == FL_LEFT_MOUSE:
            self.dragged = False
            return super().handle(event)

        if event == FL_DRAG and Fl.event_state() & FL_BUTTON1:
            self.dragged = True
            return super().handle(event)

        if event == FL_RELEASE:
            if Fl.event_button() == FL_RIGHT_MOUSE:
                result = self.handle_right_mouse_button()
                set_all_overlay_lines_and_redraw()
                return result

            if Fl.event_button() == FL_LEFT_MOUSE:
                if not self.dragged:
                    clear_q_search_center()

                    # Clicked. Select the nearest feature
                    if len(context['q01_stored']) > 0:
                        qwidget = np.array((Fl.event_x(),Fl.event_y()),
                                           dtype=int)
                        qwidget_stored = \
                            np.array([self.map_pixel_viewport_from_image(q) for q in context['q01_stored'][:,self.index,:]])
                        # within 10 pixels in the UI
                        i = np.nonzero(nps.norm2(qwidget - qwidget_stored) < 10*10)[0]
                        if len(i) < 1:
                            widget_info.value(UI_usage_message + "\n" + \
                                              "No feature near click")
                            widget_table.select_all_rows(0) # deselect all
                            set_all_overlay_lines_and_redraw()
                        elif len(i) > 1:
                            widget_info.value(UI_usage_message + "\n" + \
                                              "Too many features near click")
                            widget_table.select_all_rows(0) # deselect all
                            set_all_overlay_lines_and_redraw()
                        else:
                            # Just one feature. Select it
                            i = i[0]
                            widget_table.select_all_rows(0) # deselect all
                            widget_table.select_row(int(i))
                            set_all_overlay_lines_and_redraw()
            return super().handle(event)


        # Must be key UP, not key DOWN. Because of https://github.com/fltk/fltk/issues/1044
        if event == FL_KEYUP:
            if Fl.event_key() == fltk.FL_Delete:
                i_keep = tuple(i for i in range(widget_table.rows()) \
                               if not widget_table.row_selected(i))
                update_q01_stored(context['q01_stored'][i_keep, ...])
                widget_table.rows(len(context['q01_stored']))
                widget_table.select_all_rows(0) # deselect all
                clear_q_search_center()
                set_all_overlay_lines_and_redraw()
                widget_info.value(UI_usage_message + "\n" + \
                                  "Feature(s) deleted")

                return 1

        return super().handle(event)


class Fl_Table_Derived(Fl_Table_Row):

    def __init__(self, x, y, w, h, *args):
        Fl_Table_Row.__init__(self, x, y, w, h, *args)

        self.col_labels = \
            [ "x0",
              "y0",
              "x1",
              "y1",
              "Triangulated range",
              "Cam0 triangulated reprojection error", ]
        len_col_labels = [len(x) for x in self.col_labels]
        min_len_col_labels = min(len_col_labels)
        max_ratio_len_col_labels = 3 # limit max ratio
        normalized_len_col_labels = [min(x/min_len_col_labels, max_ratio_len_col_labels) \
                                     for x in len_col_labels]
        sum_normalized_len_col_labels = sum(normalized_len_col_labels)
        self.ratio_col_width = [w/sum_normalized_len_col_labels for w in normalized_len_col_labels]


        self.type(fltk.Fl_Table_Row.SELECT_MULTI)
        self.rows(len(context['q01_stored']))
        self.cols(len(self.col_labels))
        self.col_header(1)
        self.col_resize(0)

        self.when(FL_WHEN_RELEASE)
        self.callback(widget_table_callback)

        self.end()

    def draw_cell(self, context_table, row, col, x, y, w, h):

        if context_table == self.CONTEXT_STARTPAGE:
            fl_font(FL_HELVETICA, 12)
            return

        if context_table == self.CONTEXT_COL_HEADER:
            text = self.col_labels[col]

            fl_push_clip(x, y, w, h)
            fl_draw_box(FL_THIN_UP_BOX, x, y, w, h, self.row_header_color())
            fl_color(FL_BLACK)
            fl_draw(text, x, y, w, h, FL_ALIGN_CENTER)
            fl_pop_clip()
            return

        if context_table == self.CONTEXT_CELL:
            if col < 4:
                iimage = col // 2
                ixy    = col %  2
                text = f"{context['q01_stored'][row,iimage,ixy]:.2f}"
            else:
                p = context['p0_triangulated_stored'][row]
                if nps.norm2(p) == 0:
                    # Divergent feature. No triangulated anything available
                    text = '-'
                else:
                    if col == 4:
                        # range
                        text = f"{nps.mag(p):.2f}"
                    else:
                        # reprojection error
                        icam = 0
                        q    = mrcal.project(p, *models[icam].intrinsics())
                        text = f"{nps.mag(context['q01_stored'][row,icam] - q):.1f}"

            fl_push_clip(x, y, w, h)
            # background color
            fl_color(self.selection_color() if self.row_selected(row) else FL_WHITE)
            fl_rectf(x, y, w, h)

            # text
            fl_color(FL_BLACK)
            fl_draw(text, x, y, w, h, FL_ALIGN_CENTER)

            # border
            fl_color(FL_LIGHT2)
            fl_rect(x, y, w, h)
            fl_pop_clip()

            return

        return

    def resize(self, x,y,w,h):
        Fl_Table_Row.resize(self, x,y,w,h)
        Ncols = self.cols()
        width = self.w()
        width_margin = 2 # for some reason, I need to cut this many pixels to
                         # avoid creating a scrollbar
        x0 = 0
        for icol in range(Ncols-1):
            w_here = int(self.ratio_col_width[icol] * width)
            self.col_width(icol, w_here)
            x0 += w_here
        # last col
        self.col_width(Ncols-1, width-x0 - width_margin)



models = [mrcal.cameramodel(m) for m in args.models]


if args.rt01 is not None:
    try:
        rt01 = np.array([float(d) for d in args.rt01.split(',')])
    except:
        print("--rt01 must be a comma-separated list of 6 numbers; couldn't parse as numbers into an array", file=sys.stderr)
        sys.exit(1)
    if rt01.shape != (6,):
        print("--rt01 must be a comma-separated list of 6 numbers; incorrect number of values given", file=sys.stderr)
        sys.exit(1)

    models[1].rt_ref_cam( mrcal.compose_rt(models[0].rt_ref_cam(),
                                           rt01) )

# To make it possible to concisely print out the models
for m in models:
    m.optimization_inputs_reset()

W,H = models[0].imagersize()

if args.equalization == 'clahe':
    clahe = cv2.createCLAHE()
    clahe.setClipLimit(8)
images = [mrcal.load_image(f) for f in args.images]
for i in range(2):
    if images[i].ndim == 3:
        images[i] = np.mean(images[i], axis=-1, dtype=images[i].dtype)
if args.equalization == 'clahe':
    images = [ clahe.apply(image) for image in images ]
elif args.equalization == 'fieldscale':
    images = [ mrcam.equalize_fieldscale(image) for image in images ]
elif args.equalization == 'stretch':
    images = [ ((image - np.min(image))/(np.max(image)-np.min(image))*255.).astype(np.uint8) for image in images ]

if args.valid_intrinsics_region:
    for i in range(2):
        mrcal.annotate_image__valid_intrinsics_region(images[i], models[i])

# To make it possible to concisely print out the models
for m in models:
    m.valid_intrinsics_region( (), )

UI_usage_message = r'''Usage:

Left mouse button click/drag: pan
Mouse wheel up/down/left/right: pan
Ctrl-mouse wheel up/down: zoom
'u': reset view: zoom out, pan to the center

Right-click: Find matching feature

Ctrl-right-click immediately after a right-click in the other widget: re-run
feature search starting with the new click position as a seed, and with a much
smaller search radius

Click in table: select feature(s)

Delete: delete selected feature(s)
'''



## these two arrays correspond to each other. If you update one, do update the
## other
context = dict( # shape (N, Nimages = 2,xy=2)
                q01_stored = np.zeros((0,2,2), dtype=float),

                # If non-None, describes the feature search that just occurred
                search_center__q01_indexfrom_showfrom_searchradius = None,

                # shape (N, xy=2)
                p0_triangulated_stored = np.zeros((0,2), dtype=float),

                Rt01 = mrcal.compose_Rt( models[0].Rt_cam_ref(),
                                         models[1].Rt_ref_cam() ) )

if args.initial_correspondences is not None:
    try:
        context['q01_stored'] = vnlog.slurp(args.initial_correspondences,
                                            dtype = np.dtype([('x0 y0 x1 y1', float, (4,))]))
    except:
        print(f"--initial-correspondences expects a text table with columns 'x0 y0 x1 y1'; couldn't load text table from '{args.initial_correspondences}'", file=sys.stderr)
        sys.exit(1)

    N = len(context['q01_stored'])

    update_q01_stored(context['q01_stored']['x0 y0 x1 y1'].reshape((N,2,2),))





WINDOW_W = 800
IMAGES_H = 300
TABLE_H  = 300
BUTTON_H = 30
STATUS_H = 20

WINDOW_H = IMAGES_H + TABLE_H + BUTTON_H + STATUS_H

window = Fl_Window(WINDOW_W, WINDOW_H, "mrcal feature picker")
body   = Fl_Group(0,0,
                  WINDOW_W, IMAGES_H + TABLE_H)

y = 0
widgets_image = (Fl_Gl_Image_Widget_Derived(0,    y,
                                            WINDOW_W//2,IMAGES_H,
                                            index = 0),
                 Fl_Gl_Image_Widget_Derived(WINDOW_W//2,  y,
                                            WINDOW_W//2,IMAGES_H,
                                            index = 1))
y += IMAGES_H


widget_table  = Fl_Table_Derived(   0,  y,
                                    WINDOW_W//2,TABLE_H)
widget_info   = Fl_Multiline_Output(WINDOW_W//2,y,
                                    WINDOW_W//2,TABLE_H)
y += TABLE_H
body.end()

Nbuttons = 2
BUTTON_W = WINDOW_W//Nbuttons
ibutton = 0

widget_solve_button = Fl_Button(ibutton*BUTTON_W, y,
                                BUTTON_W if ibutton < Nbuttons-1 else WINDOW_W-ibutton*BUTTON_W,
                                BUTTON_H, "Solve for the extrinsics")
widget_solve_button.callback(widget_solve_callback)
ibutton+=1

widget_write_button = Fl_Button(ibutton*BUTTON_W, y,
                                BUTTON_W if ibutton < Nbuttons-1 else WINDOW_W-ibutton*BUTTON_W,
                                BUTTON_H, "Write output to stdout")
widget_write_button.callback(widget_write_callback)
ibutton+=1

y += BUTTON_H

widget_status = Fl_Output(0,y,
                          WINDOW_W,STATUS_H)
widget_info.value(UI_usage_message)
y += STATUS_H


window.resizable(body)
window.end()
window.show()

for i in range(2):
    widgets_image[i]. \
      update_image(decimation_level = 0,
                   image_data       = images[i])

if args.initial_features is not None:

    try:
        q0 = vnlog.slurp(args.initial_features,
                          dtype = np.dtype([('x0 y0', float, (2,))]))
    except:
        print(f"--initial-features expects a text table with columns 'x0 y0'; couldn't load text table from '{args.initial_features}'", file=sys.stderr)
        sys.exit(1)

    q0 = q0['x0 y0']
    N = len(q0)
    mask_good = np.zeros((N,), dtype=bool)
    q01 = np.zeros( (N,2,2), dtype=float)
    for i in range(N):
        message,q01_estimate,_q01 = \
            find_corresponding_feature(q0[i,:], 0,
                                       search_radius = args.search_radius)
        if _q01 is not None:
            q01[i] = _q01
            mask_good[i] = True

    update_q01_stored(q01[mask_good])


if len(context['q01_stored']):
    # Drawing lines at startup requires gl-image-display 0.19, so I only do it
    # if I need to. Most of the time we will start out with no features, so
    # older gl-image-display would be fine
    set_all_overlay_lines_and_redraw()

Fl.run()

sys.exit(0)





r'''
todo notes:

'u' should pan to the center

ctrl-drag should lock panning

arrow-keys in the table should work

better UI messages


solve diagnostics. residuals. UI for ranged fiducial/at-infinity
'''
