a
    h                     @   s  d dl Z d dlmZ d dlmZmZ d dlZd dlmZ d dl	m  m
Z d dlZddlmZ ddlmZ ddlmZ ddlmZ dd	lmZ dd
lmZmZmZ ddlmZ ddlmZmZ ddl m!Z!m"Z"m#Z#m$Z$m%Z% ddl&m'Z' ddl(m)Z)m*Z*m+Z+ e,e-Z.G dd de!Z/G dd de"Z0G dd dej1Z2G dd deZ3G dd dej1Z4G dd dej1Z5G dd dej1Z6G d d! d!ej1Z7G d"d# d#ej1Z8G d$d% d%ej1Z9G d&d' d'ej1Z:G d(d) d)e'Z;G d*d+ d+ej<Z=G d,d- d-ej1Z>G d.d/ d/ej1Z?G d0d1 d1ej1Z@G d2d3 d3ej1ZAG d4d5 d5ej1ZBed6d7G d8d9 d9eZCG d:d; d;ZDG d<d= d=eeCZEG d>d? d?e$eEZFG d@dA dAe#eEeZGG dBdC dCeEZHG dDdE dEeEeZIg dFZJdS )G    N)cached_property)OptionalUnion   )Cache)GenerationMixin)CausalLMOutputWithPast)PreTrainedModel)Unpack)auto_docstringcan_return_tuplelogging)deprecate_kwarg   )ChameleonPreTrainedModel#ChameleonVQVAEEncoderConvDownsample)LlamaAttentionLlamaDecoderLayerLlamaForCausalLM
LlamaModelTransformersKwargs)SiglipAttention   )
Emu3ConfigEmu3TextConfigEmu3VQVAEConfigc                   @   s   e Zd ZdS )Emu3AttentionN__name__
__module____qualname__ r!   r!   a/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/emu3/modular_emu3.pyr   -   s   r   c                       s   e Zd Zeed fddZedddddeje	ej e	ej
 e	e e	e e	ej
 e	eejejf  ee ejd
	ddZ  ZS )Emu3DecoderLayer)config	layer_idxc                    s    t  || t|j| _d S N)super__init__nnZDropoutZattention_dropoutdropout)selfr$   r%   	__class__r!   r"   r(   3   s    zEmu3DecoderLayer.__init__Zpast_key_valuepast_key_valuesz4.58)new_nameversionNF)	hidden_statesattention_maskposition_idsr.   	use_cachecache_positionposition_embeddingskwargsreturnc              
   K   sj   |}	|  |}| jf |||||||d|\}}
|	| | }|}	| |}| |}|	| | }|S )N)r1   r2   r3   r.   r4   r5   r6   )Zinput_layernormZ	self_attnr*   Zpost_attention_layernormZmlp)r+   r1   r2   r3   r.   r4   r5   r6   r7   residual_r!   r!   r"   forward7   s&    




zEmu3DecoderLayer.forward)NNNFNN)r   r   r    r   intr(   r   torchTensorr   
LongTensorr   booltupler
   r   r;   __classcell__r!   r!   r,   r"   r#   2   s&         r#   c                       s6   e Zd ZdZed fddZejdddZ  Z	S )Emu3VQVAEVectorQuantizera  
    A module for vector quantization using learned embedding vectors.

    This module implements the quantization process similar to te one described in
    the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous
    input vectors into discrete codebook vectors, which are learned during training.
    Current implementation improves over previous ones by avoiding costly matrix multiplications
    and allowing for post-hoc remapping of indices.
    r$   c                    s>   t    t|j|j| _| jjj	d|j d|j  d S )Ng            ?)
r'   r(   r)   	EmbeddingZcodebook_size	embed_dim	embeddingweightdatauniform_r+   r$   r,   r!   r"   r(   d   s    
z!Emu3VQVAEVectorQuantizer.__init__)hidden_statec                 C   s   |j \}}}}}|ddddd }|d|}tj|d ddd}tj| jjd dd	}	dt|| jj	dd }
||	 |
 }
