a
    hXr                    @   s8  d Z ddlZddlZddlmZ ddlmZ ddlmZmZm	Z	 ddl
Z
ddlZ
ddl
mZ ddlmZ dd	lmZ dd
lmZmZ ddlmZ ddlmZmZ ddlmZmZmZmZ ddlmZm Z m!Z!m"Z"m#Z# e$e%Z&dZ'dZ(dZ)e	e#e!e"f Z*eeddG dd deZ+eeddG dd deZ,eeddG dd deZ-G dd dej.Z/G dd dej.Z0G d d! d!ej.Z1G d"d# d#ej.Z2G d$d% d%ej.Z3G d&d' d'ej.Z4G d(d) d)ej.Z5G d*d+ d+ej.Z6G d,d- d-eZ7G d.d/ d/ej.Z8G d0d1 d1ej.Z9eG d2d3 d3eZ:eG d4d5 d5e:Z;eG d6d7 d7e:Z<eG d8d9 d9e:Z=eG d:d; d;e:Z>G d<d= d=ej.Z?G d>d? d?ej.Z@G d@dA dAej.ZAedBdG dCdD dDe:ZBG dEdF dFej.ZCG dGdH dHej.ZDG dIdJ dJej.ZEG dKdL dLej.ZFedMdG dNdO dOe:ZGg dPZHdS )QzPyTorch FLAVA model.    N)OrderedDict)	dataclass)AnyOptionalUnion)nn   )ACT2FN)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPooling)PreTrainedModel) find_pruneable_heads_and_indicesprune_linear_layer)ModelOutputauto_docstringlogging	torch_int   )FlavaConfigFlavaImageCodebookConfigFlavaImageConfigFlavaMultimodalConfigFlavaTextConfigzfacebook/flava-image-codebookg$(~k@a  
    Output from FlavaModel containing embeddings and outputs from individual encoders.

    Note that `image_embeddings` and `text_embeddigns` returned are similar to pooled output returned from a
    transformer. If you want embeddings for contrastive loss or retrieval use a FLAVA model's `image_projection` and
    `text_projection` layers on `image_embeddings` and `text_embeddings` respectively.
    )Zcustom_introc                   @   s   e Zd ZU dZdZeej ed< dZ	ee
 ed< dZeej ed< dZee
 ed< dZeej ed< dZee
 ed< ee d	d
