a
    h+                  
   @   s  d dl Z d dlmZ d dlmZmZmZ d dlZd dlm	Z	 d dl
m	  mZ ddlmZ ddlmZmZ ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZ ddlmZmZ ddl m!Z!m"Z" ddl#m$Z$ ddl%m&Z&m'Z'm(Z( ddl)m*Z* ddl+m,Z, ddl-m.Z.m/Z/m0Z0 dd Z1d[ddZ2ej3e4ej3dddZ5d\e	j6ej3ej3ej3eej3 e7e7e$e& dddZ8G dd  d e	j6Z9ed!G d"d# d#e	j6Z:G d$d% d%e	j6Z;G d&d' d'eZ<G d(d) d)e	j6Z=G d*d+ d+e	j6Z>G d,d- d-e	j6Z?G d.d/ d/e	j6Z@G d0d1 d1e	j6ZAG d2d3 d3e	j6ZBG d4d5 d5e	j6ZCG d6d7 d7e	j6ZDG d8d9 d9e	j6ZEG d:d; d;e	j6ZFG d<d= d=e	jGZHG d>d? d?e	j6ZIG d@dA dAe	j6ZJG dBdC dCe	j6ZKG dDdE dEe	j6ZLG dFdG dGe	j6ZMe'dHdIG dJdK dKe"ZNG dLdM dMZOe'G dNdO dOe"ZPG dPdQ dQe	j6ZQe'G dRdS dSePZRe'G dTdU dUePeZSG dVdW dWePZTG dXdY dYePeZUg dZZVdS )]    N)cached_property)CallableOptionalUnion   )ACT2FN)CacheDynamicCache)GenerationMixin)use_kernel_forward_from_hub)create_causal_mask)GradientCheckpointingLayer)BaseModelOutputWithPastCausalLMOutputWithPast)ROPE_INIT_FUNCTIONSdynamic_rope_update)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)TransformersKwargsauto_docstringcan_return_tuple)deprecate_kwarg)check_model_inputs   )
Emu3ConfigEmu3TextConfigEmu3VQVAEConfigc                 C   sH   | dd| j d d f }| d| j d d df }tj| |fddS )z*Rotates half the hidden dims of the input..N   dim)shapetorchcat)xx1Zx2 r'   b/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/emu3/modeling_emu3.pyrotate_half/   s    r)   c                 C   sD   | |}| |}| | t| |  }|| t||  }||fS )a  Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    )	unsqueezer)   )qkcossinposition_idsZunsqueeze_dimZq_embedZk_embedr'   r'   r(   apply_rotary_pos_emb6   s
    

r0   )hidden_statesn_repreturnc                 C   s^   | j \}}}}|dkr| S | dddddddddf |||||} | ||| ||S )z
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    r   N)r"   expandreshape)r1   r2   batchnum_key_value_headsslenhead_dimr'   r'   r(   	repeat_kvQ   s
    0r:           )modulequerykeyvalueattention_maskscalingdropoutkwargsc                 K   s   t || j}t || j}	t||dd| }
|d urf|d d d d d d d |jd f }|
| }
tjj|
dtj	d
|j}
tjj|
|| jd}
t|
|	}|dd }||
fS )Nr   r   r   )r!   dtype)ptrainingr   )r:   num_key_value_groupsr#   matmul	transposer"   nn
functionalZsoftmaxfloat32torE   rB   rG   
contiguous)r<   r=   r>   r?   r@   rA   rB   rC   
key_statesvalue_statesattn_weightscausal_maskattn_outputr'   r'   r(   eager_attention_forward]   s    
&rU   c                       s   e Zd ZdZeed fddZedddddej	e
ej	ej	f eej	 ee eej ee e
ej	ej	f d
ddZ  ZS )Emu3Attention=Multi-headed attention from 'Attention Is All You Need' paperconfig	layer_idxc                    s   t    || _|| _t|d|j|j | _|j|j | _	| jd | _
|j| _d| _tj|j|j| j |jd| _tj|j|j| j |jd| _tj|j|j| j |jd| _tj|j| j |j|jd| _d S )Nr9         Tbias)super__init__rY   rZ   getattrhidden_sizenum_attention_headsr9   r7   rH   rA   attention_dropout	is_causalrK   LinearZattention_biasq_projk_projv_projo_projselfrY   rZ   	__class__r'   r(   r_   z   s(    
zEmu3Attention.__init__past_key_valuepast_key_values4.58new_nameversionN)r1   position_embeddingsr@   ro   cache_positionrC   r3   c                 K   s$  |j d d }g |d| jR }| ||dd}	| ||dd}
| ||dd}|\}}t|	|
||\}	}
|d ur|||d}||
|| j	|\}
}t
}| jjdkrt| jj }|| |	|
||f| jsdn| j| jd|\}}|jg |dR   }| |}||fS )Nr   r   r   )r.   r-   ru   eagerr;   )rB   rA   )r"   r9   rf   viewrJ   rg   rh   r0   updaterZ   rU   rY   _attn_implementationr   rG   rc   rA   r5   rO   ri   )rk   r1   rt   r@   ro   ru   rC   Zinput_shapeZhidden_shapeZquery_statesrP   rQ   r-   r.   Zcache_kwargsattention_interfacerT   rR   r'   r'   r(   forward   s8    


zEmu3Attention.forward)NN)__name__
__module____qualname____doc__r   intr_   r   r#   Tensortupler   r   
LongTensorr   r   r{   __classcell__r'   r'   rl   r(   rV   w   s     rV   ZRMSNormc                       s.   e Zd Zd fdd	Zdd Zdd Z  ZS )	Emu3RMSNormư>c                    s&   t    tt|| _|| _dS )z:
        Emu3RMSNorm is equivalent to T5LayerNorm
        N)r^   r_   rK   	Parameterr#   onesweightvariance_epsilon)rk   ra   epsrl   r'   r(   r_      s    
