a
    h~                  
   @   sP  d dl Z d dlmZ d dlmZmZmZ d dlZd dlmZ ddl	m
Z
 ddlmZ ddlmZmZmZmZ dd	lmZ dd
lmZ ddlmZmZmZ ddlmZmZ ddlmZ ddlm Z m!Z!m"Z"m#Z#m$Z$m%Z% ddl&m'Z' ddl(m)Z)m*Z*m+Z+ e# rd dl,m  m-Z. e$/e0Z1e!G dd deZ2ee!ddG dd deZ3ee!ddG dd deZ4ee!ddG dd deZ5G dd  d ej6Z7ej8e9ej8d!d"d#Z:dQej6ej8ej8ej8eej8 e;e;ee  d%d&d'Z<G d(d) d)ej6Z=G d*d+ d+ej6Z>G d,d- d-eZ?G d.d/ d/ej6Z@e!G d0d1 d1e2ZAG d2d3 d3ej6ZBG d4d5 d5ej6ZCG d6d7 d7ej6ZDG d8d9 d9ej6ZEG d:d; d;ej6ZFG d<d= d=ej6ZGG d>d? d?ej6ZHG d@dA dAej6ZIG dBdC dCej6ZJe!dDdG dEdF dFe2ZKG dGdH dHej6ZLG dIdJ dJej6ZMe!dKdG dLdM dMe2ZNG dNdO dOe2eZOg dPZPdS )R    N)	dataclass)CallableOptionalUnion)nn   )ACT2FN)Cache)%ClassifierFreeGuidanceLogitsProcessorGenerationMixinGenerationModeLogitsProcessorList)GenerateDecoderOnlyOutput)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPoolingModelOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)TransformersKwargsauto_docstringcan_return_tupleis_torch_availablelogging	torch_int   )	AutoModel   )JanusConfigJanusVisionConfigJanusVQVAEConfigc                   @   s>   e Zd ZU eed< dZdZddgZddgZdZ	dZ
dZdZd	S )
JanusPreTrainedModelconfigmodelTZLlamaDecoderLayerJanusVisionEncoderLayerpast_key_valuescausal_maskFN)__name__
__module____qualname__r   __annotations__Zbase_model_prefixZsupports_gradient_checkpointing_no_split_modulesZ_skip_keys_device_placementZ_supports_flash_attnZ_supports_sdpa_can_compile_fullgraphZ!_supports_param_buffer_assignment r.   r.   d/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/janus/modeling_janus.pyr"   8   s   
r"   z9
    Base class for Janus VQ-VAE mode model outputs.
    )Zcustom_introc                   @   s2   e Zd ZU dZdZeej ed< dZ	ejed< dS )JanusVQVAEOutputz
    decoded_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
        Reconstructed pixel values after encoding and decoding the input.
    embedding_loss (`torch.FloatTensor`):
        Embedding loss.
    Ndecoded_pixel_valuesembedding_loss)
r(   r)   r*   __doc__r1   r   torchFloatTensorr+   r2   r.   r.   r.   r/   r0   F   s   
r0   zy
    Base class for Janus model's outputs that may also contain a past key/values (to speed up sequential decoding).
    c                   @   s   e Zd ZU dZdZeej ed< dZ	ee
e
ej   ed< dZee
ej  ed< dZee
ej  ed< dZee
ej  ed< dS )JanusBaseModelOutputWithPastal  
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
        Sequence of hidden-states at the output of the last layer of the model.

        If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
        hidden_size)` is output.
    past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
        `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
        `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
        encoder_sequence_length, embed_size_per_head)`.

        Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
        `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
        input) to speed up sequential decoding.
    image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
        Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
        sequence_length, hidden_size)`.

        image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
    Nlast_hidden_stater&   hidden_states