dZdS )FlavaModelOutputa  
    image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present):
        The image embeddings which are basically the pooled output of [`FlavaImageModel`].
    image_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present):
        The output of the [`FlavaImageModel`].
    text_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` are present):
        The text embeddings which are basically the pooled output of [`FlavaTextModel`].
    text_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids` are present):
        The output of the [`FlavaTextModel`].
    multimodal_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`):
        The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`].
    multimodal_output (`BaseModelOutputWithPooling`, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`):
        The output of the [`FlavaMultimodalModel`].
    Nimage_embeddingsimage_outputtext_embeddingstext_outputmultimodal_embeddingsmultimodal_outputreturnc                    s   t  fdd  D S )Nc                 3   s,   | ]$}|d vr | nt  | V  qdS ))r   r   r    Ngetattrto_tuple.0kself d/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/flava/modeling_flava.py	<genexpr>V   s   z,FlavaModelOutput.to_tuple.<locals>.<genexpr>tuplekeysr)   r+   r)   r,   r%   U   s    zFlavaModelOutput.to_tuple)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   r   r   r   r   r    r/   r   r%   r+   r+   r+   r,   r   4   s   
r   z@
    Class representing pretraining losses from FLAVA model
    c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeej ed< dZeej ed< dZeej ed< ed	d
dZdS )FlavaLossesa  
    mim (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels` and `pixel_values` are present, `input_ids_masked` is absent and `mim_weight` > 0.):
        Masked Image Modeling loss as used in BeIT calculated only for unimodal image data.
    mlm (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels` and `input_ids_masked` are present, `pixel_values` is absent and `mlm_weight` > 0.):
        Masked Language Modeling loss as used in BERT calculated only for unimodal text data.
    itm (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `itm_labels`, `input_ids_masked`, `pixel_values` are present and `itm_weight` > 0.):
        Image Text Matching (ITM) loss calculated for paired image-text data. Note that ITM loss is calculated on
        masked pairs in FLAVA.
    global_contrastive (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `input_ids` and `pixel_values` are present and `global_contrastive_weight` > 0.):
        Contrastive loss for image-text similarity similar to CLIP but calculated globally for paired image-text
        data. This is calculated on unmasked images and texts.
    mmm_image (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_image_weight` > 0.):
        Masked Multimodal Modeling loss's image component calculated on paired image-text data.
    mmm_text (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_text_weight` > 0.):
        Masked Multimodal Modeling loss's text component calculated on paired image-text data.
    Nmimmlmitmglobal_contrastive	mmm_imagemmm_textr!   c                 C   s&   d}|   D ]}|d urd} q"q|S )NTF)values)r*   all_nonevr+   r+   r,   r@   {   s    zFlavaLosses.all_none)r1   r2   r3   r4   r9   r   r5   r6   r7   r:   r;   r<   r=   r>   boolr@   r+   r+   r+   r,   r8   \   s   
r8   a  
    Output from FlavaForPreTraining containing embeddings, and outputs from individual encoders.

    Note that `image_embeddings` and `text_embeddings` returned are similar to pooled output returned from a
    transformer. If you want embeddings for contrastive loss or retrieval use a FLAVA model's `image_projection` and
    `text_projection` layers on `image_embeddings` and `text_embeddings` respectively.
    c                   @   s  e Zd ZU dZdZeej ed< dZ	e
ed< dZeej ed< dZee ed< dZeej ed< dZee ed< dZeej ed	< dZee ed
< dZeej ed< dZee ed< dZeej ed< dZee ed< dZeej ed< dZee ed< dZeej ed< dZeej ed< dZeej ed< dZeej ed< dZeej ed< dZeej ed< dZeej ed< ee  dddZ!dS )FlavaForPreTrainingOutputay  
    loss (`torch.FloatTensor`, *optional*, returned when `return_loss` is True):
        Total loss calculated for this model.
    loss_info (`FlavaLosses`):
        Detailed info for FLAVA Pretraining losses. Check `FlavaLosses` class description for the information on
        the keys.
    image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present):
        The image embeddings which are basically the pooled output of [`FlavaImageModel`].
    image_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present):
        The output of the [`FlavaImageModel`].
    text_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` are present):
        The text embeddings which are basically the pooled output of [`FlavaTextModel`].
    text_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids` are present):
        The output of the [`FlavaTextModel`].
    multimodal_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present and `skip_unmasked_multimodal_encoder` is `None` or `False`):
        The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`].
    multimodal_output (`BaseModelOutputWithPooling`, returned when `input_ids` and `pixel_values` are present and `skip_unmasked_multimodal_encoder` is `None` or `False`):
        The output of the [`FlavaMultimodalModel`].
    image_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present):
        The image embeddings which are basically the pooled output of [`FlavaImageModel`]. Uses `bool_masked_pos`
        to create masked images.
    image_masked_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present):
        The output of the [`FlavaImageModel`]. Uses `bool_masked_pos` to create masked images.
    text_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids_masked` are present):
        The text embeddings which are basically the pooled output of [`FlavaTextModel`].
    text_masked_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids_masked` are present):
        The output of the [`FlavaTextModel`].
    multimodal_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present):
        The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`].
    multimodal_masked_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids_masked` and `pixel_values` are present):
        The output of the [`FlavaMultimodalModel`].
    mim_logits (`torch.FloatTensor` of shape `(batch_size, num_image_patches, image_vocab_size)` or of shape `(total_masked_patches, image_vocab_size)` , *optional*, returned when `pixel_values` are present and `input_ids_masked` are not):
        The logits for MIM unimodal loss. Uses `book_masked_pos` to get masked patches. The flattened output is
            returned when `bool_masked_pos` has some of the patches masked.
    mlm_logits (`torch.FloatTensor` of shape `(batch_size, text_seq_length, text_vocab_size)` or of shape `(total_masked_seq_length, text_vocab_size)`, *optional*, returned when `input_ids_masked` are present and `pixel_values` are not):
        The logits for MLM unimodal loss. The flattened output is returned when `input_ids_masked` has some of
            the tokens masked.
    itm_logits (`torch.FloatTensor` of shape `(batch_size, 2)`, *optional*, returned when `input_ids_masked` and `pixel_values` are present):
        The logits for ITM loss. Note that ITM loss is calculated on masked pairs in FLAVA.
    contrastive_logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
        The scaled dot product scores between `image_embeddings` and `text_embeddings` but passed through FLAVA's
        `image_projection` and `text_projection` layers respectively. This represents the image-text similarity
        scores. This is calculated on unmasked images and texts.
    contrastive_logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
        The scaled dot product scores between `text_embeddings` and `image_embeddings` but passed through FLAVA's
        `text_projection` and `image_projection` layers respectively. This is calculated on unmasked images and
        texts.
    mmm_image_logits (`torch.FloatTensor` of shape `(batch_size, num_image_patches, image_vocab_size)` or of shape`(total_masked_patches, image_vocab_size)`, *optional*, returned when `pixel_values` and `input_ids_masked` are present):
        The logits for MMM image multimodal loss. Uses `book_masked_pos` to get masked patches. The flattened
            output is returned when `bool_masked_pos` has some of the patches masked.
    mmm_text_logits (`torch.FloatTensor` of shape `(batch_size, text_seq_length, text_vocab_size)` or of shape `(`(total_masked_seq_length, text_vocab_size)`), *optional*, returned when `pixel_values` and `input_ids_masked` are present):
        The logits for MMM text multimodal loss. The flattened output is returned when `input_ids_masked` has
            some of the tokens masked.
    Nloss	loss_infor   r   r   r   r   r    image_masked_embeddingsimage_masked_outputtext_masked_embeddingstext_masked_outputmultimodal_masked_embeddingsmultimodal_masked_output
mim_logits
mlm_logits
itm_logitscontrastive_logits_per_imagecontrastive_logits_per_textmmm_image_logitsmmm_text_logitsr!   c                    s$   g dt  fdd  D S )N)r   r   r    rI   rG   rK   c                 3   s,   | ]$}|vr | nt  | V  qd S Nr#   r&   r*   Ztransformer_outputsr+   r,   r-          z5FlavaForPreTrainingOutput.to_tuple.<locals>.<genexpr>r.   r)   r+   rT   r,   r%      s    z"FlavaForPreTrainingOutput.to_tuple)"r1   r2   r3   r4   rD   r   r5   r6   r7   rE   r8   r   r   r   r   r   r   r    rF   rG   rH   rI   rJ   rK   rL   rM   rN   rO   rP   rQ   rR   r/   r   r%   r+   r+   r+   r,   rC      s.   
7rC   c                       sd   e Zd ZdZdeedd fddZeje	e	ejddd	Z
dejeej eejd
ddZ  ZS )FlavaImageEmbeddingszb
    Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
    FN)configuse_mask_tokenr"   c                    s   t    |p|j}ttdd|j| _|rFttdd|jnd | _t	|j
|j|j|jd| _| jj}ttd|d |j| _t|j| _|j| _|| _d S )Nr   
image_size
patch_sizenum_channels	embed_dim)super__init__
mask_tokenr   	Parameterr5   zeroshidden_size	cls_tokenPatchEmbeddingsrZ   r[   r\   patch_embeddingsnum_patchesposition_embeddingsDropouthidden_dropout_probdropoutrW   )r*   rW   rX   rg   	__class__r+   r,   r_      s    

 zFlavaImageEmbeddings.__init__)
embeddingsheightwidthr"   c                 C   s   |j d d }| jj d d }tj s>||kr>||kr>| jS | jddddf }| jddddf }|j d }|| j }	|| j }
t|d }|d|||}|dddd}t	j
j||	|
fdd	d
}|dddddd|}tj||fddS )a   
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        r   Ng      ?r   r      ZbicubicF)sizemodeZalign_cornersdim)shaperh   r5   Zjit
is_tracingr[   r   Zreshapepermuter   
functionalZinterpolateviewcat)r*   rn   ro   rp   rg   Znum_positionsZclass_pos_embedZpatch_pos_embedrv   Z
new_heightZ	new_widthZsqrt_num_positionsr+   r+   r,   interpolate_pos_encoding  s(    



z-FlavaImageEmbeddings.interpolate_pos_encoding)pixel_valuesbool_masked_posr}   r"   c                 C   s   |j \}}}}| j||d}| \}}	}
|d ur| j||	d}| dkr`||dd}|d|}|d|  ||  }| j	|dd}t
j||fdd}|r|| ||| }n
|| j }| |}|S )N)r}   rq   r   r         ?r   ru   )rw   rf   rs   r`   expandrv   r{   Z	unsqueezeZtype_asrd   r5   r|   r}   rh   rk   )r*   r~   r   r}   
batch_sizer\   ro   rp   rn   seq_len_Zmask_tokensmask
cls_tokensr+   r+   r,   forward*  s     

zFlavaImageEmbeddings.forward)F)NF)r1   r2   r3   r4   r   rB   r_   r5   Tensorintr}   r   
BoolTensorr   __classcell__r+   r+   rl   r,   rV      s   +  rV   c                       sV   e Zd ZdZdeeeeeef f eed fddZdej	e
ej	d
ddZ  ZS )re   z#
    Image to Patch Embedding.
          r      rY   c                    s   t    t|tjjs ||f}t|tjjs6||f}|d |d  |d |d   }|| _|| _|| _t	j
||||d| _d S )Nr   r   )kernel_sizeZstride)r^   r_   
isinstancecollectionsabcIterablerZ   r[   rg   r   Conv2d
projection)r*   rZ   r[   r\   r]   rg   rl   r+   r,   r_   S  s    
 zPatchEmbeddings.__init__F)r~   r}   r"   c              
   C   sx   |j \}}}}|s\|| jd ks.|| jd kr\td| d| d| jd  d| jd  d	| |ddd}|S )Nr   r   zInput image size (*z) doesn't match model (z).rr   )rw   rZ   
ValueErrorr   flatten	transpose)r*   r~   r}   r   r\   ro   rp   xr+   r+   r,   r   f  s    zPatchEmbeddings.forward)r   r   r   r   )F)r1   r2   r3   r4   r   r   r/   r_   r5   r   rB   r   r   r+   r+   rl   r,   re   N  s       re   c                       sF   e Zd ZdZ fddZdeej eej eej dddZ  Z	S )	FlavaTextEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    s   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j| _t|dd| _| jdt|jddd | jd	tj| j tjd
dd d S )N)padding_idxepsposition_embedding_typeabsoluteposition_ids)r   rq   F)
persistenttoken_type_ids)dtype)r^   r_   r   	Embedding
vocab_sizerc   Zpad_token_idword_embeddingsZmax_position_embeddingsrh   Ztype_vocab_sizetoken_type_embeddings	LayerNormlayer_norm_epsri   rj   rk   r$   r   Zregister_bufferr5   aranger   rb   r   rs   longr*   rW   rl   r+   r,   r_   u  s    
zFlavaTextEmbeddings.__init__N	input_idsr   r   c                 C   s   |  }|d }|d u r.| jd d d |f }|d u rt| drl| jd d d |f }||d |}|}ntj|tj| jjd}| 	|}| 
|}	||	 }
| jdkr| |}|
|7 }
| |
}
| |
}
|
S )Nr   r   r   )r   devicer   )rs   r   hasattrr   r   r5   rb   r   r   r   r   r   rh   r   rk   )r*   r   r   r   input_shape
seq_lengthZbuffered_token_type_idsZ buffered_token_type_ids_expandedZinputs_embedsr   rn   rh   r+   r+   r,   r     s&    






zFlavaTextEmbeddings.forward)NNN)
r1   r2   r3   r4   r_   r   r5   r   r   r   r+   r+   rl   r,   r   r  s      r   c                	       sd   e Zd Zedd fddZd	ejeej eej ee	e
ejejf e
ej f dddZ  ZS )
FlavaSelfAttentionNrW   r"   c                    s   t    |j|j dkr>t|ds>td|j d|j d|j| _t|j|j | _| j| j | _t	j
|j| j|jd| _t	j
|j| j|jd| _t	j
|j| j|jd| _t	|j| _d S )Nr   Zembedding_sizezThe hidden size z4 is not a multiple of the number of attention heads .bias)r^   r_   rc   num_attention_headsr   r   r   attention_head_sizeall_head_sizer   LinearZqkv_biasquerykeyvalueri   Zattention_probs_dropout_probrk   r   rl   r+   r,   r_     s    

zFlavaSelfAttention.__init__Fhidden_statesattention_mask	head_maskoutput_attentionsr"   c                 C   s*  |j \}}}| ||d| j| jdd}| ||d| j| jdd}	| ||d| j| jdd}
t	||	dd}|t
| j }|d ur|| }tjj|dd}| |}|d ur|| }t	||
}|dddd }| d d | jf }|j| }|r ||fn|f}|S )Nrq   r   rr   ru   r   r   )rw   r   r{   r   r   r   r   r   r5   matmulmathsqrtr   rz   Zsoftmaxrk   ry   
contiguousrs   r   )r*   r   r   r   r   r   r   r   Zquery_layerZ	key_layerZvalue_layerZattention_scoresZattention_probsZcontext_layerZnew_context_layer_shapeoutputsr+   r+   r,   r     sB    




zFlavaSelfAttention.forward)NNF)r1   r2   r3   FlavaPossibleConfigsr_   r5   r   r   rB   r   r/   r   r   r+   r+   rl   r,   r     s      r   c                       s@   e Zd ZdZedd fddZejejejdddZ  Z	S )	FlavaSelfOutputz
    The residual connection is defined in FlavaLayer (same as ViTLayer) instead of here (as is the case with other
    models), due to the layernorm applied before each block.
    Nr   c                    s.   t    t|j|j| _t|j| _d S rS   )	r^   r_   r   r   rc   denseri   rj   rk   r   rl   r+   r,   r_     s    
zFlavaSelfOutput.__init__r   input_tensorr"   c                 C   s   |  |}| |}|S rS   r   rk   r*   r   r   r+   r+   r,   r     s    

zFlavaSelfOutput.forward)
r1   r2   r3   r4   r   r_   r5   r   r   r   r+   r+   rl   r,   r     s   r   c                	       sx   e Zd Zedd fddZee ddddZdej	e
ej	 e
ej	 eeeej	ej	f eej	 f d	d
dZ  ZS )FlavaAttentionNr   c                    s*   t    t|| _t|| _t | _d S rS   )r^   r_   r   	attentionr   outputsetpruned_headsr   rl   r+   r,   r_     s    


zFlavaAttention.__init__)headsr"   c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   ru   )lenr   r   r   r   r   r   r   r   r   r   r   r   union)r*   r   indexr+   r+   r,   prune_heads  s    zFlavaAttention.prune_headsFr   c                 C   s8   | j ||||d}| |d |}|f|dd   }|S N)r   r   r   r   r   )r   r   )r*   r   r   r   r   Zself_outputsattention_outputr   r+   r+   r,   r     s    zFlavaAttention.forward)NNF)r1   r2   r3   r   r_   r   r   r   r5   r   r   rB   r   r/   r   r   r+   r+   rl   r,   r     s      r   c                       s8   e Zd Zedd fddZejejdddZ  ZS )FlavaIntermediateNr   c                    sB   t    t|j|j| _t|jt	r6t
|j | _n|j| _d S rS   )r^   r_   r   r   rc   intermediate_sizer   r   
hidden_actstrr	   intermediate_act_fnr   rl   r+   r,   r_   1  s
    
zFlavaIntermediate.__init__)r   r"   c                 C   s   |  |}| |}|S rS   )r   r   r*   r   r+   r+   r,   r   :  s    

zFlavaIntermediate.forward	r1   r2   r3   r   r_   r5   r   r   r   r+   r+   rl   r,   r   0  s   	r   c                       s<   e Zd Zedd fddZejejejdddZ  ZS )FlavaOutputNr   c                    s.   t    t|j|j| _t|j| _	d S rS   )
r^   r_   r   r   r   rc   r   ri   rj   rk   r   rl   r+   r,   r_   B  s    
zFlavaOutput.__init__r   c                 C   s    |  |}| |}|| }|S rS   r   r   r+   r+   r,   r   H  s    

zFlavaOutput.forwardr   r+   r+   rl   r,   r   A  s   r   c                	       sh   e Zd ZdZedd fddZd
ejeej eej e	e
eejejf eej f ddd	Z  ZS )
FlavaLayerz?This corresponds to the Block class in the timm implementation.Nr   c                    sb   t    |j| _d| _t|| _t|| _t|| _	t
j|j|jd| _t
j|j|jd| _d S )Nr   r   )r^   r_   Zchunk_size_feed_forwardZseq_len_dimr   r   r   intermediater   r   r   r   rc   r   layernorm_beforelayernorm_afterr   rl   r+   r,   r_   T  s    



zFlavaLayer.__init__Fr   c           	      C   sb   | j | ||||d}|d }|dd  }|| }| |}| |}| ||}|f| }|S r   )r   r   r   r   r   )	r*   r   r   r   r   Zself_attention_outputsr   r   Zlayer_outputr+   r+   r,   r   `  s    


