a
    hK                  
   @   s0  d dl mZ d dlmZmZmZ d dlZd dlZd dlm	Z	 d dl
mZ ddlmZ ddlmZ dd	lmZmZ dd
lmZ ddlmZmZmZ ddlmZ ddlmZ ddlmZ ddlm Z  eeddG dd deZ!G dd de	j"Z#dd Z$d;ddZ%ej&e'ej&dddZ(d<e	j"ej&ej&ej&eej& e)e)ee d d!d"Z*G d#d$ d$e	j"Z+G d%d& d&e	j"Z,G d'd( d(e	j"Z-ej&ej&ej&ej&d)d*d+Z.G d,d- d-e	j"Z/G d.d/ d/e	j"Z0eG d0d1 d1eZ1ej&e)e2ej&ej&f d2d3d4Z3ej&e'e'ej&d5d6d7Z4ed8dG d9d: d:e1Z5d1d:gZ6dS )=    )	dataclass)CallableOptionalUnionN)nnpad_sequence   )ACT2FN)FlashAttentionKwargs)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)ModelOutputTransformersKwargsauto_docstring)deprecate_kwarg)can_return_tuple   )AutoModelForKeypointDetection   )LightGlueConfiga  
    Base class for outputs of LightGlue keypoint matching models. Due to the nature of keypoint detection and matching,
    the number of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the
    batch of images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask
    tensor is used to indicate which values in the keypoints, matches, matching_scores and prune tensors are keypoint
    matching information.
    )Zcustom_introc                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeej ed< dZeej ed< dZeej ed< dZeeej  ed	< dZeeej  ed
< dS )LightGlueKeypointMatchingOutputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
        Loss computed during training.
    matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
        Index of keypoint matched in the other image.
    matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
        Scores of predicted matches.
    keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
        Absolute (x, y) coordinates of predicted keypoints in a given image.
    prune (`torch.IntTensor` of shape `(batch_size, num_keypoints)`):
        Pruning mask indicating which keypoints are removed and at which layer.
    mask (`torch.BoolTensor` of shape `(batch_size, num_keypoints)`):
        Mask indicating which values in matches, matching_scores, keypoints and prune are keypoint matching
        information.
    hidden_states (`Tuple[torch.FloatTensor, ...]`, *optional*):
        Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels,
        num_keypoints)` returned when `output_hidden_states=True` is passed or when
        `config.output_hidden_states=True`
    attentions (`Tuple[torch.FloatTensor, ...]`, *optional*):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints,
        num_keypoints)` returned when `output_attentions=True` is passed or when
        `config.output_attentions=True`
    Nlossmatchesmatching_scores	keypointsprunemaskhidden_states
attentions)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   r   r   r   Z	IntTensorr   r   tupler     r)   r)   l/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/lightglue/modeling_lightglue.pyr   '   s   
r   c                       sV   e Zd Zed fddZdejee e	e
ej e
ejejf f dddZ  ZS )	LightGluePositionalEncoderconfigc                    s,   t    tjd|j|j d dd| _d S )Nr   FZbias)super__init__r   Lineardescriptor_dimnum_attention_heads	projectorselfr-   	__class__r)   r*   r0   U   s    
z#LightGluePositionalEncoder.__init__F)r   output_hidden_statesreturnc                 C   sJ   |  |}|jddd}t|}t|}||f}|r@||fn|f}|S )Nr   dim)r4   repeat_interleaver%   cossin)r6   r   r9   Zprojected_keypointsZ
embeddingsZcosinesZsinesoutputr)   r)   r*   forwardY   s    


z"LightGluePositionalEncoder.forward)F)r!   r"   r#   r   r0   r%   Tensorr   boolr   r(   rB   __classcell__r)   r)   r7   r*   r+   T   s    
r+   c                 C   sB   | dd d df }| ddd df }t j| |gddd}|S )N.r   r   r;   r<   )r%   stackflatten)xx1Zx2Zrot_xr)   r)   r*   rotate_halfe   s    rK   c           	      C   sj   | j }|  } | }||}||}| | t| |  }|| t||  }|j|d|j|dfS )a  Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    dtype)rM   float	unsqueezerK   to)	qkr?   r@   Zposition_idsZunsqueeze_dimrM   Zq_embedZk_embedr)   r)   r*   apply_rotary_pos_embm   s    