zEmu3RMSNorm.__init__c                 C   sJ   |j }|tj}|djddd}|t|| j  }| j|| S )Nr   r   T)keepdim)	rE   rN   r#   rM   powmeanZrsqrtr   r   )rk   r1   Zinput_dtypeZvariancer'   r'   r(   r{      s
    zEmu3RMSNorm.forwardc                 C   s   t | jj d| j S )Nz, eps=)r   r   r"   r   rk   r'   r'   r(   
extra_repr   s    zEmu3RMSNorm.extra_repr)r   )r|   r}   r~   r_   r{   r   r   r'   r'   rl   r(   r      s   r   c                       s$   e Zd Z fddZdd Z  ZS )Emu3MLPc                    sx   t    || _|j| _|j| _tj| j| j|jd| _tj| j| j|jd| _	tj| j| j|jd| _
t|j | _d S )Nr\   )r^   r_   rY   ra   Zintermediate_sizerK   re   Zmlp_bias	gate_projup_proj	down_projr   Z
hidden_actact_fnrk   rY   rl   r'   r(   r_      s    
zEmu3MLP.__init__c                 C   s$   |  | | || | }|S N)r   r   r   r   )rk   r%   r   r'   r'   r(   r{      s     zEmu3MLP.forwardr|   r}   r~   r_   r{   r   r'   r'   rl   r(   r      s   
r   c                       s   e Zd Zeed fddZedddddeje	ej e	ej
 e	e e	e e	ej
 e	eejejf  ee ejd
	ddZ  ZS )Emu3DecoderLayerrX   c                    s`   t    |j| _t||d| _t|| _t|j|jd| _	t|j|jd| _
t|j| _d S )NrX   r   )r^   r_   ra   rV   	self_attnr   mlpr   rms_norm_epsinput_layernormpost_attention_layernormrK   ZDropoutrc   rB   rj   rl   r'   r(   r_      s    

zEmu3DecoderLayer.__init__rn   ro   rp   rq   NF)	r1   r@   r/   ro   	use_cacheru   rt   rC   r3   c              
   K   sj   |}	|  |}| jf |||||||d|\}}
|	| | }|}	| |}| |}|	| | }|S )N)r1   r@   r/   ro   r   ru   rt   )r   r   rB   r   r   )rk   r1   r@   r/   ro   r   ru   rt   rC   residual_r'   r'   r(   r{      s&    




zEmu3DecoderLayer.forward)NNNFNN)r|   r}   r~   r   r   r_   r   r#   r   r   r   r   boolr   r   r   r{   r   r'   r'   rl   r(   r      s&         r   c                       s6   e Zd ZdZed fddZejdddZ  Z	S )Emu3VQVAEVectorQuantizera  
    A module for vector quantization using learned embedding vectors.

    This module implements the quantization process similar to te one described in
    the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous
    input vectors into discrete codebook vectors, which are learned during training.
    Current implementation improves over previous ones by avoiding costly matrix multiplications
    and allowing for post-hoc remapping of indices.
    rY   c                    s>   t    t|j|j| _| jjj	d|j d|j  d S )Ng            ?)
r^   r_   rK   	EmbeddingZcodebook_size	embed_dim	embeddingr   datauniform_r   rl   r'   r(   r_     s    
z!Emu3VQVAEVectorQuantizer.__init__)hidden_statec                 C   s   |j \}}}}}|ddddd }|d|}tj|d ddd}tj| jjd dd	}	dt|| jj	dd }
||	 |
 }
tj
|
dd	}|||||}|S )
Nr   r   r      r   r   T)r!   r   r    )r"   permuterO   rw   r#   sumr   r   rI   rJ   Zargmin)rk   r   
batch_sizetemporalchannelsheightwidthZhidden_state_flattenedZhidden_state_sumZembedding_sumZ	distancesZmin_encoding_indicesr'   r'   r(   r{   !  s    z Emu3VQVAEVectorQuantizer.forward)
r|   r}   r~   r   r   r_   r#   r   r{   r   r'   r'   rl   r(   r     s   
r   c                       s$   e Zd Z fddZdd Z  ZS )Emu3VQVAEEncoderConvDownsamplec                    s$   t    tj||dddd| _d S )Nr   r   r   kernel_sizestridepaddingr^   r_   rK   Conv2dconvrk   in_channelsrl   r'   r(   r_   4  s    
z'Emu3VQVAEEncoderConvDownsample.__init__c                 C   s    t j|dddd}| |}|S )N)r   r   r   r   Zconstantr   )padmoder?   )Fr   r   rk   r1   r'   r'   r(   r{   8  s    
z&Emu3VQVAEEncoderConvDownsample.forwardr   r'   r'   rl   r(   r   3  s   r   c                       s$   e Zd Z fddZdd Z  ZS )Emu3VQVAEEncoderConvUpsamplec                    s$   t    tj||dddd| _d S )Nr   r   r   r   r   rl   r'   r(   r_   @  s    
z%Emu3VQVAEEncoderConvUpsample.__init__c                 C   s   t j|ddd}| |}|S )N       @nearestZscale_factorr   )r   interpolater   r   r'   r'   r(   r{   D  s    
z$Emu3VQVAEEncoderConvUpsample.forwardr   r'   r'   rl   r(   r   ?  s   r   c                       s@   e Zd Zeeee ee d fddZejdddZ  Z	S )Emu3VQVAEConv3d)
in_channelout_channelr   r   c                    s   t    dd t|dd  |dd  D }d| _|d d d D ]&}|  j|d |d  |d f7  _qB|  jd7  _tj||||d| _d S )	Nc                 S   s   g | ]\}}|| qS r'   r'   ).0Z
one_kernelZ
one_strider'   r'   r(   
<listcomp>T      z,Emu3VQVAEConv3d.__init__.<locals>.<listcomp>r   r'   r   r   )r   r   )r   )r^   r_   zipr   rK   Conv3dr   )rk   r   r   r   r   Zpadding_sizesZpad_sizerl   r'   r(   r_   K  s    
$$zEmu3VQVAEConv3d.__init__r1   c                 C   s   t || j}| |}|S r   )r   r   r   r   r   r'   r'   r(   r{   a  s    
zEmu3VQVAEConv3d.forward)
r|   r}   r~   r   r   r_   r#   r   r{   r   r'   r'   rl   r(   r   J  s   r   c                       s8   e Zd Zeed fddZejejdddZ  ZS )Emu3VQVAESpatialNormr   out_channelsc                    sN   t    tj|dddd| _tj||dddd| _tj||dddd| _d S )N    r   Tnum_channels
num_groupsr   affiner   r   r   )r^   r_   rK   	GroupNorm
norm_layerr   conv_yconv_brk   r   r   rl   r'   r(   r_   h  s*    
zEmu3VQVAESpatialNorm.__init__r1   quant_statesc                 C   s@   t j||jdd  dd}| |}|| | | | }|S )NrD   r   )sizer   )r   r   r"   r   r   r   )rk   r1   r   r'   r'   r(   r{     s    
zEmu3VQVAESpatialNorm.forward	r|   r}   r~   r   r_   r#   r   r{   r   r'   r'   rl   r(   r   g  s   r   c                       s4   e Zd Zeed fddZejdddZ  ZS )Emu3VQVAETemporalUpsampler   r   c                    s    t    t||ddd| _d S )Nr   r   r   r   r   r   r   r   r^   r_   r   r   rk   r   r   rl   r'   r(   r_     s    
z"Emu3VQVAETemporalUpsample.__init__r   c                 C   sr   |j \}}}}}|ddddd |d|}tj|ddd	}|||||dddddd }| |}|S )
Nr   r   r   r   r   r   r   r   r   )r"   r   rO   rw   r   r   r   )rk   r1   r   r   r   r   r   r'   r'   r(   r{     s     $
z!Emu3VQVAETemporalUpsample.forwardr   r'   r'   rl   r(   r     s   r   c                       s4   e Zd Zeed fddZejdddZ  ZS )Emu3VQVAETemporalDownsampler   c                    s    t    t||ddd| _d S )N)r   r   r   )r   r   r   r   r   r   rl   r'   r(   r_     s    
z$Emu3VQVAETemporalDownsample.__init__r   c                 C   s   |  |}|S r   )r   r   r'   r'   r(   r{     s    
z#Emu3VQVAETemporalDownsample.forwardr   r'   r'   rl   r(   r     s   r   c                       s&   e Zd Zd fdd	Zdd Z  ZS )Emu3VQVAETemporalResnetBlockNc                    s   t    || _|d u r|n|| _t|| _t||ddd| _t|| _	t||ddd| _
| j| jkrtj||dddd| _d S )Nr   r   r   r   r   r   )r^   r_   r   r   rK   BatchNorm3dnorm1r   conv1norm2conv2r   nin_shortcutr   rl   r'   r(   r_     s2    
z%Emu3VQVAETemporalResnetBlock.__init__c                 C   sf   |}|  |}|t|9 }| |}| |}|t|9 }| |}| j| jkr^| |}|| S r   )	r   r#   sigmoidr   r   r   r   r   r   )rk   r1   r   r'   r'   r(   r{     s    




z$Emu3VQVAETemporalResnetBlock.forward)Nr   r'   r'   rl   r(   r     s     r   c                       sJ   e Zd Zdeee ee d fddZd	ejeej dddZ  Z	S )
Emu3VQVAEResnetBlockNr   r   quant_channelsc                    s   t    || _|d u r|n|}|| _|| _|d u r^tj|dddd| _tj|dddd| _nt	||| _t	||| _tj
||dddd| _tj
||dddd| _| j| jkrtj
||dddd| _d S )	Nr   r   Tr   r   r   r   r   )r^   r_   r   r   r   rK   r   r   r   r   r   r   r   r   )rk   r   r   r   rl   r'   r(   r_     s@    
zEmu3VQVAEResnetBlock.__init__)r1   r   c                 C   s   | j d u rdn|f}|}| j|g|R  }|t|9 }| |}| j|g|R  }|t|9 }| |}| j| jkr| 	|}|| S )Nr'   )
r   r   r#   r   r   r   r   r   r   r   )rk   r1   r   Z	norm_argsr   r'   r'   r(   r{     s    


zEmu3VQVAEResnetBlock.forward)NN)N)
r|   r}   r~   r   r   r_   r#   r   r{   r   r'   r'   rl   r(   r     s     ,r   c                       sR   e Zd ZdZed fddZd	ejeej e	ejeej f dddZ
  ZS )
Emu3VQVAEAttentionBlockrW   r   c                    s   t    || _|j| _|j| _| j| j | _| j| j | jkrZtd| j d| j d| jd | _	|j
| _d| _t| j| j| _t| j| j| _t| j| j| _t| j| j| _d| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).r[   Fr   )r^   r_   rY   ra   r   rb   	num_headsr9   
ValueErrorscalerc   rB   rd   rK   re   rg   rh   rf   out_projrH   r   rl   r'   r(   r_   (  s&    

z Emu3VQVAEAttentionBlock.__init__N)r1   r@   r3   c              
   K   s   |j \}}}| |}| |}| |}	|||| j| jdd}|||| j| jdd}|	||| j| jdd}	t}
| j	j
dkrt| j	j
 }
|
| |||	|| j| j| jsdn| jd\}}|||| }| |}||fS )z#Input shape: Batch x Time x Channelr   r   rv   r;   )rd   rA   rB   )r"   rf   rg   rh   rw   r   r9   rJ   rU   rY   ry   r   rd   r   rG   rB   r5   rO   r   )rk   r1   r@   rC   r   Z
seq_lengthr   Zquerieskeysvaluesrz   rT   rR   r'   r'   r(   r{   ?  s.    




zEmu3VQVAEAttentionBlock.forward)N)r|   r}   r~   r   r   r_   r#   r   r   r   r{   r   r'   r'   rl   r(   r   %  s    r   c                       s*   e Zd ZdZ fddZdddZ  ZS )Emu3VQVAEGroupNormz
    Same as the torch GroupNorm with the only difference that this ones accepts
    an optional kwarg `quant_states` which is not used. This class makes it easier to
    use SpatialNorm or GroupNorm without conditionals
    c                    s   t  jf i | d S r   )r^   r_   rk   rC   rl   r'   r(   r_   m  s    zEmu3VQVAEGroupNorm.__init__Nc                 C   s   t || j| j| j| jS r   )r   Z
group_normr   r   r]   r   )rk   inputr   r'   r'   r(   r{   p  s    zEmu3VQVAEGroupNorm.forward)N)r|   r}   r~   r   r_   r{   r   r'   r'   rl   r(   r   f  s   r   c                       s8   e Zd Zd fdd	Zdejeej dddZ  ZS )	Emu3VQVAEMiddleBlockNc                    s`   t    t|||d| _t|| _|d u r@t|dddd| _nt||| _t|||d| _	d S )Nr   r   r   Tr   )
r^   r_   r   block_1r   attn_1r   	attn_normr   block_2)rk   rY   r   r   rl   r'   r(   r_   u  s    

