a
    h                    @   s  d dl Z d dlmZ d dlmZ d dlmZmZmZ d dl	Z
d dlZd dlmZ d dlmZ ddlmZ dd	lmZ dd
lmZmZmZmZ ddlmZ ddlmZmZ ddlmZmZ ddl m!Z!m"Z"m#Z#m$Z$m%Z%m&Z&m'Z' ddl(m)Z) ddl*m+Z+m,Z, ddl-m.Z. ddl/m0Z0m1Z1m2Z2m3Z3m4Z4m5Z5 ddl6m7Z7 ddl8m9Z9 ddl:m;Z; ddl<m=Z=m>Z>m?Z?m@Z@mAZA ddlBmCZCmDZD ddlEmFZF ddlGmHZH ddlImJZJmKZKmLZL e3 rd dlZd dlMmZ d dlNm  mOZP d dlQZe4 rd dlRZRddlSmTZT ddl6mUZUmVZV e5WeXZYG dd deHZZG d d! d!e;Z[G d"d# d#eTZ\e1G d$d% d%e,Z]ee1d&d'G d(d) d)e)Z^G d*d+ d+eCZ_G d,d- d-eDZ`G d.d/ d/eLZaG d0d1 d1ejbZcG d2d3 d3ejbZdG d4d5 d5eKZeG d6d7 d7eJZfG d8d9 d9e9ZgG d:d; d;ejbZhG d<d= d=eAZiG d>d? d?e@ZjG d@dA dAe>ZkG dBdC dCe?ZlG dDdE dEejbZmG dFdG dGejbZnG dHdI dIejbZoG dJdK dKejbZpG dLdM dMe=ZqG dNdO dOejbZrG dPdQ dQejbZse1dRd'G dSdT dTe]ZtG dUdV dVe]eZuG dWdX dXeZvg dYZwdS )Z    N)Iterable)	dataclass)CallableOptionalUnion)nn)BlipImageProcessor   )ACT2FN)Cache)%ClassifierFreeGuidanceLogitsProcessorGenerationMixinGenerationModeLogitsProcessorList)GenerateDecoderOnlyOutput)BatchFeatureget_size_dict)resizeto_channel_dimension_format)ChannelDimension
ImageInputPILImageResamplingget_image_sizeinfer_channel_dimension_formatmake_list_of_imagesto_numpy_array)ModelOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)TransformersKwargsauto_docstringcan_return_tupleis_torch_availableis_vision_availablelogging   )	AutoModel)Blip2VisionModel)ChameleonVQVAEConfig)ChameleonVQVAEChameleonVQVAEEncoderAttnBlock#ChameleonVQVAEEncoderConvDownsample ChameleonVQVAEEncoderResnetBlockChameleonVQVAEVectorQuantizer)IdeficsBaseModelOutputWithPastIdeficsCausalLMOutputWithPast)eager_attention_forward)SiglipVisionConfig)SiglipEncoderSiglipEncoderLayerSiglipVisionEmbeddings)PretrainedConfig)CONFIG_MAPPING
AutoConfigc                       s*   e Zd ZdZdZdZd fdd	Z  ZS )JanusVisionConfiga
  
    This is the configuration class to store the configuration of a [`JanusVisionModel`]. It is used to instantiate a
    `JanusVisionModel` according to the specified arguments, defining the model architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
    Args:
        hidden_size (`int`, *optional*, defaults to 1024):
            Dimensionality of the encoder layers and the pooler layer.
        num_hidden_layers (`int`, *optional*, defaults to 24):
            Number of hidden layers in the Transformer encoder.
        num_attention_heads (`int`, *optional*, defaults to 16):
            Number of attention heads for each attention layer in the Transformer encoder.
        num_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        patch_size (`int`, *optional*, defaults to 16):
            The size (resolution) of each patch.
        image_size (`int`, *optional*, defaults to 384):
            The size (resolution) of each image.
        attention_dropout (`float`, *optional*, defaults to 0.0):
            Dropout probability for attention weights.
        layer_norm_eps (`float`, *optional*, defaults to 1e-06):
            The epsilon used by the layer normalization layers.
        hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
            `"relu"`, `"selu"`, and `"gelu_new"` are supported.
        mlp_ratio (`float`, *optional*, defaults to 4.0):
            Ratio of MLP hidden dimensionality to embedding dimensionality.
        attention_bias (`bool`, *optional*, defaults to `True`):
            Whether to add a bias to the queries, keys, and values in the attention layers.
        hidden_dropout_rate (`float`, *optional*, defaults to 0.0):
            The dropout probability for fully connected layers in the encoder.
        projection_dim (`int`, *optional*, defaults to 2048):
            Dimensionality of the MLP projection head.
        projection_dropout (`float`, *optional*, defaults to 0.0):
            Dropout probability for the projection layer.
        use_qk_norm (`bool`, *optional*, defaults to `False`):
            Whether to normalize the query and key matrices.
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated normal initializer for initializing all weight matrices.
        depth (`int`, *optional*, defaults to 2):
            Number of hidden layers in the aligner module.
        num_image_tokens (`int`, *optional*, defaults to 576):
            Number of image tokens.
    Zjanus_vision_modelvision_config         r	             ư>gelu      @T   F{Gz?r&   @  c                    sd   t  jf |||||||||	d	| | `|
| _|| _|| _|| _|| _|| _|| _	|| _
|| _d S )N)	hidden_sizenum_hidden_layersnum_attention_headsnum_channels
patch_size
image_sizeattention_dropoutlayer_norm_eps
hidden_act)super__init__intermediate_size	mlp_ratioattention_biashidden_dropout_rateprojection_dimprojection_dropoutuse_qk_norminitializer_rangedepthnum_image_tokens)selfrF   rG   rH   rI   rJ   rK   rL   rM   rN   rR   rS   rT   rU   rV   rW   rX   rY   rZ   kwargs	__class__ c/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/janus/modular_janus.pyrP      s.    
zJanusVisionConfig.__init__)r;   r<   r=   r	   r=   r>   r?   r@   rA   rB   Tr?   rC   r?   FrD   r&   rE   )__name__
__module____qualname____doc__
model_typeZbase_config_keyrP   __classcell__r_   r_   r]   r`   r9   W   s,   .                  r9   c                       sd   e Zd ZdZddddddddg d	d
dddd
ddfeeeeeeeeee eed fddZ  Z	S )JanusVQVAEConfiga:
  
    This is the configuration class to store the configuration of a [`JanusVQVAEModel`]. It is used to instantiate a
    `JanusVQVAEModel` according to the specified arguments, defining the model architecture.
    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information. Instantiating a
    configuration with the defaults will yield a similar configuration to the VQModel of the
    [deepseek-community/Janus-Pro-1B](https://huggingface.co/deepseek-community/Janus-Pro-1B).

    Args:
        embed_dim (`int`, *optional*, defaults to 8):
            Dimensionality of each embedding vector.
        num_embeddings (`int`, *optional*, defaults to 16384):
            Number of codebook embeddings.
        double_latent (`bool`, *optional*, defaults to `False`):
            Whether to use double z channels.
        latent_channels (`int`, *optional*, defaults to 256):
            Number of channels for the latent space.
        num_patches (`int`, *optional*, defaults to 32):
            Num of patches the input images can be divided into.
        in_channels (`int`, *optional*, defaults to 3):
            Number of input channels.
        out_channels (`int`, *optional*, defaults to 3):
            Number of out channels.
        base_channels (`int`, *optional*, defaults to 128):
            Base channel count.
        channel_multiplier (`list[int]`, *optional*, defaults to `[1, 1, 2, 2, 4]`):
            Channel multipliers for each resolution.
        num_res_blocks (`int`, *optional*, defaults to 2):
            Number of residual blocks.
        dropout (`float`, *optional*, defaults to 0.0):
            Dropout rate.
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        projection_dim (`int`, *optional*, defaults to 2048):
            Dimensionality of the MLP projection head.
        num_hidden_layers (`int`, *optional*, defaults to 2):
            Number of hidden layers in VAVAE MLP Connecter module.
        hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
            `"relu"`, `"silu"` and `"gelu_new"` are supported.
        image_token_embed_dim (`int`, *optional*, defaults to 2048):
            Dimension of image embeddings. It should be same as the dimensionality of text embeddings.
       i @  F       r	      )   rl   r&   r&      r&   r?   rD   rC   rA   )	embed_dimnum_embeddingsdouble_latentlatent_channelsnum_patchesin_channelsout_channelsbase_channelschannel_multipliernum_res_blocksdropoutc                    s\   t  jf |||||||	|