tj
|
dd	}|||||}|S )
Nr   r   r      r   T)dimZkeepdimrP   )shapepermute
contiguousviewr=   sumrH   rI   matmul	transposeZargmin)r+   rM   
batch_sizetemporalchannelsheightwidthZhidden_state_flattenedZhidden_state_sumZembedding_sumZ	distancesZmin_encoding_indicesr!   r!   r"   r;   i   s    z Emu3VQVAEVectorQuantizer.forward)
r   r   r    __doc__r   r(   r=   r>   r;   rB   r!   r!   r,   r"   rC   Y   s   
rC   c                   @   s   e Zd ZdS )Emu3VQVAEEncoderConvDownsampleNr   r!   r!   r!   r"   r_   {   s   r_   c                       s$   e Zd Z fddZdd Z  ZS )Emu3VQVAEEncoderConvUpsamplec                    s$   t    tj||dddd| _d S )Nr   r   kernel_sizestridepadding)r'   r(   r)   Conv2dconv)r+   in_channelsr,   r!   r"   r(      s    
z%Emu3VQVAEEncoderConvUpsample.__init__c                 C   s   t j|ddd}| |}|S )N       @nearestZscale_factormode)Finterpolaterf   r+   r1   r!   r!   r"   r;      s    
z$Emu3VQVAEEncoderConvUpsample.forwardr   r   r    r(   r;   rB   r!   r!   r,   r"   r`      s   r`   c                       s@   e Zd Zeeee ee d fddZejdddZ  Z	S )Emu3VQVAEConv3d)
in_channelout_channelrb   rc   c                    s   t    dd t|dd  |dd  D }d| _|d d d D ]&}|  j|d |d  |d f7  _qB|  jd7  _tj||||d| _d S )	Nc                 S   s   g | ]\}}|| qS r!   r!   ).0Z
one_kernelZ
one_strider!   r!   r"   
<listcomp>       z,Emu3VQVAEConv3d.__init__.<locals>.<listcomp>r   r!   rO   r   )r   r   )rc   )r'   r(   ziprd   r)   Conv3drf   )r+   rq   rr   rb   rc   Zpadding_sizesZpad_sizer,   r!   r"   r(      s    
$$zEmu3VQVAEConv3d.__init__r1   c                 C   s   t || j}| |}|S r&   )rl   padrd   rf   rn   r!   r!   r"   r;      s    
zEmu3VQVAEConv3d.forward)
r   r   r    r<   rA   r(   r=   r>   r;   rB   r!   r!   r,   r"   rp      s   rp   c                       s8   e Zd Zeed fddZejejdddZ  ZS )Emu3VQVAESpatialNormrg   out_channelsc                    sN   t    tj|dddd| _tj||dddd| _tj||dddd| _d S )N    ư>Tnum_channels
num_groupsepsaffiner   r   ra   )r'   r(   r)   	GroupNorm
norm_layerre   conv_yconv_br+   rg   r|   r,   r!   r"   r(      s*    
zEmu3VQVAESpatialNorm.__init__r1   quant_statesc                 C   s@   t j||jdd  dd}| |}|| | | | }|S )Nri   )sizerk   )rl   rm   rR   r   r   r   )r+   r1   r   r!   r!   r"   r;      s    
zEmu3VQVAESpatialNorm.forward	r   r   r    r<   r(   r=   r>   r;   rB   r!   r!   r,   r"   rz      s   rz   c                       s4   e Zd Zeed fddZejdddZ  ZS )Emu3VQVAETemporalUpsamplerq   rr   c                    s    t    t||ddd| _d S )Nr   r   r   r   r   r   rb   rc   r'   r(   rp   rf   r+   rq   rr   r,   r!   r"   r(      s    
z"Emu3VQVAETemporalUpsample.__init__rx   c                 C   sr   |j \}}}}}|ddddd |d|}tj|ddd	}|||||dddddd }| |}|S )
Nr   r   r   rN   r   rO   rh   ri   rj   )rR   rS   rT   rU   rl   rm   rf   )r+   r1   rY   r[   rZ   r\   r]   r!   r!   r"   r;      s     $
z!Emu3VQVAETemporalUpsample.forwardr   r!   r!   r,   r"   r      s   r   c                       s4   e Zd Zeed fddZejdddZ  ZS )Emu3VQVAETemporalDownsampler   c                    s    t    t||ddd| _d S )N)rN   r   r   )r   r   r   r   r   r   r,   r!   r"   r(      s    
z$Emu3VQVAETemporalDownsample.__init__rx   c                 C   s   |  |}|S r&   )rf   rn   r!   r!   r"   r;      s    
z#Emu3VQVAETemporalDownsample.forwardr   r!   r!   r,   r"   r      s   r   c                       s&   e Zd Zd fdd	Zdd Z  ZS )Emu3VQVAETemporalResnetBlockNc                    s   t    || _|d u r|n|| _t|| _t||ddd| _t|| _	t||ddd| _
| j| jkrtj||dddd| _d S )Nr   r   r   r   r   ra   )r'   r(   rg   r|   r)   BatchNorm3dnorm1rp   conv1norm2conv2rw   nin_shortcutr   r,   r!   r"   r(      s2    
z%Emu3VQVAETemporalResnetBlock.__init__c                 C   sf   |}|  |}|t|9 }| |}| |}|t|9 }| |}| j| jkr^| |}|| S r&   )	r   r=   sigmoidr   r   r   rg   r|   r   )r+   r1   r9   r!   r!   r"   r;     s    