attentionsimage_hidden_states)r(   r)   r*   r3   r7   r   r4   r5   r+   r&   tupler8   r9   r:   r.   r.   r.   r/   r6   X   s   
r6   zQ
    Base class for Janus causal language model (or autoregressive) outputs.
    c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeej  ed< dZeeej  ed< dZeeej  ed< dS )	JanusCausalLMOutputWithPasta  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Language modeling loss (for next-token prediction).
    logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
        Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
    past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
        `(batch_size, num_heads, sequence_length, embed_size_per_head)`)

        Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
        `past_key_values` input) to speed up sequential decoding.
    image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
        Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
        sequence_length, hidden_size)`.

        image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
    Nlosslogitsr&   r8   r9   r:   )r(   r)   r*   r3   r=   r   r4   r5   r+   r>   r&   listr8   r;   r9   r:   r.   r.   r.   r/   r<   |   s   
r<   c                       sR   e Zd Zed fddZejeeejdddZdeje	ejdd	d
Z
  ZS )JanusVisionEmbeddingsr#   c                    s   t    || _|j| _|j| _|j| _tj|j	| j| j| jdd| _
| j| j d | _| j| _t| j| j| _| jdt| jddd d S )NZvalid)in_channelsout_channelskernel_sizestridepaddingr   position_ids)r   F)
persistent)super__init__r#   hidden_size	embed_dimZ
image_size
patch_sizer   Conv2dnum_channelspatch_embeddingnum_patchesnum_positions	Embeddingposition_embeddingZregister_bufferr4   Zarangeexpandselfr#   	__class__r.   r/   rK      s     
zJanusVisionEmbeddings.__init__)
embeddingsheightwidthreturnc                 C   s   |j d }| jjj d }tj s>||kr>||kr>| | jS | jjd}|j d }|| j }|| j }	t	|d }
|
d|
|
|}|dddd}tjj|||	fddd	}|dddddd|}|S )
a  
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing and no class embeddings.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        r   r   rH   g      ?r   r   ZbicubicF)sizemodeZalign_corners)shaperU   weightr4   Zjit
is_tracingrG   	unsqueezerN   r   reshapepermuter   
functionalinterpolateview)rX   r[   r\   r]   rR   rS   Zpatch_pos_embeddimZ
new_heightZ	new_widthZsqrt_num_positionsr.   r.   r/   interpolate_pos_encoding   s&    




z.JanusVisionEmbeddings.interpolate_pos_encodingF)pixel_valuesrk   r^   c           
      C   sh   |j \}}}}| jjj}| |j|d}|ddd}|rP| |||}	n| | j	}	||	 }|S )N)dtyper   r   )
ra   rQ   rb   rm   toflatten	transposerk   rU   rG   )
rX   rl   rk   _r\   r]   Ztarget_dtypeZpatch_embedsr[   Z
pos_embedsr.   r.   r/   forward   s    
zJanusVisionEmbeddings.forward)F)r(   r)   r*   r    rK   r4   Tensorintrk   boolrr   __classcell__r.   r.   rY   r/   r@      s   &r@   )r8   n_repr^   c                 C   s^   | j \}}}}|dkr| S | dddddddddf |||||} | ||| ||S )z
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    r   N)ra   rV   re   )r8   rw   batchZnum_key_value_headsslenhead_dimr.   r.   r/   	repeat_kv   s
    0r{           )modulequerykeyvalueattention_maskscalingdropoutkwargsc                 K   s   t || j}t || j}	t||dd| }
|d urf|d d d d d d d |jd f }|
| }
tjj|
dtj	d
|j}
tjj|
|| jd}
t|
|	}|dd }||
fS )Nr   r   rH   )rj   rm   )ptrainingr   )r{   num_key_value_groupsr4   matmulrp   ra   r   rg   softmaxZfloat32rn   rm   r   r   
contiguous)r}   r~   r   r   r   r   r   r   
key_statesvalue_statesattn_weightsr'   attn_outputr.   r.   r/   eager_attention_forward   s    
&r   c                       sF   e Zd ZdZed fddZd	ejeej e	e
 dddZ  ZS )
JanusVisionAttentionz(Attention Class for Janus Vision EncoderrA   c                    sL  t    || _|j| _|j| _| j| j | _| j| j | jkrZtd| j d| j d| jd | _	|j
| _
|j}|j}d| _d| _tj| j| j| j |jd| _tj| j| j| j |jd| _tj| j| j| j |jd| _t| j| j| _|dkrt|nt | _|r"t| jnt | _|r>t| jnt | _d S )	Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).      Fr   Zbiasr   )rJ   rK   r#   rL   rM   Znum_attention_heads	num_headsrz   
ValueErrorscaleattention_dropoutprojection_dropoutZuse_qk_norm	is_causalr   r   LinearZattention_biasq_projk_projv_projprojection_layerDropoutZIdentity	LayerNormq_normk_norm)rX   r#   Zproj_dropoutZqk_normrY   r.   r/   rK     s0    

zJanusVisionAttention.__init__N)r8   r   r   c                 K   s4  |  \}}}| |}| |}| |}	|d| j| j}| |}|d| j| j}| |}|||| j| j	dd}|||| j| j	dd}|	
||| j| j	dd}	t}
| jjdkrt| jj }
|
| |||	|f| jsdn| j| j| jd|\}}|||| j}| |}| |}||fS )NrH   r   r   eagerr|   )r   r   r   )r_   r   r   r   re   r   rz   r   r   rp   ri   r   r#   Z_attn_implementationr   r   r   r   r   rM   r   r   )rX   r8   r   r   
batch_sizeseq_lenrq   query_statesr   r   Zattention_interfacer   r   outputr.   r.   r/   rr   .  s>    




	


zJanusVisionAttention.forward)N)r(   r)   r*   r3   r    rK   r4   rs   r   r   r   rr   rv   r.   r.   rY   r/   r     s     r   c                       s6   e Zd Zed fddZejejdddZ  ZS )JanusVisionMLPrA   c                    sr   t    || _t|j|j | _t|j | _	t
|j| j| _t
| j|j| _t
|j| _t
|j| _d S N)rJ   rK   r#   rt   rL   Z	mlp_ratioZintermediate_sizer   
hidden_actactivation_fnr   r   fc1fc2r   Zhidden_dropout_ratedropout1dropout2rW   rY   r.   r/   rK   [  s    
zJanusVisionMLP.__init__r8   r^   c                 C   s6   |  |}| |}| |}| |}| |}|S r   )r   r   r   r   r   rX   r8   r.   r.   r/   rr   e  s    




zJanusVisionMLP.forward)	r(   r)   r*   r    rK   r4   rs   rr   rv   r.   r.   rY   r/   r   Z  s   
r   c                       sF   e Zd Zed fddZdejejee e	ej
 dddZ  ZS )	r%   rA   c                    sX   t    |j| _tj| j|jd| _t|| _	tj| j|jd| _
t|| _|| _d S N)eps)rJ   rK   rL   rM   r   r   layer_norm_epslayer_norm1r   	self_attnlayer_norm2r   mlpr#   rW   rY   r.   r/   rK   o  s    


z JanusVisionEncoderLayer.__init__F)r8   r   output_attentionsr^   c                 C   sb   |}|  |}| j|||d\}}|| }|}| |}| |}|| }|f}|r^||f7 }|S )a=  
        Args:
            hidden_states (`torch.FloatTensor`):
                Input to the layer of shape `(batch, seq_len, embed_dim)`.
            attention_mask (`torch.FloatTensor`):
                Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
            output_attentions (`bool`, *optional*, defaults to `False`):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )r8   r   r   )r   r   r   r   )rX   r8   r   r   residualr   outputsr.   r.   r/   rr   x  s     




zJanusVisionEncoderLayer.forward)F)r(   r)   r*   r    rK   r4   rs   r   ru   r;   r5   rr   rv   r.   r.   rY   r/   r%   n  s    r%   c                       sN   e Zd ZdZed fddZed	eej	 ee
 ee
 edddZ  ZS )
JanusVisionEncoderz
    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
    [`JanusVisionEncoderLayer`].

    Args:
        config: JanusVisionConfig
    rA   c                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r.   )r%   .0rq   rA   r.   r/   
<listcomp>      z/JanusVisionEncoder.__init__.<locals>.<listcomp>F)	rJ   rK   r#   r   
ModuleListrangenum_hidden_layerslayersgradient_checkpointingrW   rY   rA   r/   rK     s    
 zJanusVisionEncoder.__init__N)r   r   output_hidden_statesr^   c           
      C   s   |dur|n| j j}|dur |n| j j}|r0dnd}|r<dnd}|}| jD ]:}|r\||f }||||d}	|	d }|rJ||	d f }qJ|r||f }t|||dS )ad  
        Args:
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        Nr.   )r   r   r   )r7   r8   r9   )r#   r   r   r   r   )
rX   inputs_embedsr   r   r   Zencoder_statesZall_attentionsr8   Zencoder_layerZlayer_outputsr.   r.   r/   rr     s0    


zJanusVisionEncoder.forward)NNN)r(   r)   r*   r3   r    rK   r   r   r4   rs   ru   r   rr   rv   r.   r.   rY   r/   r     s      r   c                
       sp   e Zd ZU dZeed< ed fddZedee	j
 ee ee ee eeeef dd	d
Zdd Z  ZS )JanusVisionModelrl   r#   rA   c                    sJ   t  | || _|j}t|| _t|| _tj	||j
d| _|   d S r   )rJ   rK   r#   rL   r@   r[   r   encoderr   r   r   post_layernorm	post_init)rX   r#   rM   rY   r.   r/   rK     s    

zJanusVisionModel.__init__NF)rl   r   r   return_dictrk   r^   c           
      C   s   |d ur|n| j j}|d ur |n| j j}|d ur4|n| j j}|d u rLtd| j||d}| j||||d}|d }| |}|d d dd d f }	| |	}	|s||	f|dd   S t||	|j	|j
dS )Nz You have to specify pixel_values)rk   )r   r   r   r   r   r   )r7   Zpooler_outputr8   r9   )r#   r   r   Zuse_return_dictr   r[   r   r   r   r8   r9   )
rX   rl   r   r   r   rk   r8   Zencoder_outputsr7   Zpooled_outputr.   r.   r/   rr     s2    	

zJanusVisionModel.forwardc                 C   s   | j S r   )r[   rX   r.   r.   r/   get_input_embeddings*  s    z%JanusVisionModel.get_input_embeddings)NNNNF)r(   r)   r*   main_input_namer    r+   rK   r   r   r4   r5   ru   r   r;   r   rr   r   rv   r.   r.   rY   r/   r     s$   
     
*r   c                       s*   e Zd Zed fddZdd Z  ZS )JanusVisionAlignerMLPrA   c                    sN   t    t j j| _t fddtd j	D | _
t j | _d S )Nc                    s   g | ]}t  j jqS r.   r   r   projection_dimr   rA   r.   r/   r   4  r   z2JanusVisionAlignerMLP.__init__.<locals>.<listcomp>r   )rJ   rK   r   r   rL   r   r   r   r   depthhidden_layersr   r   r   rW   rY   rA   r/   rK   /  s    
zJanusVisionAlignerMLP.__init__c                 C   s,   |  |}| jD ]}| |}||}q|S r   r   r   r   rX   r8   layerr.   r.   r/   rr   8  s
    



zJanusVisionAlignerMLP.forward)r(   r)   r*   r    rK   rr   rv   r.   r.   rY   r/   r   .  s   	r   c                       sJ   e Zd ZdZed fddZejdddZej	ej
dd	d
Z  ZS )JanusVQVAEVectorQuantizera  
    A module for vector quantization using learned embedding vectors.

    This module implements the quantization process similar to te one described in
    the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous
    input vectors into discrete codebook vectors, which are learned during training.
    Current implementation improves over previous ones by avoiding costly matrix multiplications
    and allowing for post-hoc remapping of indices.
    rA   c                    sL   t    |j| _|j| _t|dd| _t| j| j| _	|j
gd | _d S )Nbetag      ?r   )rJ   rK   num_embeddingsrM   embedding_dimgetattrr   r   rT   	embeddingrR   quant_state_dimsrW   rY   r.   r/   rK   K  s    
z"JanusVQVAEVectorQuantizer.__init__)hidden_statec              
   C   s   | dddd }|d| j}tj|d dddtj| jjd dd dtd	|| jj	dd  }tj
|dd}| ||j}t| | d | jt||  d   }|||   }| dddd }|||fS )
Nr   r   r   r   rH   T)rj   Zkeepdimrj   z	bd,dn->bn)rf   r   ri   r   r4   sumr   rb   Zeinsumrp   Zargminra   meandetachr   )rX   r   Zhidden_state_flattenedZ	distancesZmin_encoding_indiceshidden_state_quantr=   r.   r.   r/   rr   T  s     z!JanusVQVAEVectorQuantizer.forwardimage_tokensr^   c                 C   sb   |j d }| jjj d }| |}tj|ddd}||g| j|R }|dddd }|S )Nr   rH   r   )r   rj   r   r   )	ra   r   rb   F	normalizeri   r   rf   r   )rX   r   r   Zemb_dimr   r.   r.   r/   get_codebook_entryo  s    

z,JanusVQVAEVectorQuantizer.get_codebook_entry)r(   r)   r*   r3   r!   rK   r4   rs   rr   
LongTensorr5   r   rv   r.   r.   rY   r/   r   @  s   
	r   c                       s&   e Zd Zd fdd	Zdd Z  ZS )JanusVQVAEResnetBlockNFc                    s   t    || _|d u r|n|| _|| _tjjd|ddd| _tjj	||dddd| _
tjjd|ddd| _tj|j| _tjj	||dddd| _| j| jkr| jrtjj	||dddd| _ntjj	||dddd| _d S )	N    ư>TZ
num_groupsrP   r   Zaffiner   r   rD   rE   rF   r   )rJ   rK   rB   rC   use_conv_shortcutr4   r   	GroupNormnorm1rO   conv1norm2r   r   conv2conv_shortcutnin_shortcut)rX   r#   rB   rC   r   rY   r.   r/   rK     s    
zJanusVQVAEResnetBlock.__init__c                 C   s   |}|  |}|t|9 }| |}| |}|t|9 }| |}| |}| j| jkrz| j	rp| 
|}n
| |}|| S r   )r   r4   sigmoidr   r   r   r   rB   rC   r   r   r   )rX   r8   r   r.   r.   r/   rr     s    





zJanusVQVAEResnetBlock.forward)NFr(   r)   r*   rK   rr   rv   r.   r.   rY   r/   r     s     r   c                       s$   e Zd Z fddZdd Z  ZS )JanusVQVAEAttnBlockc                    s   t    || _tjjd|ddd| _tjj||dddd| _tjj||dddd| _	tjj||dddd| _
tjj||dddd| _d S )Nr   r   Tr   r   r   r   )rJ   rK   rB   r4   r   r   normrO   qkvproj_outrX   rB   rY   r.   r/   rK     s    
zJanusVQVAEAttnBlock.__init__c                 C   s   |}|  |}| |}| |}| |}|j\}}}}	|||||	 ddd}|||||	 }t||}
|
t	|d  }
t
j|
dd}
|||||	 }|
ddd}
t||
||||	}| |}|| S )Nr   r   r   r   r   )r   r   r   r   ra   re   rf   r4   Zbmmrt   r   r   r   )rX   r8   r   r   r   r   r   channelsr\   r]   r   r   r.   r.   r/   rr     s     




zJanusVQVAEAttnBlock.forwardr   r.   r.   rY   r/   r     s   
r   c                       s$   e Zd Z fddZdd Z  ZS )JanusVQVAEConvDownsamplec                    s$   t    tj||dddd| _d S )Nr   r   r   r   )rJ   rK   r   rO   convr   rY   r.   r/   rK     s    
z!JanusVQVAEConvDownsample.__init__c                 C   s    t j|dddd}| |}|S )N)r   r   r   r   Zconstantr   )padr`   r   )r   r  r  r   r.   r.   r/   rr     s    
z JanusVQVAEConvDownsample.forwardr   r.   r.   rY   r/   r    s   r  c                       s$   e Zd Z fddZdd Z  ZS )JanusVQVAEConvUpsamplec                    s&   t    tjj||dddd| _d S )Nr   r   r   )rJ   rK   r4   r   rO   r  r   rY   r.   r/   rK     s    
zJanusVQVAEConvUpsample.__init__c                 C   s   t j|ddd}| |}|S )Ng       @Znearest)Zscale_factorr`   )r   rh   r  r   r.   r.   r/   rr     s    
zJanusVQVAEConvUpsample.forwardr   r.   r.   rY   r/   r    s   r  c                       s8   e Zd Zeed fddZejejdddZ  Z	S )JanusVQVAEMidBlock)r#   r  c                    s8   t    t|||d| _t|| _t|||d| _d S )Nr#   rB   rC   )rJ   rK   r   block_1r   attn_1block_2)rX   r#   r  rY   r.   r/   rK     s    

zJanusVQVAEMidBlock.__init__r   c                 C   s"   |  |}| |}| |}|S r   )r  r	  r
  r   r.   r.   r/   rr     s    


zJanusVQVAEMidBlock.forward)
r(   r)   r*   r!   rt   rK   r4   rs   rr   rv   r.   r.   rY   r/   r    s   r  c                       s,   e Zd Z fddZejdddZ  ZS )JanusVQVAEEncoderc              	      sr  t    t|j| _|j| _|j}|j}|j}|j	}|j}t
jj||dddd| _dt| }|| _t | _t| jD ]}t }	t }
|||  }|||  }t| jD ]8}|	t|||d |}|| jd kr|
t| qt }|	|_|
|_|| jd krt||_| j| qzt||| _t
jjd|ddd	| _t
jj||r^d
| n|dddd| _d S )Nr   r   r   )r   r  r   r   Tr   r   ) rJ   rK   lenchannel_multipliernum_resolutionsnum_res_blocksbase_channelsrB   double_latentlatent_channelsr4   r   rO   conv_inr;   in_channel_multiplierr   downr   appendr   r   Moduleblockattnr  
downsampler  midr   norm_outconv_out)rX   r#   r  rB   r  r  r  r  i_levelr  r  block_in	block_outi_blockr  rY   r.   r/   rK     sV    


zJanusVQVAEEncoder.__init__rl   c                 C   s   |  |g}t| jD ]}t| jD ]N}| j| j| |d }t| j| jdkrh| j| j| |}|| q$|| jd kr|| j| 	|d  q|d }| 
|}| |}|t|9 }| |}|S )NrH   r   r   )r  r   r  r  r  r  r  r  r  r  r  r  r4   r   r  )rX   rl   r8   r  r!  r   r7   r.   r.   r/   rr   /  s"    


zJanusVQVAEEncoder.forward)r(   r)   r*   rK   r4   r   rr   rv   r.   r.   rY   r/   r    s   3r  c                       s0   e Zd Z fddZejejdddZ  ZS )JanusVQVAEDecoderc              	      sR  t    t|j| _|j| _|j}|j}|j}||j| jd   }t	j
j||dddd| _t||| _t
 | _tt| jD ]}t
 }t
 }||j|  }	t| jd D ]8}
|t|||	d |	}|| jd kr|t| qt
 }||_||_|dkrt||_| j| qt	j
jd|ddd	| _t	j
j||dddd| _d S )
Nr   r   r   r  r   r   r   Tr   )rJ   rK   r  r  r  r  r  r  rC   r4   r   rO   r  r  r  r   upreversedr   r  r   r   r  r  r  r  upsampler   r  r  )rX   r#   r  r  rC   r  r  r  r  r   r!  r$  rY   r.   r/   rK   I  sB    



zJanusVQVAEDecoder.__init__)r   r^   c                 C   s   |  |}| |}t| jD ]r}t| jd D ]@}| j| j| |}t| j| jdkr0| j| j| |}q0|| jd kr| j| 	|}q| 
|}|t|9 }| |}|S )Nr   r   )r  r  r   r  r  r$  r  r  r  r&  r  r4   r   r  )rX   r   r  r!  r.   r.   r/   rr   w  s    



zJanusVQVAEDecoder.forward)r(   r)   r*   rK   r4   r5   rr   rv   r.   r.   rY   r/   r#  H  s   .r#  aS  
    The VQ-VAE model used in Janus for encoding/decoding images into discrete tokens.
    This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
    [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv
    Taigman](https://huggingface.co/papers/2203.13131).
    c                       s   e Zd ZU eed< g dZdZed fddZej	ddd	Z
ej	ejd
ddZeeejeejejf dddZ  ZS )
JanusVQVAEr#   )r   r   r   rl   rA   c                    sp   t  | t|| _t|| _tj|j	|j
d| _tj|j
|j	d| _|   t|| _d| _|   d S )Nr   F)rJ   rK   r  r   r   quantizer4   r   rO   r  rM   
quant_convpost_quant_convevalr#  decoderr   r   rW   rY   r.   r/   rK     s    


zJanusVQVAE.__init__r"  c                 C   s.   |  |}| |}| |\}}}|||fS r   )r   r)  r(  )rX   rl   r8   quantZemb_lossindicesr.   r.   r/   encode  s    

zJanusVQVAE.encoder   c                 C   sr   |j d | jjd | jjd  krNtd| jjd | jjd   d|j  d| j|}| |}| |}|S )aG  
        Decodes quantized token IDs into pixel values.
        Args:
            image_tokens (torch.LongTensor): Batch of token IDs.
        Returns:
            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
                Pixel values decoded from the token IDs.
        r   r   z4Expected `image_tokens` to have shape `(batch_size, z)`, but got shape `z`.)ra   r(  r   r   r   r*  r,  )rX   r   Zcodebook_entryr8   rl   r.   r.   r/   decode  s    	"

zJanusVQVAE.decode)rl   r^   c                 C   s6   |j d }| |\}}}| ||d}t||S )Nr   rH   )ra   r/  r0  ri   r0   )rX   rl   r   r-  r2   r.  r1   r.   r.   r/   rr     s    
zJanusVQVAE.forward)r(   r)   r*   r!   r+   r,   r   rK   r4   r   r/  r5   r0  r   r   r;   rr   rv   r.   r.   rY   r/   r'    s   
	r'  c                       s*   e Zd Zed fddZdd Z  ZS )JanusVQVAEAlignerMLPrA   c                    sN   t    t j j| _t fddtd j	D | _
t j | _d S )Nc                    s   g | ]}t  j jqS r.   r   r   rA   r.   r/   r     r   z1JanusVQVAEAlignerMLP.__init__.<locals>.<listcomp>r   )rJ   rK   r   r   rM   r   r   r   r   r   r   r   r   r   rW   rY   rA   r/   rK     s    
zJanusVQVAEAlignerMLP.__init__c                 C   s,   |  |}| jD ]}| |}||}q|S r   r   r   r.   r.   r/   rr     s
    



zJanusVQVAEAlignerMLP.forward)r(   r)   r*   r!   rK   rr   rv   r.   r.   rY   r/   r1    s   	r1  c                       s:   e Zd ZdZed fddZejejdddZ	  Z
S )JanusVQVAEHeadzOHead used for sampling tokens in image generation, replacing the usual lm head.rA   c                    s>   t    t|j|j| _t|j | _	t|j|j
| _d S r   )rJ   rK   r   r   Zimage_token_embed_dimr   r   r   r   r   r   vision_headrW   rY   r.   r/   rK     s    
zJanusVQVAEHead.__init__r   c                 C   s"   |  |}| |}| |}|S r   )r   r   r3  r   r.   r.   r/   rr     s    


zJanusVQVAEHead.forward)r(   r)   r*   r3   r!   rK   r4   rs   tensorrr   rv   r.   r.   rY   r/   r2    s   r2  zl
    The Janus model which consists of a siglip vision backbone, a Llama language model and a VQ model.
    c                       s   e Zd Zed fddZdd Zdd Zdd	 Zej	ej
ej
d
ddZeedej	ej
eej eej	 ee eej	 eej
 ee eeejf d	ddZ  ZS )
JanusModelrA   c                    s   t  | || _t|j| _t| jj| _t	|j
| _t| jjj| jjj| _t| jj| _t| jj| _tj|jd| _d| _|   d S )NrA   F)rJ   rK   r#   r   _from_configZvision_configvision_modelr   alignerr'  Z	vq_configvqmodelr   rT   r   rM   generation_embeddingsr1  generation_alignerr2  generation_headr   from_configtext_configlanguage_modelr   r   rW   rY   r.   r/   rK     s    zJanusModel.__init__c                 C   s
   | j  S r   )r?  r   r   r.   r.   r/   r     s    zJanusModel.get_input_embeddingsc                 C   s   | j | d S r   )r?  set_input_embeddingsrX   r   r.   r.   r/   r@    s    zJanusModel.set_input_embeddingsc                 C   s   |  |}| |j}|S r   )r7  r8  r7   )rX   rl   image_embedsr.   r.   r/   get_image_features  s    
zJanusModel.get_image_features)	input_idsr   image_featuresc                 C   s   |du r8||   tj| jjtj|jdk}|d}n|| jjk}| }|	d
||j}||  | kr|jd |jd  }td| d| |S )z
        Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
        equal to the length of multimodal features. If the lengths are different, an error is raised.
        Nrm   devicerH   r   r   z6Image features and image tokens do not match: tokens: z, features )r   r4   r4  r#   Zimage_token_idlongrG  allr   rd   Z	expand_asrn   Znumelra   r   )rX   rD  r   rE  Zspecial_image_maskZn_image_tokensZn_image_featuresr.   r.   r/   get_placeholder_mask  s    zJanusModel.get_placeholder_maskNr   )	rD  rl   r   rG   r&   cache_positionr   	use_cachelogits_to_keepc
              
   K   s   |d u |d uA rt d|d u r,|  |}|d ur|| |}|d|jd }||j|j}| j|||d}|	||}| j
f |||||||	d|
}t|j|j|j|j|d ur|nd dS )NzaYou cannot specify both input_ids and inputs_embeds at the same time, and must specify either onerH   )r   rE  )r   r   rG   r&   rL  rK  rM  )r7   r&   r8   r9   r:   )r   r   rC  re   ra   rn   rG  rm   rJ  Zmasked_scatterr?  r6   r7   r&   r8   r9   )rX   rD  rl   r   rG   r&   rK  r   rL  rM  r   rB  rE  Zimage_attention_maskZ	lm_outputr.   r.   r/   rr   1  s@    
zJanusModel.forward)	NNNNNNNNr   )r(   r)   r*   r   rK   r   r@  rC  r4   r   r5   rJ  r   r   r   rs   r	   ru   r   rt   rr   rv   r.   r.   rY   r/   r5    s8            r5  c                       s   e Zd ZddgZdZed fddZdd Zd	d
 Ze	j
e	j
dddZeede	je	jee	j
 ee	j ee ee	j ee	j ee	j ee eee	j
f ee dddZd fdd	Ze	j
dddZe	jde	j
ee	j ee d fddZ  ZS )JanusForConditionalGenerationz(model.language_model.embed_tokens.weightzlm_head.weightTrA   c                    sB   t  | || _t|| _tj|jj|jj	dd| _
|   d S )NFr   )rJ   rK   r#   r5  r$   r   r   r>  rL   
vocab_sizelm_headr   rW   rY   r.   r/   rK   h  s
    
z&JanusForConditionalGeneration.__init__c                 C   s   | j j S r   )r$   r?  r   r   r.   r.   r/   r   q  s    z2JanusForConditionalGeneration.get_input_embeddingsc                 C   s   | j j| d S r   )r$   r?  r@  rA  r.   r.   r/   r@  t  s    z2JanusForConditionalGeneration.set_input_embeddings)inputsr^   c                 C   s   | j |}| j |}|S r   )r$   r:  r;  )rX   rQ  r   r.   r.   r/   'prepare_embeddings_for_image_generationw  s    zEJanusForConditionalGeneration.prepare_embeddings_for_image_generationNr   )rD  rl   r   rG   r&   rK  r   labelsrL  rM  r   c                 K   s   | j f |||||||	|d|}|j}t|
tr>t|
 dn|
}| |dd|ddf }d}|dur| jf ||| jjj	d|}t
|||j|j|j|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
        )rD  rl   r   rG   r&   r   rL  rK  N)r>   rS  rO  )r=   r>   r&   r8   r9   r:   )r$   r7   
isinstancert   slicerP  Zloss_functionr#   r>  rO  r<   r&   r8   r9   r:   )rX   rD  rl   r   rG   r&   rK  r   rS  rL  rM  r   r   r8   Zslice_indicesr>   r=   r.   r.   r/   rr   |  s<    	z%JanusForConditionalGeneration.forwardc           
         s8   t  j|f|||||d|}	|d dkr4||	d< |	S )N)r&   r   r   rK  rM  r   rl   )rJ   prepare_inputs_for_generation)
rX   rD  rl   r&   r   r   rK  rM  r   model_inputsrY   r.   r/   rV    s    z;JanusForConditionalGeneration.prepare_inputs_for_generation)r   c                 C   s"   | j j|}|dddd}|S )a,  
        Decodes generated image tokens from language model to continuous pixel values
        with VQGAN module via upsampling.
        Args:
            image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`):
                The tensors corresponding to the input images.
        r   r   r   r   )r$   r9  r0  rf   )rX   r   Zdecoded_imager.   r.   r/   decode_image_tokens  s    z1JanusForConditionalGeneration.decode_image_tokens)rQ  r   logits_processorc           %         s  | d| j}t|}| dd}|dkrHt jf |||d d|S |jf i |}| tj	tj
