a
    h                     @   s  d Z ddlZddlZddlmZmZmZ ddlZddlZddlm	Z	 ddl
mZmZmZ ddlmZ ddlmZmZmZ dd	lmZ dd
lmZ ddlmZmZmZ ddlmZ ddlmZm Z m!Z! ddl"m#Z#m$Z$m%Z% ddl&m'Z' e$(e)Z*dd Z+G dd de	j,Z-G dd de	j,Z.G dd de	j,Z/G dd deZ0e#G dd deZ1e#G dd de1Z2e#dd G d!d" d"e1eZ3e#d#d G d$d% d%e1Z4g d&Z5dS )'zPyTorch OpenAI ImageGPT model.    N)AnyOptionalUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)CacheDynamicCacheEncoderDecoderCache)GenerationMixin)GradientCheckpointingLayer))BaseModelOutputWithPastAndCrossAttentions!CausalLMOutputWithCrossAttentions SequenceClassifierOutputWithPast)PreTrainedModel)Conv1D find_pruneable_heads_and_indicesprune_conv1d_layer)auto_docstringloggingtorch_float   )ImageGPTConfigc                 C   s  zddl }ddl}W n ty2   td  Y n0 tj|}td|  |j	
|}g }g }|D ]D\}	}
td|	 d|
  |j	||	}||	 ||  qht||D ]&\}	}|	dd }	|	d}	td	d
 |	D s|	d dv rtdd|	 q| }|	d dvr,t|d}|	D ]}|d|rR|d|}n|g}|d dkst|d dkrt|d}n|d dkrt|d}n|d dks|d dkrt||d }t|d}n|d dv rt|d}t|d}nt|	dkr:|	d dkr:|d dkr:t||d }t|d}nV|d dkr^t|d }t|d}n2|d d!krt|d}t|d}nt||d }t|d"kr0t|d }|| }q0t|	dkr|	d dksT|	d dksT|	d d!ksT|	d dkrnVz|j|jksJ W n> tyR } z$| j|j|jf7  _ W Y d}~n
d}~0 0 td#|	  |	d d$krt||j|jj|jddd|jf< q|	d d%krt||j|jj|jdd|jd"|j f< q|	d d&kr"t||j|jj|jddd"|j df< qt|	dkrf|	d dkrf|	d" dkrft||j|j|_q|	d dkrt||_q|	d dkrt||jd|j d ddf< q|	d d!krt||jd< qt||_q| S )'z0
    Load tf checkpoints in a pytorch model
    r   NzLoading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see https://www.tensorflow.org/install/ for installation instructions.z&Converting TensorFlow checkpoint from zLoading TF weight z with shape    /c                 s   s   | ]}|d v V  qdS ))Zadam_vZadam_mZAdamWeightDecayOptimizerZAdamWeightDecayOptimizer_1Zglobal_stepN ).0nr   r   j/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/imagegpt/modeling_imagegpt.py	<genexpr>Q   s   z.load_tf_weights_in_imagegpt.<locals>.<genexpr>)Z_stepzSkipping {})wtettransformerz[A-Za-z]+\d+z(\d+)wgweightbbiaswpewte)q_projk_projv_projc_attnr	   r   attnc_projr$   lm_headsos   zInitialize PyTorch weight r-   r.   r/   )!reZ
tensorflowImportErrorloggererrorospathabspathinfotrainZlist_variablesZload_variableappendsqueezezipsplitanyformatjoingetattr	fullmatchlenintshapeAssertionErrorargstorchZ
from_numpyreshapen_embdTdata
vocab_size)modelconfigZimagegpt_checkpoint_pathr6   tfZtf_pathZ	init_varsnamesZarraysnamerJ   arrayZpointerZm_nameZscope_namesnumer   r   r!   load_tf_weights_in_imagegpt0   s    





*

