a
    hYl                     @   s  d dl mZ d dlmZ d dlmZmZ d dlmZm	Z	m
Z
mZ d dlZddlmZmZmZ ddlmZmZmZmZmZ dd	lmZmZmZmZmZmZmZm Z m!Z!m"Z"m#Z# dd
l$m%Z% ddl&m'Z'm(Z(m)Z)m*Z*m+Z+m,Z,m-Z- ddl.m/Z/ e, rddlm0Z0 e) r
d dl1Z1e* rBddlm2Z2 e+ r4d dl3m4Z5 nd dl6m4Z5 ndZ2e-7e8Z9edddddddddddddddej:fe	e; e	e< e	e; e	ee<e=e< f  e	ee<e=e< f  e	e; e	e> e	e; e	e e	e; e	e e	d e	ee?e'f  e	e dddZ@d)de	e> ddddZAee e=e dddZBe=d eCe> dddZDeejEdf e>e=eejEdf  d d!d"ZFG d#d$ d$e
d%d&ZGe(G d'd( d(eZHdS )*    )Iterable)deepcopy)	lru_cachepartial)AnyOptional	TypedDictUnionN   )BaseImageProcessorBatchFeatureget_size_dict)convert_to_rgbget_resize_output_image_sizeget_size_with_aspect_ratiogroup_images_by_shapereorder_images)ChannelDimension
ImageInput	ImageTypeSizeDictget_image_size#get_image_size_for_max_height_widthget_image_typeinfer_channel_dimension_formatmake_flat_list_of_imagesvalidate_kwargsvalidate_preprocess_arguments)Unpack)
TensorTypeauto_docstringis_torch_availableis_torchvision_availableis_torchvision_v2_availableis_vision_availablelogging)is_rocm_platform)PILImageResampling)pil_torch_interpolation_mapping)
functional
   maxsizeF.InterpolationMode
do_rescalerescale_factordo_normalize
image_mean	image_stddo_padsize_divisibilitydo_center_crop	crop_size	do_resizesizeinterpolationreturn_tensorsdata_formatc                 C   sN   t | |||||||||	|
|d |dur8|dkr8td|tjkrJtddS )z
    Checks validity of typically used arguments in an `ImageProcessorFast` `preprocess` method.
    Raises `ValueError` if arguments incompatibility is caught.
    )r/   r0   r1   r2   r3   r4   r5   r6   r7   r8   r9   r:   Nptz6Only returning PyTorch tensors is currently supported.z6Only channel first data format is currently supported.)r   
ValueErrorr   FIRSTr.    r@   d/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/image_processing_utils_fast.py"validate_fast_preprocess_argumentsK   s$    
rB   torch.Tensor)tensoraxisreturnc                 C   s:   |du r|   S z| j |dW S  ty4   |  Y S 0 dS )zF
    Squeezes a tensor, but only if the axis specified has dim 1.
    N)rE   )Zsqueezer>   )rD   rE   r@   r@   rA   safe_squeezev   s    rG   )valuesrF   c                 C   s   dd t |  D S )zO
    Return the maximum value across all indices of an iterable of values.
    c                 S   s   g | ]}t |qS r@   )max).0Zvalues_ir@   r@   rA   
<listcomp>       z&max_across_indices.<locals>.<listcomp>)zip)rH   r@   r@   rA   max_across_indices   s    rN   )imagesrF   c                 C   s    t dd | D \}}}||fS )zH
    Get the maximum height and width across all images in a batch.
    c                 S   s   g | ]
}|j qS r@   )shaperJ   Zimgr@   r@   rA   rK      rL   z(get_max_height_width.<locals>.<listcomp>)rN   )rO   _
max_height	max_widthr@   r@   rA   get_max_height_width   s    rU   )image
patch_sizerF   c                 C   sj   g }t | tjd\}}td||D ]B}td||D ]0}| dd||| ||| f }|| q2q"|S )a6  
    Divides an image into patches of a specified size.

    Args:
        image (`Union[np.array, "torch.Tensor"]`):
            The input image.
        patch_size (`int`):
            The size of each patch.
    Returns:
        list: A list of Union[np.array, "torch.Tensor"] representing the patches.
    )Zchannel_dimr   N)r   r   r?   rangeappend)rV   rW   Zpatchesheightwidthijpatchr@   r@   rA   divide_to_patches   s    "r_   c                   @   s  e Zd ZU ee ed< eeeef  ed< ee ed< ee	d  ed< ee ed< eeeef  ed< ee ed< ee	ee