rS   )r   n_repr:   c                 C   s^   | j \}}}}|dkr| S | dddddddddf |||||} | ||| ||S )z
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    r   N)shapeexpandreshape)r   rT   batchnum_key_value_headsslenhead_dimr)   r)   r*   	repeat_kv   s
    0r\           )modulequerykeyvalueattention_maskscalingdropoutkwargsc                 K   s   t || j}t || j}	t||dd| }
|d urf|d d d d d d d |jd f }|
| }
tjj|
dtj	d
|j}
tjj|
|| jd}
t|
|	}|dd }||
fS )Nr   r	   rF   r;   )r=   rM   )ptrainingr   )r\   num_key_value_groupsr%   matmul	transposerU   r   
functionalZsoftmaxZfloat32rP   rM   rd   rg   
contiguous)r^   r_   r`   ra   rb   rc   rd   re   
key_statesvalue_statesattn_weightsZcausal_maskattn_outputr)   r)   r*   eager_attention_forward   s    
&rq   c                       s   e Zd ZdZeed fddZedddddej	e
eej	ej	f  e
ej	 e
ej	 e
ej	 ee eej	e
ej	 f d
ddZ  ZS )LightGlueAttentionz=Multi-headed attention from 'Attention Is All You Need' paperr-   	layer_idxc                    s   t    || _|| _t|d|j|j | _|j|j | _	| jd | _
|j| _d| _tj|j|j| j |jd| _tj|j|j| j |jd| _tj|j|j| j |jd| _tj|j| j |j|jd| _d S )Nr[   g      Tr.   )r/   r0   r-   rt   getattrhidden_sizer3   r[   rY   rh   rc   attention_dropoutZ	is_causalr   r1   Zattention_biasq_projk_projv_projo_projr6   r-   rt   r7   r)   r*   r0      s(    
zLightGlueAttention.__init__Zpast_key_valueZpast_key_valuesz4.58)new_nameversionN)r   position_embeddingsrb   encoder_hidden_statesencoder_attention_maskre   r:   c                 K   s"  |j d d }g |d| jR }| ||dd}	|d u}
|
rJ|n|}|
rV|n|}| ||dd}| ||dd}|d ur|\}}t|	|||\}	}t}| j	j
dkrt| j	j
 }|| |	|||f| jsdn| j| jd|\}}|jg |dR   }| |}||fS )Nr;   r   r   eagerr]   )rd   rc   )rU   r[   rx   viewrj   ry   rz   rS   rq   r-   Z_attn_implementationr   rg   rw   rc   rW   rl   r{   )r6   r   r   rb   r   r   re   input_shapeZhidden_shapeZquery_statesZis_cross_attentionZcurrent_statesZcurrent_attention_maskrm   rn   r?   r@   Zattention_interfacerp   ro   r)   r)   r*   rB      s:    


zLightGlueAttention.forward)NNNN)r!   r"   r#   r$   r   intr0   r   r%   rC   r   r(   r   r   rB   rE   r)   r)   r7   r*   rr      s        rr   c                       s6   e Zd Zed fddZejejdddZ  ZS )LightGlueMLPr,   c                    sV   t    || _t|j | _t|j|j| _	t|j|j
| _tj|jdd| _d S )NT)Zelementwise_affine)r/   r0   r-   r
   Z
hidden_actactivation_fnr   r1   Zintermediate_sizefc1rv   fc2Z	LayerNorm
layer_normr5   r7   r)   r*   r0      s    
zLightGlueMLP.__init__)r   r:   c                 C   s,   |  |}| |}| |}| |}|S N)r   r   r   r   )r6   r   r)   r)   r*   rB     s
    



zLightGlueMLP.forward	r!   r"   r#   r   r0   r%   rC   rB   rE   r)   r)   r7   r*   r      s   r   c                       sl   e Zd Zeed fddZdejejejee	 ee	 e
ejee
ej  ee
ej  f dddZ  ZS )	LightGlueTransformerLayerrs   c                    s:   t    t||| _t|| _t||| _t|| _d S r   )r/   r0   rr   self_attentionr   self_mlpcross_attention	cross_mlpr|   r7   r)   r*   r0     s
    

z"LightGlueTransformerLayer.__init__F)descriptorsr   rb   r9   output_attentionsr:   c                 C   s\  |rdnd }|rdnd }|r&||f }|j \}}	}
| j||||d\}}tj||gdd}| |}|| }|rx||f}|dd|	|
d||	|
}|d ur|dddd|	d|dd|	nd }| j||||d\}}tj||gdd}| |}|| }|r<||f}||||	|
f | |||	|
f | }|rR||f |f }|||fS )Nr)   )r   rb   r   r;   r<   r   r   )r   r   r   )	rU   r   r%   catr   rW   flipr   r   )r6   r   r   rb   r9   r   all_hidden_statesall_attentions
batch_sizenum_keypointsr2   Zattention_outputZself_attentionsZintermediate_statesZoutput_statesZself_attention_descriptorsZself_attention_hidden_statesr   r   Zcross_attention_outputZcross_attentionsZcross_intermediate_statesZcross_output_statesZcross_attention_hidden_statesr)   r)   r*   rB     sd    


	&

z!LightGlueTransformerLayer.forward)FF)r!   r"   r#   r   r   r0   r%   rC   r   rD   r(   rB   rE   r)   r)   r7   r*   r   
  s     "r   )
similaritymatchability0matchability1r:   c           
      C   s   | j \}}}tj|tj|dd }tj| d}tj| dd ddd}| ||d |d fd}	|| | |	ddd|d|f< tj|d |	dddddf< tj|d |	dddddf< |	S )z;create the log assignment matrix from logits and similarityr   r   r;   rF   r   N)	rU   r   rk   Z
logsigmoidrj   Zlog_softmaxrl   new_fullsqueeze)
r   r   r   r   Znum_keypoints_0Znum_keypoints_1ZcertaintiesZscores0Zscores1scoresr)   r)   r*   sigmoid_log_double_softmax]  s     ""&&r   c                       sN   e Zd Zed fddZejejejdddZejejddd	Z  Z	S )
LightGlueMatchAssignmentLayerr,   c                    s@   t    |j| _tj| j| jdd| _tj| jddd| _d S )NTr.   r   )r/   r0   r2   r   r1   final_projectionmatchabilityr5   r7   r)   r*   r0   m  s    
z&LightGlueMatchAssignmentLayer.__init__)r   r   r:   c                 C   s2  |j \}}}| |}|tj| j|jdd  }||d d||}|d d df }|d d df }||dd }	|d ur||d d|}|d d df d}
|d d df ddd}|
| }|		|dkt
|	jj}	| |}||d d|d}|d d df }|d d df }t|	||}|S )Ndeviceg      ?r   r   r   r;   rF   )rU   r   r%   tensorr2   r   rW   rj   rO   masked_fillZfinforM   minr   r   )r6   r   r   r   r   r2   Zm_descriptorsZm_descriptors0Zm_descriptors1r   Zmask0Zmask1r   Zmatchability_0Zmatchability_1r   r)   r)   r*   rB   t  s&    

z%LightGlueMatchAssignmentLayer.forwardr   r:   c                 C   s    |  |}tj|d}|S )z0Get matchability of descriptors as a probabilityr;   )r   r   rk   sigmoidr   )r6   r   r   r)   r)   r*   get_matchability  s    
z.LightGlueMatchAssignmentLayer.get_matchability)
r!   r"   r#   r   r0   r%   rC   rB   r   rE   r)   r)   r7   r*   r   l  s   r   c                       s6   e Zd Zed fddZejejdddZ  ZS )LightGlueTokenConfidenceLayerr,   c                    s   t    t|jd| _d S )Nr   )r/   r0   r   r1   r2   tokenr5   r7   r)   r*   r0     s    
z&LightGlueTokenConfidenceLayer.__init__r   c                 C   s$   |  | }tj|d}|S )Nr;   )r   detachr   rk   r   r   )r6   r   r   r)   r)   r*   rB     s    z%LightGlueTokenConfidenceLayer.forwardr   r)   r)   r7   r*   r     s   r   c                   @   s.   e Zd ZU dZeed< dZdZdZdZ	dZ
dS )LightGluePreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    r-   Z	lightgluepixel_valuesFTN)r!   r"   r#   r$   r   r'   Zbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingZ_supports_flash_attnZ_supports_sdpar)   r)   r)   r*   r     s   
r   )r   	thresholdr:   c                 C   sh  | j \}}}| ddddddf d}| ddddddf d}|j}|j}tj|j d |jdd }tj|j d |jdd }	||d|k}
|	|d|k}|j }|	d}t
|
||}t
||d||}|
||k@ }||d|@ }t
||d}t
||d}t||gdd|d d}t||gdd|d d}||fS )z1obtain matches from a score matrix [Bx M+1 x N+1]Nr;   r   r   r   r   )rU   maxindicesr%   aranger   gathervaluesexpZ
new_tensorwhererG   rj   rW   )r   r   r   _Zmax0Zmax1matches0matches1indices0indices1Zmutual0Zmutual1zeromatching_scores0matching_scores1Zvalid0Zvalid1r   r   r)   r)   r*   get_matches_from_scores  s(      

""r   )r   heightwidthr:   c                 C   sV   t j||g| j| jdd }|d }|djd }| |ddddf  |d  } | S )a  
    Normalize keypoints locations based on image image_shape

    Args:
        keypoints (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`):
            Keypoints locations in (x, y) format.
        height (`int`):
            Image height.
        width (`int`):
            Image width.

    Returns:
        Normalized keypoints locations of shape (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`).
    r   rM   Nr   r;   .).NN)r%   r   r   rM   r   r   )r   r   r   sizeshiftscaler)   r)   r*   normalize_keypoints  s
    r   zV
    LightGlue model taking images as inputs and outputting the matching of them.
    c                       s|  e Zd ZdZed fddZeedddZd#e	j
e	j
ee ee	j
ee	j
e	j
f f d	d
dZe	j
ee	j
e	j
e	j
dddZd$ddZe	j
e	j
ee	j
dddZe	j
e	j
e	j
e	j
e	j
e	j
edddZdd Ze	j
e	j
e	j
e	j
ee	j
e	j
f dddZd%e	j
e	j
eee	j
ee ee ee	j
e	j
e	j
eef dddZeed&e	jee	j ee ee eeef d d!d"Z  ZS )'LightGlueForKeypointMatchingar  
    LightGlue is a model matching keypoints in images by leveraging detections from a keypoint detector such as
    SuperPoint. It is based on the SuperGlue architecture and is designed to be lightweight and efficient.
    It consists of :
        1. Keypoint Encoder
        2. A Graph Neural Network with self and cross attention layers
        3. Matching Assignment layers

    The correspondence ids use -1 to indicate non-matching points.

    Philipp Lindenberger, Paul-Edouard Sarlin and Marc Pollefeys. LightGlue: Local Feature Matching at Light Speed.
    In ICCV 2023. https://huggingface.co/papers/2306.13643
    r,   c                    s   t    tj j jd| _ jj| _ j	| _	 j
| _ j| _ j| _ j| _| j	| jkrvtj| j| j	dd| _n
t | _t | _t fddt j
D | _t fddt j
D | _t fddt j
d D | _|   d S )	N)trust_remote_codeTr.   c                    s   g | ]}t  |d qS ))rt   )r   ).0ir,   r)   r*   
<listcomp>      z9LightGlueForKeypointMatching.__init__.<locals>.<listcomp>c                    s   g | ]}t  qS r)   )r   r   r   r,   r)   r*   r     r   c                    s   g | ]}t  qS r)   )r   r   r,   r)   r*   r     r   r   )r/   r0   r   from_configZkeypoint_detector_configr   keypoint_detectorZdescriptor_decoder_dim keypoint_detector_descriptor_dimr2   Znum_hidden_layers
num_layersfilter_thresholddepth_confidencewidth_confidencer   r1   input_projectionZIdentityr+   positional_encoderZ
ModuleListrangetransformer_layersmatch_assignment_layerstoken_confidenceZ	post_initr5   r7   r,   r*   r0     s0    


z%LightGlueForKeypointMatching.__init__)layer_indexr:   c                 C   s*   ddt d| | j   }t |ddS )z-scaled confidence threshold for a given layerg?g?g      r   r   )npr   r   Zclip)r6   r   r   r)   r)   r*   _get_confidence_threshold  s    z6LightGlueForKeypointMatching._get_confidence_thresholdF)r   r   r9   r:   c                 C   s,   |   }| |}| j||d}||fS )Nr9   )r   rl   r   r   )r6   r   r   r9   Zprojected_descriptorskeypoint_encoding_outputr)   r)   r*   _keypoint_processing  s    
z1LightGlueForKeypointMatching._keypoint_processing)keypoint_confidencesr   r   
num_pointsr:   c           
      C   s~   |j \}}|| jd k rj||dkd}||d d}| |}d||k  jdd|  }|| jk}	ntj	|tj
d}	|	S )zRevaluate whether we should stop inference based on the confidence of the keypointsr   r   r   r;   g      ?r<   rL   )rU   r   r   rW   r   rN   sumr   r%   onesrD   )
r6   r   r   r   r   r   r   r   Zratio_confidentearly_stopped_pairsr)   r)   r*   _get_early_stopped_image_pairs'  s    

z;LightGlueForKeypointMatching._get_early_stopped_image_pairsNc                 C   s@   |d ur|| }|| }| j | ||}t|| j\}}||fS r   )r   r   r   )r6   r   r   r   early_stopsr   r   r   r)   r)   r*   _get_keypoint_matching:  s    z3LightGlueForKeypointMatching._get_keypoint_matching)confidencesr   r   r:   c                 C   s,   |d| j  k}|dur(||| |kO }|S )z#mask points which should be removedr   N)r   r   )r6   r   r   r   Zkeepr)   r)   r*   _get_pruning_maskB  s    z.LightGlueForKeypointMatching._get_pruning_mask)r   r   r   r   prune_outputr   r   c                    s   |j \}}	}	| j| |}
| ||
|  |dktd  fdd||d |d  |fD \}}}}}t|D ]}|||| f  d7  < qvdd ||||fD \}}}}||f}t|ddd	}|||||fS )
z
        For a given layer, prune keypoints based on the confidence of the keypoints and the matchability of the
        descriptors.
        r   Fc                 3   s"   | ]}d d t | D V  qdS )c                 S   s   g | ]\}}|| qS r)   r)   )r   tr   r)   r)   r*   r   ^  r   zULightGlueForKeypointMatching._do_layer_keypoint_pruning.<locals>.<genexpr>.<listcomp>N)zipr   r   Zpruned_keypoints_maskr)   r*   	<genexpr>]  s   zJLightGlueForKeypointMatching._do_layer_keypoint_pruning.<locals>.<genexpr>r   c                 s   s   | ]}t |d dV  qdS )T)batch_firstNr   )r   Zpruned_tensorr)   r)   r*   r   e  s   Tr;   r   Zpadding_value)	rU   r   r   r   r   r%   r   r   r   )r6   r   r   r   r   r   r   r   r   r   Zdescriptors_matchabilityZpruned_descriptorsZpruned_keypoints_0Zpruned_keypoints_1Zpruned_maskZpruned_indicesr   Zpruned_keypointsr)   r   r*   _do_layer_keypoint_pruningI  s    

z7LightGlueForKeypointMatching._do_layer_keypoint_pruningc                    sd   t   dd ||fD \}}dd ||fD \}} fdd||||fD \}}}}||||fS )Nc                 s   s   | ]}t |d ddV  qdS )Tr;   r   Nr   r   r)   r)   r*   r   w  s   zMLightGlueForKeypointMatching._concat_early_stopped_outputs.<locals>.<genexpr>c                 s   s   | ]}t |d ddV  qdS )Tr   r   Nr   r   r)   r)   r*   r   {  s   c                 3   s   | ]}|  V  qd S r   r)   r   early_stops_indicesr)   r*   r     s   )r%   rG   )r6   r   final_pruned_keypoints_indices!final_pruned_keypoints_iterationsr   r   r)   r   r*   _concat_early_stopped_outputsn  s    



	z:LightGlueForKeypointMatching._concat_early_stopped_outputs)r   r   r   r   r:   c                    s  |j \ } fdd|||fD \}}}|d d df }|d d df }|d d df }|d d df }	|d d df }
|d d df }tj d d|fd|j|jd}tj d d|f|j|jd}t d D ]}t|| dkd|| d|| j	dd||d|| f< t|	| dkd|| d|	| j	dd||d|| f< |
| ||d|| f< || ||d|| f< q||fS )	Nc                 3   s    | ]}|  d  d dV  qdS )r   r;   N)rW   r   r   r)   r*   r     s   zJLightGlueForKeypointMatching._do_final_keypoint_pruning.<locals>.<genexpr>r   r   r   r;   r   )r   )
rU   r%   fullr   rM   Zzerosr   r   r   clamp)r6   r   r   r   r   r   r   r   r   r   r   r   Z_matchesZ_matching_scoresr   r)   r   r*   _do_final_keypoint_pruning  s0    	

 &&z7LightGlueForKeypointMatching._do_final_keypoint_pruning)r   r   r   r   r   r   r9   r:   c           (   
      s  |rdnd }|rdnd }	|j d dkr\|j d d }
|j|
dtjd||
||
||	fS |j}|j \}}}}tj||ddd}||d |d}|d ur||d |nd }||d || j}tj	|d |d}t
|||}| j|||d	\}}|d }| jdk}| jdk}g }g }g }g }g }tj	d||d|d d}t|}t| jD ]}| }|d ur| ||}ntj||d
 f|jd}| j| |||||d}|\}}} |r|| }|r|	|  }	|r$|| jd k r| j| |}!| j|!|||d}"ntj|tjd}"t|"r|"d |  }#| j||| d\}$}%|t|# |t|$ |t|% |r|t|   |t|   ||"  }t fdd||d |d ||fD \}}&}'}}|&|'f}|rt fdd|||!fD \}}}!t|"r$ qN|rZ|  ||||||!|\}}}}}qZ|r|r| !|||||\}}}}| "||||\}}n(| ||| jd \}}t|| j }||d|}|||||	fS )Nr)   r   r   r;   rL   r   r<   r   r   rF   )rb   r9   r   )r   r   c                 3   s   | ]}|   V  qd S r   r)   r   r   r)   r*   r     s   zALightGlueForKeypointMatching._match_image_pair.<locals>.<genexpr>c                 3   s   | ]}|   V  qd S r   r)   r   r   r)   r*   r     s   )#rU   r   r%   r   Z	new_zerosr   r   rW   r   r   r   r   r   r   rV   Z	ones_liker   r   r   Zget_extended_attention_maskr   r   r   r   rD   anyr>   r   extendlistr(   allr   r   r   )(r6   r   r   r   r   r   r   r9   r   r   rU   r   r   r   Zinitial_num_keypointsZnum_points_per_pairZimage_indicesr   Zdo_early_stopZdo_keypoint_pruningr   r   r   r   r   Zpruned_keypoints_indicesZpruned_keypoints_iterationsr   r   Zextended_attention_maskZlayer_outputr   Z	attentionr   r   Zearly_stopped_image_indicesZearly_stopped_matchesZearly_stopped_matching_scoresZkeypoints_0Z
keypoint_1r)   r   r*   _match_image_pair  s    











	
z.LightGlueForKeypointMatching._match_image_pair)r   labelsr   r9   r:   c              
   C   s  d }|d urt d|d ur |n| jj}|d ur4|n| jj}|jdksT|ddkr\t d|j\}}}}	}
||d ||	|
}| |}|d d \}}}}||ddd	|}||dd| j
	|}||dd}| }|d d d d d d df |
 |d d d d d d df< |d d d d d d df |	 |d d d d d d df< | j|||	|
|||d	\}}}}}t||||||||d
S )Nz9LightGlue is not trainable, no labels should be provided.   r   r   zOInput must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)   r;   r   )r   r   r9   )r   r   r   r   r   r   r   r    )
ValueErrorr-   r   r9   ndimr   rU   rW   r   rP   r   cloner  r   )r6   r   r  r   r9   r   r   r   Zchannelsr   r   Zkeypoint_detectionsr   r   r   Zabsolute_keypointsr   r   r   r   r    r)   r)   r*   rB   \  sJ    	
88
z$LightGlueForKeypointMatching.forward)F)N)NNN)NNN)r!   r"   r#   r$   r   r0   r   rN   r   r%   rC   r   rD   r(   r   r   r   r   r   r   r   r  r   r   r&   Z
LongTensorr   r   rB   rE   r)   r)   r7   r*   r     sl     	
	%+    .   
r   )Nr   )r]   )7dataclassesr   typingr   r   r   numpyr   r%   r   Ztorch.nn.utils.rnnr   Zactivationsr
   Zmodeling_flash_attention_utilsr   Zmodeling_utilsr   r   Zprocessing_utilsr   utilsr   r   r   Zutils.deprecationr   Zutils.genericr   Zauto.modeling_autor   Zconfiguration_lightgluer   r   Moduler+   rK   rS   rC   r   r\   rN   rq   rr   r   r   r   r   r   r   r(   r   r   r   __all__r)   r)   r)   r*   <module>   sl   	#
 HT)     ,