F."2*$r[   c                       s>   e Zd Zdee ed fddZejejdddZ	  Z
S )	ImageGPTLayerNormh㈵>)hidden_sizeepsc                    s&   t    || _tt|| _d S N)super__init__r_   r   	ParameterrM   Tensorr(   )selfr^   r_   	__class__r   r!   rb      s    
zImageGPTLayerNorm.__init__)tensorreturnc                 C   s4   |t t jt |ddd| j  }|| j }|S )Nr#   T)ZaxisZkeepdim)rM   sqrtmeanZsquarer_   r(   )re   rh   r   r   r!   forward   s    &
zImageGPTLayerNorm.forward)r]   )__name__
__module____qualname__tuplerI   floatrb   rM   rd   rl   __classcell__r   r   rf   r!   r\      s   r\   c                       s   e Zd Zdee ee d fddZdd Zddd	Zdd
dZ	dd Z
dd Zdejee eej eej eej eej ee ee eej ed
ddZ  ZS )ImageGPTAttentionFNis_cross_attention	layer_idxc                    sF  t    |j}| jdttj||ftjddd||dd | jdt	ddd |j
| _|j| _| j| j | _| j| _| j| j | jkrtd| j d	| j d
|j| _|| _|j| _|| _|j| _| jrtd| j | j| _t| j| j| _ntd| j | j| _t| j| j| _t|j| _t|j| _t  | _!d S )Nr*   dtyper   F)
persistentZmasked_biasg     z=`embed_dim` must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).r5   r	   )"ra   rb   max_position_embeddingsZregister_bufferrM   Ztrilonesboolviewrh   r^   	embed_dimZnum_attention_heads	num_headshead_dim
split_size
ValueErrorscale_attn_weightsru   scale_attn_by_inverse_layer_idxrv   reorder_and_upcast_attnr   r0   q_attnr2   r   DropoutZ
attn_pdropattn_dropoutresid_pdropresid_dropoutsetpruned_heads)re   rT   ru   rv   Zmax_positionsrf   r   r!   rb      sB    

zImageGPTAttention.__init__c                 C   s   t |dkrd S t|| j| j| j\}}t||| j |d| j  g}t| j	|dd| _	t| j
|dd| _
| j| j | jt |  | _| jt | | _| j|| _d S )Nr   r5   r   dim)rH   r   r   r   r   rM   catr   r   r0   r2   union)re   headsindexZ
index_attnr   r   r!   prune_heads   s     zImageGPTAttention.prune_headsc                 C   s  t ||dd}| jr0|t|dd  }| jrH|t| jd  }| j	s|d|d }}| j
d d d d || |d |f }	t |jj}
t j|
|j|jd}
t |	||
}|d ur|| }tjdd|}||j}| |}|d ur|| }t ||}||fS )Nr#         ?r   rx   devicer   )rM   matmul	transposer   r   sizer   rq   rv   ru   r*   finforx   minrh   r   wherer   Softmaxtyper   )re   querykeyvalueattention_mask	head_maskattn_weightsquery_length
key_lengthcausal_mask
mask_valueattn_outputr   r   r!   _attn   s(    &

zImageGPTAttention._attnc                 C   s  |  \}}}}	|  \}
}
}}
tj|| ||tj|jd}d}| jr\|t| dd  }| jrt|t| jd  }tj	|jj
ddb |d||	|ddd|	| }}tj|| | d	|d
}|||||}W d    n1 s0    Y  | jsl| d| d }}| jd d d d || |d |f }t|jj}tj||j|jd}t|||}|d ur~|| }tjdd|}|jtjkrtd|
|j}| |}|d ur|| }t||}||fS )Nr         ?r#   r   r   F)Zenabledr   r   )betaalphar   zDError with upcasting, attn_weights does not have dtype torch.float32)r   rM   emptyZfloat32r   r   rq   r   rv   Zautocastr   rN   r   Zbaddbmmru   r*   r   rx   r   rh   r   r   r   RuntimeErrorr   r   )re   r   r   r   r   r   bszr   Z	q_seq_lenZdk_Z	k_seq_lenr   Zscale_factorqkr   r   r   r   r   r   r   r!   _upcast_and_reordered_attn
  s:    &.&


z,ImageGPTAttention._upcast_and_reordered_attnc                 C   s2   |  dd ||f }|j| }|ddddS )zJ
        Splits hidden_size dim into attn_head_size and num_heads
        Nr#   r   r5   r   r	   )r   r}   permutere   rh   r   Zattn_head_sizeZ	new_shaper   r   r!   _split_heads>  s    