f  ed	< ee ed
< ee	e
ee
 f  ed< ee	e
ee
 f  ed< ee ed< ee	eef  ed< ee ed< ee	eef  ed< ed ed< ee ed< dS )DefaultFastImageProcessorKwargsr8   r9   default_to_square)r'   r-   resampler6   r7   r/   r0   r1   r2   r3   do_convert_rgbr;   r<   input_data_formattorch.devicedevicedisable_groupingN)__name__
__module____qualname__r   bool__annotations__dictstrintr	   floatlistr   r   r@   r@   r@   rA   r`      s"   
r`   F)totalc                       s  e Zd ZdZdZdZdZdZdZdZ	dZ
dZdZdZdZdZejZdZdZdgZeZdZee d fddZeedd	d
ZdHdededdddZe dIde!e"e"f e#d eddddZ$de%ddddZ&de'e%e(e% f e'e%e(e% f ddddZ)e*dddJe#e e#e'e%e+e% f  e#e'e%e+e% f  e#e e#e% e#d e!dddZ,dee%ee'e%e+e% f e'e%e+e% f ddd d!Z-de.e/e"f dd"d#d$Z0e1e1d%d&d'Z2e.dd(d)Z3dKe1e"e1d+d,d-Z4dLe1e#e e#e'e/ef  e#d dd.d/d0Z5dMe1e#e e#e'e/ef  e#d e"e+d d1d2d3Z6dNe#e e#e e#e e#e'e%e+e% f  e#e'e%e+e% f  e#e e.d4d5d6Z7dOe#e e#e% e#e e#e'e%e!e% f  e#e'e%e!e% f  e#e e#e e#e e#e e#d e#e'e/e8f  e#e d7d8d9Z9e1ee e:d:d;d<Z;e<e1ee e:d:d=d>Z=dd?e1eee#e'e/df  ee e:d@dAdBZ>e+d eee#d eeee%ee#e'e%e+e% f  e#e'e%e+e% f  e#e e#e'e/e8f  e:dCdDdEZ? fdFdGZ@  ZAS )PBaseImageProcessorFastNTgp?pixel_values)kwargsc              	      s   t  jf i | | |}|d| j}|d urHt||d| jdnd | _|d| j}|d urpt|ddnd | _| jj	D ]>}||d }|d urt
| || q~t
| |tt| |d  q~t| jj	 | _d S )Nr9   ra   r9   ra   r7   
param_name)super__init__filter_out_unused_kwargspopr9   r   ra   r7   valid_kwargsrl   setattrr   getattrrq   keys_valid_kwargs_names)selfru   r9   r7   keykwarg	__class__r@   rA   rz      s    
zBaseImageProcessorFast.__init__)rF   c                 C   s   dS )zv
        `bool`: Whether or not this image processor is a fast processor (backed by PyTorch and TorchVision).
        Tr@   )r   r@   r@   rA   is_fast   s    zBaseImageProcessorFast.is_fastrC   r-   )rV   r9   r:   	antialiasrF   c                 K   s   |dur|nt jj}|jr>|jr>t| dd |j|j}np|jrZt||jdtj	d}nT|j