||d
| || _|| _|| _|| _|| _|| _| `| `	| `
d S )N)
rn   ro   rp   rq   rs   ru   rv   rw   rx   rX   )rO   rP   rr   rt   rU   rG   rN   image_token_embed_dim
resolutionZattn_resolutionsZ	attn_type)r[   rn   ro   rp   rq   rr   rs   rt   ru   rv   rw   rx   rX   rU   rG   rN   ry   r\   r]   r_   r`   rP      s.    zJanusVQVAEConfig.__init__)
ra   rb   rc   rd   intboollistfloatrP   rf   r_   r_   r]   r`   rg      s<   .rg   c                       s2   e Zd ZdZdZeeedZd fdd	Z	  Z
S )	JanusConfiga;  
    This is the configuration class to store the configuration of a [`JanusModel`]. It is used to instantiate an
    Janus model according to the specified arguments, defining the model architecture. Instantiating a configuration
    with the defaults will yield a similar configuration to that of the Janus-1B or Janus-7B models.

    e.g. [deepseek-community/Janus-Pro-1B](https://huggingface.co/deepseek-community/Janus-Pro-1B) or
    [deepseek-community/Janus-Pro-7B](https://huggingface.co/deepseek-community/Janus-Pro-7B)

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
            The config object or dictionary of the text backbone.
        vision_config (`Union[AutoConfig, dict]`,  *optional*, defaults to `JanusVisionConfig`):
            The config object or dictionary of the vision backbone.
        vq_config (`Union[AutoConfig, dict]`,  *optional*, defaults to `JanusVQVAEConfig`):
            The config object or dictionary of the VQVAE backbone.
        image_token_id (`int`, *optional*, defaults to 100581):
            Token index of a placeholder image token.

    Example:

    ```python
    >>> from transformers import JanusForConditionalGeneration, JanusConfig, JanusVisionConfig, JanusVQVAEConfig, LlamaConfig

    >>> # Initializing a Janus vision config
    >>> vision_config = JanusVisionConfig()

    >>> # Initializing a Llama config
    >>> text_config = LlamaConfig()

    >>> # Initializing a VQ config
    >>> vq_config = JanusVQVAEConfig()

    >>> # Initializing a Janus Pro 1B style configuration
    >>> configuration = JanusConfig(vision_config=vision_config, text_config=text_config, vq_config=vq_config)

    >>> # Initializing a model from the Janus Pro 1B style configuration
    >>> model = JanusForConditionalGeneration(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```Zjanus)text_configr:   	vq_configN c                    sn  t |tr4|dd|d< t|d  f i || _nD|d u rTtd td  | _n$t |trf|| _ntdt	| |d u rtd t
 | _n@t |trt
f i || _n$t |t
r|| _ntdt	| |d u rtd t | _nDt |trtf i || _n&t |tr"|| _ntdt	| | jj| _| jj| jj | j_|| _t jf i | d S )	Nre   llamaz7`text_config` is None. Initializing with default valueszTInvalid type for `text_config`. Must be either `dict` or `LlamaConfig`. Type found: zK`vision_config` is None. Initializing with default JanusVisionConfig valuesz\Invalid type for `vision_config`. Must be either `dict` or `JanusVisionConfig`. Type found: zF`vq_config` is None. Initializing with default JanusVQVAEConfig valueszWInvalid type for `vq_config`. Must be either `dict` or `JanusVQVAEConfig`. Type found: )
isinstancedictgetr7   r   loggerinfor6   
ValueErrortyper9   r:   rg   r   rX   rK   rJ   rr   image_token_idrO   rP   )r[   r   r:   r   r   r\   r]   r_   r`   rP   G  sR    









zJanusConfig.__init__)NNNr   )ra   rb   rc   rd   re   r8   r9   rg   Zsub_configsrP   rf   r_   r_   r]   r`   r     s   -    r   c                   @   s>   e Zd ZU eed< dZdZddgZddgZdZ	dZ
dZdZd	S )
JanusPreTrainedModelconfigmodelTZLlamaDecoderLayerJanusVisionEncoderLayerpast_key_valuesZcausal_maskFN)ra   rb   rc   r   __annotations__Zbase_model_prefixZsupports_gradient_checkpointing_no_split_modulesZ_skip_keys_device_placementZ_supports_flash_attnZ_supports_sdpa_can_compile_fullgraphZ!_supports_param_buffer_assignmentr_   r_   r_   r`   r     s   
r   z9
    Base class for Janus VQ-VAE mode model outputs.
    )Zcustom_introc                   @   s2   e Zd ZU dZdZeej ed< dZ	ejed< dS )JanusVQVAEOutputz
    decoded_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
        Reconstructed pixel values after encoding and decoding the input.
    embedding_loss (`torch.FloatTensor`):
        Embedding loss.
    Ndecoded_pixel_valuesembedding_loss)