zImageGPTAttention._split_headsc                 C   s8   | dddd }| dd || f }||S )zS
        Merges attn_head_size dim and num_attn_heads dim into hidden_size
        r   r5   r   r	   Nr   )r   
contiguousr   r}   r   r   r   r!   _merge_headsF  s    zImageGPTAttention._merge_heads
hidden_states
layer_pastr   r   encoder_hidden_statesencoder_attention_mask	use_cacheoutput_attentionscache_positionri   c
                 C   s  |d u}
|j \}}}|d urLt|trH|j| j}|
r@|j}qL|j}n|}|
rT|n|}|
rt| dspt	d|d ur|r| 
|}|j| j j}|j| j j}n\| 
|}| |j| jdd\}}||d| j| jdd}||d| j| jdd}nT| |j| jdd\}}}||d| j| jdd}||d| j| jdd}|d ur|
sj|	nd }	|||| jd|	i\}}|
rd|j| j< |||| j| jdd}| jr| |||||\}}n| |||||\}}| || j| j}| |}| |}||fS )	Nr   zIf class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `ImageGPTAttention(..., is_cross_attention=True)`.r5   r   r#   r   r   T)rJ   
isinstancer   
is_updatedgetrv   Zcross_attention_cacheZself_attention_cachehasattrr   r   Zlayerskeysvaluesr0   rB   r   r}   r   r   r   updater   r   r   r   r2   r   )re   r   r   r   r   r   r   r   r   r   ru   r   Zseq_lenr   r   Zcurr_past_key_valueZcurrent_statesr   r   r   r   r   r   r   r!   rl   N  sN    






zImageGPTAttention.forward)FN)NN)NN)NNNNNFFN)rm   rn   ro   r   r|   rI   rb   r   r   r   r   r   rM   rd   r   rp   rl   rr   r   r   rf   r!   rs      s4   +
&
4        rs   c                       s0   e Zd Z fddZejejdddZ  ZS )ImageGPTMLPc                    sF   t    |j}t||| _t||| _t|j | _t	
|j| _d S r`   )ra   rb   r^   r   c_fcr2   r
   Zactivation_functionactr   r   r   dropout)re   Zintermediate_sizerT   r~   rf   r   r!   rb     s    
zImageGPTMLP.__init__)r   ri   c                 C   s,   |  |}| |}| |}| |}|S r`   )r   r   r2   r   )re   r   r   r   r!   rl     s
    



zImageGPTMLP.forward)rm   rn   ro   rb   rM   rd   rl   rr   r   r   rf   r!   r     s   r   c                       sl   e Zd Zd fdd	Zd	ejee eej eej eej eej ee ee eej e	d
ddZ
  ZS )
ImageGPTBlockNc                    s   t    |j}|jd ur |jnd| }t||jd| _t||d| _t||jd| _	|j
r|t|d|d| _t||jd| _t||| _d S )N   r_   rv   Trt   )ra   rb   r^   Zn_innerr\   layer_norm_epsilonln_1rs   r1   ln_2add_cross_attentioncrossattentionln_cross_attnr   mlp)re   rT   rv   r^   Z	inner_dimrf   r   r!   rb     s    
zImageGPTBlock.__init__Fr   c
              
   C   s   |}
|  |}| j|||||||	d}|d }|dd  }||
 }|d urt| dsdtd|  d|}
| |}| j||||||||	d}|d }|
| }||dd   }|}
| |}| |}|
| }|f| S )N)r   r   r   r   r   r   r   r   r   z'If `encoder_hidden_states` are passed, z` has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`)r   r   r   r   r   r   r   )r   r1   r   r   r   r   r   r   )re   r   r   r   r   r   r   r   r   r   ZresidualZattn_outputsr   outputsZcross_attn_outputsZfeed_forward_hidden_statesr   r   r!   rl     sN    
	





zImageGPTBlock.forward)N)NNNNNFFN)rm   rn   ro   rb   rM   rd   r   r   r|   rp   rl   rr   r   r   rf   r!   r     s*           r   c                       sD   e Zd ZU eed< eZdZdZdZ	dgZ
 fddZdd	 Z  ZS )