r|jrt| dd |j
|j}n*|jr|jr|j|jf}ntd| dtj rt r| ||||S t j||||dS )a@  
        Resize an image to `(size["height"], size["width"])`.

        Args:
            image (`torch.Tensor`):
                Image to resize.
            size (`SizeDict`):
                Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
            interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
                `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.

        Returns:
            `torch.Tensor`: The resized image.
        NF)r9   ra   rd   zjSize must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got .r:   r   )FZInterpolationModeZBILINEARZshortest_edgeZlongest_edger   r9   r   r   r?   rS   rT   r   rZ   r[   r>   torchcompilerZis_compilingr&   compile_friendly_resizeresize)r   rV   r9   r:   r   ru   new_sizer@   r@   rA   r      s4    zBaseImageProcessorFast.resize)rV   r   r:   r   rF   c                 C   s~   | j tjkrh|  d } tj| |||d} | d } t| dkd| } t| dk d| } |  tj} ntj| |||d} | S )z{
        A wrapper around `F.resize` so that it is compatible with torch.compile when the image is a uint8 tensor.
           r   r   )	dtyper   Zuint8rp   r   r   whereroundto)rV   r   r:   r   r@   r@   rA   r   (  s    
z.BaseImageProcessorFast.compile_friendly_resize)rV   scalerF   c                 K   s   || S )a?  
        Rescale an image by a scale factor. image = image * scale.

        Args:
            image (`torch.Tensor`):
                Image to rescale.
            scale (`float`):
                The scaling factor to rescale pixel values by.

        Returns:
            `torch.Tensor`: The rescaled image.
        r@   )r   rV   r   ru   r@   r@   rA   rescale=  s    zBaseImageProcessorFast.rescale)rV   meanstdrF   c                 K   s   t |||S )a  
        Normalize an image. image = (image - image_mean) / image_std.

        Args:
            image (`torch.Tensor`):
                Image to normalize.
            mean (`torch.Tensor`, `float` or `Iterable[float]`):
                Image mean to use for normalization.
            std (`torch.Tensor`, `float` or `Iterable[float]`):
                Image standard deviation to use for normalization.

        Returns:
            `torch.Tensor`: The normalized image.
        )r   	normalize)r   rV   r   r   ru   r@   r@   rA   r   Q  s    z BaseImageProcessorFast.normalizer*   r+   re   )r1   r2   r3   r/   r0   rf   rF   c                 C   sB   |r8|r8t j||dd|  }t j||dd|  }d}|||fS )Nrf   g      ?F)r   rD   )r   r1   r2   r3   r/   r0   rf   r@   r@   rA   !_fuse_mean_std_and_rescale_factorh  s
    
z8BaseImageProcessorFast._fuse_mean_std_and_rescale_factor)rO   r/   r0   r1   r2   r3   rF   c                 C   sP   | j ||||||jd\}}}|r<| |jtjd||}n|rL| ||}|S )z/
        Rescale and normalize images.
        )r1   r2   r3   r/   r0   rf   )r   )r   rf   r   r   r   Zfloat32r   )r   rO   r/   r0   r1   r2   r3   r@   r@   rA   rescale_and_normalizey  s    	z,BaseImageProcessorFast.rescale_and_normalize)rV   r9   rF   c                 K   s>   |j du s|jdu r&td|  t||d |d fS )a  
        Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
        any edge, the image is padded with 0's and then center cropped.

        Args:
            image (`"torch.Tensor"`):
                Image to center crop.
            size (`dict[str, int]`):
                Size of the output image.

        Returns:
            `torch.Tensor`: The center cropped image.
        Nz=The size dictionary must have keys 'height' and 'width'. Got rZ   r[   )rZ   r[   r>   r   r   center_crop)r   rV   r9   ru   r@   r@   rA   r     s    z"BaseImageProcessorFast.center_crop)rV   rF   c                 C   s   t |S )a'  
        Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
        as is.
        Args:
            image (ImageInput):
                The image to convert.

        Returns:
            ImageInput: The converted image.
        )r   )r   rV   r@   r@   rA   r     s    z%BaseImageProcessorFast.convert_to_rgbc                 C   sB   | j du r|S | j D ](}||v rtd| d || q|S )zJ
        Filter out the unused kwargs from the kwargs dictionary.
        Nz!This processor does not use the `z ` parameter. It will be ignored.)unused_kwargsloggerZwarning_oncer|   )r   ru   
kwarg_namer@   r@   rA   r{     s    

z/BaseImageProcessorFast.filter_out_unused_kwargs   )rO   expected_ndimsrF   c                 C   s   |  |}t||dS )z
        Prepare the images structure for processing.

        Args:
            images (`ImageInput`):
                The input images to process.

        Returns:
            `ImageInput`: The images with a valid nesting.
        r   )Zfetch_imagesr   )r   rO   r   r@   r@   rA   _prepare_images_structure  s    
z0BaseImageProcessorFast._prepare_images_structure)rV   rc   rd   rf   rF   c                 C   s   t |}|tjtjtjfvr*td| |r8| |}|tjkrNt|}n|tjkrft	
| }|jdkrz|d}|d u rt|}|tjkr|ddd }|d ur||}|S )NzUnsupported input image type    r   r
   )r   r   ZPILZTORCHZNUMPYr>   r   r   Zpil_to_tensorr   Z
from_numpy
contiguousndimZ	unsqueezer   r   ZLASTZpermuter   )r   rV   rc   rd   rf   Z
image_typer@   r@   rA   _process_image  s$    






z%BaseImageProcessorFast._process_image)rO   rc   rd   rf   r   rF   c                    sl   | j ||d}t| j|||d t|dko<t|d ttf}|rV fdd|D }n fdd|D }|S )a  
        Prepare image-like inputs for processing.

        Args:
            images (`ImageInput`):
                The image-like inputs to process.
            do_convert_rgb (`bool`, *optional*):
                Whether to convert the images to RGB.
            input_data_format (`str` or `ChannelDimension`, *optional*):
                The input data format of the images.
            device (`torch.device`, *optional*):
                The device to put the processed images on.
            expected_ndims (`int`, *optional*):
                The expected number of dimensions for the images. (can be 2 for segmentation maps etc.)

        Returns:
            List[`torch.Tensor`]: The processed images.
        r   rc   rd   rf   r   c                    s   g | ]} fd d|D qS )c                    s   g | ]} |qS r@   r@   rQ   Zprocess_image_partialr@   rA   rK   (  rL   zPBaseImageProcessorFast._prepare_image_like_inputs.<locals>.<listcomp>.<listcomp>r@   )rJ   Znested_listr   r@   rA   rK   (  rL   zEBaseImageProcessorFast._prepare_image_like_inputs.<locals>.<listcomp>c                    s   g | ]} |qS r@   r@   rQ   r   r@   rA   rK   *  rL   )r   r   r   len
isinstancerq   tuple)r   rO   rc   rd   rf   r   Zhas_nested_structureprocessed_imagesr@   r   rA   _prepare_image_like_inputs  s    
z1BaseImageProcessorFast._prepare_image_like_inputs)r9   r7   ra   r2   r3   r<   rF   c           	      K   s   |du ri }|dur*t f i t||d}|durHt f i t|dd}t|trZt|}t|trlt|}|du rztj}||d< ||d< ||d< ||d< ||d< |d	}t|tt	frt
| n||d
< |S )z
        Update kwargs that need further processing before being validated
        Can be overridden by subclasses to customize the processing of kwargs.
        Nrv   r7   rw   r9   r2   r3   r<   rb   r:   )r   r   r   rq   r   r   r?   r|   r'   ro   r(   )	r   r9   r7   ra   r2   r3   r<   ru   rb   r@   r@   rA   _further_process_kwargs.  s*    


z.BaseImageProcessorFast._further_process_kwargsr/   r0   r1   r2   r3   r8   r9   r6   r7   r:   r;   r<   c                 K   s$   t |||||||||	|
||d dS )z@
        validate the kwargs for the preprocess method.
        r   N)rB   )r   r/   r0   r1   r2   r3   r8   r9   r6   r7   r:   r;   r<   ru   r@   r@   rA   _validate_preprocess_kwargsZ  s    z2BaseImageProcessorFast._validate_preprocess_kwargs)rO   ru   rF   c                 O   s   | j |g|R i |S )N)
preprocess)r   rO   argsru   r@   r@   rA   __call__|  s    zBaseImageProcessorFast.__call__c                 O   s   t | | jd | jD ]}||t| |d  q|d}|d}|d}| jf i |}| jf i | |d | j|g|R |||d|S )N)Zcaptured_kwargsZvalid_processor_keysrc   rd   rf   r<   r   )	r   r   r   
setdefaultr   r|   r   r   _preprocess_image_like_inputs)r   rO   r   ru   r   rc   rd   rf   r@   r@   rA   r     s$    




z!BaseImageProcessorFast.preprocessr   )rO   rc   rd   rf   ru   rF   c                O   s*   | j ||||d}| j|g|R i |S )z
        Preprocess image-like inputs.
        To be overridden by subclasses when image-like inputs other than images should be processed.
        It can be used for segmentation maps, depth maps, etc.
        )rO   rc   rd   rf   )r   _preprocess)r   rO   rc   rd   rf   r   ru   r@   r@   rA   r     s    z4BaseImageProcessorFast._preprocess_image_like_inputs)rO   r8   r9   r:   r6   r7   r/   r0   r1   r2   r3   rg   r;   rF   c              	   K   s   t ||d\}}i }| D ]$\}}|r8| j|||d}|||< qt||}t ||d\}}i }| D ]4\}}|r| ||}| ||||	|
|}|||< qht||}|rtj|ddn|}td|i|dS )N)rg   )rV   r9   r:   r   )dimrt   )dataZtensor_type)	r   itemsr   r   r   r   r   stackr   )r   rO   r8   r9   r:   r6   r7   r/   r0   r1   r2   r3   rg   r;   ru   Zgrouped_imagesZgrouped_images_indexZresized_images_groupedrP   Zstacked_imagesZresized_imagesZprocessed_images_groupedr   r@   r@   rA   r     s&    



z"BaseImageProcessorFast._preprocessc                    s&   t   }|dd  |dd  |S )NZ_valid_processor_keysr   )ry   to_dictr|   )r   Zencoder_dictr   r@   rA   r     s    
zBaseImageProcessorFast.to_dict)NT)NT)NNNNNN)r   )NNN)NNNr   )NNNNNN)NNNNNNNNNNNN)Brh   ri   rj   rb   r2   r3   r9   ra   r7   r8   r6   r/   r0   r1   rc   r;   r   r?   r<   rd   rf   Zmodel_input_namesr`   r}   r   r   rz   propertyrk   r   r   r   staticmethodr   ro   r   r   rp   r   r	   r   r   r   rq   r   r   rm   rn   r   r   r   r{   r   r   r   r   r   r   r   r   r    r   r   r   r   __classcell__r@   r@   r   rA   rs      sf  
  6  
      
    )    .      .            " ,rs   )N)Icollections.abcr   copyr   	functoolsr   r   typingr   r   r   r	   numpynpZimage_processing_utilsr   r   r   Zimage_transformsr   r   r   r   r   Zimage_utilsr   r   r   r   r   r   r   r   r   r   r   Zprocessing_utilsr   utilsr   r    r!   r"   r#   r$   r%   Zutils.import_utilsr&   r'   r   r(   Ztorchvision.transforms.v2r)   r   Ztorchvision.transformsZ
get_loggerrh   r   r?   rk   rp   rq   ro   rn   rB   rG   rN   r   rU   arrayr_   r`   rs   r@   r@   r@   rA   <module>   sz   4$	
*