ra   rb   rc   rd   r   r   torchFloatTensorr   r   r_   r_   r_   r`   r     s   
r   c                   @   s   e Zd ZdS )JanusBaseModelOutputWithPastNra   rb   rc   r_   r_   r_   r`   r     s   r   c                   @   s   e Zd ZdS )JanusCausalLMOutputWithPastNr   r_   r_   r_   r`   r     s   r   c                   @   s$   e Zd ZdejeejdddZdS )JanusVisionEmbeddingsF)pixel_valuesinterpolate_pos_encodingreturnc           
      C   sh   |j \}}}}| jjj}| |j|d}|ddd}|rP| |||}	n| | j	}	||	 }|S )Ndtyper&   rl   )
shapeZpatch_embeddingweightr   toflatten	transposer   Zposition_embeddingposition_ids)
r[   r   r   _heightwidthZtarget_dtypeZpatch_embedsZ
embeddingsZ
pos_embedsr_   r_   r`   forward  s    
zJanusVisionEmbeddings.forwardN)F)ra   rb   rc   r   Tensorr|   r   r_   r_   r_   r`   r     s   r   c                       sF   e Zd ZdZed fddZd	ejeej e	e
 dddZ  ZS )
JanusVisionAttentionz(Attention Class for Janus Vision Encoderr   c                    sL  t    || _|j| _|j| _| j| j | _| j| j | jkrZtd| j d| j d| jd | _	|j
| _
|j}|j}d| _d| _tj| j| j| j |jd| _tj| j| j| j |jd| _tj| j| j| j |jd| _t| j| j| _|dkrt|nt | _|r"t| jnt | _|r>t| jnt | _d S )	Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).g      Frl   Zbiasr   )rO   rP   r   rF   rn   rH   	num_headshead_dimr   scalerL   rV   rW   	is_causalZnum_key_value_groupsr   LinearrS   q_projk_projv_projprojection_layerDropoutZIdentity	LayerNormq_normk_norm)r[   r   Zproj_dropoutZqk_normr]   r_   r`   rP     s0    

zJanusVisionAttention.__init__N)hidden_statesattention_maskr\   c                 K   s4  |  \}}}| |}| |}| |}	|d| j| j}| |}|d| j| j}| |}|||| j| j	dd}|||| j| j	dd}|	
||| j| j	dd}	t}
| jjdkrt| jj }
|
| |||	|f| jsdn| j| j| jd|\}}|||| j}| |}| |}||fS )Nrl   r&   eagerr?   )rx   Zscalingr   )sizer   r   r   reshaper   r   r   r   r   viewr1   r   Z_attn_implementationr   ZtrainingrL   r   r   rn   r   rV   )r[   r   r   r\   
batch_sizeseq_lenr   Zquery_statesZ
key_statesZvalue_statesZattention_interfaceZattn_outputZattn_weightsoutputr_   r_   r`   r     s>    




	


zJanusVisionAttention.forward)N)ra   rb   rc   rd   r9   rP   r   r   r   r   r    r   rf   r_   r_   r]   r`   r     s     r   c                       s6   e Zd Zed fddZejejdddZ  ZS )JanusVisionMLPr   c                    sr   t    || _t|j|j | _t|j | _	t
|j| j| _t
| j|j| _t
|j| _t
|j| _d S N)rO   rP   r   r{   rF   rR   rQ   r
   rN   activation_fnr   r   fc1fc2r   rT   dropout1dropout2r[   r   r]   r_   r`   rP     s    
zJanusVisionMLP.__init__r   r   c                 C   s6   |  |}| |}| |}| |}| |}|S r   )r   r   r   r   r   r[   r   r_   r_   r`   r     s    




zJanusVisionMLP.forward)	ra   rb   rc   r9   rP   r   r   r   rf   r_   r_   r]   r`   r     s   
r   c                       s"   e Zd Zed fddZ  ZS )r   r   c                    sZ   t  | || _|j| _t|| _tj| j|j	d| _
tj| j|j	d| _t|| _d S )N)eps)rO   rP   r   rF   rn   r   Z	self_attnr   r   rM   Zlayer_norm1Zlayer_norm2r   Zmlpr   r]   r_   r`   rP     s    
z JanusVisionEncoderLayer.__init__ra   rb   rc   r9   rP   rf   r_   r_   r]   r`   r     s   r   c                       s"   e Zd Zed fddZ  ZS )JanusVisionEncoderr   c                    s0   t    t fddt jD | _d S )Nc                    s   g | ]}t  qS r_   )r   .0r   r   r_   r`   
<listcomp>'      z/JanusVisionEncoder.__init__.<locals>.<listcomp>)rO   rP   r   
ModuleListrangerG   Zlayersr   r]   r   r`   rP   %  s    zJanusVisionEncoder.__init__r   r_   r_   r]   r`   r   $  s   r   c                       s"   e Zd Zed fddZ  ZS )JanusVisionModelr   c                    s   t  | t|| _d S r   )rO   rP   r   encoderr   r]   r_   r`   rP   +  s    zJanusVisionModel.__init__r   r_   r_   r]   r`   r   *  s   r   c                       s*   e Zd Zed fddZdd Z  ZS )JanusVisionAlignerMLPr   c                    sN   t    t j j| _t fddtd j	D | _
t j | _d S )Nc                    s   g | ]}t  j jqS r_   r   r   rU   r   r   r_   r`   r   6  r   z2JanusVisionAlignerMLP.__init__.<locals>.<listcomp>rl   )rO   rP   r   r   rF   rU   r   r   r   rY   hidden_layersr
   rN   r   r   r]   r   r`   rP   1  s    
