a
    hB                     @   s  d dl Z d dlZd dlZd dlZd dlmZmZ d dlmZ d dl	m
Z
 ddlmZmZ e rd dlZd dlmZ ej rd dlZdZnd	Zdd
lmZ eeZdd ZedZeedddZdd Zd,e
e e e
e dddZ!d-e
e e e
e dddZ"ej#dddZ$dd Z%edZ&dd Z'd d! Z(d"d# Z)d.e
e e
e  e d%d&d'Z*ed(d)ed/e
e e
e  e
e  d%d*d+Z+dS )0    N)contextmanagerredirect_stdout)StringIO)Optional   )is_torch_availablerequires)	save_fileTF)loggingc                   C   s    t rtj sdS tj dkS )z7Return True if rank=0 or we aren't running distributed.Tr   )_torch_distributed_availabletorchdistributedZis_initializedZget_rank r   r   ^/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/model_debugging_utils.py_is_rank_zero,   s    r   zobject at 0x[0-9A-Fa-f]+)x_strreturnc                 C   s   t d| S )z
    Replace memory addresses in an object's repr with a stable placeholder
    so that beautiful JSON diffs won't be ruined by ephemeral addresses.
    zobject at 0xXXXXXXXX)MEMORY_ADDRESS_REGEXsub)r   r   r   r   _sanitize_repr_for_diff6   s    r   c                 C   s   t  rdt| j S dS )z@Return a stable string representation for a DTensor-like object.zDTensor (rank0) -> zDTensor(non-rank0))r   repr_local_tensor)xr   r   r   _dtensor_repr>   s    r   
debug_pathuse_reprpath_to_valuec              	   C   s   t jdd |rt| }nh|rl|ds0|d7 }|rBtj||n|}td|  	 
 i| d| }ntd|d|dt| jt| j|d	}| jt jt jt jhv r|tt|  tt|  tt|  tt|  d
 |S )a  
    Converts Tensors and DTensors to a JSON-serializable dictionary representation.

    Args:
        value: Any Python object, often including torch Tensors, lists, dicts, etc.
        debug_path (`str`, *optional*, defaults to `None`): Directory to dump debug JSON and SafeTensors files.
        use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensor as the
            `value` property in the asscoiated FULL_TENSORS.json file, or to store the full tensors in separate
            SafeTensors file and store the relative path to that file in the `value` property in the dictionary.
        path_to_value (`str`, *optional*, defaults to `None`): The file name for the SafeTensors file holding the full
            tensor value if `use_repr=False`.

    Returns:
        A nested Python structure (list, dict, or sanitized string) that is safe to json.dump.
    T)sci_modez.safetensorsdataz./z	use_repr=z and path_to_value=z cannot both be falsy.)shapedtypevalue)meanstdminmax)r   set_printoptions_repr_to_listendswithospathjoinr	   
contiguousdetachcpu
ValueErrorr   r    r!   Zfloat16Zfloat32Zbfloat16updater   r#   r$   r%   r&   )r"   r   r   r   Z	value_outfilepathoutr   r   r   _serialize_tensor_like_ioE   s.    

r4   c                    s   t | ttfr( fddt| D S t | trL fdd|  D S t| drht| j dS t | t	j
rt|  dS tt| S )a  
    Recursively build a JSON-serializable Python structure from `value`.
    Tensors and DTensors become either sanitized repr strings, or are saved to disk as SafeTensors files and their
    relative paths are recorded in the returned Python structure.
    Lists/tuples/dicts are recursed into.
    All memory addresses are replaced with a stable placeholder.

    Args:
        value: Any Python object, often including torch Tensors, lists, dicts, etc.
        debug_path (`str`, *optional*, defaults to `None`): Directory to dump debug JSON and SafeTensors files.
        use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensors as the
            `value` property in the asscoiated FULL_TENSORS.json file, or to store full tensors in separate SafeTensors
            files and store the relative path to that file in the `value` property.
        path_to_value (`str`, *optional*, defaults to `None`): The file name for the SafeTensors file holding the full
            tensor value if `use_repr=False`.

    Returns:
        A nested Python structure (list, dict, or sanitized string) that is safe to json.dump.
    c              	      s*   g | ]"\}}t |  d | dqS _r   _serialize_io).0ivr   r   r   r   r   
<listcomp>   s   z!_serialize_io.<locals>.<listcomp>c              
      s,   i | ]$\}}|t |  d | dqS r5   r7   )r9   kr;   r<   r   r   
<dictcomp>   s   z!_serialize_io.<locals>.<dictcomp>r   r   )
isinstancelisttuple	enumeratedictitemshasattrr4   r   r   Tensorr   r   )r"   r   r   r   r   r<   r   r8   v   s    


