a
    hI                     @   s  d dl Z d dlmZ d dlmZ d dlZd dlmZ d dl	m
Z
mZmZmZmZmZ d dlmZmZ d dlmZ d dlmZ d d	lmZ d d
lmZmZmZ d dlmZ d1ddZe  fddZ!eeddZ"d2ddZ#dd Z$G dd de%Z&G dd dee
Z'G dd dee
Z(G dd dee
Z)G dd  d ee
Z*G d!d" d"ee
Z+G d#d$ d$e
Z,G d%d& d&eZ-G d'd( d(eeZ.G d)d* d*eee
Z/G d+d, d,eee
Z0G d-d. d.eee
Z1G d/d0 d0eee
Z2dS )3    N)defaultdict)partial)assert_array_equal)BaseEstimatorClassifierMixinMetaEstimatorMixinRegressorMixinTransformerMixinclone)_Scorermean_squared_error)BaseCrossValidator)GroupsConsumerMixin)SIMPLE_METHODS)MetadataRouterMethodMappingprocess_routing)_check_partial_fit_first_callTc                 K   sb   t  }|d j}|d j}t| ds4tdd | _|sJdd | D }| j| | | dS )	zUtility function to store passed metadata to a method of obj.

    If record_default is False, kwargs whose values are "default" are skipped.
    This is so that checks on keyword arguments whose default was not changed
    are skipped.

          _recordsc                   S   s   t tS N)r   list r   r   a/var/www/html/assistant/venv/lib/python3.9/site-packages/sklearn/tests/metadata_routing_common.py<lambda>*       z!record_metadata.<locals>.<lambda>c                 S   s(   i | ] \}}t |tr|d kr||qS )default
isinstancestr).0keyvalr   r   r   
<dictcomp>,   s   z#record_metadata.<locals>.<dictcomp>N)inspectstackfunctionhasattrr   r   itemsappend)objrecord_defaultkwargsr&   calleecallerr   r   r   record_metadata   s    


r0   c           
   	   K   s   t | dt |t |t }|D ]}t| t| ks^J d|  d|  | D ]p\}}|| }	||v r|	durt|	|	 sJ qft
|	tjrt|	| qf|	|u sfJ d|	 d| d| qfq&dS )a  Check whether the expected metadata is passed to the object's method.

    Parameters
    ----------
    obj : estimator object
        sub-estimator to check routed params for
    method : str
        sub-estimator's method where metadata is routed to, or otherwise in
        the context of metadata routing referred to as 'callee'
    parent : str
        the parent method which should have called `method`, or otherwise in
        the context of metadata routing referred to as 'caller'
    split_params : tuple, default=empty
        specifies any parameters which are to be checked as being a subset
        of the original values
    **kwargs : dict
        passed metadata
    r   z	Expected z vs Nz
. Method: )getattrdictgetr   setkeysr)   npisinallr   Zndarrayr   )
r+   methodparentZsplit_paramsr-   Zall_recordsrecordr"   valueZrecorded_valuer   r   r   check_recorded_metadata4   s$     r=   F)r,   c                 C   s   t | trH| D ]4\}}|dur0||v r0|| }nd}t|j|d qdS |du rTg n|}tD ]4}||v rjq\t| |}dd |j D }|r\J q\dS )a  Check if a metadata request dict is empty.

    One can exclude a method or a list of methods from the check using the
    ``exclude`` parameter. If metadata_request is a MetadataRouter, then
    ``exclude`` can be of the form ``{"object" : [method, ...]}``.
    N)excludec                 S   s&   g | ]\}}t |ts|d ur|qS r   r   )r!   propaliasr   r   r   
<listcomp>w   s   z+assert_request_is_empty.<locals>.<listcomp>)r   r   assert_request_is_emptyrouterr   r1   requestsr)   )Zmetadata_requestr>   nameZroute_mappingZ_excluder9   mmrpropsr   r   r   rB   b   s     


rB   c                    s^      D ] \}}t| |}|j|ksJ q fddtD }|D ]}tt| |jr@J q@d S )Nc                    s   g | ]}| vr|qS r   r   )r!   r9   
dictionaryr   r   rA      r   z(assert_request_equal.<locals>.<listcomp>)r)   r1   rD   r   len)requestrI   r9   rD   rF   Zempty_methodsr   rH   r   assert_request_equal   s    
rL   c                   @   s   e Zd Zdd Zdd ZdS )	_Registryc                 C   s   | S r   r   )selfmemor   r   r   __deepcopy__   s    z_Registry.__deepcopy__c                 C   s   | S r   r   rN   r   r   r   __copy__   s    z_Registry.__copy__N)__name__
__module____qualname__rP   rR   r   r   r   r   rM      s   rM   c                   @   sB   e Zd ZdZdddZdddZddd	Zdd
dZdddZdS )ConsumingRegressorac  A regressor consuming metadata.

    Parameters
    ----------
    registry : list, default=None
        If a list, the estimator will append itself to the list in order to have
        a reference to the estimator later on. Since that reference is not
        required in all tests, registration can be skipped by leaving this value
        as None.
    Nc                 C   s
   || _ d S r   registryrN   rX   r   r   r   __init__   s    zConsumingRegressor.__init__r   c                 C   s(   | j d ur| j |  t| ||d | S Nsample_weightmetadatarX   r*   record_metadata_not_defaultrN   Xyr]   r^   r   r   r   partial_fit   s    
