a
    h                  
   @   sv  d Z ddlZddlmZmZmZ ddlZddlZddlmZ ddl	m
Z
mZmZ ddlmZ ddlmZ dd	lmZmZmZmZmZmZ dd
lmZmZ ddlmZmZmZ ddlm Z m!Z!m"Z" ddl#m$Z$ e"%e&Z'G dd dej(Z)d=ddZ*G dd dej(Z+G dd dej(Z,G dd dej(Z-G dd dej(Z.G dd dej(Z/G dd dej(Z0G dd  d ej(Z1G d!d" d"ej(Z2d>ej(ej3ej3ej3eej3 e4e4eej3 d$d%d&Z5G d'd( d(ej(Z6G d)d* d*ej(Z7G d+d, d,eZ8G d-d. d.ej(Z9e G d/d0 d0eZ:e G d1d2 d2e:Z;e G d3d4 d4e:Z<e d5d6G d7d8 d8e:Z=e d9d6G d:d; d;e:Z>g d<Z?dS )?zPyTorch MarkupLM model.    N)CallableOptionalUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPoolingMaskedLMOutputQuestionAnsweringModelOutputSequenceClassifierOutputTokenClassifierOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel)apply_chunking_to_forward find_pruneable_heads_and_indicesprune_linear_layer)auto_docstringcan_return_tuplelogging   )MarkupLMConfigc                       s*   e Zd ZdZ fddZdddZ  ZS )XPathEmbeddingszConstruct the embeddings from xpath tags and subscripts.

    We drop tree-id in this version, as its info can be covered by xpath.
    c                    s   t     j| _t j| j  j| _t j	| _
t | _t j| j d j | _td j  j| _t fddt| jD | _t fddt| jD | _d S )N   c                    s   g | ]}t  j jqS  )r   	EmbeddingZmax_xpath_tag_unit_embeddingsxpath_unit_hidden_size.0_configr   j/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/markuplm/modeling_markuplm.py
<listcomp>?   s   z,XPathEmbeddings.__init__.<locals>.<listcomp>c                    s   g | ]}t  j jqS r   )r   r   Zmax_xpath_subs_unit_embeddingsr    r!   r$   r   r&   r'   F   s   )super__init__	max_depthr   Linearr    hidden_sizeZxpath_unitseq2_embeddingsDropouthidden_dropout_probdropoutZReLU
activationxpath_unitseq2_inner	inner2emb
ModuleListrangexpath_tag_sub_embeddingsxpath_subs_sub_embeddingsselfr%   	__class__r$   r&   r)   2   s"    



zXPathEmbeddings.__init__Nc              	   C   s   g }g }t | jD ]P}|| j| |d d d d |f  || j| |d d d d |f  qtj|dd}tj|dd}|| }| | | 	| 
|}|S )Ndim)r4   r*   appendr5   r6   torchcatr2   r/   r0   r1   )r8   xpath_tags_seqxpath_subs_seqZxpath_tags_embeddingsZxpath_subs_embeddingsixpath_embeddingsr   r   r&   forwardL   s    &(zXPathEmbeddings.forward)NN)__name__
__module____qualname____doc__r)   rE   __classcell__r   r   r9   r&   r   ,   s   r   c                 C   s6   |  | }tj|dd|| | }| | S )a  
    Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
    are ignored. This is modified from fairseq's `utils.make_positions`.

    Args:
        x: torch.Tensor x:

    Returns: torch.Tensor
    r   r<   )neintr?   ZcumsumZtype_aslong)	input_idspadding_idxpast_key_values_lengthmaskZincremental_indicesr   r   r&   "create_position_ids_from_input_ids_   s    rR   c                       s2   e Zd ZdZ fddZdd Zd
dd	Z  ZS )MarkupLMEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    s   t    || _tj|j|j|jd| _t|j	|j| _
|j| _t|| _t|j|j| _tj|j|jd| _t|j| _| jdt|j	ddd |j| _tj|j	|j| jd| _
d S )N)rO   epsposition_ids)r   r;   F)
persistent)r(   r)   r%   r   r   
vocab_sizer,   Zpad_token_idword_embeddingsZmax_position_embeddingsposition_embeddingsr*   r   rD   Ztype_vocab_sizetoken_type_embeddings	LayerNormlayer_norm_epsr-   r.   r/   Zregister_bufferr?   arangeexpandrO   r7   r9   r   r&   r)   r   s     