zEmu3VQVAEMiddleBlock.__init__r   c                 C   s   |  ||}|}| ||}|j\}}}}||||| dd}| |d }|||||dddd}|| }| ||}|S )Nr   r   r   r   )	r   r  r"   rw   rJ   r  r5   r   r  )rk   r1   r   r   r   r   r   r   r'   r'   r(   r{     s    zEmu3VQVAEMiddleBlock.forward)N)N)	r|   r}   r~   r_   r#   FloatTensorr   r{   r   r'   r'   rl   r(   r   t  s   r   c                       s,   e Zd Z fddZejdddZ  ZS )Emu3VQVAEDownBlockc              
      s*  t    t|j| _|j| _|j}|j}dt| }|| _t	
 | _t| jD ]}t	
 }t	
 }t	
 }|||  }	|||  }
t| jD ]T}|t|	|
d |
}	|jd ur||jv r|t| |t	j|	dddd qt	 }||_||_||_|| jd krt|	|_| j| qPd S )Nr   r   r   r   Tr   r   )r^   r_   lenchannel_multipliernum_resolutionsnum_res_blocksbase_channelsr   in_channel_multiplierrK   
ModuleListdownrangeappendr   attn_resolutionsr   r   Moduleblockattn
attn_normsr   
downsample)rk   rY   r  r  r  i_levelr  r  r  block_in	block_outi_blockr  rl   r'   r(   r_     s@    