fvrttd|  | |  |d ur|nt }d|d< |jd u rtd d	|_|j|d
< | ||j|\}}	}|j|j }
}t|jdkrtd|j d|d u}| j|||jd |jrR|jdkrR|t|j d |_| j||jd |d ||d}| jf |||jd|\}}| jjj j!}|j\}}|"dd}| dd }|"dd}||d< ||d d d f |jk||d d d f |j#d k@ }||d d d f $||j% | & |}| '|||}|(dd d u r~| j)|j*p^d|d t+|j,|| |d|d< t-j.||f|
|d}|j/}|j0}|j1}|j2}|j3}|r|rdnd }|r|rdnd }|r|rdnd }|r|rdnd }t4|D ]
}| j5f ||d|}|d 6|j|d< |d 6|j|d< | jj7f i |||d}| 8||}|j9d d dd d f : } | j;| }!|||!}"|j<rt-j=|"dd}#t-j>|#dd?d}$nt-j@|"dd}$|$|d d |f< t-A|$|$g}$|$Bd}$| C|$}q|r`|r,||!f7 }|r@|| D f7 }|rP||jE7 }|r`||jF7 }|r|tG||!||||jHdS |S d S ) Ngeneration_configgeneration_modetext)rQ  r   rZ  guidance_scalezGot incompatible mode for Image Generation, should be one of greedy or sampling. Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`.TrL  zU`guidance_scale` is required for CFG but not provided. Setting to default value of 5.   r]  r   z;Expected input ids of shape (batch_size, seq_len), but got z3Passing `inputs embeds` is not supported currently.)rG  r   )rZ  Zinput_ids_seq_lengthZencoder_input_idsZprefix_allowed_tokens_fnrY  rG  )rD  r   Zexpand_sizer   Zboi_token_idr&   Zstatic)cache_implementationr   Zmax_cache_lenmodel_kwargsrF  r.   )r   rD  rK  )r   r   rH   r   )Znum_samples)	sequencesscoresr>   r9   r8   r&   )IpoprZ  copydeepcopyrJ   generateupdateZget_generation_moder   ZSAMPLEZGREEDY_SEARCHr   validateZ_validate_model_kwargsr   r]  loggerwarningZ_prepare_model_inputsZbos_token_idrm   rG  r  ra   Z_prepare_special_tokensr  r
   Z_get_logits_processorZ_expand_inputs_for_generationZnum_return_sequencesr$   r7  r#   num_image_tokensrepeatZgeneration_kwargsZmasked_fill_Zpad_token_idr   Z_get_initial_cache_positiongetZ