r8   )r"   c              	   C   sx   t jddd t H}t|  t|  | }W d   n1 sD0    Y  W d   n1 sb0    Y  t| S )z
    Converts a tensor into a sanitized multi-line string representation.

    Args:
        value (`torch.Tensor`): The tensor to represent.

    Returns:
        `list[str]`: List of string lines representing the tensor.
    Tx   )r   Z	linewidthN)r   r'   r   r   printgetvaluer   
splitlines)r"   bufrawr   r   r   r(      s
    
Dr(   c                 C   s0   |  dr,| dd  | d D ]}t| qd S )Nchildrenoutputs)getpopprune_outputs_if_childrennodechildr   r   r   rR      s    
rR   z(.*)\.(\d+)$c                    sH   t | dd}|r | ds$dS |d t fdd| d D S )z
    Checks whether a node represents a layer block with submodules.

    Args:
        node (`dict`): A node from the call tree.

    Returns:
        `bool`: Whether the node is a layer block.
    module_path rN   F   c                 3   s&   | ]}d   d | ddv V  qdS ).rV   rW   NrP   )r9   rU   numberr   r   	<genexpr>       z!is_layer_block.<locals>.<genexpr>)LAYER_SUFFIX_REmatchrP   groupany)rT   r`   r   r[   r   is_layer_block   s
    

rc   c                    s~   |  dsdS dd t| d D }t|dkrddd |dd D   fd	dt| d D | d< | d D ]}t| qldS )
z
    Recursively removes intermediate layers from the tree to improve readability.
    Keeps at least the first and last layers if many consecutive layers are present.

    Args:
        node (`dict`): The root or subnode to prune recursively.
    rN   Nc                 S   s    g | ]\}}t |r||fqS r   )rc   r9   r:   rU   r   r   r   r=      r^   z-prune_intermediate_layers.<locals>.<listcomp>rX   c                 S   s   g | ]\}}|qS r   r   )r9   r:   r6   r   r   r   r=      r^   r   c                    s   g | ]\}}| vr|qS r   r   rd   Z	to_remover   r   r=      r^   )rP   rC   lenprune_intermediate_layers)rT   Zlayer_blocksrU   r   rf   r   rh      s    
rh   c              
      s<  | rdz&t j| dd t j| |jd }W qn ty` } ztd|  d|W Y d }~qnd }~0 0 n
|jd }td| d |d }|d	 }t	|j
 t|d
"}tj|j
|dd W d    n1 s0    Y   fdd tt|j
} | t|d
 }tj||dd W d    n1 s.0    Y  d S )NTexist_okZ_debug_tree"Unexpected or existing debug_path=rY   zWriting model trace at z.jsonz_FULL_TENSORS.jsonz_SUMMARY.jsonwrX   )indentc                    sJ    fdd  |  di   |  di  |  dg D ]}| q8d S )Nc                    sN   t | tr.| dd  |  D ]} | qnt | trJ| D ]} | q<d S )Nr"   )r@   rD   rQ   valuesrA   )valr;   itemcleanr   r   rr      s    

z:log_model_debug_trace.<locals>.strip_values.<locals>.cleaninputsrO   rN   rZ   rS   strip_valuesrq   r   ru      s
    	z+log_model_debug_trace.<locals>.strip_values)r*   makedirsr+   r,   _debugger_module_dump_name	Exceptionr0   loggerinforR   
_call_treeopenjsondumploadsdumps)r   modelbasee	full_pathZsummary_pathfZ	tree_copyr   rt   r   log_model_debug_trace   s$    *

0r   rY   )r   do_prune_layersr   c           	   
      s   j j  ddg d_g _ _rrztjdd W n6 typ } ztd d|W Y d}~n
d}~0 0 fdd}	 D ]&\}}|d	krq||  d|  qj
t fd
d}|_
dS )a  
    Attaches a debugging wrapper to every module in the model.

    This records structured inputs and outputs during the forward pass into a call tree.

    Args:
        model (`PreTrainedModel`, `nn.Module`): Model to wrap.
        debug_path (`str`): Optional directory to dump debug JSON files.
        do_prune_layers (`bool`, *optional*, defaults to `True`): Whether to prune intermediate layers.
        use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensors as the
            `value` property in the asscoiated FULL_TENSORS.json file, or to store full tensors in separate SafeTensors
            files and store the relative path to that file in the `value` property.
    NrV   rs   rO   rN   Tri   rk   rY   c                    s0   j t fdd}|_ d S )Nc                     s  t  rN| |d  fdd D  t  ddd g d}j| t  | i |}W d    n1 sz0    Y  t  rtdd  D d	krd |d