zJanusVisionAlignerMLP.__init__c                 C   s,   |  |}| jD ]}| |}||}q|S r   r   r   r   r[   r   layerr_   r_   r`   r   :  s
    



zJanusVisionAlignerMLP.forward)ra   rb   rc   r9   rP   r   rf   r_   r_   r]   r`   r   0  s   	r   c                       s6   e Zd Zed fddZejejdddZ  Z	S )JanusVQVAEVectorQuantizerr   c                    s   t  | |jgd | _d S )Nr&   )rO   rP   rr   quant_state_dimsr   r]   r_   r`   rP   C  s    z"JanusVQVAEVectorQuantizer.__init__image_tokensr   c                 C   sb   |j d }| jjj d }| |}tj|ddd}||g| j|R }|dddd }|S )Nr   r   r&   )pdimr	   rl   )	r   Z	embeddingr   F	normalizer   r   permute
contiguous)r[   r   r   Zemb_dimZhidden_state_quantr_   r_   r`   get_codebook_entryG  s    

z,JanusVQVAEVectorQuantizer.get_codebook_entry)
ra   rb   rc   rg   rP   r   
LongTensorr   r   rf   r_   r_   r]   r`   r   B  s   r   c                   @   s   e Zd ZdS )JanusVQVAEResnetBlockNr   r_   r_   r_   r`   r   W  s   r   c                   @   s   e Zd ZdS )JanusVQVAEAttnBlockNr   r_   r_   r_   r`   r   [  s   r   c                   @   s   e Zd ZdS )JanusVQVAEConvDownsampleNr   r_   r_   r_   r`   r   _  s   r   c                       s$   e Zd Z fddZdd Z  ZS )JanusVQVAEConvUpsamplec                    s&   t    tjj||dddd| _d S )Nr	   rl   Zkernel_sizeZstridepadding)rO   rP   r   r   Conv2dconv)r[   rs   r]   r_   r`   rP   d  s    
zJanusVQVAEConvUpsample.__init__c                 C   s   t j|ddd}| |}|S )Ng       @Znearest)Zscale_factormode)r   Zinterpolater   r   r_   r_   r`   r   h  s    
zJanusVQVAEConvUpsample.forward)ra   rb   rc   rP   r   rf   r_   r_   r]   r`   r   c  s   r   c                       s8   e Zd Zeed fddZejejdddZ  Z	S )JanusVQVAEMidBlock)r   channelsc                    s8   t    t|||d| _t|| _t|||d| _d S )Nr   rs   rt   )rO   rP   r   block_1r   attn_1block_2)r[   r   r   r]   r_   r`   rP   o  s    

zJanusVQVAEMidBlock.__init__r   c                 C   s"   |  |}| |}| |}|S r   )r   r   r   r   r_   r_   r`   r   }  s    


zJanusVQVAEMidBlock.forward)
ra   rb   rc   rg   r{   rP   r   r   r   rf   r_   r_   r]   r`   r   n  s   r   c                       s,   e Zd Z fddZejdddZ  ZS )JanusVQVAEEncoderc              	      sr  t    t|j| _|j| _|j}|j}|j}|j	}|j}t
jj||dddd| _dt| }|| _t | _t| jD ]}t }	t }
|||  }|||  }t| jD ]8}|	t|||d |}|| jd kr|
t| qt }|	|_|
|_|| jd krt||_| j| qzt||| _t
jjd|ddd	| _t
jj||r^d
| n|dddd| _d S )Nr	   rl   r   )rl   r   rj   r@   TZ
num_groupsrI   r   Zaffiner&   ) rO   rP   lenrv   num_resolutionsrw   ru   rs   rp   rq   r   r   r   conv_intuplein_channel_multiplierr   downr   appendr   r   Moduleblockattnr   
downsampler   mid	GroupNormnorm_outconv_out)r[   r   ru   rs   rp   rq   rv   r  i_levelr  r  block_in	block_outi_blockr  r]   r_   r`   rP     sV    


zJanusVQVAEEncoder.__init__)r   c                 C   s   |  |g}t| jD ]}t| jD ]N}| j| j| |d }t| j| jdkrh| j| j| |}|| q$|| jd kr|| j| 	|d  q|d }| 
|}| |}|t|9 }| |}|S )Nr   r   rl   )r  r   r   rw   r  r  r   r  r  r	  r
  r  r   sigmoidr  )r[   r   r   r  r  hidden_statelast_hidden_stater_   r_   r`   r     s"    


zJanusVQVAEEncoder.forward)ra   rb   rc   rP   r   r   r   rf   r_   r_   r]   r`   r     s   3r   c                       s0   e Zd Z fddZejejdddZ  ZS )JanusVQVAEDecoderc              	      sR  t    t|j| _|j| _|j}|j}|j}||j| jd   }t	j
j||dddd| _t||| _t
 | _tt| jD ]}t
 }t
 }||j|  }	t| jd D ]8}