_get_cacher_  max
max_lengthr4   Zzerosr   r   output_scoresoutput_logitsreturn_dict_in_generater   rV  rn   r?  Z#_update_model_kwargs_for_generationr7   cloner<  Z	do_sampler   ZmultinomialZsqueezeZargmaxcatrd   rR  floatr9   r8   r   r&   )%rX   rQ  r   rY  r   rZ  r[  r`  rD  Zmodel_input_namerm   rG  Zkwargs_has_attention_maskrk  r   r   Zinput_tokensmaskr   Zgenerated_tokensr   r   rp  rq  rr  Z
raw_scoresZ
raw_logitsZdecoder_hidden_statesZdecoder_attentionsirW  r   r   rb  Znext_token_scoresZprobs
next_tokenrY   r.   r/   rf    s    	

















	z&JanusForConditionalGeneration.generate)
NNNNNNNNNr   )NNNNNN)NNN)r(   r)   r*   Z_tied_weights_keysr-   r   rK   r   r@  r4   rs   rR  r   r   r   r5   r   r	   ru   r   rt   r   r   rr   rV  rX  Zno_gradr   rf  rv   r.   r.   rY   r/   rN  d  s`   	          6         rN  )r"   rN  r5  r'  r   )r|   )Qrd  dataclassesr   typingr   r   r   r4   r   Zactivationsr   Zcache_utilsr	   Z
generationr
   r   r   r   Zgeneration.utilsr   Zmodeling_layersr   Zmodeling_outputsr   r   r   Zmodeling_utilsr   r   Zprocessing_utilsr   utilsr   r   r   r   r   r   autor   Zconfiguration_janusr   r    r!   Ztorch.nn.functionalrg   r   Z
get_loggerr(   ri  r"   r0   r6   r<   r  r@   rs   rt   r{   ru  r   r   r   r%   r   r   r   r   r   r   r  r  r  r  r#  r'  r1  r2  r5  rN  __all__r.   r.   r.   r/   <module>   s    
K L1P>?,#MD=l  9