zFlavaLayer.forward)NNF)r1   r2   r3   r4   r   r_   r5   r   r   rB   r   r/   r   r   r+   r+   rl   r,   r   Q  s      r   c                
       sV   e Zd Zedd fddZd
ejeej eej eeee	e
ef ddd	Z  ZS )FlavaEncoderNr   c                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r+   )r   r'   r   rW   r+   r,   
<listcomp>  rU   z)FlavaEncoder.__init__.<locals>.<listcomp>F)	r^   r_   rW   r   Z
ModuleListrangenum_hidden_layerslayerZgradient_checkpointingr   rl   r   r,   r_     s    
 zFlavaEncoder.__init__FTr   r   r   r   output_hidden_statesreturn_dictr"   c                 C   s   |rdnd }|rdnd }t | jD ]R\}	}
|r8||f }|d urH||	 nd }|
||||}|d }|r"||d f }q"|r||f }|stdd |||fD S t|||dS )Nr+   r   r   c                 s   s   | ]}|d ur|V  qd S rS   r+   )r'   rA   r+   r+   r,   r-     rU   z'FlavaEncoder.forward.<locals>.<genexpr>)last_hidden_stater   
attentions)	enumerater   r/   r   )r*   r   r   r   r   r   r   Zall_hidden_statesZall_self_attentionsiZlayer_moduleZlayer_head_maskZlayer_outputsr+   r+   r,   r     s"    	