|t|||	d |	}|| jd kr|t| qt
 }||_||_|dkrt||_| j| qt	j
jd|ddd	| _t	j
j||dddd| _d S )
Nrl   r	   r   r   r   rj   r@   Tr   )rO   rP   r   rv   r   rw   ru   rq   rt   r   r   r   r  r   r
  r   upreversedr   r  r   r   r  r  r  r   upsampler  r  r  )r[   r   ru   rq   rt   r  r  r  r  r  r  r  r]   r_   r`   rP     sB    



zJanusVQVAEDecoder.__init__)r  r   c                 C   s   |  |}| |}t| jD ]r}t| jd D ]@}| j| j| |}t| j| jdkr0| j| j| |}q0|| jd kr| j| 	|}q| 
|}|t|9 }| |}|S )Nrl   r   )r  r
  r   r   rw   r  r  r   r  r  r  r   r  r  )r[   r  r  r  r_   r_   r`   r      s    



zJanusVQVAEDecoder.forward)ra   rb   rc   rP   r   r   r   rf   r_   r_   r]   r`   r    s   .r  c                       sh   e Zd Zg dZdZed fddZejej	dddZ
eeej	eej	ej	f d	d
dZ  ZS )
JanusVQVAE)r   r   r   r   r   c                    s(   t  | t|| _d| _|   d S )NF)rO   rP   r  decodergradient_checkpointing	post_initr   r]   r_   r`   rP     s    
zJanusVQVAE.__init__r   c                 C   sr   |j d | jjd | jjd  krNtd| jjd | jjd   d|j  d| j|}| |}| |}|S )aG  
        Decodes quantized token IDs into pixel values.
        Args:
            image_tokens (torch.LongTensor): Batch of token IDs.
        Returns:
            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
                Pixel values decoded from the token IDs.
        rl   r   z4Expected `image_tokens` to have shape `(batch_size, z)`, but got shape `z`.)r   quantizer   r   r   Zpost_quant_convr  )r[   r   Zcodebook_entryr   r   r_   r_   r`   decode%  s    	"

zJanusVQVAE.decode)r   r   c                 C   s6   |j d }| |\}}}| ||d}t||S )Nr   r   )r   encoder  r   r   )r[   r   r   Zquantr   indicesr   r_   r_   r`   r   8  s    
zJanusVQVAE.forward)ra   rb   rc   r   Zmain_input_namerg   rP   r   r   r   r  r"   r!   r  r   rf   r_   r_   r]   r`   r    s   r  c                       s*   e Zd Zed fddZdd Z  ZS )JanusVQVAEAlignerMLPr   c                    sN   t    t j j| _t fddtd j	D | _
t j | _d S )Nc                    s   g | ]}t  j jqS r_   r   r   r   r_   r`   r   K  r   z1JanusVQVAEAlignerMLP.__init__.<locals>.<listcomp>rl   )rO   rP   r   r   rn   rU   r   r   r   rG   r   r
   rN   r   r   r]   r   r`   rP   F  s    
zJanusVQVAEAlignerMLP.__init__c                 C   s,   |  |}| jD ]}| |}||}q|S r   r   r   r_   r_   r`   r   O  s
    



zJanusVQVAEAlignerMLP.forward)ra   rb   rc   rg   rP   r   rf   r_   r_   r]   r`   r!  E  s   	r!  c                       s:   e Zd ZdZed fddZejejdddZ	  Z
S )JanusVQVAEHeadzOHead used for sampling tokens in image generation, replacing the usual lm head.r   c                    s>   t    t|j|j| _t|j | _	t|j|j
| _d S r   )rO   rP   r   r   ry   rU   proj_outr
   rN   r   ro   vision_headr   r]   r_   r`   rP   Z  s    
zJanusVQVAEHead.__init__r   c                 C   s"   |  |}| |}| |}|S r   )r#  r   r$  r   r_   r_   r`   r   `  s    


zJanusVQVAEHead.forward)ra   rb   rc   rd   rg   rP   r   r   tensorr   rf   r_   r_   r]   r`   r"  W  s   r"  zl
    The Janus model which consists of a siglip vision backbone, a Llama language model and a VQ model.
    c                       s   e Zd Zed fddZdd Zdd Zdd	 Zej	ej
ej
d
ddZeedej	ej
eej eej	 ee eej	 eej
 ee eeejf d	ddZ  ZS )
JanusModelr   c                    s   t  | || _t|j| _t| jj| _t	|j
| _t| jjj| jjj| _t| jj| _t| jj| _tj|jd| _d| _|   d S )Nr   F)rO   rP   r   r   _from_configr:   vision_modelr   alignerr  r   vqmodelr   Z	Embeddingro   rn   generation_embeddingsr!  generation_alignerr"  generation_headr'   from_configr   language_modelr  r  r   r]   r_   r`   rP   m  s    zJanusModel.__init__c                 C   s
   | j  S r   )r/  get_input_embeddingsr[   r_   r_   r`   r0    s    zJanusModel.get_input_embeddingsc                 C   s   | j | d S r   )r/  set_input_embeddingsr[   valuer_   r_   r`   r2    s    zJanusModel.set_input_embeddingsc                 C   s   |  |}| |j}|S r   )r(  r)  r  )r[   r   image_embedsr_   r_   r`   get_image_features  s    
zJanusModel.get_image_features)	input_idsinputs_embedsimage_featuresc                 C   s   |du r8||   tj| jjtj|jdk}|d}n|| jjk}| }|	d
||j}||  | kr|jd |jd  }td| d| |S )z
        Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
        equal to the length of multimodal features. If the lengths are different, an error is raised.
        Nr   devicer   r   rl   z6Image features and image tokens do not match: tokens: z, features )r0  r   r%  r   r   longr;  allsum	unsqueezeZ	expand_asr   Znumelr   r   )r[   r7  r8  r9  Zspecial_image_maskZn_image_tokensZn_image_featuresr_   r_   r`   get_placeholder_mask  s    zJanusModel.get_placeholder_maskNr   )	r7  r   r   r   r   cache_positionr8  	use_cachelogits_to_keepc
              
   K   s   |d u |d uA rt d|d u r,|  |}|d ur|| |}|d|jd }||j|j}| j|||d}|	||}| j