zEmu3VQVAEDownBlock.__init__r   c           
      C   s   t | jD ]\}}t| jD ]}|j| |}t|jdkr|}|j| |}|j\}}}}	|	||||	 
dd}|j| |d }||||	|dddd}|| }q|| jd kr
||}q
|S )Nr   r   r   r   )	enumerater  r  r
  r  r  r  r  r"   rw   rJ   r5   r   r	  r  )
rk   r1   r  blocksr  r   r   r   r   r   r'   r'   r(   r{     s    
zEmu3VQVAEDownBlock.forwardr|   r}   r~   r_   r#   r  r{   r   r'   r'   rl   r(   r    s   %r  c                       s0   e Zd Z fddZejejdddZ  ZS )Emu3VQVAEUpBlockc              	      s  t    t|j| _|j| _|j}|j|jd  }t	 | _
tt| jD ]}t	 }t	 }t	 }|j|j|  }t| jd D ]D}	|t|||d |}||jv r|t| |t|| qt }
||
_||
_||
_|dkrt||
_| j
d|
 qLd S )Nr   r   r   r   )r^   r_   r  r  r	  r
  r   r  rK   r  upreversedr  r  r   r  r   r   r  r  r  r  r   upsampleinsert)rk   rY   r   r  r  r  r  r  r  r  r  rl   r'   r(   r_     s<    



zEmu3VQVAEUpBlock.__init__r   c                 C   s   t | jd d d D ]\}}t| jd D ]}|j| ||}t|jdkr*|}|j| ||}|j\}}}	}
|	|||	|
 
dd}|j| |d }|||	|
|dddd}|| }q*|t| jd kr||}q|S )Nr   r   r   r   r   )r  r  r  r
  r  r  r  r  r"   rw   rJ   r5   r   r!  )rk   r1   r   r  r  r  r   r   r   r   r   r'   r'   r(   r{     s    
zEmu3VQVAEUpBlock.forwardr  r'   r'   rl   r(   r    s   %r  c                       s,   e Zd Z fddZejdddZ  ZS )Emu3VQVAEEncoderc                    s  t    |j}|j}|j}|j}|j}|r4d| n|}||d  }tjj	||dddd| _
t|| _t||| _tjjd|ddd	| _tjj	||dddd| _tt|j}	t | _t | _t|	D ]}
t||}| j| qt|jD ]}t||d
}| j| qd S )Nr   r   r   r   r   r   r   T)r   r   r   r   r   )r^   r_   r  r   double_latentlatent_channelsr  r#   rK   r   conv_inr  
down_blockr   middle_blockr   norm_outconv_outr   mathlog2temporal_downsample_factorr  	time_convtime_res_stackr  r   r  r
  r   )rk   rY   r  r   r$  r%  r  r   r  Ztemporal_down_blocksir   r   time_res_convrl   r'   r(   r_     s>    




zEmu3VQVAEEncoder.__init__)pixel_valuesc                 C   s   |j d }|jdg|j dd  R  }| |}| |}| |}| |}|t|9 }| |}|jd|g|j dd  R  }|	ddddd}| j
D ]}||}|t|9 }q| jD ]}||}q|	ddddd}|S )Nr   r   r   r   r   r   )r"   r5   r&  r'  r(  r)  r#   r   r*  r   r.  r/  )rk   r2  Ztemporal_dimr1   r   layerr'   r'   r(   r{   3  s"    