zConsumingRegressor.partial_fitc                 C   s(   | j d ur| j |  t| ||d | S r[   r_   ra   r   r   r   fit   s    
zConsumingRegressor.fitc                 C   s    t | ||d tjt|fdS )Nr\   shaper`   r6   ZzerosrJ   ra   r   r   r   predict   s    zConsumingRegressor.predictc                 C   s   t | ||d dS Nr\   r   r`   ra   r   r   r   score   s    zConsumingRegressor.score)N)r   r   )r   r   )Nr   r   )r   r   )	rS   rT   rU   __doc__rZ   rd   re   ri   rl   r   r   r   r   rV      s   

	
	
rV   c                   @   sL   e Zd ZdZdddZdd Zddd	Zd
d Zdd Zdd Z	dd Z
dS )NonConsumingClassifier5A classifier which accepts no metadata on any method.        c                 C   s
   || _ d S r   )alpha)rN   rq   r   r   r   rZ      s    zNonConsumingClassifier.__init__c                 C   s   t || _t || _| S r   )r6   uniqueclasses_	ones_likecoef_rN   rb   rc   r   r   r   re      s    zNonConsumingClassifier.fitNc                 C   s   | S r   r   )rN   rb   rc   classesr   r   r   rd      s    z"NonConsumingClassifier.partial_fitc                 C   s
   |  |S r   )ri   rN   rb   r   r   r   decision_function   s    z(NonConsumingClassifier.decision_functionc                 C   s>   t jt|fd}d|d t|d < d|t|d d < |S )Nrf   r   r   r   )r6   emptyrJ   )rN   rb   Zy_predr   r   r   ri      s    zNonConsumingClassifier.predictc                 C   sd   t jt|dfd}t ddg|d t|d d d f< t ddg|t|d d d d f< |S )Nr   rf         ?rp   )r6   rz   rJ   asarray)rN   rb   y_probar   r   r   predict_proba   s    &&z$NonConsumingClassifier.predict_probac                 C   s
   |  |S r   )r~   rx   r   r   r   predict_log_proba   s    z(NonConsumingClassifier.predict_log_proba)rp   )N)rS   rT   rU   rm   rZ   re   rd   ry   ri   r~   r   r   r   r   r   rn      s   

rn   c                   @   s(   e Zd ZdZdd Zdd Zdd ZdS )	NonConsumingRegressorro   c                 C   s   | S r   r   rv   r   r   r   re      s    zNonConsumingRegressor.fitc                 C   s   | S r   r   rv   r   r   r   rd      s    z!NonConsumingRegressor.partial_fitc                 C   s   t t|S r   )r6   ZonesrJ   rx   r   r   r   ri      s    zNonConsumingRegressor.predictN)rS   rT   rU   rm   re   rd   ri   r   r   r   r   r      s   r   c                   @   s`   e Zd ZdZdddZdddZdd	d
ZdddZdddZdddZ	dddZ
dddZdS )ConsumingClassifiera  A classifier consuming metadata.

    Parameters
    ----------
    registry : list, default=None
        If a list, the estimator will append itself to the list in order to have
        a reference to the estimator later on. Since that reference is not
        required in all tests, registration can be skipped by leaving this value
        as None.

    alpha : float, default=0
        This parameter is only used to test the ``*SearchCV`` objects, and
        doesn't do anything.
    Nrp   c                 C   s   || _ || _d S r   )rq   rX   )rN   rX   rq   r   r   r   rZ     s    zConsumingClassifier.__init__r   c                 C   s2   | j d ur| j |  t| ||d t| | | S r[   )rX   r*   r`   r   )rN   rb   rc   rw   r]   r^   r   r   r   rd     s    