f |||||||	d|
}t|j|j|j|j|d ur|nd dS )NzaYou cannot specify both input_ids and inputs_embeds at the same time, and must specify either oner   )r8  r9  )r8  r   r   r   rB  rA  rC  )r  r   r   
attentionsimage_hidden_states)r   r0  r6  r   r   r   r;  r   r@  Zmasked_scatterr/  r   r  r   r   rD  )r[   r7  r   r   r   r   rA  r8  rB  rC  r\   r5  r9  Zimage_attention_maskZ	lm_outputr_   r_   r`   r     s@    
zJanusModel.forward)	NNNNNNNNr   )ra   rb   rc   r   rP   r0  r2  r6  r   r   r   r@  r"   r!   r   r   r   r|   r   r{   r   rf   r_   r_   r]   r`   r&  g  s8            r&  c                       s   e Zd ZddgZdZed fddZdd Zd	d
 Ze	j
e	j
dddZeede	je	jee	j
 ee	j ee ee	j ee	j ee	j ee eee	j
f ee dddZd fdd	Ze	j
dddZe	jde	j
ee	j ee d fddZ  ZS )JanusForConditionalGenerationz(model.language_model.embed_tokens.weightzlm_head.weightTr   c                    sB   t  | || _t|| _tj|jj|jj	dd| _
|   d S )NFr   )rO   rP   r   r&  r   r   r   r   rF   
vocab_sizelm_headr  r   r]   r_   r`   rP     s
    
z&JanusForConditionalGeneration.__init__c                 C   s   | j j S r   )r   r/  r0  r1  r_   r_   r`   r0    s    z2JanusForConditionalGeneration.get_input_embeddingsc                 C   s   | j j| d S r   )r   r/  r2  r3  r_   r_   r`   r2    s    z2JanusForConditionalGeneration.set_input_embeddings)inputsr   c                 C   s   | j |}| j |}|S r   )r   r+  r,  )r[   rI  r  r_   r_   r`   'prepare_embeddings_for_image_generation  s    zEJanusForConditionalGeneration.prepare_embeddings_for_image_generationNr   )r7  r   r   r   r   rA  r8  labelsrB  rC  r\   c                 K   s   | j f |||||||	|d|}|j}t|
tr>t|
 dn|
}| |dd|ddf }d}|dur| jf ||| jjj	d|}t
|||j|j|j|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
        )r7  r   r   r   r   r8  rB  rA  N)logitsrK  rG  )lossrL  r   r   rD  rE  )r   r  r   r{   slicerH  Zloss_functionr   r   rG  r   r   r   rD  rE  )r[   r7  r   r   r   r   rA  r8  rK  rB  rC  r\   outputsr   Zslice_indicesrL  rM  r_   r_   r`   r     s<    	z%JanusForConditionalGeneration.forwardc           
         s8   t  j|f|||||d|}	|d dkr4||	d< |	S )N)r   r8  r   rA  rC  r   r   )rO   prepare_inputs_for_generation)
r[   r7  r   r   r   r8  rA  rC  r\   model_inputsr]   r_   r`   rP  %  s    z;JanusForConditionalGeneration.prepare_inputs_for_generation)r   c                 C   s"   | j j|}|dddd}|S )a,  
        Decodes generated image tokens from language model to continuous pixel values
        with VQGAN module via upsampling.
        Args:
            image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`):
                The tensors corresponding to the input images.
        r   r&   r	   rl   )r   r*  r  r   )r[   r   Zdecoded_imager_   r_   r`   decode_image_tokensC  s    z1JanusForConditionalGeneration.decode_image_tokens)rI  r   logits_processorc           %         s  | d| j}t|}| dd}|dkrHt jf |||d d|S |jf i |}| tj	tj