zMarkupLMEmbeddings.__init__c                 C   sN   |  dd }|d }tj| jd || j d tj|jd}|d|S )z
        We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.

        Args:
            inputs_embeds: torch.Tensor

        Returns: torch.Tensor
        Nr;   r   dtypedevicer   )sizer?   r^   rO   rM   rb   	unsqueezer_   )r8   inputs_embedsinput_shapeZsequence_lengthrV   r   r   r&   &create_position_ids_from_inputs_embeds   s    	z9MarkupLMEmbeddings.create_position_ids_from_inputs_embedsNr   c                 C   s<  |d ur|  }n|  d d }|d ur0|jn|j}	|d u r`|d urVt|| j|}n
| |}|d u rztj|tj|	d}|d u r| |}|d u r| j	j
tjtt|| jg tj|	d }|d u r| j	jtjtt|| jg tj|	d }|}
| |}| |}| ||}|
| | | }| |}| |}|S )Nr;   r`   )rc   rb   rR   rO   rg   r?   zerosrM   rY   r%   Z
tag_pad_idonestuplelistr*   Zsubs_pad_idrZ   r[   rD   r\   r/   )r8   rN   rA   rB   token_type_idsrV   re   rP   rf   rb   Zwords_embeddingsrZ   r[   rD   
embeddingsr   r   r&   rE      s8    









zMarkupLMEmbeddings.forward)NNNNNNr   )rF   rG   rH   rI   r)   rg   rE   rJ   r   r   r9   r&   rS   o   s          rS   c                       s4   e Zd Z fddZejejejdddZ  ZS )MarkupLMSelfOutputc                    sB   t    t|j|j| _tj|j|jd| _t|j	| _
d S NrT   )r(   r)   r   r+   r,   denser\   r]   r-   r.   r/   r7   r9   r   r&   r)      s    
zMarkupLMSelfOutput.__init__hidden_statesinput_tensorreturnc                 C   s&   |  |}| |}| || }|S Nrp   r/   r\   r8   rr   rs   r   r   r&   rE      s    

zMarkupLMSelfOutput.forwardrF   rG   rH   r)   r?   TensorrE   rJ   r   r   r9   r&   rn      s   rn   c                       s0   e Zd Z fddZejejdddZ  ZS )MarkupLMIntermediatec                    sB   t    t|j|j| _t|jt	r6t
|j | _n|j| _d S ru   )r(   r)   r   r+   r,   intermediate_sizerp   
isinstance
hidden_actstrr
   intermediate_act_fnr7   r9   r   r&   r)      s
    
zMarkupLMIntermediate.__init__rr   rt   c                 C   s   |  |}| |}|S ru   )rp   r   r8   rr   r   r   r&   rE      s    

zMarkupLMIntermediate.forwardrx   r   r   r9   r&   rz      s   rz   c                       s4   e Zd Z fddZejejejdddZ  ZS )MarkupLMOutputc                    sB   t    t|j|j| _tj|j|jd| _t	|j
| _d S ro   )r(   r)   r   r+   r{   r,   rp   r\   r]   r-   r.   r/   r7   r9   r   r&   r)      s    
zMarkupLMOutput.__init__rq   c                 C   s&   |  |}| |}| || }|S ru   rv   rw   r   r   r&   rE      s    

zMarkupLMOutput.forwardrx   r   r   r9   r&   r      s   r   c                       s0   e Zd Z fddZejejdddZ  ZS )MarkupLMPoolerc                    s*   t    t|j|j| _t | _d S ru   )r(   r)   r   r+   r,   rp   ZTanhr0   r7   r9   r   r&   r)     s    
zMarkupLMPooler.__init__r   c                 C   s(   |d d df }|  |}| |}|S )Nr   )rp   r0   )r8   rr   Zfirst_token_tensorpooled_outputr   r   r&   rE     s    