zEmu3VQVAEEncoder.forward)r|   r}   r~   r_   r#   r   r{   r   r'   r'   rl   r(   r#    s   'r#  c                       s6   e Zd Zed fddZejejdddZ  ZS )Emu3VQVAEDecoderr   c           	         s  t    |j}|j|jd  }t | _t|j	D ] }t
|j|jd}| j| q4tt|j}t | _t|D ]}t|j|j}| j| qxtj|j|dddd| _t|||d| _t|| _|j|jd  }t||| _tj||jdddd| _d S )Nr   r   r   r   r   )r   r   )r^   r_   r   r  r  rK   r  r/  r  r
  r   r%  r  r   r+  r,  r-  r.  r   r   r&  r   r(  r  up_blockr   r)  r   r*  )	rk   rY   r   r  r   r1  Ztemp_upsample_block_numr0  r   rl   r'   r(   r_   R  s@    



zEmu3VQVAEDecoder.__init__r   c                 C   s  t j||fdd}|ddddd}| jD ]}||}q*| jD ]}||}|t |9 }q>|ddddd}t j|ddd\}}|jdg|jdd  R  }|jdg|jdd  R  }| 	|}| 
||}| ||}| ||}|t |9 }| |}|S )Nr   r    r   r   r   r   r   )r#   r$   r   r/  r.  r   chunkr5   r"   r&  r(  r5  r)  r*  )rk   r1   r   Zhidden_quant_statesr3  r'   r'   r(   r{   y  s$    




zEmu3VQVAEDecoder.forward)	r|   r}   r~   r   r_   r#   r   r{   r   r'   r'   rl   r(   r4  Q  s   'r4  aR  
    The VQ-VAE model used in Emu3 for encoding/decoding images into discrete tokens.
    This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
    [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv
    Taigman](https://huggingface.co/papers/2203.13131).
    )Zcustom_introc                       sx   e Zd ZU eed< dZdZdZdZdZ	dZ
g dZdd Zed fd	d
ZejejdddZejdddZ  ZS )	Emu3VQVAErY   Z
emuvideovqr2  T)r   r   r   r   c                 C   sL  t |tjtjfrftjj|jddd |jd urdtj|j\}}dt	
| }tj|j| | nt |tjrtjj|jt	
dd |jd urtj|j\}}|dkrdt	
| nd}tj|j| | nrt |tjtjtjfrtj|jd tj|jd	 n8t |tjrH|jj  |jd urH|jj|j   d S )
NZfan_outZrelu)r   Znonlinearityr      )ar   r   r;   )
isinstancerK   r   r   initZkaiming_normal_r   r]   Z_calculate_fan_in_and_fan_outr+  sqrtr   re   Zkaiming_uniform_ZBatchNorm2dr   r   Z	constant_r   r   Znormal_padding_idxZzero_)rk   r<   Zfan_inr   boundr'   r'   r(   _init_weights  s&    

zEmu3VQVAE._init_weightsr   c                    s   t  | || _t|| _t|| _t|| _dt	|j
d  | _t|j|jddd| _t|j|jddd| _dt	|j
d  | _|   |   d S )Nr   r   )r   r   r   r   r   )r^   r_   rY   r#  encoderr4  decoderr   quantizer  r  vision_spatial_factorr   r%  r   
quant_convpost_quant_convspatial_scale_factoreval	post_initr   rl   r'   r(   r_     s    


zEmu3VQVAE.__init__r2  image_sizesc                    s   |j dk}|r> jj}|j\}}}}|dd|ddd}n|j\}}}}} |}	|	ddddd}	 |	}	|	ddddd}	 	|	}
|r|

dn|
} fddt||D }|S )Nr   r   r   r   r   c                    s@   g | ]8\}}|d t |d  j d t |d  j f qS )Nr   r   )r   rC  )r   Zsingle_imager   r   r'   r(   r     s   z$Emu3VQVAE.encode.<locals>.<listcomp>)ndimrY   r-  r"   r*   repeatr@  r   rD  rB  Zsqueezer   )rk   r2  rJ  is_imager   r   r   r   r   r1   codesimage_tokensr'   r   r(   encode  s     




zEmu3VQVAE.encoder   c                 C   s   |j dk}|r|d}|j\}}}}| j| }|jd }||||||ddddd }| 	|}	|ddddd}|	ddddd}	| 
|	|}
|
||| jj | jj|| j || j }
|r|
d d df S |
S )Nr   r   r   r   r   r   )rK  r*   r"   rB  r   flattenrw   r   rO   rE  rA  r5   rY   r-  r   rF  )rk   r1   rM  r   r   r   r   Zquantr   Z
post_quantZvideor'   r'   r(   decode  s&    


$

zEmu3VQVAE.decode)r|   r}   r~   r   __annotations__base_model_prefixZmain_input_name_supports_sdpa_supports_flash_attn_supports_flex_attn_supports_attention_backend_no_split_modulesr?  r_   r#   r   rP  rR  r   r'   r'   rl   r(   r7    s   
	r7  c                   @   s   e Zd ZdZdd Zedd Zedd Zedd	 Zed
d Z	edd Z
edd Zeej ejdddZejejdddZdS )Emu3ImageVocabularyMappingzM
    A class for mapping discrete image tokens from VQGAN to BPE tokens.
    c                 C   s"   || _ |d| _|d| _d S )Nz<|extra_200|>z<image>)	vocab_mapgeteol_token_idimage_token_id)rk   r[  r'   r'   r(   r_     s    z#Emu3ImageVocabularyMapping.__init__c                 C   s   t dd | j D S )Nc                 S   s   g | ]\}}| d r|qS z<|visual token