fvrttd|  | |  |d ur|nt }d|d< |jd u rtd d	|_|j|d
< | ||j|\}}	}|j|j }
}t|jdkrtd|j d|d u}| j|||jd |jrR|jdkrR|t|j d |_| j||jd |d ||d}| jf |||jd|\}}| jjj j!}|j\}}|"dd}| dd }|"dd}||d< ||d d d f |jk||d d d f |j#d k@ }||d d d f $||j% | & |}| '|||}|(dd d u r~| j)|j*p^d|d t+|j,|| |d|d< t-j.||f|
|d}|j/}|j0}|j1}|j2}|j3}|r|rdnd }|r|rdnd }|r|rdnd }|r|rdnd }t4|D ]
}| j5f ||d|}|d 6|j|d< |d 6|j|d< | jj7f i |||d}| 8||}|j9d d dd d f : } | j;| }!|||!}"|j<rt-j=|"dd}#t-j>|#dd?d}$nt-j@|"dd}$|$|d d |f< t-A|$|$g}$|$Bd}$| C|$}q|r`|r,||!f7 }|r@|| D f7 }|rP||jE7 }|r`||jF7 }|r|tG||!||||jHdS |S d S ) Ngeneration_configgeneration_modetext)rI  r   rT  guidance_scalezGot incompatible mode for Image Generation, should be one of greedy or sampling. Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`.TrB  zU`guidance_scale` is required for CFG but not provided. Setting to default value of 5.   rW  r&   z;Expected input ids of shape (batch_size, seq_len), but got z3Passing `inputs embeds` is not supported currently.)r;  rl   )rT  Zinput_ids_seq_lengthZencoder_input_idsZprefix_allowed_tokens_fnrS  r;  )r7  r   Zexpand_sizer   Zboi_token_idr   Zstatic)cache_implementationr   Zmax_cache_lenmodel_kwargsr:  r_   )r8  r7  rA  )output_attentionsoutput_hidden_statesr   )r   )Znum_samples)	sequencesscoresrL  rD  r   r   )IpoprT  copydeepcopyrO   generateupdateZget_generation_moder   ZSAMPLEZGREEDY_SEARCHr   validateZ_validate_model_kwargsr   rW  r   warningZ_prepare_model_inputsZbos_token_idr   r;  r   r   Z_prepare_special_tokensr  r   Z_get_logits_processorZ_expand_inputs_for_generationZnum_return_sequencesr   r(  r   rZ   repeatZgeneration_kwargsZmasked_fill_Zpad_token_idr0  Z_get_initial_cache_positionr   Z
_get_cacherY  max
max_lengthr   zerosr[  r\  output_scoresoutput_logitsreturn_dict_in_generater   rP  r   r/  Z#_update_model_kwargs_for_generationr  cloner-  Z	do_sampleZsoftmaxZmultinomialZsqueezeZargmaxcatr?  rJ  r~   rD  r   r   r   )%r[   rI  r   rS  r\   rT  rU  rZ  r7  Zmodel_input_namer   r;  Zkwargs_has_attention_maskrZ   r   r   Zinput_tokensmaskr8  Zgenerated_tokensr[  r\  rj  rk  rl  Z
raw_scoresZ
raw_logitsZdecoder_hidden_statesZdecoder_attentionsirQ  rO  r  r^  Znext_token_scoresZprobs
next_tokenr]   r_   r`   rb  O  s    	

















	z&JanusForConditionalGeneration.generate)
NNNNNNNNNr   )NNNNNN)NNN)ra   rb   rc   Z_tied_weights_keysr   r   rP   r0  r2  r   r   rJ  r"   r!   r   r   r   r   r|   r   r{   r   r    r   rP  rR  Zno_gradr   rb  rf   r_   r_   r]   r`   rF    s`   	          6         rF  c                       s  e Zd ZdZdddejddddddf
eeee	e
f  e
eeee
ef eeeeee f  eeeee f  ee d
 fddZdejee
ee
e
e
f f eee	ef  eee	ef  ejd
ddZdejddfejeee	e
f e
f eee
e
e
f  eeee	ef  eee	ef  ejdddZdeee ee ee eee  eee  ee	 ee	 dddZdejeeee f eeee f eee	ef  ejdddZ  ZS )JanusImageProcessora
  
    Constructs a JANUS image processor.

    Args:
        do_resize (`bool`, *optional*, defaults to `True`):
            Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
            `do_resize` parameter in the `preprocess` method.
        size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
            Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
            method.
        min_size (`int`, *optional*, defaults to 14):
            The minimum allowed size for the resized image. Ensures that neither the height nor width
            falls below this value after resizing.
        resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
            Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
            overridden by the `resample` parameter in the `preprocess` method.
        do_rescale (`bool`, *optional*, defaults to `True`):
            Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
            `do_rescale` parameter in the `preprocess` method.
        rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
            Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
            overridden by the `rescale_factor` parameter in the `preprocess` method.
        do_normalize (`bool`, *optional*, defaults to `True`):
            Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
            method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
        image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
            Mean to use if normalizing the image. This is a float or list of floats the length of the number of
            channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
            overridden by the `image_mean` parameter in the `preprocess` method.
        image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
            Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
            number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
            Can be overridden by the `image_std` parameter in the `preprocess` method.
        do_convert_rgb (`bool`, *optional*, defaults to `True`):
            Whether to convert the image to RGB.
    TN   gp?)
	do_resizer   min_sizeresample
do_rescalerescale_factordo_normalize
image_mean	image_stddo_convert_rgbc                    s@   t  jf i | || _|d u r(d| _ntdd |D | _d S )N)   r}  r}  c                 s   s   | ]}t |d  V  qdS )   N)r{   )r   xr_   r_   r`   	<genexpr>I  r   z/JanusImageProcessor.__init__.<locals>.<genexpr>)rO   rP   ru  background_colorr  )r[   rt  r   ru  rv  rw  rx  ry  rz  r{  r|  r\   r]   r_   r`   rP   5  s
    zJanusImageProcessor.__init__r   )imager  data_formatinput_data_formatr   c                 C   s  t ||\}}|tjkr"|jd n|jd }||krP|durHt|||n|}|S t||}t|trl|g}nt||krt	d| d|tjkr4t
j|||f|jd}	t|D ]\}
}||	|
ddddf< q||kr|| d }||	dd||| ddf< n*|| d }||	dddd||| f< nt
j|||f|jd}	t|D ] \}
}||	dddd|
f< qR||kr|| d }||	||| ddddf< n*|| d }||	dd||| ddf< |	S )a}  
        Pads an image to a square based on the longest edge.

        Args:
            image (`np.ndarray`):
                The image to pad.
            background_color (`int` or `tuple[int, int, int]`, *optional*, defaults to 0):
                The color to use for the padding. Can be an integer for single channel or a
                tuple of integers representing for multi-channel images. If passed as integer
                in mutli-channel mode, it will default to `0` in subsequent channels.
            data_format (`str` or `ChannelDimension`, *optional*):
                The channel dimension format for the output image. Can be one of:
                    - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                    - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
                If unset, will use same as the input image.
            input_data_format (`str` or `ChannelDimension`, *optional*):
                The channel dimension format for the input image. Can be one of:
                    - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                    - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.

        Returns:
            `np.ndarray`: The padded image.
        r   r   Nz(background_color must have no more than z) elements to match the number of channelsr   r&   )r   r   ZFIRSTr   r   rg  r   r{   r   r   npri  r   	enumerate)r[   r  r  r  r  r   r   rI   Zmax_dimresultrp  colorstartr_   r_   r`   pad_to_squareK  sB    



  
 z!JanusImageProcessor.pad_to_square)r  r   r  rv  r  r  r   c                 K   s   |dur|n| j }|du r"t|}t||\}}	t||	}