zFlavaEncoder.forward)NNFFT)r1   r2   r3   r   r_   r5   r   r   rB   r   r/   r   r   r   r+   r+   rl   r,   r     s   	     
r   c                       s2   e Zd Zed fddZejdddZ  ZS )FlavaPoolerr   c                    s*   t    t|j|j| _t | _d S rS   )r^   r_   r   r   rc   r   ZTanh
activationr   rl   r+   r,   r_     s    
zFlavaPooler.__init__)r   c                 C   s(   |d d df }|  |}| |}|S Nr   )r   r   )r*   r   Zfirst_token_tensorpooled_outputr+   r+   r,   r     s    

zFlavaPooler.forwardr   r+   r+   rl   r,   r     s   r   c                   @   s>   e Zd ZU eed< dZdZeej	ej
ejf ddddZdS )FlavaPreTrainedModelrW   flavaTN)moduler"   c                 C   s>  t |tjtjfr@|jjjd| jjd |j	dur>|j	j
  nt |tjr|jjjd| jjd |jdur~|jj|j 
  nt |tjr|j	j
  |jjd nt |tr|j	j
  nzt |tr|jj
  |jj
  |jdur|jj
  n@t |tr|jr:|jj
  nt |tr:|jj| jj dS )zInitialize the weightsg        )meanZstdNr   )r   r   r   r   weightdataZnormal_rW   Zinitializer_ranger   Zzero_r   r   r   Zfill_FlavaMaskedPredictionHeadrV   rd   rh   r`   FlavaMultimodalModeluse_cls_token
FlavaModellogit_scalelogit_scale_init_value)r*   r   r+   r+   r,   _init_weights  s.    




z"FlavaPreTrainedModel._init_weights)r1   r2   r3   r   r7   base_model_prefixsupports_gradient_checkpointingr   r   r   r   r   r  r+   r+   r+   r,   r     s   
r   c                       s   e Zd ZU eed< dZdZdeed fddZe	j
dd	d
Ze	j
dddZeeee f ddddZedeej eej ee eej eej ee ee ee eeef d	ddZ  ZS )FlavaImageModelrW   zflava.image_modelr~   TrW   add_pooling_layerc                    sX   t  | || _t|| _t|| _tj|j	|j
d| _|rFt|nd| _|   dS v
        add_pooling_layer (bool, *optional*, defaults to `True`):
            Whether to add a pooling layer
        r   N)r^   r_   rW   rV   rn   r   encoderr   r   rc   r   	layernormr   pooler	post_initr*   rW   r  rl   r+   r,   r_     s    

zFlavaImageModel.__init__r!   c                 C   s   | j jS rS   rn   rf   r)   r+   r+   r,   get_input_embeddings  s    z$FlavaImageModel.get_input_embeddingsr   c                 C   s   || j _d S rS   r  r*   r   r+   r+   r,   set_input_embeddings  s    z$FlavaImageModel.set_input_embeddingsNheads_to_pruner"   c                 C   s*   |  D ]\}}| jj| j| qdS z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        Nitemsr  r   r   r   r*   r  r   r   r+   r+   r,   _prune_heads  s    zFlavaImageModel._prune_heads	r~   r   r}   r   r   r   r   r   r"   c	                 C   s   |dur|n| j j}|dur |n| j j}|dur4|n| j j}|du rLtd| || j j}| j|||d}	| j|	|||||d}
|
d }| 	|}| j
dur| 
|nd}|s||f|
dd  S t|||
j|
jdS )z
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        Nz You have to specify pixel_values)r   r}   r   r   r   r   r   r   r   r   Zpooler_outputr   r   )rW   r   r   use_return_dictr   get_head_maskr   rn   r  r  r  r   r   r   )r*   r~   r   r}   r   r   r   r   r   embedding_outputencoder_outputssequence_outputr   r+   r+   r,   r     s:    
zFlavaImageModel.forward)T)NNNNNNNN)r1   r2   r3   r   r7   r	  main_input_namerB   r_   r   Moduler  r  dictr   listr   r   r   r5   r   r   r   r/   r   r   r   r+   r+   rl   r,   r    s6   
        
r  c                       s   e Zd ZU eed< dZdeed fddZeddd	Z	e
jd
ddZeeee f ddddZedeej eej eej eej eej ee ee ee eeef d	ddZ  ZS )FlavaTextModelrW   zflava.text_modelTr  c                    sX   t  | || _t|| _t|| _tj|j	|j
d| _|rFt|nd| _|   dS r  )r^   r_   rW   r   rn   r   r  r   r   rc   r   r  r   r  r  r  rl   r+   r,   r_   B  s    

zFlavaTextModel.__init__r!   c                 C   s   | j jS rS   rn   r   r)   r+   r+   r,   r  R  s    z#FlavaTextModel.get_input_embeddingsr  c                 C   s   || j _d S rS   r.  r  r+   r+   r,   r  U  s    z#FlavaTextModel.set_input_embeddingsNr  c                 C   s*   |  D ]\}}| jj| j| qdS r  r  r  r+   r+   r,   r   X  s    zFlavaTextModel._prune_heads)	r   r   r   r   r   r   r   r   r"   c	                 C   s  |dur|n| j j}|dur |n| j j}|dur4|n| j j}|du rLtd| }	|du rltj|	|jd}| 	|| j j
}| ||	|j}
| j|||d}| j||
||||d}|d }| |}| jdur| |nd}|s||f|dd  S t|||j|jdS )	a  
        input_ids (`torch.LongTensor` of shape `(batch_size, text_seq_length)`):
            Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
            IDs?](../glossary#input-ids)
        token_type_ids (`torch.LongTensor` of shape `(batch_size, text_seq_length)`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:
            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.
            [What are token type IDs?](../glossary#token-type-ids)
        NzYou have to specify input_idsr   r   r"  r   r   r#  )rW   r   r   r$  r   rs   r5   onesr   r%  r   get_extended_attention_maskrn   r  r  r  r   r   r   )r*   r   r   r   r   r   r   r   r   r   extended_attention_maskr&  r'  r(  r   r+   r+   r,   r   `  sJ    