zMarkupLMPooler.forwardrx   r   r   r9   r&   r      s   r   c                       s0   e Zd Z fddZejejdddZ  ZS )MarkupLMPredictionHeadTransformc                    sV   t    t|j|j| _t|jtr6t	|j | _
n|j| _
tj|j|jd| _d S ro   )r(   r)   r   r+   r,   rp   r|   r}   r~   r
   transform_act_fnr\   r]   r7   r9   r   r&   r)     s    
z(MarkupLMPredictionHeadTransform.__init__r   c                 C   s"   |  |}| |}| |}|S ru   )rp   r   r\   r   r   r   r&   rE     s    


z'MarkupLMPredictionHeadTransform.forwardrx   r   r   r9   r&   r     s   	r   c                       s,   e Zd Z fddZdd Zdd Z  ZS )MarkupLMLMPredictionHeadc                    sL   t    t|| _tj|j|jdd| _t	t
|j| _| j| j_d S )NF)bias)r(   r)   r   	transformr   r+   r,   rX   decoder	Parameterr?   rh   r   r7   r9   r   r&   r)   #  s
    

z!MarkupLMLMPredictionHead.__init__c                 C   s   | j | j_ d S ru   )r   r   r8   r   r   r&   _tie_weights0  s    z%MarkupLMLMPredictionHead._tie_weightsc                 C   s   |  |}| |}|S ru   )r   r   r   r   r   r&   rE   3  s    

z MarkupLMLMPredictionHead.forward)rF   rG   rH   r)   r   rE   rJ   r   r   r9   r&   r   "  s   r   c                       s0   e Zd Z fddZejejdddZ  ZS )MarkupLMOnlyMLMHeadc                    s   t    t|| _d S ru   )r(   r)   r   predictionsr7   r9   r   r&   r)   ;  s    
zMarkupLMOnlyMLMHead.__init__)sequence_outputrt   c                 C   s   |  |}|S ru   )r   )r8   r   prediction_scoresr   r   r&   rE   ?  s    
zMarkupLMOnlyMLMHead.forwardrx   r   r   r9   r&   r   :  s   r           )modulequerykeyvalueattention_maskscalingr/   	head_maskc                 K   s   t ||dd| }	|d urN|d d d d d d d |jd f }
|	|
 }	tjj|	dt jd|j	}	tjj
|	|| jd}	|d ur|	|dddd }	t |	|}|dd }||	fS )N   r	   r;   )r=   ra   )ptrainingr   )r?   matmul	transposeshaper   Z
functionalZsoftmaxZfloat32tora   r/   r   view
contiguous)r   r   r   r   r   r   r/   r   kwargsattn_weightsZcausal_maskattn_outputr   r   r&   eager_attention_forwardE  s    &r   c                       sL   e Zd Z fddZdejeej eej ee e	ej dddZ
  ZS )	MarkupLMSelfAttentionc                    s   t    |j|j dkr>t|ds>td|j d|j d|| _|j| _t|j|j | _| j| j | _	t
|j| j	| _t
|j| j	| _t
|j| j	| _t
|j| _|j| _| jd | _d S )Nr   Zembedding_sizezThe hidden size (z6) is not a multiple of the number of attention heads ()g      )r(   r)   r,   num_attention_headshasattr
ValueErrorr%   rL   attention_head_sizeall_head_sizer   r+   r   r   r   r-   Zattention_probs_dropout_probr/   attention_dropoutr   r7   r9   r   r&   r)   b  s"    