startswithr   namevalr'   r'   r(   r     r   z;Emu3ImageVocabularyMapping.image_tokens.<locals>.<listcomp>sortedr[  itemsr   r'   r'   r(   rO    s    z'Emu3ImageVocabularyMapping.image_tokensc                 C   s   t dd | j D S )Nc                 S   s   g | ]\}}| d r|qS r_  r`  rb  r'   r'   r(   r     r   z?Emu3ImageVocabularyMapping.image_tokens_str.<locals>.<listcomp>re  r   r'   r'   r(   image_tokens_str  s    z+Emu3ImageVocabularyMapping.image_tokens_strc                    s    fdd j D S )Nc                    s$   i | ]}t |d d  j| qS )irD   )r   r[  )r   tokenr   r'   r(   
<dictcomp>"  r   z6Emu3ImageVocabularyMapping.img2bpe.<locals>.<dictcomp>)rh  r   r'   r   r(   img2bpe   s    z"Emu3ImageVocabularyMapping.img2bpec                 C   s   dd | j  D S )Nc                 S   s   i | ]\}}||qS r'   r'   )r   r,   vr'   r'   r(   rj  &  r   z6Emu3ImageVocabularyMapping.bpe2img.<locals>.<dictcomp>)rk  rg  r   r'   r'   r(   bpe2img$  s    z"Emu3ImageVocabularyMapping.bpe2imgc                 C   s>   t jt| j d t jd}| j D ]\}}|||< q(|S Nr   rE   )r#   zerosmaxrm  r   r   rg  rk   mappingr,   rl  r'   r'   r(   bpe2img_mapping_tensor(  s    
z1Emu3ImageVocabularyMapping.bpe2img_mapping_tensorc                 C   s>   t jt| j d t jd}| j D ]\}}|||< q(|S rn  )r#   rp  rq  rk  r   r   rg  rr  r'   r'   r(   img2bpe_mapping_tensor/  s    
z1Emu3ImageVocabularyMapping.img2bpe_mapping_tensor)	img_batchr3   c                 C   sR   |j }tj|jd dftjd| j }| j|d }tj||gdd}||S )Nr   r   ro  cpur   r    )	devicer#   r   r"   r   r]  ru  rN   r$   )rk   rv  rx  Zeol_row
img_tokensr'   r'   r(   convert_img2bpe6  s
     z*Emu3ImageVocabularyMapping.convert_img2bpec                 C   s0   |j }|dd df }| j|d }||S )N.r   rw  )rx  rt  rN   )rk   rv  rx  ry  r'   r'   r(   convert_bpe2img=  s    z*Emu3ImageVocabularyMapping.convert_bpe2imgN)r|   r}   r~   r   r_   r   rO  rh  rk  rm  rt  ru  listr#   r   rz  r{  r'   r'   r'   r(   rZ    s    





rZ  c                   @   sD   e Zd ZU eed< dZdZdgZddgZdZ	dZ
dZdZdZdZdS )	Emu3PreTrainedModelrY   modelTr   ro   rS   FN)r|   r}   r~   r   rS  rT  Zsupports_gradient_checkpointingrY  Z_skip_keys_device_placementrV  rU  Z_can_compile_fullgraphZ!_supports_param_buffer_assignmentrW  rX  r'   r'   r'   r(   r}  D  s   
r}  c                       sD   e Zd ZU ejed< ded fddZe e	dd Z
  ZS )	Emu3RotaryEmbeddinginv_freqNr   c                    s   t    t|dr:t|jtr:|jd|jd| _nd| _|j| _	|j| _
|| _t| j | _| | j|\}| _| jd|dd | j| _d S )Nrope_scaling	rope_typetypedefaultr  F)
persistent)r^   r_   hasattrr:  r  dictr\  r  Zmax_position_embeddingsZmax_seq_len_cachedZoriginal_max_seq_lenrY   r   Zrope_init_fnattention_scalingZregister_bufferr  Zoriginal_inv_freq)rk   rY   rx  r  rl   r'   r(   r_   Y  s    
zEmu3RotaryEmbedding.__init__c           
      C   s   | j d d d d f  |jd dd|j}|d d d d d f  }t|jjtrl|jjdkrl|jjnd}t	j
|ddV | |  dd}t	j||fdd	}| | j }| | j }	W d    n1 s0    Y  |j|jd
|	j|jd
fS )Nr   r   r   Zmpsrw  F)device_typeZenabledr   r    ro  )r  floatr4   r"   rN   rx  r:  r  strr#   ZautocastrJ   r$   r-   r  r.   rE   )
rk   r%   r/   Zinv_freq_expandedZposition_ids_expandedr  ZfreqsZembr-   r.   r'   r'   r(   r{   j  s    0&,zEmu3RotaryEmbedding.forward)N)r|   r}   r~   r#   r   rS  r   r_   no_gradr   r{   r   r'   r'   rl   r(   r  V  s
   

r  c                       s~   e Zd ZeedZed fddZee	d	e
ej e
ej e
ej e
e e
ej e
ej e
e ee ed	ddZ  ZS )
Emu3TextModel)r1   
attentionsr   c                    s   t     j| _ j| _t j j| j| _t	 fddt
 jD | _t j jd| _t d| _d| _|   d S )Nc                    s   g | ]}t  |qS r'   )r   )r   rZ   r   r'   r(   r     r   z*Emu3TextModel.__init__.<locals>.<listcomp>r   r   F)r^   r_   Zpad_token_idr=  
vocab_sizerK   r   ra   embed_tokensr  r  num_hidden_layerslayersr   r   normr  
rotary_embZgradient_checkpointingrH  r   rl   r   r(   r_     s    zEmu3TextModel.__init__N)		input_idsr@   r/   ro   inputs_embedsru   r   rC   r3   c              	   K   s   |d u |d uA rt d|d u r*| |}|rB|d u rBt| jd}|d u rz|d urZ| nd}	tj|	|	|jd  |jd}|d u r|	d}t
| j|||||d}
|}| ||}| jd | jj D ] }||f|
||||d|}q| |}t||dS )	Nz:You must specify exactly one of input_ids or inputs_embedsr   r   r   )rx  )rY   Zinput_embedsr@   ru   ro   r/   )r@   r/   ro   ru   rt   )last_hidden_statero   )r   r  r	   rY   Zget_seq_lengthr#   Zaranger"   rx  r*   r   r  r  r  r  r   )rk   r  r@   r/   ro   r  ru   r   rC   Zpast_seen_tokensrS   r1   rt   Zdecoder_layerr'   r'   r(   r{     sP    

	