zFlavaTextModel.forward)T)NNNNNNNN)r1   r2   r3   r   r7   r	  rB   r_   re   r  r   r*  r  r+  r   r,  r   r   r   r5   r   r   r/   r   r   r   r+   r+   rl   r,   r-  <  s4   
        
r-  c                       s   e Zd ZU eed< dZdZded fddZee	e
e	 f dd	d
dZedejeej eej ee ee ee eeef dddZ  ZS )r  rW   zflava.multimodal_modelr   Tr   c                    sv   t  | || _| jj| _| jr:ttdd|j| _	t
|| _tj|j|jd| _|rdt|nd| _|   dS )r  r   r   N)r^   r_   rW   r  r   ra   r5   rb   rc   rd   r   r  r   r   r  r   r  r  r  rl   r+   r,   r_     s    

zFlavaMultimodalModel.__init__Nr  c                 C   s*   |  D ]\}}| jj| j| qdS r  r  r  r+   r+   r,   r     s    z!FlavaMultimodalModel._prune_headsr   c                 C   s(  |dur|n| j j}|dur |n| j j}|dur4|n| j j}| \}}}	| jrz| j|dd}
tj	|
|fdd}|d7 }|du rtj
||f|jd}| || j j}| |||f|j}| j||||||d}|d }| |}| jdur| |nd}|s||f|dd  S t|||j|jdS )	z
        hidden_states (`torch.FloatTensor` of shape `(batch_size, image_num_patches + text_seq_len, hidden_size)`):
            The concatenated hidden states of unimodal encoders.
        Nrq   r   ru   r/  r"  r   r#  )rW   r   r   r$  rs   r  rd   r   r5   r|   r0  r   r%  r   r1  r  r  r  r   r   r   )r*   r   r   r   r   r   r   r   r   r   r   r2  r'  r(  r   r+   r+   r,   r     sD    
zFlavaMultimodalModel.forward)T)NNNNN)r1   r2   r3   r   r7   r	  r)  r_   r+  r   r,  r   r   r5   r   r   rB   r   r/   r   r   r   r+   r+   rl   r,   r    s(   
     
r  c                       s*  e Zd ZU eed< ed fddZedeej	 eej	 eej	 eej	 ee
 ee
 ee
 ejdddZedeej	 eej ee
 eej	 eej	 ee
 ee
 ee
 ejd		d
dZedeej eej eej	 eej	 eej	 eej eej	 ee
 ee
 e
ee
 eeef dddZ  ZS )r  rW   r   c                    s0  t  | t|jts.tdt|j dt|jtsPtdt|j dt|j	t
svtddt|j	 d |j}|j}|j	}|j| _|j| _|j| _|j| _t|| _t|| _t|| _t| j| j| _t| j| j| _tt| jj| _t| j| j| _ t| j| j| _!| "  d S )NzLconfig.text_config is expected to be of type FlavaTextConfig but is of type r   zNconfig.image_config is expected to be of type FlavaImageConfig but is of type zMconfig.multimodal_config is expected to be of type FlavaMultimodalConfig but zis of type )#r^   r_   r   text_configr   	TypeErrortypeimage_configr   multimodal_configr   Zprojection_dimrc   Ztext_hidden_sizeZimage_hidden_sizeZmm_hidden_sizer-  
text_modelr  image_modelr  multimodal_modelr   r   image_projectiontext_projectionra   r5   ZtensorrW   r  r  image_to_mm_projectiontext_to_mm_projectionr  )r*   rW   r3  r6  r7  rl   r+   r,   r_     sF    


zFlavaModel.__init__N)r   r   r   r   r   r   r   r"   c              	   C   s.   | j |||||||d}|d }	| |	}
|
S )a  
        input_ids (`torch.LongTensor` of shape `(batch_size, text_seq_length)`):
            Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
            IDs?](../glossary#input-ids)
        token_type_ids (`torch.LongTensor` of shape `(batch_size, text_seq_length)`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:
            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.
            [What are token type IDs?](../glossary#token-type-ids)

        Returns:
            text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
            applying the projection layer to the pooled output of [`FlavaTextModel`].

        Examples:

        ```python
        >>> from transformers import AutoProcessor, FlavaModel

        >>> model = FlavaModel.from_pretrained("{0}")
        >>> processor = AutoProcessor.from_pretrained("{0}")

        >>> inputs = processor(
        ...     text=["a photo of a cat", "a photo of a dog"], max_length=77, padding="max_length", return_tensors="pt"
        ... )
        >>> text_features = model.get_text_features(**inputs)
        ```
        )r   r   r   r   r   r   r   r   )r8  r<  )r*   r   r   r   r   r   r   r   Ztext_outputsr   Ztext_featuresr+   r+   r,   get_text_features<  s    )

zFlavaModel.get_text_featuresr!  c	              
   C   s0   | j ||||||||d}	|	d }
| |
}|S )a  
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).

        Returns:
            image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
            applying the projection layer to the pooled output of [`FlavaImageModel`].

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, FlavaModel

        >>> model = FlavaModel.from_pretrained("{0}")
        >>> processor = AutoProcessor.from_pretrained("{0}")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> image_features = model.get_image_features(**inputs)
        ```
        )r~   r   r   r   r   r   r}   r   r   )r9  r;  )r*   r~   r   r}   r   r   r   r   r   Zimage_outputsr   Zimage_featuresr+   r+   r,   get_image_featurest  s    &
zFlavaModel.get_image_featuresT)r   r~   r   r   r   r   image_attention_maskskip_multimodal_encoderr   r   r   r"   c              	   C   s  |dur|n| j j}|
s tdd}d}d}d}|durn| j||||	|
|d}|d |d  }}| |d }d}d}d}d}|dur| j|||||	|
|d}|d |d  }}| |d }d}d}|dur\|dur\|s\|dur.|j\}}}| jj	r|d7 }t
j|||jd	}t
j||gdd
}nd}t
j||gdd
}| j|||d}|d }|sr||||||fS t||||||dS )a	  
        input_ids (`torch.LongTensor` of shape `(batch_size, image_num_patches + text_seq_len)`):
            Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
            IDs?](../glossary#input-ids)
        token_type_ids (`torch.LongTensor` of shape `(batch_size, image_num_patches + text_seq_len)`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:
            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.
            [What are token type IDs?](../glossary#token-type-ids)
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        image_attention_mask (`torch.Tensor` of shape `(batch_size, image_num_patches)`, *optional*):
            Mask to avoid performing attention on padding pixel values for image inputs. Mask values selected in `[0, 1]`:
            - 1 for pixel values that are real (i.e., **not masked**),
            - 0 for pixel values that are padding (i.e., **masked**).
        skip_multimodal_encoder (*bool*, *optional*):
            Skip any calculations for multimodal encoder. Useful if multimodal encoding is not going to be used.

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, FlavaModel

        >>> model = FlavaModel.from_pretrained("facebook/flava-full")
        >>> processor = AutoProcessor.from_pretrained("facebook/flava-full")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(text=["a photo of a cat"], images=image, return_tensors="pt", padding=True)

        >>> outputs = model(**inputs)

        >>> image_embeddings = outputs.image_embeddings
        >>> text_embeddings = outputs.text_embeddings
        >>> multimodal_embeddings = outputs.multimodal_embeddings

        >>> outputs.image_embeddings.shape
        torch.Size([1, 197, 768])

        >>> text_embeddings.shape
        torch.Size([1, 7, 768])

        >>> multimodal_embeddings.shape
        torch.Size([1, 205, 768])
        ```
        NzRFLAVA model requires hidden states to work. Please set `output_hidden_states=True`)r~   r   r   r   r   r   r   rr   rq   )r   r   r   r   r   r   r   r   r/  ru   )r   r   )r   r   r   r   r   r    )rW   r   r   r9  r=  r8  r>  rw   r:  r  r5   r0  r   r|   r   )r*   r   r~   r   r   r   r   rA  rB  r   r   r   r   Zimage_statesZimage_mm_projectionr   r   Ztext_statesZtext_mm_projectionr   r   r    r   r   r   Zattention_mask_imageZattention_multimodalZmultimodal_inputr+   r+   r,   r     s    C


	zFlavaModel.forward)NNNNNNN)NNNNNNNN)NNNNNNNNNTN)r1   r2   r3   r   r7   r_   r   r   r5   r   rB   r6   r?  r   r@  