z$Emu3VQVAETemporalResnetBlock.forward)Nro   r!   r!   r,   r"   r      s     r   c                       sJ   e Zd Zdeee ee d fddZd	ejeej dddZ  Z	S )
Emu3VQVAEResnetBlockNrg   r|   quant_channelsc                    s   t    || _|d u r|n|}|| _|| _|d u r^tj|dddd| _tj|dddd| _nt	||| _t	||| _tj
||dddd| _tj
||dddd| _| j| jkrtj
||dddd| _d S )	Nr}   r~   Tr   r   r   ra   r   )r'   r(   rg   r|   r   r)   r   r   r   rz   re   r   r   r   )r+   rg   r|   r   r,   r!   r"   r(   '  s@    
zEmu3VQVAEResnetBlock.__init__)r1   r   c                 C   s   | j d u rdn|f}|}| j|g|R  }|t|9 }| |}| j|g|R  }|t|9 }| |}| j| jkr| 	|}|| S )Nr!   )
r   r   r=   r   r   r   r   rg   r|   r   )r+   r1   r   Z	norm_argsr9   r!   r!   r"   r;   S  s    


zEmu3VQVAEResnetBlock.forward)NN)N)
r   r   r    r<   r   r(   r=   r>   r;   rB   r!   r!   r,   r"   r   &  s     ,r   c                       s"   e Zd Zed fddZ  ZS )Emu3VQVAEAttentionBlockrD   c                    s   t  | d| _d S )Nr   )r'   r(   Znum_key_value_groupsrL   r,   r!   r"   r(   f  s    z Emu3VQVAEAttentionBlock.__init__)r   r   r    r   r(   rB   r!   r!   r,   r"   r   e  s   r   c                       s*   e Zd ZdZ fddZdddZ  ZS )Emu3VQVAEGroupNormz
    Same as the torch GroupNorm with the only difference that this ones accepts
    an optional kwarg `quant_states` which is not used. This class makes it easier to
    use SpatialNorm or GroupNorm without conditionals
    c                    s   t  jf i | d S r&   )r'   r(   r+   r7   r,   r!   r"   r(   t  s    zEmu3VQVAEGroupNorm.__init__Nc                 C   s   t || j| j| j| jS r&   )rl   Z
group_normr   rI   biasr   )r+   inputr   r!   r!   r"   r;   w  s    zEmu3VQVAEGroupNorm.forward)N)r   r   r    r^   r(   r;   rB   r!   r!   r,   r"   r   m  s   r   c                       s8   e Zd Zd fdd	Zdejeej dddZ  ZS )	Emu3VQVAEMiddleBlockNc                    s`   t    t|||d| _t|| _|d u r@t|dddd| _nt||| _t|||d| _	d S )Nr   r}   r~   Tr   )
r'   r(   r   block_1r   attn_1r   	attn_normrz   block_2)r+   r$   rg   r   r,   r!   r"   r(   |  s    

zEmu3VQVAEMiddleBlock.__init__r   c                 C   s   |  ||}|}| ||}|j\}}}}||||| dd}| |d }|||||dddd}|| }| ||}|S )Nr   r   r   r   )	r   r   rR   rU   rX   r   reshaperS   r   )r+   r1   r   r9   rY   r[   r\   r]   r!   r!   r"   r;     s    zEmu3VQVAEMiddleBlock.forward)N)N)	r   r   r    r(   r=   FloatTensorr   r;   rB   r!   r!   r,   r"   r   {  s   r   c                       s,   e Zd Z fddZejdddZ  ZS )Emu3VQVAEDownBlockc              
      s*  t    t|j| _|j| _|j}|j}dt| }|| _t	
 | _t| jD ]}t	
 }t	
 }t	
 }|||  }	|||  }
t| jD ]T}|t|	|
d |
}	|jd ur||jv r|t| |t	j|	dddd qt	 }||_||_||_|| jd krt|	|_| j| qPd S )Nr   r{   r}   r~   Tr   r   )r'   r(   lenchannel_multipliernum_resolutionsnum_res_blocksbase_channelsrA   in_channel_multiplierr)   
ModuleListdownrangeappendr   attn_resolutionsr   r   Moduleblockattn
attn_normsr_   
downsample)r+   r$   r   r   r   i_levelr   r   r   block_in	block_outi_blockr   r,   r!   r"   r(     s@    


zEmu3VQVAEDownBlock.__init__rx   c           
      C   s   t | jD ]\}}t| jD ]}|j| |}t|jdkr|}|j| |}|j\}}}}	|	||||	 
dd}|j| |d }||||	|dddd}|| }q|| jd kr
||}q
|S )Nr   r   r   r   )	enumerater   r   r   r   r   r   r   rR   rU   rX   r   rS   r   r   )
r+   r1   r   blocksr   r9   rY   r[   r\   r]   r!   r!   r"   r;     s    
zEmu3VQVAEDownBlock.forwardr   r   r    r(   r=   r   r;   rB   r!   r!   r,   r"   r     s   %r   c                       s0   e Zd Z fddZejejdddZ  ZS )Emu3VQVAEUpBlockc              	      s  t    t|j| _|j| _|j}|j|jd  }t	 | _
tt| jD ]}t	 }t	 }t	 }|j|j|  }t| jd D ]D}	|t|||d |}||jv r|t| |t|| qt }
||
_||
_||
_|dkrt||
_| j
d|
 qLd S )NrO   r   r   r   )r'   r(   r   r   r   r   rG   r   r)   r   upreversedr   r   r   r   r   rz   r   r   r   r   r`   upsampleinsert)r+   r$   r   r   r   r   r   r   r   r   r   r,   r!   r"   r(     s<    



