a
    hP                     @   s<  d dl mZ d dlmZmZ d dlZd dlmZ ddlmZ ddl	m
Z
 ddlmZ dd	lmZmZ dd
lmZ ddlmZmZ ddlmZ ddlmZ eeddG dd deZeeddG dd deZG dd dejZeG dd deZeddG dd deZeddG dd  d eeZg d!Z dS )"    )	dataclass)OptionalUnionN)nn   )ACT2FN)Cache)GenerationMixin)BaseModelOutputWithPastModelOutput)PreTrainedModel)auto_docstringcan_return_tuple   )	AutoModel   )VipLlavaConfigzM
    Base class for VipLlava outputs, with hidden states and attentions.
    )Zcustom_introc                   @   s$   e Zd ZU dZdZeej ed< dS )VipLlavaModelOutputWithPasta  
    past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
        `(batch_size, num_heads, sequence_length, embed_size_per_head)`)

        Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
        `past_key_values` input) to speed up sequential decoding.
    image_hidden_states (`torch.FloatTensor`, *optional*):
        A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
        image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
    Nimage_hidden_states)	__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__ r   r   j/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/vipllava/modeling_vipllava.pyr   &   s   
r   zT
    Base class for VipLlava causal language model (or autoregressive) outputs.
    c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeej  ed< dZeeej  ed< dZeej ed< dS )	VipLlavaCausalLMOutputWithPasta]  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Language modeling loss (for next-token prediction).
    logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
        Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
    past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
        `(batch_size, num_heads, sequence_length, embed_size_per_head)`)

        Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
        `past_key_values` input) to speed up sequential decoding.
    image_hidden_states (`torch.FloatTensor`, *optional*):
        A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
        image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
    Nlosslogitspast_key_valueshidden_states
attentionsr   )r   r   r   r   r   r   r   r   r   r    r!   listr"   tupler#   r   r   r   r   r   r   <   s   
r   c                       s*   e Zd Zed fddZdd Z  ZS )VipLlavaMultiModalProjectorconfigc                    s   t    t|jtrdnt|j}tj||jj	 |j
d| _tj||jj	 |jj	dd| _t|j | _tj|jj	|jj	dd| _d S )Nr   )epsTZbias)super__init__
isinstancevision_feature_layersintlenr   Z	LayerNormvision_confighidden_sizeZprojector_layernorm_epsprojector_layernormLineartext_configlinear_1r   Zprojector_hidden_actactlinear_2)selfr(   Znum_feature_layers	__class__r   r   r,   \   s    

z$VipLlavaMultiModalProjector.__init__c                 C   s,   |  |}| |}| |}| |}|S N)r3   r6   r7   r8   )r9   r"   r   r   r   forwardk   s
    



z#VipLlavaMultiModalProjector.forward)r   r   r   r   r,   r=   __classcell__r   r   r:   r   r&   [   s   r&   c                   @   s6   e Zd ZU eed< dZdZdZdZdZ	dZ
dZdZdS )VipLlavaPreTrainedModelr(    Tr!   N)r   r   r   r   r   Zbase_model_prefixZsupports_gradient_checkpointingZ_skip_keys_device_placementZ_supports_flash_attnZ_supports_sdpaZ_can_compile_fullgraphZ_supports_flex_attnZ_supports_attention_backendr   r   r   r   r?   s   s   
r?   zx
    The VipLlava model which consists of a vision backbone and a language model, without a language modeling head.
    c                       s   e Zd ZddiZed fddZdd Zdd	 Zd
d Zdd Z	de
jeeeee f  dddZe
je
je
jdddZede
je
jee
j ee
j ee ee
j eeeee f  ee ee ee ee ee
j eeef dddZ  ZS )VipLlavaModelzlanguage_model.modellanguage_modelr'   c                    s>   t  | t|j| _t|| _t|j| _	| 
  d S r<   )r+   r,   r   from_configr1   vision_towerr&   multi_modal_projectorr5   rB   	post_initr9   r(   r:   r   r   r,      s
    
zVipLlavaModel.__init__c                 C   s
   | j  S r<   )rB   get_input_embeddingsr9   r   r   r   rH      s    z"VipLlavaModel.get_input_embeddingsc                 C   s   | j | d S r<   )rB   set_input_embeddingsr9   valuer   r   r   rJ      s    z"VipLlavaModel.set_input_embeddingsc                 C   s
   || _ d S r<   rB   r9   decoderr   r   r   set_decoder   s    zVipLlavaModel.set_decoderc                 C   s   | j S r<   rM   rI   r   r   r   get_decoder   s    zVipLlavaModel.get_decoderNpixel_valuesr.   c                    sv   |dur|n| j j}| j|dd t|trH j| ddddf }n  fdd|D }tj|dd}| |}|S )	aW  
        Obtains image last hidden states from the vision tower and apply multimodal projection.

        Args:
            pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
               The tensors corresponding to the input images.
            vision_feature_layers (`Union[int, list[int]]`):
                The vision feature layer, or the list of indexes of the layers to select
                the vision feature.
        Returns:
            image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
        NT)output_hidden_statesr   c                    s&   g | ]} j | d d dd f qS )Nr   )r"   ).0indexZimage_outputsr   r   