< nt| dd|d
< j }|d s|d jrjd d | |S )Nargskwargsc                    s&   i | ]}t  | d kr| | qS )r   )rg   r9   r>   Zdict_inputsr   r   r?   5  r^   zY_attach_debugger_logic.<locals>.wrap_forward.<locals>.wrapped_forward.<locals>.<dictcomp>_inputsr   r   c                 s   s   | ]
}d V  qdS )r   Nr   )r9   r6   r   r   r   r]   F  r^   zX_attach_debugger_logic.<locals>.wrap_forward.<locals>.wrapped_forward.<locals>.<genexpr>r   rO   _outputsrN   re   )	r   r8   _debugger_model_call_stackappendr   Zno_gradsumZnamed_childrenrQ   )inpskwsrT   r3   finished)r   r   r   moduleorig_forwardr   r   r   wrapped_forward1  s>    

,



zE_attach_debugger_logic.<locals>.wrap_forward.<locals>.wrapped_forward)forward	functoolswraps)r   r   r   )r   r   r   )r   r   r   r   wrap_forward.  s    'z,_attach_debugger_logic.<locals>.wrap_forwardrW   c                     s   t  r>  dt| |d  ddd g d}j| | i |}t  r܈jrt|  dd|d< j }|d jd< |d jd< |d	 jd	< fd
dtj D  rtj t	d |S )Nz (top-level)r   r   r   r   r   rO   rs   rN   c                    s$   g | ]} j | s j |d qS )N)r{   rQ   r   )r   r   r   r=     r^   zG_attach_debugger_logic.<locals>.top_wrapped_forward.<locals>.<listcomp>)r   r   )
r   r8   r   r   rQ   r{   rA   keysrh   r   )r   r   Ztop_noder3   r   
class_namer   r   r   Zreal_top_forwardr   r   r   top_wrapped_forwardd  s:    


z3_attach_debugger_logic.<locals>.top_wrapped_forward)	__class____name__r{   r   rw   r*   rv   rx   r0   named_modulesr   r   r   )	r   r   r   r   r   r   name	submoduler   r   r   r   _attach_debugger_logic  s$    (.%r   )r   )backendsc              	   c   sl   dd |   D }| j|| < t| ||| z"| V  W | D ]\}}||_q<n| D ]\}}||_qV0 dS )a  
    # Model addition debugger - context manager for model adders
    This context manager is a power user tool intended for model adders.

    It tracks all forward calls within a model forward and logs a slice of each input and output on a nested JSON file.
    If `use_repr=True` (the default), the JSON file will record a `repr()`-ized version of the tensors as a list of
    strings. If `use_repr=False`, the full tensors will be stored in separate SafeTensors files and the JSON file will
    provide a relative path to that file.

    To note, this context manager enforces `torch.no_grad()`.

    ## Usage

    add the context manager to a model to debug

    ```python
    import torch

    from PIL import Image
    from transformers import LlavaProcessor, LlavaForConditionalGeneration, model_addition_debugger_context

    torch.random.manual_seed(673)

    # load pretrained model and processor
    model_id = "llava-hf/llava-1.5-7b-hf"
    processor = LlavaProcessor.from_pretrained(model_id)
    model = LlavaForConditionalGeneration.from_pretrained(model_id)

    # create random image input
    random_image = Image.fromarray(torch.randint(0, 256, (224, 224, 3), dtype=torch.uint8).numpy())

    # prompt
    prompt = "<image>Describe this image."

    # process inputs
    inputs = processor(text=prompt, images=random_image, return_tensors="pt")

    # call forward method (not .generate!)
    with model_addition_debugger_context(model, debug_path="Your_debug_path", do_prune_layers=False):
        output = model.forward(**inputs)
    ```

    c                 S   s   i | ]\}}||j qS r   )r   )r9   r6   mr   r   r   r?     r^   z3model_addition_debugger_context.<locals>.<dictcomp>N)r   r   r   rE   )r   r   r   r   Zorig_forwardsZmodule_instanceZforward_methodr   r   r   model_addition_debugger_context  s    3

r   )NTN)NTN)rY   TT)NTT),r   r}   r*   re
contextlibr   r   ior   typingr   Zutils.import_utilsr   r   r   Zsafetensors.torchr	   r   Zis_availableZtorch.distributed.tensorr   utilsr
   Z
get_loggerr   ry   r   compiler   strr   r   boolr4   r8   rG   r(   rR   r_   rc   rh   r   r   r   r   r   r   r   <module>   sb   


 1+	
-      