LongTensorr   r/   r   r   r   r+   r+   rl   r,   r    s   
+       7        5           
r  c                       s8   e Zd Zeed fddZejejdddZ  ZS )FlavaImageCodebookResPath)in_sizeout_sizec                    s   t    |d }t }t |d< tj||ddd|d< t |d< tj||ddd|d< t |d	< tj||ddd|d
< t |d< tj||ddd|d< t|| _d S )N   Zrelu_1r   r   r   paddingZconv_1Zrelu_2Zconv_2Zrelu_3Zconv_3Zrelu_4r   Zconv_4)r^   r_   r   r   ReLUr   
Sequentialpath)r*   rE  rF  kwargsZhid_sizerL  rl   r+   r,   r_   :  s    
z"FlavaImageCodebookResPath.__init__r   r"   c                 C   s
   |  |S rS   )rL  r*   r   r+   r+   r,   r   J  s    z!FlavaImageCodebookResPath.forward	r1   r2   r3   r   r_   r5   r   r   r   r+   r+   rl   r,   rD  9  s   rD  c                       s:   e Zd Zeeed fddZejejdddZ  ZS )FlavaImageCodebookBlock)rE  rF  
num_layersc                    sP   t    d|d  | _||kr6tj||ddd| _n
t | _t||| _d S )Nr   rr   r   rH  )	r^   r_   	post_gainr   r   id_pathZIdentityrD  res_path)r*   rE  rF  rR  rM  rl   r+   r,   r_   O  s    

z FlavaImageCodebookBlock.__init__rN  c                 C   s   |  || j| |  S rS   )rT  rS  rU  rO  r+   r+   r,   r   [  s    zFlavaImageCodebookBlock.forwardrP  r+   r+   rl   r,   rQ  N  s   rQ  c                       s@   e Zd Zdeeeeed fddZejejdddZ  Z	S )	FlavaImageCodebookLayerGroupT)
num_blocksrR  rE  rF  use_poolc                    s   t    t }t|D ]B}|dkr@t||||d|d  < qt||||d|d  < q|rptjdd|d< t|| _d S )Nr   Zblock_r   rr   )r   pool)	r^   r_   r   r   rQ  r   Z	MaxPool2drK  group)r*   rW  rR  rE  rF  rX  blocksr   rl   r+   r,   r_   `  s    
z%FlavaImageCodebookLayerGroup.__init__rN  c                 C   s
   |  |S rS   )rZ  rO  r+   r+   r,   r   n  s    z$FlavaImageCodebookLayerGroup.forward)T)
r1   r2   r3   r   rB   r_   r5   r   r   r   r+   r+   rl   r,   rV  _  s   rV  a"  
    The FLAVA's image codebook model inspired from DALL-E's original encoder. Outputs raw hidden states and can be used
    to generate image tokens for an image based on DALL-E's vocab. Used to generate labels for MIM. Use
    `get_codebook_indices` to get image tokens for an image.
    c                       sv   e Zd ZU dZeed< dZdZeed fddZ	e
je
jdd	d
Ze
je
jdddZe
je
jdddZ  ZS )FlavaImageCodebook rW   r~   F)rW   rM  c                    sd  t  | || _|j| _|j| _|j| _|j| _|j| _| j| j }t }t	
 |d< t	jd| j | jddd|d< t }t	j| jd| j ddd|d	< t| j|d| j d| j |d
< t| j|d| j d| j |d< t| j|d| j d| j |d< t| j|d| j d| j dd|d< t	||d< t	|| _|   | jjr`|  D ]}d|_qRd S )NZrelu   r   r   rH  conv   r   inputZgroup_1rr   Zgroup_2rG  Zgroup_3F)rX  Zgroup_4r   )r^   r_   rW   Z
num_groupsinput_channelsZnum_blocks_per_grouprc   r   r   r   rJ  r   rV  rK  r[  r  freeze
parametersZrequires_grad)r*   rW   rM  rR  Zoutput_blocksr[  paramrl   r+   r,   r_     s>    

zFlavaImageCodebook.__init__)r~   r"   c                 C   s*   dt  dt  d | |}tj|ddS )Na)  
        Args:
            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
                Pixel values. Codebook pixel values can be obtained using [`AutoImageProcessor`] by passing
                `return_codebook_pixels=True`. See [`FlavaImageProcessor.__call__`] for details.

        Examples:
        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoImageProcessor, FlavaImageCodebook

        >>> model = FlavaImageCodebook.from_pretrained("E")
        >>> image_processor = AutoImageProcessor.from_pretrained("a  ")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = image_processor([image], return_codebook_pixels=True, return_tensors="pt")
        >>> inputs = dict(pixel_values=inputs.codebook_pixel_values)

        >>> outputs = model.get_codebook_indices(**inputs)
        ```
        r   )Zaxis)_CHECKPOINT_FOR_CODEBOOK_DOCr[  r5   Zargmaxr*   r~   Zz_logitsr+   r+   r,   get_codebook_indices  s    
z'FlavaImageCodebook.get_codebook_indicesc                 C   s   |  |}tjdd|S )Nr   ru   )r[  r   ZSoftmaxrh  r+   r+   r,   get_codebook_probs  s    
z%FlavaImageCodebook.get_codebook_probsc                 C   sh   dt  dt  d t|jdkr2td|j d|jd | jkr^td|jd  d	| j | |S )
Na*  
        Args:
            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
                Pixel values. Codebook pixel values can be obtained using [`AutoImageProcessor`] by passing
                `return_codebook_pixels=True`. See [`FlavaImageProcessor.__call__`] for details.

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoImageProcessor, FlavaImageCodebook

        >>> model = FlavaImageCodebook.from_pretrained("rf  a  ")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = image_processor([image], return_codebook_pixels=True, return_tensors="pt")
        >>> inputs = dict(pixel_values=inputs.codebook_pixel_values)

        >>> outputs = model(**inputs)
        >>> print(outputs.shape)
        (1, 196)
        ```
        rG  zinput shape z
 is not 4dr   z