zConsumingClassifier.partial_fitc                 C   s@   | j d ur| j |  t| ||d t|| _t|| _| S r[   )rX   r*   r`   r6   rr   rs   rt   ru   ra   r   r   r   re     s    
zConsumingClassifier.fitc                 C   sN   t | ||d tjt|fdd}d|t|d d < d|d t|d < |S )Nr\   Zint8)rg   Zdtyper   r   r   r`   r6   rz   rJ   rN   rb   r]   r^   Zy_scorer   r   r   ri      s    zConsumingClassifier.predictc                 C   sr   t | ||d tjt|dfd}tddg|d t|d d d f< tddg|t|d d d d f< |S )Nr\   r   rf   r{   rp   )r`   r6   rz   rJ   r|   )rN   rb   r]   r^   r}   r   r   r   r~   )  s    &&z!ConsumingClassifier.predict_probac                 C   s"   t | ||d tjt|dfdS )Nr\   r   rf   rh   rN   rb   r]   r^   r   r   r   r   2  s    z%ConsumingClassifier.predict_log_probac                 C   sL   t | ||d tjt|fd}d|t|d d < d|d t|d < |S )Nr\   rf   r   r   r   r   r   r   r   r   ry   8  s    z%ConsumingClassifier.decision_functionc                 C   s   t | ||d dS rj   rk   ra   r   r   r   rl   A  s    zConsumingClassifier.score)Nrp   )Nr   r   )r   r   )r   r   )r   r   )r   r   )r   r   )r   r   )rS   rT   rU   rm   rZ   rd   re   ri   r~   r   ry   rl   r   r   r   r   r      s   
 


	
	

	r   c                   @   sB   e Zd ZdZdddZdddZddd	Zdd
dZdddZdS )ConsumingTransformera~  A transformer which accepts metadata on fit and transform.

    Parameters
    ----------
    registry : list, default=None
        If a list, the estimator will append itself to the list in order to have
        a reference to the estimator later on. Since that reference is not
        required in all tests, registration can be skipped by leaving this value
        as None.
    Nc                 C   s
   || _ d S r   rW   rY   r   r   r   rZ   T  s    zConsumingTransformer.__init__r   c                 C   s.   | j d ur| j |  t| ||d d| _| S )Nr\   T)rX   r*   r`   Zfitted_ra   r   r   r   re   W  s    
zConsumingTransformer.fitc                 C   s   t | ||d |d S rj   rk   r   r   r   r   	transforma  s    zConsumingTransformer.transformc                 C   s,   t | ||d | j||||dj|||dS r[   )r`   re   r   ra   r   r   r   fit_transformg  s    z"ConsumingTransformer.fit_transformc                 C   s   t | ||d |d S rj   rk   r   r   r   r   inverse_transforms  s    z&ConsumingTransformer.inverse_transform)N)Nr   r   )r   r   )r   r   )NN)	rS   rT   rU   rm   rZ   re   r   r   r   r   r   r   r   r   H  s   




r   c                   @   s.   e Zd ZdZd	ddZd
ddZdddZdS )"ConsumingNoFitTransformTransformerzA metadata consuming transformer that doesn't inherit from
    TransformerMixin, and thus doesn't implement `fit_transform`. Note that
    TransformerMixin's `fit_transform` doesn't route metadata to `transform`.Nc                 C   s
   || _ d S r   rW   rY   r   r   r   rZ     s    z+ConsumingNoFitTransformTransformer.__init__c                 C   s(   | j d ur| j |  t| ||d | S r[   )rX   r*   r0   ra   r   r   r   re     s    
z&ConsumingNoFitTransformTransformer.fitc                 C   s   t | ||d |S r[   )r0   r   r   r   r   r     s    z,ConsumingNoFitTransformTransformer.transform)N)NNN)NN)rS   rT   rU   rm   rZ   re   r   r   r   r   r   r   z  s   

r   c                       s*   e Zd Zd fdd	Z fddZ  ZS )ConsumingScorerNc                    s   t  jtdi dd || _d S )Nr   ri   )Z
score_funcsignr-   Zresponse_method)superrZ   r   rX   rY   	__class__r   r   rZ     s    zConsumingScorer.__init__c                    sH   | j d ur| j |  t| fi | |dd }t j|||||dS )Nr]   r]   )rX   r*   r`   r3   r   _score)rN   Zmethod_callerZclfrb   rc   r-   r]   r   r   r   r     s
    