ImageGPTPreTrainedModelrT   r%   	input_idsTr   c                    s   t  j|i | d S r`   )ra   rb   )re   inputskwargsrf   r   r!   rb     s    z ImageGPTPreTrainedModel.__init__c                 C   s   t |tjtfr>|jjjd| jjd |j	dur|j	j
  nXt |tjr~|jjjd| jjd |jdur|jj|j 
  nt |tr|jjd | D ]>\}}d|v rd|v r|jjd| jjtd| jj  d qdS )zInitialize the weights.g        )rk   ZstdNr   r2   r(   r5   )r   r   Linearr   r(   rQ   Znormal_rT   Zinitializer_ranger*   Zzero_	EmbeddingZpadding_idxr\   Zfill_Znamed_parametersmathrj   n_layer)re   modulerW   pr   r   r!   _init_weights  s    


z%ImageGPTPreTrainedModel._init_weights)rm   rn   ro   r   __annotations__r[   Zload_tf_weightsZbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingZ_no_split_modulesrb   r   rr   r   r   rf   r!   r     s   
r   c                       s   e Zd Zed fddZdd Zdd Zdd	 Zede	e
j e	eee
j   e	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	e e	e e	e e	e e	e
j eeeef dddZ  ZS )ImageGPTModelrT   c                    s   t     j| _t j| j| _t j| j| _	t
 j| _t fddt jD | _t| j jd| _d| _d | _d| _|   d S )Nc                    s   g | ]}t  |d qS )r   )r   )r   ir   r   r!   
<listcomp>%      z*ImageGPTModel.__init__.<locals>.<listcomp>r   F)ra   rb   r^   r~   r   r   rR   r,   rz   r+   r   Z
embd_pdropdropZ
ModuleListrangeZnum_hidden_layershr\   r   ln_fmodel_parallel
device_mapgradient_checkpointing	post_initre   rT   rf   r   r!   rb     s     zImageGPTModel.__init__c                 C   s   | j S r`   r,   )re   r   r   r!   get_input_embeddings/  s    z"ImageGPTModel.get_input_embeddingsc                 C   s
   || _ d S r`   r   )re   Znew_embeddingsr   r   r!   set_input_embeddings2  s    z"ImageGPTModel.set_input_embeddingsc                 C   s(   |  D ]\}}| j| j| qdS )zv
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        N)itemsr   r1   r   )re   Zheads_to_prunelayerr   r   r   r!   _prune_heads5  s    zImageGPTModel._prune_headsN)r   past_key_valuesr   token_type_idsposition_idsr   inputs_embedsr   r   r   r   output_hidden_statesreturn_dictr   r   ri   c           $      K   s  |dur|n| j j}|dur |n| j j}|
dur4|
n| j j}
|durH|n| j j}|durj|durjtdnd|dur| || | }|d|d }|j	d }n,|dur| dd }|j	d }ntd|dur|j
n|j
}| jr| jr|
rtd d}
|
r0|du r0tt| j dt| j d}|
rVt|trVtd	 t|}|durh| n|}|dur|d|d }|du rtj||d | tj|d
}|d}|dur|dkrtd||d}|ddddddf }|j| jd}d| t| jj }| j jrh|durh| \}}}||f}|	du r\tj||d}	| |	}	nd}	|  || j j!}|du r| "|}| #|}|||j
 }|dur| "|}|| }| $|}||df }|rdnd}|r| j jrdnd}|rdnd}t%| j&D ]\}} | j'rntj()|j
 |durT||j
}t|tj*rn||j
}|r~||f }| ||||| ||	|
||d	}!|!d }|r||!d f }| j jr||!d f }| j'r| j+, D ]B\}"}#||#d krdt-|" | j.kr|dt-|"d  }qq| /|}|j| }|rN||f }|sptdd |||||fD S t0|||||dS )a  
        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
            `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
            `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
            sequence tokens in the vocabulary.

            If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
            `input_ids`.

            Indices can be obtained using [`AutoImageProcessor`]. See [`ImageGPTImageProcessor.__call__`] for details.

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, ImageGPTModel
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("openai/imagegpt-small")
        >>> model = ImageGPTModel.from_pretrained("openai/imagegpt-small")

        >>> inputs = image_processor(images=image, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> last_hidden_states = outputs.last_hidden_state
        ```NzDYou cannot specify both input_ids and inputs_embeds at the same timer#   r   z5You have to specify either input_ids or inputs_embedszZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...Fr   zPassing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.r   z$batch_size has to be defined and > 0rw   r   )r   r   )r   r   r   r   r   r5   zcuda:c                 s   s   | ]}|d ur|V  qd S r`   r   )r   vr   r   r!   r"     s   z(ImageGPTModel.forward.<locals>.<genexpr>)Zlast_hidden_stater   r   