input has z channels but model built for )rg  r   rw   r   rb  r[  )r*   r~   r+   r+   r,   r     s    zFlavaImageCodebook.forward)r1   r2   r3   r	  r   r7   r)  r
  r   r_   r5   r   ri  rj  r6   r   r   r+   r+   rl   r,   r\  s  s   
,r\  c                       s$   e Zd Z fddZdd Z  ZS )FlavaPredictionHeadTransformc                    sV   t    t|j|j| _t|jtr6t	|j | _
n|j| _
tj|j|jd| _d S )Nr   )r^   r_   r   r   rc   r   r   r   r   r	   transform_act_fnr   r   r   rl   r+   r,   r_     s    
z%FlavaPredictionHeadTransform.__init__c                 C   s"   |  |}| |}| |}|S rS   )r   rl  r   r   r+   r+   r,   r     s    


z$FlavaPredictionHeadTransform.forwardr1   r2   r3   r_   r   r   r+   r+   rl   r,   rk    s   	rk  c                       s.   e Zd Zd fdd	Zdd Zdd Z  ZS )	r  Nc                    sb   t    || _t|| _tj|j|jdd| _	t
t|j| _|d urT|| j	_| j| j	_d S )NFr   )r^   r_   rW   rk  	transformr   r   rc   r   decoderra   r5   rb   r   r   )r*   rW   r   rl   r+   r,   r_     s    

z"FlavaMaskedPredictionHead.__init__c                 C   s   | j | j_ d S rS   )r   ro  r)   r+   r+   r,   _tie_weights  s    z&FlavaMaskedPredictionHead._tie_weightsc                 C   s   |  |}| |}|S rS   )rn  ro  rO  r+   r+   r,   r     s    

z!FlavaMaskedPredictionHead.forward)N)r1   r2   r3   r_   rp  r   r   r+   r+   rl   r,   r     s   r  c                       s$   e Zd Z fddZdd Z  ZS )FlavaITMHeadc                    s.   t    || _t|| _t|jd| _d S )Nrr   )	r^   r_   rW   r   r  r   r   rc   seq_relationshipr   rl   r+   r,   r_     s    

zFlavaITMHead.__init__c                 C   s   |  |}| |}|S rS   )r  rr  rO  r+   r+   r,   r     s    

zFlavaITMHead.forwardrm  r+   r+   rl   r,   rq    s   rq  c                       s$   e Zd Z fddZdd Z  ZS )FlavaGlobalContrastiveHeadc                    s   t    || _|j| _d S rS   )r^   r_   rW   global_backprop_contrastiver   rl   r+   r,   r_   $  s    
z#FlavaGlobalContrastiveHead.__init__c                    s2  t |}t j rt j sBt j d jd} g}g}n d}t j }	| j	r~t jj
j }t jj
j}nHfddt|	D } fddt|	D }t j|  t j| |t j  t j| jd }t |}t |}t  |dd| }
t |dd| }|
||fS )Nr   r/  c                    s   g | ]}t  qS r+   r5   Z
zeros_liker   )r   r+   r,   r   9  rU   z6FlavaGlobalContrastiveHead.forward.<locals>.<listcomp>c                    s   g | ]}t  qS r+   ru  r   )r   r+   r,   r   :  rU   r   )r5   expdistributedZis_availableZis_initializedr   rs   r   Zget_world_sizert  r   rz   Z
all_gatherr   Zget_rankr|   r   r   )r*   r   r   r  ZtemperaturelabelsZimage_embeddings_allZtext_embeddings_allZlocal_batch_sizeZ
world_sizelogits_per_imagelogits_per_textr+   )r   r   r,   r   )  s,    




z"FlavaGlobalContrastiveHead.forwardrm  r+   r+   rl   r,   rs  #  s   rs  zk
    The FLAVA model for pretraining which outputs losses, embeddings, logits and transformer outputs.
    c                       s   e Zd Zg dZdeeej d fddZe	j
dddZedee	j ee	j ee	j ee	j ee	j
 ee	j
 ee	j
 ee	j ee	j
 ee ee	j
 ee	j
 ee	j
 ee eee ee eee	j
 ef d
ddZ  ZS )FlavaForPreTraining)zmmm_text_head.decoder.biaszmmm_image_head.decoder.biaszmlm_head.decoder.biaszmim_head.decoder.biasN)rW   image_codebookc                    s   t  | t|| _|| _| jdu r8|jr8t|j| _t|j	| _
t|j| _t|| _t|j	| _t|j| _t|| _|j	j| _|jj| _|j| _|j| _|j| _|j| _|j| _|j| _|j| _|j| _|   dS )z
        image_codebook ([`nn.Module`]):
            If passed, the image codebook will be set to this. Otherwise, it will be initialized using the
            image_codebook_config defined in the config first as the first parameter.
        N)r^   r_   r  r   r|  Zinit_codebookr\  Zimage_codebook_configr  r6  mim_headr3  mlm_headrq  itm_headmmm_image_headmmm_text_headrs  global_contrastive_headr   image_vocab_sizetext_vocab_size
mlm_weight
mim_weightglobal_contrastive_weightce_ignore_index
itm_weightmmm_image_weightmmm_text_weight skip_unmasked_multimodal_encoderr  )r*   rW   r|  rl   r+   r,   r_   Y  s,    