<listcomp>       z4VipLlavaModel.get_image_features.<locals>.<listcomp>)dim)	r(   r.   rD   r-   r/   r"   r   catrE   )r9   rS   r.   image_featuresr   rW   r   get_image_features   s    

z VipLlavaModel.get_image_features)	input_idsinputs_embedsr]   c                 C   s   |du r8||   tj| jjtj|jdk}|d}n|| jjk}| }|	d
||j}|jd |jd  }||  | krtd| d| |S )z
        Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
        equal to the length of multimodal features. If the lengths are different, an error is raised.
        N)dtypedevicerZ   r   r   z6Image features and image tokens do not match: tokens: z, features )rH   r   Ztensorr(   Zimage_token_idlongrb   allsumZ	unsqueezeZ	expand_astoshapeZnumel
ValueError)r9   r_   r`   r]   special_image_maskZn_image_tokensZn_image_featuresr   r   r   get_placeholder_mask   s    z"VipLlavaModel.get_placeholder_mask)r_   rS   attention_maskposition_idsr!   r`   r.   	use_cacheoutput_attentionsrT   return_dictcache_positionreturnc                 K   s  |	dur|	n| j j}	|
dur |
n| j j}
|dur4|n| j j}|durH|n| j j}|du |duA rhtd|du r||  |}|dur| j||d}||j	|j
}| j|||d}|||}| jf ||||||	|
d|d	|}t|j|j|j|j|dur|ndd}|r|S | S )z
        vision_feature_layers (`Union[int, list[int]]`, *optional*):
            The vision feature layer, or the list of indexes of the layers to select
            the vision feature.
        Nz:You must specify exactly one of input_ids or inputs_embedsrR   )r`   r]   T)	rk   rl   r!   r`   rm   rn   rT   ro   rp   )last_hidden_stater!   r"   r#   r   )r(   rn   rT   use_return_dictr.   rh   rH   r^   rf   rb   ra   rj   Zmasked_scatterrB   r   rr   r!   r"   r#   Zto_tuple)r9   r_   rS   rk   rl   r!   r`   r.   rm   rn   rT   ro   rp   	lm_kwargsr]   ri   outputsoutputr   r   r   r=      sP    
zVipLlavaModel.forward)N)NNNNNNNNNNNN)r   r   r   _checkpoint_conversion_mappingr   r,   rH   rJ   rP   rQ   r   r   r   r   r/   r$   r^   
LongTensorrj   r   Tensorr   boolr%   r   r=   r>   r   r   r:   r   rA      sP                 
rA   zV
    The VIPLLAVA model which consists of a vision backbone and a language model.
    c                       sH  e Zd ZdddddZdgZed fdd	Zd
d Zdd Ze	j
dddZdd Zdd Zd%ejeeeee f  dddZedd Zedd Zedd Zeed&ejejeej eej ee eej eeeee f  eej ee ee ee ee eej eeejf eee f d d!d"Z!d' fd#d$	Z"  Z#S )( VipLlavaForConditionalGenerationzmodel.language_modelzmodel.vision_towerzmodel.multi_modal_projectorlm_head)z^language_model.modelz^vision_towerz^multi_modal_projectorz^language_model.lm_headzlm_head.weightr'   c                    s<   t  | t|| _tj|jj|jjdd| _	| 
  d S )NFr*   )r+   r,   rA   modelr   r4   r5   r2   