t|dd}|d |d krrtd|d  d|d  |d }||
 }tt|| | jtt|	| | jg}t|f||||d|}| j	|||d	}|S )
a  
        Resize an image to dynamically calculated size.

        Args:
            image (`np.ndarray`):
                Image to resize.
            size (`dict[str, int]` or `int`):
                The size to resize the image to. If a dictionary, it should have the keys `"height"` and `"width"`.
            background_color (`tuple[int, int, int]`):
                The background color to use for the padding.
            resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
                `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
            data_format (`ChannelDimension` or `str`, *optional*):
                The channel dimension format for the output image. If unset, the channel dimension format of the input
                image is used. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
                - `None`: will be inferred from input
            input_data_format (`ChannelDimension` or `str`, *optional*):
                The channel dimension format for the input image. If unset, the channel dimension format is inferred
                from the input image. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
                - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.

        Returns:
            `np.ndarray`: The resized image.
        NT)Zdefault_to_squarer   r   z5Output height and width must be the same. Got height=z and width=)r   rv  r  r  )r  r  r  )
r  r   r   rg  r   r   r{   ru  r   r  )r[   r  r   r  rv  r  r  r\   r   r   max_sizedeltaZoutput_size_nonpaddedr_   r_   r`   r     s>    &
	zJanusImageProcessor.resize)imagesrw  rx  ry  rz  r{  r  return_tensorsc	                 C   sZ  |dur|n| j }|du r$d| j n|}|dur4|n| j}|durF|n| j}|durX|n| j}t|}t|d tjjrt	|dkr|S |d S |du rt
|d }g }	|D ]}
t|
}
|r| j|
|||d}
|r| j|
||d}
|
ddtj}
|r(|r(|dkr(t|
tj|d	}
tj|
}
|	|
 qd
|	i}|dkrJ|nd}t||dS )znApplies post-processing to the decoded image tokens by reversing transformations applied during preprocessing.Ng      ?r   rl   )r  rz  r{  r  )r   r  r~  zPIL.Image.Image)Zinput_channel_dimr   )dataZtensor_type)rw  rx  ry  rz  r{  r   r   PILZImager   r   r   unnormalizeZrescaleZclipZastyper  Zuint8r   r   ZLASTZ	fromarrayr  r   )r[   r  rw  rx  ry  rz  r{  r  r  r   r  r  r_   r_   r`   postprocess  s6    zJanusImageProcessor.postprocess)r  rz  r{  r  r   c                 C   s   d}t |tr4t||kr>td| dt| n
|g| }t |trnt||krxtd| dt| n
|g| }tdd t||D }tdd |D }| j||||d}|S )	a~  
        Unnormalizes `image` using the mean and standard deviation specified by `mean` and `std`.
        image = (image * image_std) + image_mean
        Args:
            image (`torch.Tensor` of shape `(batch_size, num_channels, image_size, image_size)` or `(num_channels, image_size, image_size)`):
                Batch of pixel values to postprocess.
            image_mean (`float` or `Iterable[float]`):
                The mean to use for unnormalization.
            image_std (`float` or `Iterable[float]`):
                The standard deviation to use for unnormalization.
            input_data_format (`ChannelDimension` or `str`, *optional*):
                The channel dimension format for the input image. If unset, the channel dimension format is inferred
                from the input image. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
                - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
        r	   zmean must have z$ elements if it is an iterable, got zstd must have c                 s   s   | ]\}}| | V  qd S r   r_   )r   meanstdr_   r_   r`   r  9  r   z2JanusImageProcessor.unnormalize.<locals>.<genexpr>c                 s   s   | ]}d | V  qdS )rl   Nr_   )r   r  r_   r_   r`   r  :  r   )r  r  r  r  )r   r   r   r   r  zipr   )r[   r  rz  r{  r  rI   Zrev_image_meanZrev_image_stdr_   r_   r`   r    s    



zJanusImageProcessor.unnormalize)r   NN)NNNNNNN)N)ra   rb   rc   rd   r   ZBICUBICr|   r   r   strr{   r   r~   r}   rP   r  Zndarrayr  r   arrayr  r   r   r  r   r  rf   r_   r_   r]   r`   rr    s   '
   NN       

8 rr  )	rr  r   rF  r&  r  r   rg   r9   r   )xr`  collections.abcr   dataclassesr   typingr   r   r   numpyr  r   r   Z.transformers.models.blip.image_processing_blipr   Zactivationsr
   Zcache_utilsr   Z
generationr   r   r   r   Zgeneration.utilsr   Zimage_processing_utilsr   r   Zimage_transformsr   r   Zimage_utilsr   r   r   r   r   r   r   Zmodeling_outputsr   Zmodeling_utilsr   r   Zprocessing_utilsr   utilsr    r!   r"   r#   r$   r%   autor'   Zblip_2.modeling_blip_2r(   Z!chameleon.configuration_chameleonr)   Zchameleon.modeling_chameleonr*   r+   r,   r-   r.   Zidefics.modeling_ideficsr/   r0   Zllama.modeling_llamar1   Zsiglip.configuration_siglipr2   Zsiglip.modeling_siglipr3   r4   r5   Ztorch.nnZtorch.nn.functionalZ
functionalr   Ztorch.utils.checkpointr  Zconfiguration_utilsr6   r7   r8   Z
get_loggerra   r   r9   rg   r   r   r   r   r   r   r  r   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r!  r"  r&  rF  rr  __all__r_   r_   r_   r`   <module>   s   $	 
aZnLMD0l  9  4