zFlavaForPreTraining.__init__)r   c                 C   s"   |  dkr||dd}|S )Nrr   r   rq   )rv   r{   rs   rO  r+   r+   r,   _resize_to_2d|  s    z!FlavaForPreTraining._resize_to_2dT)r   input_ids_maskedr~   codebook_pixel_valuesr   r   r   r   rA  r  
mlm_labels
mim_labels
itm_labelsr   r   r   return_lossr"   c           6      C   sz  |dur|n| j j}|dur |n| j j}|
dur4|
n| j}
|du rX|durXtd |}| j||||||	|
||dd
}| j|||||	|||dd	}d}|j}|j}|j}|j}|j	}d } } } } } }} d }! }" }#}$d }% }&}'|dus|dur@|du r@|r@| j
du r"td|du r4td| j
|}| jdkr|dur|du r|}(|dur| |}| |}| j||d< |(dd|d	 dddf }(|| j})||) }*|(|)ddf }(| |(}!|rtj|!d
| j|*d
}|| j9 }n
| |(}!| jdkr|dur|du r|}+|dur| |}|+dd|d	 dddf }+|| j})||) },|+|)ddf }+| |+}"|rtj|"d
| j|,d
}|| j9 }n
| |+}"| jdkr|dur| |}%|dur|d}-t|-  |-|-!dg}|rJtj|%|} | | j9 } |dur\|| }|durn|| }|dur|| }|| }|durd| j"dkrd|}(|d	d	 }.|(dddd|. ddf }(|durZ| |}| |}| j||d< || j})||) }*|(|)ddf }(| #|(}$|rdtj|$d
| j|*d
}|| j"9 }n
| #|(}$|dur| j$dkr|}+|+dd|d	 dddf }+|dur| |}|| j})||) },|+|)ddf }+| %|+}#|rtj|#d
| j|,d
}|| j$9 }n
| %|+}#|dur|dur| j&dkr| j'|dddddf }/tjj(|/d
d}/| j)|dddddf }0tjj(|0d
d}0| jj*j+,t-t. | /|0|/| jj*\}&}'}1|dur|&| }&|'| }'|1| }1|rtj|&|1}2tj|'|1}3|2|3 d }|| j&9 }t0||| |||d}4|rV|41 sVt2dd |43 D }|s:||j4durt|j45 nd||j6dur|j65 nd|j	|j7dur|j75 nd||j4dur|j45 nd||j6dur|j65 nd||j7dur|j75 nd|!|"|%|&|&|$|#f}5|r(|41 s(||4f|5 }5t8dd |5D S t9||4||j4||j6|j	|j7||j4||j6||j7|!|"|%|&|'|$|#dS )a  
        input_ids (`torch.LongTensor` of shape `(batch_size, text_seq_len)`):
            Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
            IDs?](../glossary#input-ids)
        input_ids_masked (`torch.LongTensor` of shape `(batch_size, text_seq_len)`):
            Indices of input sequence tokens in the vocabulary. These ones are the masked version of the original task
            to be used with MLM. Indices can be obtained using [`AutoTokenizer`] along with
            [`DataCollatorForMaskedLanguageModeling`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids)
        codebook_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_image_patches, patch_size, patch_size, 3)`, *optional*):
            Pixel values for image patches that are used to compute the image codebook labels for masked image modeling.
        token_type_ids (`torch.LongTensor` of shape `(batch_size, text_seq_len)`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:
            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.
            [What are token type IDs?](../glossary#token-type-ids)
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        image_attention_mask (`torch.FloatTensor` of shape `(batch_size, image_num_patches)`, *optional*):
            Mask to avoid performing attention on padding token indices specifically for images. Mask values selected
            in `[0, 1]`:
            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
            [What are attention masks?](../glossary#attention-mask)
        skip_unmasked_multimodal_encoder (*bool*, *optional*):
            Skip any calculations for multimodal encoder for unmasked inputs. FLAVA pretraining doesn't need unmasked
            multimodal embeddings or outputs as of now.
        mlm_labels (`torch.LongTensor` of shape `(batch_size, text_seq_len)`, *optional*):
            Labels for computing the left-to-right language and multimodal masked modeling loss (next word prediction).
            Indices should be in `[-100, 0, ..., text_config.vocab_size - 1]` (see `input_ids` docstring). Tokens with
            indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0,
            ..., text_config.vocab_size - 1]`.
        mim_labels (`torch.LongTensor` of shape `(batch_size, image_num_patches)`, *optional*):
            Labels for computing the image and multimodal masked modeling loss. Indices should be in `[-100, 0, ...,
            image_config.vocab_size - 1]`. Tokens with indices set to `-100` are ignored (masked), the loss is only
            computed for the tokens with labels in `[0, ..., image_config.vocab_size - 1]`. If not passed, they are
            generated automatically using the image codebook assigned to the model. By default, it uses
            [`FlavaImageCodebook`]. See [`FlavaImageCodebook`] to understand how to generate mim_labels.
        itm_labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*):
            Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match.
            The pairs with 0 will be skipped for calculation of MMM and global contrastive losses as well.
        return_loss (`bool`, *optional*, default to None):
            Whether to return calculated loss or not.

        Examples:
        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import FlavaForPreTraining, AutoProcessor

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> model = FlavaForPreTraining.from_pretrained("facebook/flava-full")
        >>> processor = AutoProcessor.from_pretrained("facebook/flava-full")

        >>> text = ["a photo of a cat"]

        >>> inputs = processor(
        ...     images=[image],
        ...     text=text,
        ...     return_masks=True,
        ...     return_codebook_pixels=True,
        ...     padding=True,
        ...     max_length=77,
        ...     return_tensors="pt",
        ... )


        >>> output = model(**inputs)
        ```
        Nz`input_ids_masked` isn't passed which means MLM loss won't be calculated correctlySetting it to `input_ids` so that model can work. Please pass it if this is unintentional. This is usually OKAY if you are doing inference on unmasked text...T)
r   r~   r   r   r   rA  rB  r   r   r   )	r   r~   r   r   rA  r   r   r   r   z`return_loss` is set to True but the image codebook is not initialized and no `mim_labels`  have been passed. Reinstantiate the model with `init_codebook` set to True or pass in your custom `mim_labels`z`codebook_pixel_value` are required to generate `mim_labels` if loss is expected. Call `AutoProcessor` with `return_codebook_pixels` set to Truer   r   rq   rr   ru   )r9   r:   r;   r<   r=   r>   c                 s   s   | ]}|d ur|ndV  qd S r   r+   )r'   rD   r+   r+   r,   r-     rU   z.FlavaForPreTraining.forward.<locals>.<genexpr>c                 s   s   | ]}|d u r|V  qd S rS   r+   )r'   r   r+   r+   r,   r-     rU   )rD   rE   r   r   r   r   r   r    rF   rG   rH   rI   rJ   rK   rL   rM   rN   rO   rP   rQ   rR   ):rW   r$  r  r  loggerwarningr   r   r   r   r|  RuntimeErrorr   ri  r  r  r  ners   r}  r   rz   Zcross_entropyr{   r  r  r~  r  r  r  r5   whereanynewr  r  r  r  r  r<  	normalizer;  r  r  Zclamp_LOGIT_SCALE_CLAMP_MINLOGIT_SCALE_CLAMP_MAXr  r8   r@   sumr?   r   r%   r   r    r/   rC   )6r*   r   r  r~   r  r   r   r   r   rA  r  r  r  r  r   r   r   r  Zflava_outputZflava_masked_outputZpos_maskr   r   rF   rH   rJ   Z
total_lossZmim_lossZmlm_lossZmmm_text_lossZmmm_image_lossZgc_lossZitm_lossrL   rM   rR   rQ   rN   ry  rz  Zsequence_for_imageZmasked_tokensZmim_labels_filteredZsequence_for_textZmlm_labels_filteredZ	pos_pairsZ	end_indexZtext_embeddingZimage_embeddingZ	gc_labelsZgc_loss_imageZgc_loss_textZflava_lossesr   r+   r+   r,   r     s   _
 


"

 

"













"



 



	zFlavaForPreTraining.forward)N)NNNNNNNNNNNNNNTNN)r1   r2   r3   Z_tied_weights_keysr   r   r   r*  r_   r5   r   r  r   rC  r6   rB   r   r/   rC   r   r   r+   r+   rl   r,   r{  K  sR   #                 r{  )r{  r\  r  r  r  r   r-  )Ir4   r   r   r   dataclassesr   typingr   r   r   r5   Ztorch.utils.checkpointr   Zactivationsr	   Zmodeling_layersr
   Zmodeling_outputsr   r   Zmodeling_utilsr   Zpytorch_utilsr   r   utilsr   r   r   r   Zconfiguration_flavar   r   r   r   r   Z
get_loggerr1   r  rg  r  r  r   r   r8   rC   r*  rV   re   r   r   r   r   r   r   r   r   r   r   r  r-  r  r  rD  rQ  rV  r\  rk  r  rq  rs  r{  __all__r+   r+   r+   r,   <module>   s   	
	"	\d$9I*.*"`p_  -u(   #