zEmu3TextModel.forward)NNNNNNN)r|   r}   r~   r   rV   Z_can_record_outputsr   r_   r   r   r   r#   r   r   r   r  r   r   r   r   r{   r   r'   r'   rl   r(   r  z  s0          r  c                       s   e Zd ZU dgZddiZddgdgfiZeed<  fddZe	e
deej eej eej ee eej eej ee eej eeejf ee edddZ  ZS )Emu3ForCausalLMlm_head.weightlm_headZcolwise_repr1   logitsrY   c                    s@   t  | t|| _|j| _tj|j|jdd| _| 	  d S NFr\   )
r^   r_   r  r~  r  rK   re   ra   r  rH  r   rl   r'   r(   r_     s
    
zEmu3ForCausalLM.__init__Nr   )r  r@   r/   ro   r  labelsr   ru   logits_to_keeprC   r3   c
              
   K   s   | j f |||||||d|
}|j}t|	tr<t|	 dn|	}| |dd|ddf }d}|dur| jf ||| jjd|
}t	|||j
|j|jdS )a  
        Example:

        ```python
        >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration
        >>> import torch
        >>> import requests
        >>> from PIL import Image

        >>> model = Emu3ForCausalLM.from_pretrained("BAAI/Emu3-Chat-hf", dtype=torch.bfloat16)
        >>> processor = Emu3Processor.from_pretrained("BAAI/Emu3-Chat-hf")

        >>> inputs = processor(text=["Can you write me a poem about winter."], return_tensors="pt").to(model.device)

        >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
        >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        ```r  r@   r/   ro   r  r   ru   Nr  r  r  lossr  ro   r1   r  )r~  r  r:  r   slicer  loss_functionrY   r  r   ro   r1   r  )rk   r  r@   r/   ro   r  r  r   ru   r  rC   outputsr1   slice_indicesr  r  r'   r'   r(   r{     s0     zEmu3ForCausalLM.forward)	NNNNNNNNr   )r|   r}   r~   _tied_weights_keysZ_tp_planZ_pp_planr   rS  r_   r   r   r   r#   r   r   r   r  r   r   r   r   r   r   r{   r   r'   r'   rl   r(   r    s:   
	         r  c                       s   e Zd ZddiZ fddZdd Zdd Zd	d
 Zdd Ze	j
e	jdddZe	j
e	jdddZe	je	jeedddZe	je	j
e	j
dddZeede	je	j
e	jee	j ee	j ee ee	j
 ee ee	j ee eeef dddZ  ZS )	Emu3Modelztext_model.model
