a
    hY                     @   s   d Z ddlmZ ddlmZmZ ddlZddlm	Z	 ddl
mZ ddlmZ dd	lmZmZmZ dd
lmZ eeZe rddlZeddG dd deZdgZdS )z
Processor class for SAM2.
    deepcopy)OptionalUnionN   )
ImageInput)ProcessorMixin)BatchEncoding)
TensorTypeis_torch_availablelogging)requires)torch)backendsc                       sL  e Zd ZdZdgZdZd%ee ed fddZd&e	e	ee
eeeee    ejf  ee
eeee   ejf  ee
eeee   ejf  ee
eee  ejf  ee
eef  ed	d
dZd'edddddZd(ddZd)ddZd*ddZdd Zdd Zd+e
ejejef eeeee edddZd,dd Zd-d#d$Z  ZS ).Sam2Processora  
    Constructs a SAM2 processor which wraps a SAM2 image processor and an 2D points & Bounding boxes processor into a
    single processor.

    [`Sam2Processor`] offers all the functionalities of [`Sam2ImageProcessorFast`] and [`Sam2VideoProcessor`]. See the docstring of
    [`~Sam2ImageProcessorFast.__call__`] and [`~Sam2VideoProcessor.__call__`] for more information.

    Args:
        image_processor (`Sam2ImageProcessorFast`):
            An instance of [`Sam2ImageProcessorFast`].
        target_size (`int`, *optional*):
            The target size (target_size, target_size) to which the image will be resized.
        point_pad_value (`int`, *optional*, defaults to -10):
            The value used for padding input points.
    image_processorZSam2ImageProcessorFastN)target_sizepoint_pad_valuec                    s8   t  j|fi | || _|d ur&|n
| jjd | _d S )Nheight)super__init__r   r   sizer   )selfr   r   r   kwargs	__class__ d/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/sam2/processing_sam2.pyr   :   s    zSam2Processor.__init__)imagessegmentation_mapsinput_pointsinput_labelsinput_boxesoriginal_sizesreturn_tensorsreturnc                    sN  |dur"| j |f||d|}	n:|durTt|tjrB|  }td|i|d}	ntd|	d }|durt|dkrt|t|krtd|dus|dus|durJ| j	|dd	d
dd}
| j	|dddd}| j	|ddddd}|
dur| 
|
dd }|dur| 
|dd }|dur6| 
|dd  |
dur\|dur\||kr\td|durt|dkrt fdd|D rtd|
dur| |
|dg }tj|tjd}| j||dd |	d|i |dur| ||}tj|tjd}|	d|i |durJtj|tjd}| j||dd |	d|i |	S )a  
        This method uses [`Sam2ImageProcessorFast.__call__`] method to prepare image(s) for the model. It also prepares 2D
        points and bounding boxes for the model if they are provided.

        Args:
            images (`ImageInput`, *optional*):
                The image(s) to process.
            segmentation_maps (`ImageInput`, *optional*):
                The segmentation maps to process.
            input_points (`list[list[list[list[float]]]]`, `torch.Tensor`, *optional*):
                The points to add to the frame.
            input_labels (`list[list[list[int]]]`, `torch.Tensor`, *optional*):
                The labels for the points.
            input_boxes (`list[list[list[float]]]`, `torch.Tensor`, *optional*):
                The bounding boxes to add to the frame.
            original_sizes (`list[list[float]]`, `torch.Tensor`, *optional*):
                The original sizes of the images.
            return_tensors (`str` or `TensorType`, *optional*):
                The type of tensors to return.
            **kwargs:
                Additional keyword arguments to pass to the image processor.

        Returns:
            A [`BatchEncoding`] with the following fields:
            - `pixel_values` (`torch.Tensor`): The processed image(s).
            - `original_sizes` (`list[list[float]]`): The original sizes of the images.
            - `reshaped_input_sizes` (`torch.Tensor`): The reshaped input sizes of the images.
            - `labels` (`torch.Tensor`): The processed segmentation maps (if provided).
            - `input_points` (`torch.Tensor`): The processed points.
            - `input_labels` (`torch.Tensor`): The processed labels.
            - `input_boxes` (`torch.Tensor`): The processed bounding boxes.
        N)r    r%   r$   )Ztensor_typez0Either images or original_sizes must be provided   z{original_sizes must be of length 1 or len(images). If you are passing a single image, you must pass a single original_size.   Zpointsz;[image level, object level, point level, point coordinates]   )expected_depth