zEmu3VQVAEUpBlock.__init__r   c                 C   s   t | jd d d D ]\}}t| jd D ]}|j| ||}t|jdkr*|}|j| ||}|j\}}}	}
|	|||	|
 
dd}|j| |d }|||	|
|dddd}|| }q*|t| jd kr||}q|S )NrO   r   r   r   r   )r   r   r   r   r   r   r   r   rR   rU   rX   r   rS   r   )r+   r1   r   r   r   r   r9   rY   r[   r\   r]   r!   r!   r"   r;     s    
zEmu3VQVAEUpBlock.forwardr   r!   r!   r,   r"   r     s   %r   c                       s,   e Zd Z fddZejdddZ  ZS )Emu3VQVAEEncoderc                    s  t    |j}|j}|j}|j}|j}|r4d| n|}||d  }tjj	||dddd| _
t|| _t||| _tjjd|ddd	| _tjj	||dddd| _tt|j}	t | _t | _t|	D ]}
t||}| j| qt|jD ]}t||d
}| j| qd S )Nr   rO   r   r   ra   r}   r~   T)r   r   r   r   r{   )r'   r(   r   rg   double_latentlatent_channelsr   r=   r)   re   conv_inr   
down_blockr   middle_blockr   norm_outconv_outr<   mathlog2temporal_downsample_factorr   	time_convtime_res_stackr   r   r   r   r   )r+   r$   r   rg   r   r   r   r|   r   Ztemporal_down_blocksirf   r:   time_res_convr,   r!   r"   r(     s>    




zEmu3VQVAEEncoder.__init__)pixel_valuesc                 C   s   |j d }|jdg|j dd  R  }| |}| |}| |}| |}|t|9 }| |}|jd|g|j dd  R  }|	ddddd}| j
D ]}||}|t|9 }q| jD ]}||}q|	ddddd}|S )Nr   rO   r   r   r   rN   )rR   r   r   r   r   r   r=   r   r   rS   r   r   )r+   r   Ztemporal_dimr1   rf   layerr!   r!   r"   r;   :  s"    








zEmu3VQVAEEncoder.forward)r   r   r    r(   r=   r?   r;   rB   r!   r!   r,   r"   r     s   'r   c                       s6   e Zd Zed fddZejejdddZ  ZS )Emu3VQVAEDecoderrD   c           	         s  t    |j}|j|jd  }t | _t|j	D ] }t
|j|jd}| j| q4tt|j}t | _t|D ]}t|j|j}| j| qxtj|j|dddd| _t|||d| _t|| _|j|jd  }t||| _tj||jdddd| _d S )NrO   r{   r   r   ra   )r   r   )r'   r(   rG   r   r   r)   r   r   r   r   r   r   r   r<   r   r   r   r   r   re   r   r   r   r   up_blockrz   r   r|   r   )	r+   r$   r   r   r:   r   Ztemp_upsample_block_numr   rf   r,   r!   r"   r(   Y  s@    



zEmu3VQVAEDecoder.__init__r   c                 C   s  t j||fdd}|ddddd}| jD ]}||}q*| jD ]}||}|t |9 }q>|ddddd}t j|ddd\}}|jdg|jdd  R  }|jdg|jdd  R  }| 	|}| 
||}| ||}| ||}|t |9 }| |}|S )Nr   rQ   r   r   r   rN   rO   )r=   catrS   r   r   r   chunkr   rR   r   r   r   r   r   )r+   r1   r   Zhidden_quant_statesr   r!   r!   r"   r;     s$    