text_modelc                    s>   t  | t|j| _t|j| _t	|j
| _|   d S r   )r^   r_   r  _from_configtext_configr  r7  Z	vq_configvqmodelrZ  Zvocabulary_mapvocabulary_mappingrH  r   rl   r'   r(   r_     s
    zEmu3Model.__init__c                 C   s
   | j  S r   )r  get_input_embeddingsr   r'   r'   r(   r  '  s    zEmu3Model.get_input_embeddingsc                 C   s   | j | d S r   )r  set_input_embeddingsrk   r?   r'   r'   r(   r  *  s    zEmu3Model.set_input_embeddingsc                 C   s
   || _ d S r   r  rk   rA  r'   r'   r(   set_decoder-  s    zEmu3Model.set_decoderc                 C   s   | j S r   r  r   r'   r'   r(   get_decoder0  s    zEmu3Model.get_decoderrI  c                    s.    j ||} fdd|D }t|}|S )a  
        Tokenizes images into discrete tokens with VQGAN module. Converts
        obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
        special tokens.

        Args:
            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
                The tensors corresponding to the input images.
            image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`):
                The sizes of the images in the batch, being (height, width) for each image.
        c                    s   g | ]} j | qS r'   )r  rz  rQ  )r   tokensr   r'   r(   r   @  r   z.Emu3Model.get_image_tokens.<locals>.<listcomp>)r  rP  r#   r$   )rk   r2  rJ  Zimage_tokens_listZbpe_tokens_listZ
bpe_tokensr'   r   r(   get_image_tokens3  s    
zEmu3Model.get_image_tokensc                    s:     ||} fdd|D }  |}t||}|S )a7  
        Tokenizes images into discrete tokens with VQGAN module and embeds
        them with text embeddings layer

        Args:
            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
                The tensors corresponding to the input images.
        c                    s,   g | ]$\}}| j j | j j d   qS r  )r  rC  )r   r   r   r   r'   r(   r   N  s   z0Emu3Model.get_image_features.<locals>.<listcomp>)r  r  r#   split)rk   r2  rJ  rO  Zsplit_sizesimage_featuresr'   r   r(   get_image_featuresD  s    	
zEmu3Model.get_image_features)rO  r   r   c                 C   s>   |ddddf  d||d }| j|}| j|}|S )a  
        Decodes generated image tokens from language model to continuous pixel values
        with VQGAN module via upsampling.

        Args:
            image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`):
                The tensors corresponding to the input images.
            height (`int`):
                Height of the generated image before upsampling.
            width (`int`):
                Width of the generated image before upsampling.
        Nr   r   )rw   r  r{  r  rR  )rk   rO  r   r   	sequencesimager'   r'   r(   decode_image_tokensV  s    "zEmu3Model.decode_image_tokens)r  r  r  c                 C   s   |du r8||   tj| jjtj|jdk}|d}n|| jjk}| }|	d
||j}|jd |jd  }||  | krtd| d| |S )z
        Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
        equal to the length of multimodal features. If the lengths are different, an error is raised.
        N)rE   rx  r   r   r   z6Image features and image tokens do not match: tokens: z, features )r  r#   Ztensorr  r^  longrx  allr   r*   Z	expand_asrN   r"   Znumelr   )rk   r  r  r  special_image_maskZn_image_tokensZn_image_featuresr'   r'   r(   get_placeholder_maski  s    zEmu3Model.get_placeholder_maskN)r  r2  rJ  r@   r/   ro   r  r   ru   rC   r3   c
              	   K   s   |du |duA rt d|du r,|  |}|durj| ||}tj|dd}| j|||d}|||}| jf ||||||	d|
}|S )ap  
        image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`):
            The sizes of the images in the batch, being (height, width) for each image. Image sizes can be obtained using
            [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses
            [`Emu3ImageProcessor`] for processing images).
        NzaYou cannot specify both input_ids and inputs_embeds at the same time, and must specify either oner   r    )r  r  )r@   r/   ro   r  r   ru   )r   r  r  r#   r$   r  Zmasked_scatterr  )rk   r  r2  rJ  r@   r/   ro   r  r   ru   rC   Zimage_embedsr  r  r'   r'   r(   r{     s0    
zEmu3Model.forward)	NNNNNNNNN)r|   r}   r~   _checkpoint_conversion_mappingr_   r  r  r  r  r#   r  r   r  r  r  r   r  r  r   r   r   r   r   r   r   r   r   r   r   r{   r   r'   r'   rl   r(   r    sH   	         
r  c                       s  e Zd ZdZdgZddddZ fddZd	d
 Zdd Ze	j
dddZdd Zdd Zedd Zedd Zedd Zdd Zeed$ejejejeej eej ee eej ee eej eej eeejf ee ee e!f ddd Z"d% fd"d#	Z#  Z$S )&Emu3ForConditionalGeneration r  zmodel.text_modelzmodel.vqmodelr  )z^text_model.modelz^vqmodelz^text_model.lm_headc                    s<   t  | t|| _tj|jj|jjdd| _	| 
  d S r  )r^   r_   r  r~  rK   re   r  ra   r  r  rH  r   rl   r'   r(   r_     s    
z%Emu3ForConditionalGeneration.__init__c                 C   s
   | j  S r   )r~  r  r   r'   r'   r(   r    s    z1Emu3ForConditionalGeneration.get_input_embeddingsc                 C   s   | j | d S r   )r~  r  r  r'   r'   r(   r    s    z1Emu3ForConditionalGeneration.set_input_embeddings)r3   c                 C   s   | j S r   )r  r   r'   r'   r(   get_output_embeddings  s    z2Emu3ForConditionalGeneration.get_output_embeddingsc                 C   s   | j | d S r   )r~  r  r  r'   r'   r(   r    s    z(Emu3ForConditionalGeneration.set_decoderc                 C   s
   | j  S r   )r~  r  r   r'   r'   r(   r    s    z(Emu3ForConditionalGeneration.get_decoderc                 C   s   | j jS r   )r~  r  r   r'   r'   r(   r    s    z'Emu3ForConditionalGeneration.text_modelc                 C   s   | j jS r   )r~  r  r   r'   r'   r(   r    s    z$Emu3ForConditionalGeneration.vqmodelc                 C   s   | j jS r   )r~  r  r   r'   r'   r(   r    s    z/Emu3ForConditionalGeneration.vocabulary_mappingc                 K   s   | j jf i |S r   )r~  r  r   r'   r'   r(   r    s    z0Emu3ForConditionalGeneration.decode_image_tokensNr   )r  r2  rJ  r@   r/   ro   r  r   ru   r  r  rC   r3   c              
   K   s   | j f |||||||	d|}|d }t|tr>t| dn|}| |dd|ddf }d}|
dur| jf ||
| jjjd|}t	|||j
|j|jdS )an  
        image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`):
            The sizes of the images in the batch, being (height, width) for each image. Image sizes can be obtained using
            [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses
            [`Emu3ImageProcessor`] for processing images).
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Example:

        ```python
        >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration
        >>> import torch
        >>> import requests
        >>> from PIL import Image

        >>> model = Emu3ForConditionalGeneration.from_pretrained("BAAI/Emu3-Chat-hf", dtype=torch.bfloat16)
        >>> processor = Emu3Processor.from_pretrained("BAAI/Emu3-Chat-hf")

        >>> conversation = [
        ...     {
        ...     "role": "system",
        ...     "content": [
        ...         {"type": "text", "text": "You are a helpful assistant."},
        ...         ],
        ...     },
        ...     {
        ...     "role": "user",
        ...     "content": [
        ...         {"type": "image"},
        ...         {"type": "text", "text": "Please describe the image."},
        ...         ],
        ...     },
        ... ]

        >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
        >>> image = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw)

        >>> inputs = processor(images=[image], text=[prompt], return_tensors="pt").to(model.device, torch.bfloat16)

        >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
        >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        ```r  r   Nr  r  )r~  r:  r   r  r  r  rY   r  r  r   ro   r1   r  )rk   r  r2  rJ  r@   r/   ro   r  r   ru   r  r  rC   r  r1   r  r  r  r'   r'   r(   r{     s8    >z$Emu3ForConditionalGeneration.forwardTc	              
      s<   t  j|f|||||||d|	}
|d dkr8d |
d< |
S )N)ro   r@   r  ru   r/   r2  r   r   r2  )r^   prepare_inputs_for_generation)rk   r  ro   r@   r  ru   r/   r   r2  rC   Zmodel_inputsrl   r'   r(   r  ?  s     	z:Emu3ForConditionalGeneration.prepare_inputs_for_generation)NNNNNNNNNNr   )NNNNNTN)%r|   r}   r~   rT  r  r  r_   r  r  rK   r  r  r  r  propertyr  r  r  r  r   r   r#   r   r  r   r   r   r   r   r   r   r   r   r   r{   r  r   r'   r'   rl   r(   r    sn   


           
]       r  )r  r  r  r}  r7  r  )Nr   )r;   )Wr+  	functoolsr   typingr   r   r   r#   Ztorch.nnrK   Ztorch.nn.functionalrL   r   Zactivationsr   Zcache_utilsr   r	   Z
generationr
   Zintegrationsr   Zmasking_utilsr   Zmodeling_layersr   Zmodeling_outputsr   r   Zmodeling_rope_utilsr   r   Zmodeling_utilsr   r   Zprocessing_utilsr   utilsr   r   r   Zutils.deprecationr   Zutils.genericr   Zconfiguration_emu3r   r   r   r)   r0   r   r   r:   r  r  rU   rV   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r#  r4  r7  rZ  r}  r  r  r  r  r  __all__r'   r'   r'   r(   <module>   s   
 G."$1?A";:FFo6$SL  ,