vocab_sizer|   rF   rG   r:   r   r   r,   )  s    
z)VipLlavaForConditionalGeneration.__init__c                 C   s
   | j  S r<   )r}   rH   rI   r   r   r   rH   /  s    z5VipLlavaForConditionalGeneration.get_input_embeddingsc                 C   s   | j | d S r<   )r}   rJ   rK   r   r   r   rJ   2  s    z5VipLlavaForConditionalGeneration.set_input_embeddings)rq   c                 C   s   | j S r<   )r|   rI   r   r   r   get_output_embeddings5  s    z6VipLlavaForConditionalGeneration.get_output_embeddingsc                 C   s   | j | d S r<   )r}   rP   rN   r   r   r   rP   8  s    z,VipLlavaForConditionalGeneration.set_decoderc                 C   s
   | j  S r<   )r}   rQ   rI   r   r   r   rQ   ;  s    z,VipLlavaForConditionalGeneration.get_decoderNrR   c                 C   s   | j j||dS )NrR   )r}   r^   )r9   rS   r.   r   r   r   r^   >  s    z3VipLlavaForConditionalGeneration.get_image_featuresc                 C   s   | j jS r<   )r}   rB   rI   r   r   r   rB   D  s    z/VipLlavaForConditionalGeneration.language_modelc                 C   s   | j jS r<   )r}   rD   rI   r   r   r   rD   H  s    z-VipLlavaForConditionalGeneration.vision_towerc                 C   s   | j jS r<   )r}   rE   rI   r   r   r   rE   L  s    z6VipLlavaForConditionalGeneration.multi_modal_projectorr   )r_   rS   rk   rl   r!   r`   r.   labelsrm   rn   rT   ro   rp   logits_to_keeprq   c                 K   s   |
dur|
n| j j}
|dur |n| j j}|dur4|n| j j}|durH|n| j j}| jf |||||||	||
|d|d|}|d }t|trt| dn|}| 	|dd|ddf }d}|dur| j
||| j jjd}t|||j|j|j|jdS )a  
        vision_feature_layers (`Union[int, list[int]]`, *optional*):
            The vision feature layer, or the list of indexes of the layers to select
            the vision feature.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Example:

        ```python
        >>> import torch
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, VipLlavaForConditionalGeneration

        >>> model = VipLlavaForConditionalGeneration.from_pretrained("llava-hf/vip-llava-7b-hf", device_map="auto", dtype=torch.float16)
        >>> processor = AutoProcessor.from_pretrained("llava-hf/vip-llava-7b-hf")

        >>> prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.###Human: <image>\n{}###Assistant:"
        >>> question = "Can you please describe this image?"
        >>> prompt = prompt.format(question)
        >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/compel-neg.png"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(text=text, images=image, return_tensors="pt").to(0, torch.float16)

        >>> # Generate
        >>> generate_ids = model.generate(**inputs, max_new_tokens=20)
        >>> processor.decode(generate_ids[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
        The image features a brown and white cat sitting on a green surface, with a red ball in its
        ```NT)r_   rS   rk   rl   r!   r`   rm   r.   rn   rT   ro   rp   r   )r    r   r~   )r   r    r!   r"   r#   r   )r(   rn   rT   rs   r.   r}   r-   r/   slicer|   Zloss_functionr5   r~   r   r!   r"   r#   r   )r9   r_   rS   rk   rl   r!   r`   r.   r   rm   rn   rT   ro   rp   r   rt   ru   r"   Zslice_indicesr    r   r   r   r   r=   P  sH    6z(VipLlavaForConditionalGeneration.forwardc           
         s8   t  j|f|||||d|}	|d dkr4||	d< |	S )N)r!   r`   rk   rp   r   r   rS   )r+   prepare_inputs_for_generation)
r9   r_   r!   r`   rS   rk   rp   r   kwargsZmodel_inputsr:   r   r   r     s    
z>VipLlavaForConditionalGeneration.prepare_inputs_for_generation)N)NNNNNNNNNNNNNr   )NNNNNN)$r   r   r   rw   Z_tied_weights_keysr   r,   rH   rJ   r   Moduler   rP   rQ   r   r   r   r   r/   r$   r^   propertyrB   rD   rE   r   r   rx   ry   r   rz   r%   r   r=   r   r>   r   r   r:   r   r{     s|    


              
b      r{   )rA   r{   r?   )!dataclassesr   typingr   r   r   r   Zactivationsr   Zcache_utilsr   Z
generationr	   Zmodeling_outputsr
   r   Zmodeling_utilsr   utilsr   r   autor   Zconfiguration_vipllavar   r   r   r   r&   r?   rA   r{   __all__r   r   r   r   <module>   sF     1