zEmu3VQVAEDecoder.forward)	r   r   r    r   r(   r=   r>   r;   rB   r!   r!   r,   r"   r   X  s   'r   aR  
    The VQ-VAE model used in Emu3 for encoding/decoding images into discrete tokens.
    This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
    [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv
    Taigman](https://huggingface.co/papers/2203.13131).
    )Zcustom_introc                       sx   e Zd ZU eed< dZdZdZdZdZ	dZ
g dZdd Zed fd	d
ZejejdddZejdddZ  ZS )	Emu3VQVAEr$   Z
emuvideovqr   T)r   r   r   rC   c                 C   sL  t |tjtjfrftjj|jddd |jd urdtj|j\}}dt	
| }tj|j| | nt |tjrtjj|jt	
dd |jd urtj|j\}}|dkrdt	
| nd}tj|j| | nrt |tjtjtjfrtj|jd tj|jd	 n8t |tjrH|jj  |jd urH|jj|j   d S )
NZfan_outZrelu)rk   Znonlinearityr      )ar   rE   g        )
isinstancer)   re   rw   initZkaiming_normal_rI   r   Z_calculate_fan_in_and_fan_outr   sqrtrK   LinearZkaiming_uniform_ZBatchNorm2dr   r   Z	constant_rF   rJ   Znormal_Zpadding_idxZzero_)r+   moduleZfan_inr:   boundr!   r!   r"   _init_weights  s&    

zEmu3VQVAE._init_weightsrD   c                    s   t  | || _t|| _t|| _t|| _dt	|j
d  | _t|j|jddd| _t|j|jddd| _dt	|j
d  | _|   |   d S )Nr   r   )r   r   r   r   r   )r'   r(   r$   r   encoderr   decoderrC   quantizer   r   vision_spatial_factorrp   r   rG   
quant_convpost_quant_convspatial_scale_factoreval	post_initrL   r,   r!   r"   r(     s    


zEmu3VQVAE.__init__r   image_sizesc                    s   |j dk}|r> jj}|j\}}}}|dd|ddd}n|j\}}}}} |}	|	ddddd}	 |	}	|	ddddd}	 	|	}
|r|

dn|
} fddt||D }|S )NrN   r   r   r   r   c                    s@   g | ]8\}}|d t |d  j d t |d  j f qS )Nr   r   )r<   r   )rs   Zsingle_imager   r+   r!   r"   rt     s   z$Emu3VQVAE.encode.<locals>.<listcomp>)ndimr$   r   rR   	unsqueezerepeatr   rS   r   r   Zsqueezerv   )r+   r   r   is_imagerZ   rY   r[   r\   r]   r1   codesimage_tokensr!   r   r"   encode  s     




zEmu3VQVAE.encoderx   c                 C   s   |j dk}|r|d}|j\}}}}| j| }|jd }||||||ddddd }| 	|}	|ddddd}|	ddddd}	| 
|	|}
|
||| jj | jj|| j || j }
|r|
d d df S |
S )Nr   r   rO   r   rN   r   )r   r   rR   r   rH   flattenrU   rS   rT   r   r   r   r$   r   r|   r   )r+   r1   r   rY   rZ   r\   r]   Zquantr[   Z
post_quantZvideor!   r!   r"   decode  s&    


$

zEmu3VQVAE.decode)r   r   r    r   __annotations__base_model_prefixZmain_input_nameZ_supports_sdpaZ_supports_flash_attn_supports_flex_attn_supports_attention_backend_no_split_modulesr   r(   r=   r>   r   r   rB   r!   r!   r,   r"   r     s   
	r   c                   @   s   e Zd ZdZdd Zedd Zedd Zedd	 Zed
d Z	edd Z
edd Zeej ejdddZejejdddZdS )Emu3ImageVocabularyMappingzM
    A class for mapping discrete image tokens from VQGAN to BPE tokens.
    c                 C   s"   || _ |d| _|d| _d S )Nz<|extra_200|>z<image>)	vocab_mapgeteol_token_idimage_token_id)r+   r  r!   r!   r"   r(     s    z#Emu3ImageVocabularyMapping.__init__c                 C   s   t dd | j D S )Nc                 S   s   g | ]\}}| d r|qS z<|visual token
startswithrs   namevalr!   r!   r"   rt   !  ru   z;Emu3ImageVocabularyMapping.image_tokens.<locals>.<listcomp>sortedr  itemsr   r!   r!   r"   r     s    z'Emu3ImageVocabularyMapping.image_tokensc                 C   s   t dd | j D S )Nc                 S   s   g | ]\}}| d r|qS r	  r
  r  r!   r!   r"   rt   %  ru   z?Emu3ImageVocabularyMapping.image_tokens_str.<locals>.<listcomp>r  r   r!   r!   r"   image_tokens_str#  s    z+Emu3ImageVocabularyMapping.image_tokens_strc                    s    fdd j D S )Nc                    s$   i | ]}t |d d  j| qS )ir   )r<   r  )rs   tokenr   r!   r"   
<dictcomp>)  ru   z6Emu3ImageVocabularyMapping.img2bpe.<locals>.<dictcomp>)r  r   r!   r   r"   img2bpe'  s    z"Emu3ImageVocabularyMapping.img2bpec                 C   s   dd | j  D S )Nc                 S   s   i | ]\}}||qS r!   r!   )rs   kvr!   r!   r"   r  -  ru   z6Emu3ImageVocabularyMapping.bpe2img.<locals>.<dictcomp>)r  r  r   r!   r!   r"   bpe2img+  s    z"Emu3ImageVocabularyMapping.bpe2imgc                 C   s>   t jt| j d t jd}| j D ]\}}|||< q(|S Nr   dtype)r=   zerosmaxr  keysr<   r  r+   mappingr  r  r!   r!   r"   bpe2img_mapping_tensor/  s    
z1Emu3ImageVocabularyMapping.bpe2img_mapping_tensorc                 C   s>   t jt| j d t jd}| j D ]\}}|||< q(|S r  )r=   r  r  r  r  r<   r  r  r!   r!   r"   img2bpe_mapping_tensor6  s    
z1Emu3ImageVocabularyMapping.img2bpe_mapping_tensor)	img_batchr8   c                 C   sR   |j }tj|jd dftjd| j }| j|d }tj||gdd}||S )Nr   r   r  cpurO   rQ   )	devicer=   ZonesrR   r<   r  r"  tor   )r+   r#  r%  Zeol_row
img_tokensr!   r!   r"   convert_img2bpe=  s
     z*Emu3ImageVocabularyMapping.convert_img2bpec                 C   s0   |j }|dd df }| j|d }||S )N.rO   r$  )r%  r!  r&  )r+   r#  r%  r'  r!   r!   r"   convert_bpe2imgD  s    z*Emu3ImageVocabularyMapping.convert_bpe2imgN)r   r   r    r^   r(   r   r   r  r  r  r!  r"  listr=   r>   r(  r)  r!   r!   r!   r"   r    s    