input_nameexpected_formatexpected_coord_sizer   labelsz([image level, object level, point level])r*   r+   r,   Zboxesz)[image level, box level, box coordinates]zbInput points and labels have inconsistent dimensions. Please ensure they have the same dimensions.c                 3   s   | ]}t | d  k V  qdS r'   N)len).0Z	img_boxesZboxes_max_dimsr   r   	<genexpr>       z)Sam2Processor.__call__.<locals>.<genexpr>zInput boxes have inconsistent dimensions that would require padding, but boxes cannot be padded due to model limitations. Please ensure all images have the same number of boxes.)ZdtypeT)preserve_paddingr!   r"   is_bounding_boxr#   )r   
isinstancer   Tensorcputolistr	   
ValueErrorr0   _validate_single_input_get_nested_dimensionsany_pad_nested_listtensorZfloat32_normalize_tensor_coordinatesupdateZint64)r   r   r    r!   r"   r#   r$   r%   r   Zencoding_image_processorZprocessed_pointsZprocessed_labelsZprocessed_boxesZpoints_max_dimsZlabels_max_dimsZpadded_pointsZfinal_pointsZpadded_labelsZfinal_labelsZfinal_boxesr   r2   r   __call__?   s    +$	





zSam2Processor.__call__Fztorch.Tensor)r   coordsr&   c           	      C   sl   |\}}|| }}t | }|r0|ddd}|d ||  |d< |d ||  |d< |rh|dd}|S )a  
        Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format.

        Args:
            target_size (`int`):
                The target size of the image.
            coords (`torch.Tensor`):
                The coordinates to be normalized.
            original_size (`tuple`):
                The original size of the image.
            is_bounding_box (`bool`, *optional*, defaults to `False`):
                Whether the coordinates are bounding boxes.
        r)   ).r   ).r'   r(   )r   floatZreshape)	r   r   rE   original_sizer7   Zold_hZold_wZnew_hZnew_wr   r   r   _normalize_coordinates   s    
z$Sam2Processor._normalize_coordinatesr   c                    s   |du rdS t |tjrV d ks2t|jdkr>|  S  fdd|D S nt |tjr d ks|t|jdkr| S  fdd|D S nRt |t	rʈ kr|S  fdd|D S n$t |t
tfr|S tdt| dS )aS  
        Recursively convert various input formats (tensors, numpy arrays, lists) to nested lists.

        Args:
            data: Input data in any format
            expected_depth: Expected nesting depth
            current_depth: Current depth in recursion

        Returns:
            Nested list representation of the data
        Nr)   c                    s   g | ]} | d  qS r'   _convert_to_nested_listr1   itemcurrent_depthr*   r   r   r   
<listcomp>   r4   z9Sam2Processor._convert_to_nested_list.<locals>.<listcomp>c                    s   g | ]} | d  qS rJ   rK   rM   rO   r   r   rQ      r4   c                    s   g | ]} | d  qS rJ   rK   rM   rO   r   r   rQ     r4   zUnsupported data type: )r8   r   r9   r0   shapenumpyr;   npndarraylistintrG   r<   type)r   datar*   rP   r   rO   r   rL      s"    
z%Sam2Processor._convert_to_nested_listc                 C   s   |du rg }t |ts|S t|dkr6|t| nt|d t||d< t|dkr|D ]`}t |tr\| |}t|D ]>\}}|d t|kr|| q|t||d  |||d < q|q\|S )a`  
        Get the maximum dimensions at each level of nesting.

        Args:
            nested_list (`list`):
                Nested list structure.
            max_dims (`list`, *optional*):
                Current maximum dimensions (for recursion).

        Returns:
            `list`: A list of maximum dimensions for each nesting level.
        Nr   r'   )r8   rV   r0   appendmaxr>   	enumerate)r   nested_listZmax_dimsrN   Zsub_dimsidimr   r   r   r>     s     


z$Sam2Processor._get_nested_dimensionsc           	         s\  |du r| j }|t|kr|S t|ts.|g}t|}|| }|t|d krd||g||   n|dkr|t|d k r||d d }| || n|g||d    | fddt|| D  n8||d d }| || | fddt|D  |t|d k rXtt|D ]2}t|| tr$| || ||d |||< q$|S )a  
        Recursively pad a nested list to match target dimensions.

        Args:
            nested_list (`list`):
                Nested list to pad.
            target_dims (`list`):
                Target dimensions for each level.
            current_level (`int`, *optional*, defaults to 0):
                Current nesting level.
            pad_value (`int`, *optional*):
                Value to use for padding.

        Returns:
            `list`: The padded nested list.
        Nr'   r   r)   c                    s   g | ]}t  qS r   r   r1   _templater   r   rQ   [  r4   z2Sam2Processor._pad_nested_list.<locals>.<listcomp>c                    s   g | ]}t  qS r   r   r`   rb   r   r   rQ   `  r4   )r   r0   r8   rV   extend_create_empty_nested_structureranger@   )	r   r]   Ztarget_dimsZcurrent_level	pad_valueZcurrent_sizer   Ztemplate_dimsr^   r   rb   r   r@   ,  s0    