zMarkupLMSelfAttention.__init__NFrr   r   r   output_attentionsrt   c                 K   s   |j d d }g |d| jR }| ||dd}| ||dd}	| ||dd}
t}| jj	dkrt
| jj	 }|| ||	|
|f| jsdn| j| j|d|\}}|jg |dR   }|r||fn|f}|S )Nr;   r   r   eagerr   )r/   r   r   )r   r   r   r   r   r   r   r   r%   Z_attn_implementationr   r   r   r   Zreshaper   )r8   rr   r   r   r   r   rf   Zhidden_shapeZquery_statesZ
key_statesZvalue_statesZattention_interfacer   r   outputsr   r   r&   rE   w  s0    	
zMarkupLMSelfAttention.forward)NNF)rF   rG   rH   r)   r?   ry   r   FloatTensorboolrj   rE   rJ   r   r   r9   r&   r   a  s      r   c                       sT   e Zd Z fddZdd Zd
ejeej eej ee	 e
ej ddd	Z  ZS )MarkupLMAttentionc                    s*   t    t|| _t|| _t | _d S ru   )r(   r)   r   r8   rn   outputsetpruned_headsr7   r9   r   r&   r)     s    


zMarkupLMAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r<   )lenr   r8   r   r   r   r   r   r   r   r   rp   r   union)r8   headsindexr   r   r&   prune_heads  s    zMarkupLMAttention.prune_headsNFr   c           	      K   s@   | j |f|||d|}| |d |}|f|dd   }|S N)r   r   r   r   r   )r8   r   )	r8   rr   r   r   r   r   Zself_outputsattention_outputr   r   r   r&   rE     s    zMarkupLMAttention.forward)NNF)rF   rG   rH   r)   r   r?   ry   r   r   r   rj   rE   rJ   r   r   r9   r&   r     s      r   c                       sT   e Zd Z fddZd
ejeej eej ee e	ej dddZ
dd	 Z  ZS )MarkupLMLayerc                    s:   t    |j| _d| _t|| _t|| _t|| _	d S )Nr   )
r(   r)   chunk_size_feed_forwardseq_len_dimr   	attentionrz   intermediater   r   r7   r9   r   r&   r)     s    


zMarkupLMLayer.__init__NFr   c           
      K   sP   | j |f|||d|}|d }|dd  }t| j| j| j|}	|	f| }|S r   )r   r   feed_forward_chunkr   r   )
r8   rr   r   r   r   r   Zself_attention_outputsr   r   layer_outputr   r   r&   rE     s     
zMarkupLMLayer.forwardc                 C   s   |  |}| ||}|S ru   )r   r   )r8   r   Zintermediate_outputr   r   r   r&   r     s    
z MarkupLMLayer.feed_forward_chunk)NNF)rF   rG   rH   r)   r?   ry   r   r   r   rj   rE   r   rJ   r   r   r9   r&   r     s      r   c                       sd   e Zd Z fddZed	ejeej eej ee	 ee	 ee	 e
eej ef dddZ  ZS )
MarkupLMEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r   )r   )r"   rC   r$   r   r&   r'         z,MarkupLMEncoder.__init__.<locals>.<listcomp>F)	r(   r)   r%   r   r3   r4   num_hidden_layerslayerZgradient_checkpointingr7   r9   r$   r&   r)     s    
 zMarkupLMEncoder.__init__NFT)rr   r   r   r   output_hidden_statesreturn_dictrt   c                 K   s   |rdnd }|rdnd }	t | jD ]\\}
}|r8||f }|d urH||
 nd }|f ||||d|}|d }|r"|	|d f }	q"|r||f }t|||	dS )Nr   )rr   r   r   r   r   r   )last_hidden_staterr   
attentions)	enumerater   r   )r8   rr   r   r   r   r   r   r   Zall_hidden_statesZall_self_attentionsrC   Zlayer_moduleZlayer_head_maskZlayer_outputsr   r   r&   rE     s0    

zMarkupLMEncoder.forward)NNFFT)rF   rG   rH   r)   r   r?   ry   r   r   r   r   rj   r   rE   rJ   r   r   r9   r&   r     s         r   c                       sJ   e Zd ZU eed< dZdd Zeee	e
ejf  d fddZ  ZS )MarkupLMPreTrainedModelr%   markuplmc                 C   s   t |tjr:|jjjd| jjd |jdur|jj	  n~t |tj
rz|jjjd| jjd |jdur|jj|j 	  n>t |tjr|jj	  |jjd nt |tr|jj	  dS )zInitialize the weightsr   )meanZstdN      ?)r|   r   r+   weightdataZnormal_r%   Zinitializer_ranger   Zzero_r   rO   r\   Zfill_r   )r8   r   r   r   r&   _init_weights*  s    