r  c                   @   s   e Zd ZdgZdZdZdS )Emu3PreTrainedModelr#   TN)r   r   r    r  r  r  r!   r!   r!   r"   r+  K  s   r+  c                       s,   e Zd ZeedZed fddZ  ZS )Emu3TextModel)r1   
attentionsrD   c                    s0   t    t fddt jD | _d S )Nc                    s   g | ]}t  |qS r!   )r#   )rs   r%   rD   r!   r"   rt   \  ru   z*Emu3TextModel.__init__.<locals>.<listcomp>)r'   r(   r)   r   r   Znum_hidden_layersZlayersrL   r,   rD   r"   r(   Y  s    zEmu3TextModel.__init__)	r   r   r    r#   r   Z_can_record_outputsr   r(   rB   r!   r!   r,   r"   r,  S  s   r,  c                       s2   e Zd ZU eed<  fddZ fddZ  ZS )Emu3ForCausalLMr$   c                    s   t  | t|| _d S r&   )r'   r(   r,  modelrL   r,   r!   r"   r(   c  s    zEmu3ForCausalLM.__init__c                     s   t    dS )a  
        Example:

        ```python
        >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration
        >>> import torch
        >>> import requests
        >>> from PIL import Image

        >>> model = Emu3ForCausalLM.from_pretrained("BAAI/Emu3-Chat-hf", dtype=torch.bfloat16)
        >>> processor = Emu3Processor.from_pretrained("BAAI/Emu3-Chat-hf")

        >>> inputs = processor(text=["Can you write me a poem about winter."], return_tensors="pt").to(model.device)

        >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
        >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        ```N)r'   r;   )Zsuper_kwargsr,   r!   r"   r;   g  s    zEmu3ForCausalLM.forward)r   r   r    r   r   r(   r;   rB   r!   r!   r,   r"   r.  `  s   
r.  c                       s   e Zd ZddiZ fddZdd Zdd Zd	d
 Zdd Ze	j
e	jdddZe	j
e	jdddZe	je	jeedddZe	je	j
e	j
dddZeede	je	j
e	jee	j ee	j ee ee	j
 ee ee	j ee eeef dddZ  ZS )	Emu3Modelztext_model.model