attentionscross_attentions)1rT   r   r   r   use_return_dictr   Z%warn_if_padding_and_no_attention_maskr   r}   rJ   r   r   Ztrainingr8   Zwarning_oncer   r   r   rp   Zfrom_legacy_cacheZget_seq_lengthrM   ZarangelongZ	unsqueezetorx   r   r   r   r{   Zinvert_attention_maskZget_head_maskr   r,   r+   r   	enumerater   r   cudaZ
set_devicerd   r   r   strZlast_devicer   r   )$re   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   Zinput_shapeZ
batch_sizer   Zpast_lengthZencoder_batch_sizeZencoder_sequence_lengthr   Zencoder_hidden_shapeZposition_embedsr   Ztoken_type_embedsZoutput_shapeZall_self_attentionsZall_cross_attentionsZall_hidden_statesr   blockr   r   r   r   r   r!   rl   <  s    0
















"


zImageGPTModel.forward)NNNNNNNNNNNNNN)rm   rn   ro   r   rb   r   r   r   r   r   rM   rd   rp   r|   r   r   r   rl   rr   r   r   rf   r!   r     sJ                 
r   z
    The ImageGPT Model transformer with a language modeling head on top (linear layer with weights tied to the input
    embeddings).
    )Zcustom_introc                       s   e Zd ZdgZed fddZed	eej	 ee