zConsumingScorer._score)N)rS   rT   rU   rZ   r   __classcell__r   r   r   r   r     s   r   c                   @   s4   e Zd ZdddZdddZdddZdd	d
ZdS )ConsumingSplitterNc                 C   s
   || _ d S r   rW   rY   r   r   r   rZ     s    zConsumingSplitter.__init__r   c                 c   sh   | j d ur| j |  t| ||d t|d }ttd|}tt|t|}||fV  ||fV  d S )N)groupsr^   r   r   )rX   r*   r`   rJ   r   range)rN   rb   rc   r   r^   split_indextrain_indicestest_indicesr   r   r   split  s    

zConsumingSplitter.splitc                 C   s   dS )Nr   r   )rN   rb   rc   r   r^   r   r   r   get_n_splits  s    zConsumingSplitter.get_n_splitsc                 c   s<   t |d }ttd|}tt|t |}|V  |V  d S )Nr   r   )rJ   r   r   )rN   rb   rc   r   r   r   r   r   r   r   _iter_test_indices  s
    z$ConsumingSplitter._iter_test_indices)N)Nr   r   )NNNN)NNN)rS   rT   rU   rZ   r   r   r   r   r   r   r   r     s   


r   c                   @   s(   e Zd ZdZdd Zdd Zdd ZdS )	MetaRegressorz(A meta-regressor which is only a router.c                 C   s
   || _ d S r   )	estimator)rN   r   r   r   r   rZ     s    zMetaRegressor.__init__c                 K   s6   t | dfi |}t| jj||fi |jj| _d S Nre   )r   r
   r   re   
estimator_rN   rb   rc   
fit_paramsparamsr   r   r   re     s    zMetaRegressor.fitc                 C   s*   t | jjdj| jt jdddd}|S Nownerre   r/   r.   r   method_mapping)r   r   rS   addr   r   rN   rC   r   r   r   get_metadata_routing  s
    z"MetaRegressor.get_metadata_routingNrS   rT   rU   rm   rZ   re   r   r   r   r   r   r     s   r   c                   @   s4   e Zd ZdZdddZdddZdd Zd	d
 ZdS )WeightedMetaRegressorz*A meta-regressor which is also a consumer.Nc                 C   s   || _ || _d S r   r   rX   rN   r   rX   r   r   r   rZ     s    zWeightedMetaRegressor.__init__c                 K   s\   | j d ur| j |  t| |d t| dfd|i|}t| jj||fi |jj| _| S Nr   re   r]   rX   r*   r0   r   r
   r   re   r   )rN   rb   rc   r]   r   r   r   r   r   re     s    
 zWeightedMetaRegressor.fitc                 K   s*   t | dfi |}| jj|fi |jjS )Nri   )r   r   ri   r   )rN   rb   Zpredict_paramsr   r   r   r   ri     s    zWeightedMetaRegressor.predictc                 C   s:   t | jjd| j| jt jdddjdddd}|S )Nr   re   r   ri   r   r   r   rS   Zadd_self_requestr   r   r   r   r   r   r   r     s    
z*WeightedMetaRegressor.get_metadata_routing)N)N)rS   rT   rU   rm   rZ   re   ri   r   r   r   r   r   r     s
   

	r   c                   @   s,   e Zd ZdZd	ddZd
ddZdd ZdS )WeightedMetaClassifierzEA meta-estimator which also consumes sample_weight itself in ``fit``.Nc                 C   s   || _ || _d S r   r   r   r   r   r   rZ     s    zWeightedMetaClassifier.__init__c                 K   s\   | j d ur| j |  t| |d t| dfd|i|}t| jj||fi |jj| _| S r   r   )rN   rb   rc   r]   r-   r   r   r   r   re     s    
 zWeightedMetaClassifier.fitc                 C   s0   t | jjd| j| jt jdddd}|S r   r   r   r   r   r   r     s    z+WeightedMetaClassifier.get_metadata_routing)N)Nr   r   r   r   r   r     s   

	r   c                   @   s4   e Zd ZdZdd ZdddZdddZd	d
 ZdS )MetaTransformerzA simple meta-transformer.c                 C   s
   || _ d S r   )transformer)rN   r   r   r   r   rZ     s    zMetaTransformer.__init__Nc                 K   s6   t | dfi |}t| jj||fi |jj| _| S r   )r   r
   r   re   transformer_r   r   r   r   re     s     zMetaTransformer.fitc                 K   s*   t | dfi |}| jj|fi |jjS )Nr   )r   r   r   r   )rN   rb   rc   Ztransform_paramsr   r   r   r   r     s    zMetaTransformer.transformc                 C   s0   t | jjdj| jt jdddjddddS )Nr   re   r   r   )r   r   )r   r   rS   r   r   r   rQ   r   r   r   r     s    z$MetaTransformer.get_metadata_routing)N)N)rS   rT   rU   rm   rZ   re   r   r   r   r   r   r   r     s
   

r   )T)N)3r%   collectionsr   	functoolsr   numpyr6   Znumpy.testingr   Zsklearn.baser   r   r   r   r	   r
   Zsklearn.metrics._scorerr   r   Zsklearn.model_selectionr   Zsklearn.model_selection._splitr   Z sklearn.utils._metadata_requestsr   Zsklearn.utils.metadata_routingr   r   r   Zsklearn.utils.multiclassr   r0   tupler=   r`   rB   rL   r   rM   rV   rn   r   r   r   r   r   r   r   r   r   r   r   r   r   r   <module>   s:    
+

.#T2"