text_modelc                    s>   t  | t|j| _t|j| _t	|j
| _|   d S r&   )r'   r(   r,  _from_configtext_configr1  r   Z	vq_configvqmodelr  Zvocabulary_mapvocabulary_mappingr   rL   r,   r!   r"   r(     s
    zEmu3Model.__init__c                 C   s
   | j  S r&   )r1  get_input_embeddingsr   r!   r!   r"   r6    s    zEmu3Model.get_input_embeddingsc                 C   s   | j | d S r&   )r1  set_input_embeddingsr+   valuer!   r!   r"   r7    s    zEmu3Model.set_input_embeddingsc                 C   s
   || _ d S r&   r1  r+   r   r!   r!   r"   set_decoder  s    zEmu3Model.set_decoderc                 C   s   | j S r&   r:  r   r!   r!   r"   get_decoder  s    zEmu3Model.get_decoderr   c                    s.    j ||} fdd|D }t|}|S )a  
        Tokenizes images into discrete tokens with VQGAN module. Converts
        obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
        special tokens.

        Args:
            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
                The tensors corresponding to the input images.
            image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`):
                The sizes of the images in the batch, being (height, width) for each image.
        c                    s   g | ]} j | qS r!   )r5  r(  r   )rs   tokensr   r!   r"   rt     ru   z.Emu3Model.get_image_tokens.<locals>.<listcomp>)r4  r   r=   r   )r+   r   r   Zimage_tokens_listZbpe_tokens_listZ
bpe_tokensr!   r   r"   get_image_tokens  s    
zEmu3Model.get_image_tokensc                    s:     ||} fdd|D }  |}t||}|S )a7  
        Tokenizes images into discrete tokens with VQGAN module and embeds
        them with text embeddings layer

        Args:
            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
                The tensors corresponding to the input images.
        c                    s,   g | ]$\}}| j j | j j d   qS r   )r4  r   )rs   r\   r]   r   r!   r"   rt     s   z0Emu3Model.get_image_features.<locals>.<listcomp>)r?  r6  r=   split)r+   r   r   r   Zsplit_sizesimage_featuresr!   r   r"   get_image_features  s    	
zEmu3Model.get_image_features)r   r\   r]   c                 C   s>   |ddddf  d||d }| j|}| j|}|S )a  
        Decodes generated image tokens from language model to continuous pixel values
        with VQGAN module via upsampling.

        Args:
            image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`):
                The tensors corresponding to the input images.
            height (`int`):
                Height of the generated image before upsampling.
            width (`int`):
                Width of the generated image before upsampling.
        NrO   r   )rU   r5  r)  r4  r   )r+   r   r\   r]   	sequencesimager!   r!   r"   decode_image_tokens  s    "zEmu3Model.decode_image_tokens)	input_idsinputs_embedsrA  c                 C   s   |du r8||   tj| jjtj|jdk}|d}n|| jjk}| }|	d
||j}|jd |jd  }||  | krtd| d| |S )z
        Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
        equal to the length of multimodal features. If the lengths are different, an error is raised.
        N)r  r%  rO   r   r   z6Image features and image tokens do not match: tokens: z, features )r6  r=   Ztensorr5  r  longr%  allrV   r   Z	expand_asr&  rR   Znumel