" zSam2Processor._pad_nested_listc                    s<   t  dkrg d  S  fddt d D S dS )a  
        Create an empty nested structure with given dimensions filled with pad_value.

        Args:
            dims (`list`):
                The dimensions of the nested structure.
            pad_value (`int`):
                The value to fill the structure with.
        r'   r   c                    s    g | ]}  d d qS r/   )re   r`   dimsrg   r   r   r   rQ   w  r4   z@Sam2Processor._create_empty_nested_structure.<locals>.<listcomp>N)r0   rf   )r   ri   rg   r   rh   r   re   j  s    
z,Sam2Processor._create_empty_nested_structurec                 C   sL   t |tr,t|dkrdS d| |d  S t |tjtjfrHt|jS dS )z
        Get the nesting level of a list structure.

        Args:
            input_list (`list`):
                The list to get the nesting level of.
        r   r'   )	r8   rV   r0   _get_nesting_levelrT   rU   r   r9   rR   )r   Z
input_listr   r   r   rj   y  s    

z Sam2Processor._get_nesting_level)rY   r*   r+   r,   r-   r&   c              
   C   s   |du rdS t |tjtjfr|j|krNtd| d| d| d|j d	n8|dur|jd |krtd| d| d|jd  d	| ||S t |t	r| 
|}||krtd| d
| d| d| d	| ||S dS )a  
                Validate a single input by ensuring proper nesting and raising an error if the input is not valid.

                Args:
                    data (`torch.Tensor`, `np.ndarray`, or `list`):
                        Input data to process.
                    expected_depth (`int`):
                        Expected nesting depth.
                    input_name (`str`):
                        Name of the input for error messages.
                    expected_format (`str`):
                        The expected format of the input.
                    expected_coord_size (`int`, *optional*):
                        Expected coordinate size (2 for points, 4 for boxes, None for labels).
        .
        NzInput z must be a tensor/array with z, dimensions. The expected nesting format is z. Got z dimensions.rF   z as the last dimension, got .z must be a nested list with z( levels. The expected nesting format is z levels.)r8   r   r9   rT   rU   ndimr<   rR   rL   rV   rj   )r   rY   r*   r+   r,   r-   rP   r   r   r   r=     s(    


z$Sam2Processor._validate_single_inputc                 C   s   |r|| j k}|jddd}tt|D ]|}||jd k r(|t|k rN|| n|d }| j| j|| ||d}	|r|| }
t|
	|| |	|| ||< q(|	||< q(dS )a  
        Helper method to normalize coordinates in a tensor across multiple images.

        Args:
            tensor (`torch.Tensor`):
                Input tensor with coordinates.
            original_sizes (`list`):
                Original image sizes.
            is_bounding_box (`bool`, *optional*, defaults to `False`):
                Whether coordinates are bounding boxes.
            preserve_padding (`bool`, *optional*, defaults to `False`):
                Whether to preserve padding values (for points).
        rF   T)r_   Zkeepdimr   r6   N)
r   allrf   r0   rR   rI   r   r   whereZ	expand_as)r   rA   r$   r7   r5   maskZ
coord_maskZimg_idxrH   Znormalized_coordsZimg_maskr   r   r   rB     s    

z+Sam2Processor._normalize_tensor_coordinates        Tc           	      K   s    | j j|||||||fi |S )a-  
        Remove padding and upscale masks to the original image size.

        Args:
            masks (`Union[List[torch.Tensor], List[np.ndarray]]`):
                Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
            original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
                The original sizes of each image before it was resized to the model's expected input shape, in (height,
                width) format.
            mask_threshold (`float`, *optional*, defaults to 0.0):
                Threshold for binarization and post-processing operations.
            binarize (`bool`, *optional*, defaults to `True`):
                Whether to binarize the masks.
            max_hole_area (`float`, *optional*, defaults to 0.0):
                The maximum area of a hole to fill.
            max_sprinkle_area (`float`, *optional*, defaults to 0.0):
                The maximum area of a sprinkle to fill.
            apply_non_overlapping_constraints (`bool`, *optional*, defaults to `False`):
                Whether to apply non-overlapping constraints to the masks.

        Returns:
            (`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
            is given by original_size.
        )r   post_process_masks)	r   Zmasksr$   Zmask_thresholdZbinarizeZmax_hole_areaZmax_sprinkle_areaZ!apply_non_overlapping_constraintsr   r   r   r   rq     s    #z Sam2Processor.post_process_masks)Nr   )NNNNNNN)F)r   )N)r   N)N)FF)rp   Trp   rp   F)__name__
__module____qualname____doc__
attributesZimage_processor_classr   rW   r   r   r   rV   rG   r   r9   strr
   r	   rD   rI   rL   r>   r@   re   rj   rT   rU   r=   rB   rq   __classcell__r   r   r   r   r   %   s`             
&
%
> 2
'     r   )ru   copyr   typingr   r   rS   rT   Zimage_utilsr   Zprocessing_utilsr   Ztokenization_utils_baser	   utilsr
   r   r   Zutils.import_utilsr   Z
get_loggerrr   loggerr   r   __all__r   r   r   r   <module>   s"   
   k