e
ej	   eej	 eej	 eej	 eej	 eej	 eej	 eej	 eej	 ee ee ee ee eej	 eee
ef dddZ  ZS )
ImageGPTForCausalImageModelingzlm_head.weightr   c                    sH   t  | t|| _tj|j|jd dd| _d| _	d | _
|   d S )Nr   Fr*   )ra   rb   r   r%   r   r   rO   rR   r3   r   r   r   r   rf   r   r!   rb     s    
z'ImageGPTForCausalImageModeling.__init__N)r   r   r   r   r   r   r   r   r   labelsr   r   r   r   r   r   ri   c                 K   s   |dur|n| j j}| j|||||||||	|||||d}|d }| |}d}|
dur|dddddf  }|
dddf  }t }||d|d|d}|s|f|dd  }|dur|f| S |S t|||j	|j
|j|jdS )a&
  
        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
            `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
            `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
            sequence tokens in the vocabulary.

            If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
            `input_ids`.

            Indices can be obtained using [`AutoImageProcessor`]. See [`ImageGPTImageProcessor.__call__`] for details.
        labels (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, ImageGPTForCausalImageModeling
        >>> import torch
        >>> import matplotlib.pyplot as plt
        >>> import numpy as np

        >>> image_processor = AutoImageProcessor.from_pretrained("openai/imagegpt-small")
        >>> model = ImageGPTForCausalImageModeling.from_pretrained("openai/imagegpt-small")
        >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        >>> model.to(device)  # doctest: +IGNORE_RESULT

        >>> # unconditional generation of 8 images
        >>> batch_size = 4
        >>> context = torch.full((batch_size, 1), model.config.vocab_size - 1)  # initialize with SOS token
        >>> context = context.to(device)
        >>> output = model.generate(
        ...     input_ids=context, max_length=model.config.n_positions + 1, temperature=1.0, do_sample=True, top_k=40
        ... )

        >>> clusters = image_processor.clusters
        >>> height = image_processor.size["height"]
        >>> width = image_processor.size["width"]

        >>> samples = output[:, 1:].detach().cpu().numpy()
        >>> samples_img = [
        ...     np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [height, width, 3]).astype(np.uint8) for s in samples
        ... ]  # convert color cluster tokens back to pixels
        >>> f, axes = plt.subplots(1, batch_size, dpi=300)

        >>> for img, ax in zip(samples_img, axes):  # doctest: +IGNORE_RESULT
        ...     ax.axis("off")
        ...     ax.imshow(img)
        ```N)r   r   r   r   r   r   r   r   r   r   r   r   r   r   .r#   r   )losslogitsr   r   r  r  )rT   r  r%   r3   r   r   r}   r   r   r   r   r  r  )re   r   r   r   r   r   r   r   r   r   r  r   r   r   r   r   r   transformer_outputsr   Z	lm_logitsr  Zshift_logitsZshift_labelsloss_fctoutputr   r   r!   rl     sH    G
z&ImageGPTForCausalImageModeling.forward)NNNNNNNNNNNNNNN)rm   rn   ro   Z_tied_weights_keysr   rb   r   r   rM   rd   rp   r|   r   r   r   rl   rr   r   r   rf   r!   r
    sJ                  
r
  z
    The ImageGPT Model transformer with an image classification head on top (linear layer).
    [`ImageGPTForImageClassification`] average-pools the hidden states in order to do the classification.
    c                       s   e Zd Zed fddZedeej ee	e	ej   eej eej eej eej eej eej ee
 ee
 ee
 ee
 eee	ef dddZ  ZS )	ImageGPTForImageClassificationr   c                    s@   t  | |j| _t|| _tj|j| jdd| _| 	  d S )NFr  )
ra   rb   
num_labelsr   r%   r   r   rO   scorer   r   rf   r   r!   rb     s
    
z'ImageGPTForImageClassification.__init__N)r   r   r   r   r   r   r   r  r   r   r   r   r   ri   c                 K   s  |dur|n| j j}| j||||||||	|
||d}|d }|jdd}| |}d}|dur>| j jdu r| jdkrd| j _n4| jdkr|jtj	ks|jtj
krd| j _nd| j _| j jdkrt }| jdkr|| | }n
|||}nN| j jdkr t }||d	| j|d	}n| j jdkr>t }|||}|sn|f|dd  }|durj|f| S |S t|||j|j|jd
S )ay  
        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
            `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
            `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
            sequence tokens in the vocabulary.

            If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
            `input_ids`.

            Indices can be obtained using [`AutoImageProcessor`]. See [`ImageGPTImageProcessor.__call__`] for details.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, ImageGPTForImageClassification
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("openai/imagegpt-small")
        >>> model = ImageGPTForImageClassification.from_pretrained("openai/imagegpt-small")

        >>> inputs = image_processor(images=image, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> logits = outputs.logits
        ```N)
r   r   r   r   r   r   r   r   r   r   r   r   r   Z
regressionZsingle_label_classificationZmulti_label_classificationr#   )r  r  r   r   r  )rT   r  r%   rk   r  Zproblem_typer  rx   rM   r  rI   r   r@   r   r}   r   r   r   r   r  )re   r   r   r   r   r   r   r   r  r   r   r   r   r   r  r   Zpooled_hidden_statesr  r  r  r  r   r   r!   rl     s\    2



"


z&ImageGPTForImageClassification.forward)NNNNNNNNNNNN)rm   rn   ro   r   rb   r   r   rM   rd   rp   r|   r   r   r   rl   rr   r   r   rf   r!   r    s<   	            
r  )r
  r  r   r   r[   )6__doc__r   r:   typingr   r   r   rM   Ztorch.utils.checkpointr   Ztorch.nnr   r   r   Zactivationsr
   Zcache_utilsr   r   r   Z
generationr   Zmodeling_layersr   Zmodeling_outputsr   r   r   Zmodeling_utilsr   Zpytorch_utilsr   r   r   utilsr   r   r   Zconfiguration_imagegptr   Z
get_loggerrm   r8   r[   Moduler\   rs   r   r   r   r   r
  r  __all__r   r   r   r!   <module>   sL   
l mM& q t