z%MarkupLMPreTrainedModel._init_weights)pretrained_model_name_or_pathc                    s   t  j|g|R i |S ru   )r(   from_pretrained)clsr   Z
model_argsr   r9   r   r&   r   <  s    z'MarkupLMPreTrainedModel.from_pretrained)rF   rG   rH   r   __annotations__Zbase_model_prefixr   classmethodr   r   r~   osPathLiker   rJ   r   r   r9   r&   r   $  s
   
r   c                       s   e Zd Zd fdd	Zdd Zdd Zdd	 Zeede	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	e e	e e	e eeef dddZ  ZS )MarkupLMModelTc                    sD   t  | || _t|| _t|| _|r2t|nd| _| 	  dS )zv
        add_pooling_layer (bool, *optional*, defaults to `True`):
            Whether to add a pooling layer
        N)
r(   r)   r%   rS   rm   r   encoderr   pooler	post_init)r8   r%   add_pooling_layerr9   r   r&   r)   D  s    

zMarkupLMModel.__init__c                 C   s   | j jS ru   rm   rY   r   r   r   r&   get_input_embeddingsT  s    z"MarkupLMModel.get_input_embeddingsc                 C   s   || j _d S ru   r   )r8   r   r   r   r&   set_input_embeddingsW  s    z"MarkupLMModel.set_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   r   r   r   )r8   Zheads_to_pruner   r   r   r   r&   _prune_headsZ  s    zMarkupLMModel._prune_headsN)rN   rA   rB   r   rl   rV   r   re   r   r   r   rt   c                 C   s  |	dur|	n| j j}	|
dur |
n| j j}
|dur4|n| j j}|durV|durVtdn@|durt| || | }n"|dur| dd }ntd|dur|jn|j}|du rtj	||d}|du rtj
|tj|d}|dd}|j| jd	}d
| d }|dur| dkrP|dddd}|| j jdddd}n$| dkrt|ddd}|jt|  jd	}ndg| j j }| j||||||d}| j||||	|
dd}|d }| jdur| |nd}t|||j|jdS )a  
        xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
            Tag IDs for each token in the input sequence, padded up to config.max_depth.
        xpath_subs_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
            Subscript IDs for each token in the input sequence, padded up to config.max_depth.

        Examples:

        ```python
        >>> from transformers import AutoProcessor, MarkupLMModel

        >>> processor = AutoProcessor.from_pretrained("microsoft/markuplm-base")
        >>> model = MarkupLMModel.from_pretrained("microsoft/markuplm-base")

        >>> html_string = "<html> <head> <title>Page Title</title> </head> </html>"

        >>> encoding = processor(html_string, return_tensors="pt")

        >>> outputs = model(**encoding)
        >>> last_hidden_states = outputs.last_hidden_state
        >>> list(last_hidden_states.shape)
        [1, 4, 768]
        ```NzDYou cannot specify both input_ids and inputs_embeds at the same timer;   z5You have to specify either input_ids or inputs_embeds)rb   r`   r   r   )ra   r   g     r   )rN   rA   rB   rV   rl   re   T)r   r   r   r   )r   Zpooler_outputrr   r   )r%   r   r   use_return_dictr   Z%warn_if_padding_and_no_attention_maskrc   rb   r?   ri   rh   rM   rd   r   ra   r=   r_   r   next
parametersrm   r   r   r   rr   r   )r8   rN   rA   rB   r   rl   rV   r   re   r   r   r   rf   rb   Zextended_attention_maskZembedding_outputZencoder_outputsr   r   r   r   r&   rE   b  sh    '


zMarkupLMModel.forward)T)NNNNNNNNNNN)rF   rG   rH   r)   r   r   r   r   r   r   r?   Z
LongTensorr   r   r   rj   r   rE   rJ   r   r   r9   r&   r   A  s>              
r   c                       s   e Zd Z fddZeedeej eej eej eej eej eej eej eej eej eej ee	 ee	 ee	 e
eej ef dddZ  ZS )MarkupLMForQuestionAnsweringc                    s@   t  | |j| _t|dd| _t|j|j| _| 	  d S NF)r   )
r(   r)   
num_labelsr   r   r   r+   r,   
qa_outputsr   r7   r9   r   r&   r)     s
    z%MarkupLMForQuestionAnswering.__init__N)rN   rA   rB   r   rl   rV   r   re   start_positionsend_positionsr   r   r   rt   c                 C   s"  |dur|n| j j}| j||||||||||dd}|d }| |}|jddd\}}|d }|d }d}|	dur|
durt|	 dkr|	d}	t|
 dkr|
d}
|d}|		d| |
	d| t
|d}|||	}|||
}|| d	 }t||||j|jd
S )ae  
        xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
            Tag IDs for each token in the input sequence, padded up to config.max_depth.
        xpath_subs_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
            Subscript IDs for each token in the input sequence, padded up to config.max_depth.

        Examples:

        ```python
        >>> from transformers import AutoProcessor, MarkupLMForQuestionAnswering
        >>> import torch

        >>> processor = AutoProcessor.from_pretrained("microsoft/markuplm-base-finetuned-websrc")
        >>> model = MarkupLMForQuestionAnswering.from_pretrained("microsoft/markuplm-base-finetuned-websrc")

        >>> html_string = "<html> <head> <title>My name is Niels</title> </head> </html>"
        >>> question = "What's his name?"

        >>> encoding = processor(html_string, questions=question, return_tensors="pt")

        >>> with torch.no_grad():
        ...     outputs = model(**encoding)

        >>> answer_start_index = outputs.start_logits.argmax()
        >>> answer_end_index = outputs.end_logits.argmax()

        >>> predict_answer_tokens = encoding.input_ids[0, answer_start_index : answer_end_index + 1]
        >>> processor.decode(predict_answer_tokens).strip()
        'Niels'
        ```NT
rA   rB   r   rl   rV   r   re   r   r   r   r   r   r;   r<   )Zignore_indexr   )lossstart_logits
end_logitsrr   r   )r%   r   r   r   splitsqueezer   r   rc   Zclamp_r   r   rr   r   )r8   rN   rA   rB   r   rl   rV   r   re   r   r   r   r   r   r   r   logitsr   r   Z
total_lossZignored_indexloss_fctZ
start_lossZend_lossr   r   r&   rE     sN    0






z$MarkupLMForQuestionAnswering.forward)NNNNNNNNNNNNN)rF   rG   rH   r)   r   r   r   r?   ry   r   r   rj   r   rE   rJ   r   r   r9   r&   r     s@   
             r   zC
    MarkupLM Model with a `token_classification` head on top.
    )Zcustom_introc                       s   e Zd Z fddZeedeej eej eej eej eej eej eej eej eej ee	 ee	 ee	 e
eej ef dddZ  ZS )MarkupLMForTokenClassificationc                    sb   t  | |j| _t|dd| _|jd ur2|jn|j}t|| _	t
|j|j| _|   d S r   )r(   r)   r   r   r   classifier_dropoutr.   r   r-   r/   r+   r,   
classifierr   r8   r%   r   r9   r   r&   r)   ?  s    z'MarkupLMForTokenClassification.__init__NrN   rA   rB   r   rl   rV   r   re   labelsr   r   r   rt   c                 C   s   |dur|n| j j}| j|||||||||
|dd}|d }| |}d}|	durtt }||d| j j|	d}t|||j|j	dS )a  
        xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
            Tag IDs for each token in the input sequence, padded up to config.max_depth.
        xpath_subs_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
            Subscript IDs for each token in the input sequence, padded up to config.max_depth.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.

        Examples:

        ```python
        >>> from transformers import AutoProcessor, AutoModelForTokenClassification
        >>> import torch

        >>> processor = AutoProcessor.from_pretrained("microsoft/markuplm-base")
        >>> processor.parse_html = False
        >>> model = AutoModelForTokenClassification.from_pretrained("microsoft/markuplm-base", num_labels=7)

        >>> nodes = ["hello", "world"]
        >>> xpaths = ["/html/body/div/li[1]/div/span", "/html/body/div/li[1]/div/span"]
        >>> node_labels = [1, 2]
        >>> encoding = processor(nodes=nodes, xpaths=xpaths, node_labels=node_labels, return_tensors="pt")

        >>> with torch.no_grad():
        ...     outputs = model(**encoding)

        >>> loss = outputs.loss
        >>> logits = outputs.logits
        ```NTr   r   r;   r   r   rr   r   )
r%   r   r   r   r   r   r   r   rr   r   )r8   rN   rA   rB   r   rl   rV   r   re   r   r   r   r   r   r   r   r   r   r   r   r&   rE   M  s:    .
z&MarkupLMForTokenClassification.forward)NNNNNNNNNNNN)rF   rG   rH   r)   r   r   r   r?   ry   r   r   rj   r   rE   rJ   r   r   r9   r&   r   8  s<               r   z
    MarkupLM Model transformer with a sequence classification/regression head on top (a linear layer on top of the
    pooled output) e.g. for GLUE tasks.
    c                       s   e Zd Z fddZeedeej eej eej eej eej eej eej eej eej ee	 ee	 ee	 e
eej ef dddZ  ZS )!MarkupLMForSequenceClassificationc                    sd   t  | |j| _|| _t|| _|jd ur4|jn|j}t	|| _
t|j|j| _|   d S ru   )r(   r)   r   r%   r   r   r   r.   r   r-   r/   r+   r,   r   r   r   r9   r   r&   r)     s    
z*MarkupLMForSequenceClassification.__init__Nr   c                 C   sP  |dur|n| j j}| j|||||||||
|dd}|d }| |}| |}d}|	dur<| j jdu r| jdkr~d| j _n4| jdkr|	jtj	ks|	jtj
krd| j _nd| j _| j jdkrt }| jdkr|| |	 }n
|||	}nN| j jdkrt }||d| j|	d}n| j jdkr<t }|||	}t|||j|jd	S )
a  
        xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
            Tag IDs for each token in the input sequence, padded up to config.max_depth.
        xpath_subs_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
            Subscript IDs for each token in the input sequence, padded up to config.max_depth.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).

        Examples:

        ```python
        >>> from transformers import AutoProcessor, AutoModelForSequenceClassification
        >>> import torch

        >>> processor = AutoProcessor.from_pretrained("microsoft/markuplm-base")
        >>> model = AutoModelForSequenceClassification.from_pretrained("microsoft/markuplm-base", num_labels=7)

        >>> html_string = "<html> <head> <title>Page Title</title> </head> </html>"
        >>> encoding = processor(html_string, return_tensors="pt")

        >>> with torch.no_grad():
        ...     outputs = model(**encoding)

        >>> loss = outputs.loss
        >>> logits = outputs.logits
        ```NTr   r   Z
regressionZsingle_label_classificationZmulti_label_classificationr;   r   )r%   r   r   r/   r   Zproblem_typer   ra   r?   rM   rL   r   r   r   r   r   r   rr   r   )r8   rN   rA   rB   r   rl   rV   r   re   r   r   r   r   r   r   r   r   r   r   r   r&   rE     sT    -




"


z)MarkupLMForSequenceClassification.forward)NNNNNNNNNNNN)rF   rG   rH   r)   r   r   r   r?   ry   r   r   rj   r   rE   rJ   r   r   r9   r&   r    s<               r  )r   r  r   r   r   )r   )r   N)@rI   r   typingr   r   r   r?   Ztorch.utils.checkpointr   Ztorch.nnr   r   r   Zactivationsr
   Zmodeling_layersr   Zmodeling_outputsr   r   r   r   r   r   Zmodeling_utilsr   r   Zpytorch_utilsr   r   r   utilsr   r   r   Zconfiguration_markuplmr   Z
get_loggerrF   loggerModuler   rR   rS   rn   rz   r   r   r   r   r   ry   floatr   r   r   r   r   r   r   r   r   r  __all__r   r   r   r&   <module>   sp    
3
c  ;.)1 	mar