ValueError)r+   rG  rH  rA  special_image_maskZn_image_tokensZn_image_featuresr!   r!   r"   get_placeholder_mask  s    zEmu3Model.get_placeholder_maskN)rG  r   r   r2   r3   r.   rH  r4   r5   r7   r8   c
              	   K   s   |du |duA rt d|du r,|  |}|durj| ||}tj|dd}| j|||d}|||}| jf ||||||	d|
}|S )ap  
        image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`):
            The sizes of the images in the batch, being (height, width) for each image. Image sizes can be obtained using
            [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses
            [`Emu3ImageProcessor`] for processing images).
        NzaYou cannot specify both input_ids and inputs_embeds at the same time, and must specify either oner   rQ   )rH  rA  )r2   r3   r.   rH  r4   r5   )rK  r6  rB  r=   r   rM  Zmasked_scatterr1  )r+   rG  r   r   r2   r3   r.   rH  r4   r5   r7   Zimage_embedsrL  outputsr!   r!   r"   r;     s0    
zEmu3Model.forward)	NNNNNNNNN)r   r   r    _checkpoint_conversion_mappingr(   r6  r7  r<  r=  r=   r   r?   r?  rB  Zno_gradr<   rF  rM  r   r   r>   r   r   r@   r
   r   r   rA   r   r;   rB   r!   r!   r,   r"   r0  |  sH   	         
r0  c                       s  e Zd ZdZdgZddddZ fddZd	d
 Zdd Ze	j
dddZdd Zdd Zedd Zedd Zedd Zdd Zeed$ejejejeej eej ee eej ee eej eej eeejf ee ee e!f ddd Z"d% fd"d#	Z#  Z$S )&Emu3ForConditionalGeneration zlm_head.weightzmodel.text_modelzmodel.vqmodellm_head)z^text_model.modelz^vqmodelz^text_model.lm_headc                    s<   t  | t|| _tj|jj|jjdd| _	| 
  d S )NF)r   )r'   r(   r0  r/  r)   r   r3  Zhidden_size
vocab_sizerR  r   rL   r,   r!   r"   r(     s    
z%Emu3ForConditionalGeneration.__init__c                 C   s
   | j  S r&   )r/  r6  r   r!   r!   r"   r6  %  s    z1Emu3ForConditionalGeneration.get_input_embeddingsc                 C   s   | j | d S r&   )r/  r7  r8  r!   r!   r"   r7  (  s    z1Emu3ForConditionalGeneration.set_input_embeddings)r8   c                 C   s   | j S r&   )rR  r   r!   r!   r"   get_output_embeddings+  s    z2Emu3ForConditionalGeneration.get_output_embeddingsc                 C   s   | j | d S r&   )r/  r<  r;  r!   r!   r"   r<  .  s    z(Emu3ForConditionalGeneration.set_decoderc                 C   s
   | j  S r&   )r/  r=  r   r!   r!   r"   r=  1  s    z(Emu3ForConditionalGeneration.get_decoderc                 C   s   | j jS r&   )r/  r1  r   r!   r!   r"   r1  5  s    z'Emu3ForConditionalGeneration.text_modelc                 C   s   | j jS r&   )r/  r4  r   r!   r!   r"   r4  9  s    z$Emu3ForConditionalGeneration.vqmodelc                 C   s   | j jS r&   )r/  r5  r   r!   r!   r"   r5  =  s    z/Emu3ForConditionalGeneration.vocabulary_mappingc                 K   s   | j jf i |S r&   )r/  rF  r   r!   r!   r"   rF  A  s    z0Emu3ForConditionalGeneration.decode_image_tokensNr   )rG  r   r   r2   r3   r.   rH  r4   r5   labelslogits_to_keepr7   r8   c              
   K   s   | j f |||||||	d|}|d }t|tr>t| dn|}| |dd|ddf }d}|
dur| jf ||
| jjjd|}t	|||j
|j|jdS )an  
        image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`):
            The sizes of the images in the batch, being (height, width) for each image. Image sizes can be obtained using
            [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses
            [`Emu3ImageProcessor`] for processing images).
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Example:

        ```python
        >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration
        >>> import torch
        >>> import requests
        >>> from PIL import Image

        >>> model = Emu3ForConditionalGeneration.from_pretrained("BAAI/Emu3-Chat-hf", dtype=torch.bfloat16)
        >>> processor = Emu3Processor.from_pretrained("BAAI/Emu3-Chat-hf")

        >>> conversation = [
        ...     {
        ...     "role": "system",
        ...     "content": [
        ...         {"type": "text", "text": "You are a helpful assistant."},
        ...         ],
        ...     },
        ...     {
        ...     "role": "user",
        ...     "content": [
        ...         {"type": "image"},
        ...         {"type": "text", "text": "Please describe the image."},
        ...         ],
        ...     },
        ... ]

        >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
        >>> image = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw)

        >>> inputs = processor(images=[image], text=[prompt], return_tensors="pt").to(model.device, torch.bfloat16)

        >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
        >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        ```)rG  r2   r3   r.   rH  r4   r5   r   N)logitsrU  rS  )lossrW  r.   r1   r-  )r/  r   r<   slicerR  Zloss_functionr$   r3  rS  r   r.   r1   r-  )r+   rG  r   r   r2   r3   r.   rH  r4   r5   rU  rV  r7   rN  r1   Zslice_indicesrW  rX  r!   r!   r"   r;   D  s8    >z$Emu3ForConditionalGeneration.forwardTc	              
      s<   t  j|f|||||||d|	}
|d dkr8d |
d< |
S )N)r.   r2   rH  r5   r3   r   r4   r   r   )r'   prepare_inputs_for_generation)r+   rG  r.   r2   rH  r5   r3   r4   r   r7   Zmodel_inputsr,   r!   r"   rZ    s     	z:Emu3ForConditionalGeneration.prepare_inputs_for_generation)NNNNNNNNNNr   )NNNNNTN)%r   r   r    r   Z_tied_weights_keysrO  r(   r6  r7  r)   r   rT  r<  r=  propertyr1  r4  r5  rF  r   r   r=   r?   r   r>   r   r   r@   r   r<   r
   r   rA   r   r;   rZ  rB   r!   r!   r,   r"   rP    sn   


           
]       rP  )rP  r.  r,  r+  r   r0  )Kr   	functoolsr   typingr   r   r=   Ztorch.nnr)   Ztorch.nn.functionalZ
functionalrl   Ztorch.utils.checkpointZcache_utilsr   Z
generationr   Zmodeling_outputsr   Zmodeling_utilsr	   Zprocessing_utilsr
   utilsr   r   r   Zutils.deprecationr   Zchameleon.modeling_chameleonr   r   Zllama.modeling_llamar   r   r   r   r   Zsiglip.modeling_siglipr   Zconfiguration_emu3r   r   r   Z
get_loggerr   loggerr   r#   r   rC   r_   r`   rp   rz   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r+  r,  r.  r0  rP  __all__r!   r!   r!   r"   <module>   sb   
'"$1?";:FFo6  ,