a
    hq                    @   s*  d dl Z d dlmZ d dlZd dlZd dlZd dlmZ	 d dl
mZmZmZ d dlmZmZ d dlmZ d dlmZ d dlmZ d dlmZ d d	lmZmZmZ d d
lmZ d dlm Z m!Z!m"Z"m#Z# d dl$m%Z% d dl&m'Z' d dl(m)Z)m*Z*m+Z+m,Z, dd Z-G dd dej.Z/G dd dej0Z1G dd dej2Z3dd Z.dd Z0dd Z2dd Z4dd  Z5d!d" Z6e7d#d$gd$d$gd$d#gd%d%gd%d&gd&d%ggZ8g d'Z9e7d$d$gd&d&gd(d&ggZ:g d)Z;e7d$d%gd*d+gd,d-gd%d%gd.d+gd-d-gd$d$gd d/gd%d$gg	Z<d0gd( d1gd(  d2gd(  Z=e7d,d+gd%d&gd d#ggZ>g d3Z?e7g d4g d4g d5g d5g d6g d6g d7g d7gZ@e7g d8ZAe7g d9g d:g d;g d<g d=g d>g d?g d@gZBe7g d8ZCeD ZEe7d#d$gd$d$gd$d#gd%d%gd%d&gd&d%ggZFg d'ZGg dAZHddCdDZIdEdF ZJejKLdGe.e4e0e5gejKLdHg dIdJdK ZMejKLdGe.e4e0e5gdLdM ZNejKLdGe.e4e0e5gdNdO ZOejKLdGe.e4e0e5e2e6gdPdQ ZPejKLdGe.e4e0e5e2e6gdRdS ZQejKLdGe.e4e0e5gdTdU ZRejKLdGe.e4e0e5gdVdW ZSejKLdGe.e4e0e5gdXdY ZTejKLdGe.e4e0e5gdZd[ ZUejKLdGe.e4e0e5gd\d] ZVejKLdGe.e4e0e5gd^d_ ZWejKLdGe.e4gd`da ZXejKLdGe.e4e2e6gdbdc ZYejKLdde.deeZdfife4deeZdfife2dgeZdfife6dgeZdfifgdhdi Z[ejKLdGe.e4e0e5gdjdk Z\ejKLdde.ded ife4ded ife2dgd ife6dgd ifgdldm Z]ejKLdGe.e4gdndo Z^ejKLdGe.e4gdpdq Z_ejKLdGe.e4gdrds Z`ejKLdGe.e4gdtdu ZaejKLdGe.e4gdvdw ZbejKLdGe.e4gdxdy ZcejKLdGe.e4gdzd{ ZdejKLdGe.e4gd|d} ZeejKLdGe.e4gd~d ZfejKLdGe.e4gdd ZgejKLdGe.e4gdd ZhejKLdGe.e4gdd ZiejKLdGe.e4gdd ZjejKLdGe.e4gdd ZkejKLdGe.e4gdd ZlejKLdGe.e4gdd ZmejKLdGe.e4gdd ZnejKLdGe.e4gdd ZoejKLdGe.e4e2e6gdd ZpejKLdGe.e4gdd ZqejKLdGe.e4gdd ZrejKLdGe.e4gdd ZsejKLdGe.e4gdd ZtejKLdGe.e4gdd ZuejKLdGe.e4gejKLdHg dIdd ZvejKLdGe.e4gdd ZwejKLdGe.e4gdd ZxejKLdGe.e4gdd ZyejKLdGe0e5gdd ZzejKLdGe0e5gdd Z{ejKLdGe0e5gdd Z|ejKLdGe0e5gdd Z}ejKLdGe0e5gdd Z~ejKLdGe0e5gdd ZejKLdGe0e5gdd ZejKLdGe0e5gdd ZejKLdGe0e5gdd ZejKLdGe0e5gejKLdHg dIdd ZejKLdGe0e5gdd ZdddZejKLdGe2e6gdd ZejKLdGe2e6gejKLdHg dIdd ZejKLdGe2e6gddÄ ZejKLdGe2e6gddń ZejKLdGe2e6gejKLdHg dIddǄ ZejKLdGe2e6gddɄ ZejKLdGe2e6gdd˄ ZejKLdGe2e6gdd̈́ ZejKLdGe2e6gddτ Zddф Zddӄ ZddՄ Zddׄ Zddل ZejKLdg dۢdd݄ Zdd߄ Zdd Zdd Zdd Zdd Zdd Zdd Zdd Zdd ZejKLdg ddd ZejKLdej.ej0gdd Zdd Zdd ZejKLde.e0gdd ZejKLde.e4e0e5e2e6gejKLdejejfdd  ZejKLde.e4e0e5e2e6gdd ZejKLde.e0e2gdd Zdd ZdS (	      N)Mock)datasetslinear_modelmetrics)cloneis_classifier)ConvergenceWarning)Nystroem)	_sgd_fast)_stochastic_gradient)RandomizedSearchCVShuffleSplitStratifiedShuffleSplit)make_pipeline)LabelEncoderMinMaxScalerStandardScalerscale)OneClassSVM)get_tags)assert_allcloseassert_almost_equalassert_array_almost_equalassert_array_equalc                 C   s4   d| vrd| d< d| vr d | d< d| vr0d| d< d S )Nrandom_state*   tolmax_iter    kwargsr   r   _/var/www/html/assistant/venv/lib/python3.9/site-packages/sklearn/linear_model/tests/test_sgd.py_update_kwargs    s    r#   c                       s@   e Zd Z fddZ fddZ fddZ fddZ  ZS )	_SparseSGDClassifierc                    s&   t |}t j||g|R i |S N)sp
csr_matrixsuperfitselfXyargskw	__class__r   r"   r)   +   s    
z_SparseSGDClassifier.fitc                    s&   t |}t j||g|R i |S r%   )r&   r'   r(   partial_fitr*   r0   r   r"   r2   /   s    
z _SparseSGDClassifier.partial_fitc                    s   t |}t |S r%   )r&   r'   r(   decision_functionr+   r,   r0   r   r"   r3   3   s    
z&_SparseSGDClassifier.decision_functionc                    s   t |}t |S r%   )r&   r'   r(   predict_probar4   r0   r   r"   r5   7   s    
z"_SparseSGDClassifier.predict_proba)__name__
__module____qualname__r)   r2   r3   r5   __classcell__r   r   r0   r"   r$   *   s   r$   c                   @   s$   e Zd Zdd Zdd Zdd ZdS )_SparseSGDRegressorc                 O   s(   t |}tjj| ||g|R i |S r%   )r&   r'   r   SGDRegressorr)   r*   r   r   r"   r)   =   s    
z_SparseSGDRegressor.fitc                 O   s(   t |}tjj| ||g|R i |S r%   )r&   r'   r   r;   r2   r*   r   r   r"   r2   A   s    
z_SparseSGDRegressor.partial_fitc                 O   s&   t |}tjj| |g|R i |S r%   )r&   r'   r   r;   r3   r+   r,   r.   r/   r   r   r"   r3   E   s    
z%_SparseSGDRegressor.decision_functionNr6   r7   r8   r)   r2   r3   r   r   r   r"   r:   <   s   r:   c                   @   s$   e Zd Zdd Zdd Zdd ZdS )_SparseSGDOneClassSVMc                 O   s&   t |}tjj| |g|R i |S r%   )r&   r'   r   SGDOneClassSVMr)   r<   r   r   r"   r)   L   s    
z_SparseSGDOneClassSVM.fitc                 O   s&   t |}tjj| |g|R i |S r%   )r&   r'   r   r?   r2   r<   r   r   r"   r2   P   s    
z!_SparseSGDOneClassSVM.partial_fitc                 O   s&   t |}tjj| |g|R i |S r%   )r&   r'   r   r?   r3   r<   r   r   r"   r3   T   s    
z'_SparseSGDOneClassSVM.decision_functionNr=   r   r   r   r"   r>   K   s   r>   c                  K   s   t |  tjf i | S r%   )r#   r   SGDClassifierr    r   r   r"   r@   Y   s    r@   c                  K   s   t |  tjf i | S r%   )r#   r   r;   r    r   r   r"   r;   ^   s    r;   c                  K   s   t |  tjf i | S r%   )r#   r   r?   r    r   r   r"   r?   c   s    r?   c                  K   s   t |  tf i | S r%   )r#   r$   r    r   r   r"   SparseSGDClassifierh   s    rA   c                  K   s   t |  tf i | S r%   )r#   r:   r    r   r   r"   SparseSGDRegressorm   s    rB   c                  K   s   t |  tf i | S r%   )r#   r>   r    r   r   r"   SparseSGDOneClassSVMr   s    rC         )rF   rF   rF   rG   rG   rG      )rF   rG   rG   g            ?g      g      ?g      ?      onetwothree)rK   rL   rM   )rF   rF   r   r   r   r   )r   r   rF   r   r   r   )r   r   r   r   rF   rF   )r   r   r   rF   r   r   )rF   rF   rF   rF   rG   rG   rG   rG   )rF   ?皙?r   r   r   )rF   zG?g\(\?r   r   r   )rF   Q?g)\(?r   r   r   )rF   Q?Gz?r   r   r   )r   r   r   g{Gz?rR   rF   )r   r   r   gHzG?rP   rF   )r   r   r   rR   gffffff?rF   )r   r   r   g(\?rF   rF   )r   rF   rF           c                 C   s   |d u rt |jd }n|}t |jd }|}	d}
d}| ttfv rJd}t|D ]\}}t ||}||	7 }|||  }|d||  9 }||| |  7 }|	||  | 7 }	||9 }||7 }||d  }|
|9 }
|
|	7 }
|
|d  }
qR||
fS )NrF   rT         ?{Gz?)npzerosshaperA   rB   	enumeratedot)klassr,   r-   etaalphaweight_initintercept_initweightsaverage_weights	interceptaverage_interceptdecayientrypgradientr   r   r"   asgd   s.    rj   c                 C   s   | ddd|d}| || | ddd|d}|j |||j |j d | dddd|d}| || |j|jksxJ t|j|j |jdd | || |j|jksJ t|j|j d S )	NrV   F)r^   eta0shufflelearning_rateMbP?	coef_initr`   T)r^   rk   rl   
warm_startrm   r^   )r)   coef_copy
intercept_t_r   
set_params)r\   r,   Ylrclfclf2clf3r   r   r"   _test_warm_start   s    
r}   r\   ry   )constantoptimalZ
invscalingadaptivec                 C   s   t | tt| d S r%   )r}   r,   rx   r\   ry   r   r   r"   test_warm_start   s    r   c                 C   sx   | ddd}| tt ttd d tjf }tj||f }tt	 | t| W d    n1 sj0    Y  d S )NrV   Fr^   rl   )
r)   r,   rx   rW   arrayZnewaxisZc_pytestraises
ValueError)r\   rz   Y_r   r   r"   test_input_format   s    r   c                 C   sV   | ddd}t |}|jdd |tt | ddd}|tt t|j|j d S )NrV   l1)r^   penaltyl2)r   )r   rw   r)   r,   rx   r   rs   r\   rz   r{   r   r   r"   
test_clone  s    r   c                 C   s   | ddd}| tt t|ds&J t|ds4J t|dsBJ t|dsPJ |  }| tt t|drpJ t|dr~J t|drJ t|drJ d S )NTrV   )averagerk   Z_average_coefZ_average_interceptZ_standard_interceptZ_standard_coef)r)   r,   rx   hasattrr\   rz   r   r   r"   test_plain_has_no_average_attr  s    r   c                 C   s   | dd}|  }t dD ]R}t|rR|jttttd |jttttd q|tt |tt qt|j|jdd | t	t
ttfv rt|j|jdd n| ttfv rt|j|j d S )NiX  r   d   classes   decimal)ranger   r2   r,   rx   rW   uniquer   rs   r@   rA   r;   rB   r   ru   r?   rC   r   offset_)r\   clf1r{   _r   r   r"   %test_late_onset_averaging_not_reached:  s    
r   c              	   C   s   d}d}t t}d||dk< d||dk< | ddd	||dd
d}| d
dd	||dd
d}|t| |t| t| t||||j |jd\}}t	|j | dd t
|j|dd d S )Nrn   -C6?      rF   rU   rG      r~   squared_errorF)r   rm   lossrk   r^   r   rl   )r_   r`   r   r   )rW   r   rx   r)   r,   rj   rs   ravelru   r   r   )r\   rk   r^   ZY_encoder   r{   rb   rd   r   r   r"   !test_late_onset_averaging_reachedW  sH    
	


r   c                 C   sV   t jt jdk }t jt jdk }dD ],}d}| |d|d||}|j|k s$J q$d S )Nr   TF  rn   )early_stoppingr   r   )irisdatatargetr)   n_iter_)r\   r,   rx   r   r   rz   r   r   r"   test_early_stopping  s    r   c                 C   sT   | ddddd}| tjtj | ddddd}| tjtj |j|jksPJ d S )Nr   rV   rn   r   )rm   rk   r   r   r~   )r)   r   r   r   r   )r\   r   r{   r   r   r"   "test_adaptive_longer_than_constant  s
    r   c              
   C   s   t jt j }}d}d}d}d}| dtj||ddd ||d}||| |j|ksXJ | dtj|ddd ||d	}t|rt	||d
}	nt
||d
}	t|	||\}
}t|
}
|||
 ||
  |j|ksJ t|j|j d S )N皙?r   F
   Tr~   rV   )r   r   validation_fractionrm   rk   r   r   rl   )r   r   rm   rk   r   r   rl   )Z	test_sizer   )r   r   r   rW   randomRandomStater)   r   r   r   r   nextsplitsortr   rs   )r\   r,   rx   r   seedrl   r   r   r{   ZcvZ	idx_trainZidx_valr   r   r"   )test_validation_set_not_used_for_training  sD    




r   c                    sB   t jt j  dD ]* fdddD }t|t| qd S )Nr   c                    s&   g | ]}|d dd  jqS )r   r   )r   n_iter_no_changer   r   )r)   r   ).0r   r,   rx   r   r\   r   r"   
<listcomp>  s   	z)test_n_iter_no_change.<locals>.<listcomp>)rG   rH   r   )r   r   r   r   sorted)r\   Zn_iter_listr   r   r"   test_n_iter_no_change  s    	r   c                 C   sF   | ddd}t t |tt W d    n1 s80    Y  d S )NTrS   )r   r   )r   r   r   r)   X3Y3r   r   r   r"   )test_not_enough_sample_for_early_stopping  s    r   c              	   C   s>   dD ]4}| ddd|ddd}| tt t|tt qd S )N)hingesquared_hingelog_lossmodified_huberr   rV   Tr   )r   r^   fit_interceptr   r   rl   )r)   r,   rx   r   predictTtrue_result)r\   r   rz   r   r   r"   test_sgd_clf  s    r   c                 C   sJ   t jtdd( |  jtttdd W d   n1 s<0    Y  dS )z1Check that the shape of `coef_init` is validated.z)Provided coef_init does not match datasetmatchrH   rp   N)r   r   r   r)   r,   rx   rW   rX   r\   r   r   r"   test_provide_coef  s    r   zklass, fit_paramsr`   r   offset_initc                 C   sL   |  }t jtdd$ |jttfi | W d   n1 s>0    Y  dS )z:Check that `intercept_init` or `offset_init` is validated.zdoes not match datasetr   Nr   r   r   r)   r,   rx   )r\   
fit_paramsZsgd_estimatorr   r   r"   test_set_intercept_offset  s    r   c                 C   sH   d}t jt|d" | ddtt W d   n1 s:0    Y  dS )zSCheck that we raise an error for `early_stopping` used with
    `partial_fit`.
    z/early_stopping should be False with partial_fitr   T)r   N)r   r   r   r2   r,   rx   )r\   err_msgr   r   r"   (test_sgd_early_stopping_with_partial_fit  s    r   c                 C   s   |  j ttfi | dS )zdCheck that we can pass a scaler with binary classification to
    `intercept_init` or `offset_init`.N)r)   X5Y5)r\   r   r   r   r"    test_set_intercept_offset_binary$  s    r   c              
   C   s   d}d}d}d}t jd}|j||fd}|j|d}| dd||d	d
d	dd}t ||}	t |	}	|||	 t| ||	||\}
}|
d
d}
t	|j
|
dd t|j|dd d S )N皙?       @   r   r   sizer   r~   TrF   Fr   rm   rk   r^   r   r   r   rl   rE      r   )rW   r   r   normalr[   signr)   rj   reshaper   rs   r   ru   )r\   r]   r^   	n_samples
n_featuresrngr,   wrz   r-   rb   rd   r   r   r"   &test_average_binary_computed_correctly3  s0    
r   c                 C   sH   |   tt}|  j tt|jd |   tt}|  j tt|jd d S )Nr`   )r)   r   r   ru   r,   rx   r   r   r   r"   test_set_intercept_to_interceptU  s    r   c                 C   sL   | ddd}t t" |ttd W d    n1 s>0    Y  d S )NrV   r   r^   r   	   )r   r   r   r)   X2rW   onesr   r   r   r"   test_sgd_at_least_two_labels_  s    r   c                 C   sR   d}t jt|d, | ddjttttd W d    n1 sD0    Y  d S )Na`  class_weight 'balanced' is not supported for partial_fit\. In order to use 'balanced' weights, use compute_class_weight\('balanced', classes=classes, y=y\). In place of y you can use a large enough sample of the full training set target to properly estimate the class frequency distributions\. Pass the resulting weights as the class_weight parameter\.r   balanced)class_weightr   )r   r   r   r2   r,   rx   rW   r   )r\   regexr   r   r"   &test_partial_fit_weight_class_balancedg  s    
r   c                 C   sf   | ddd tt}|jjdks$J |jjdks4J |ddggjdksNJ |t}t	|t
 d S )NrV   r   r   rH   rG   r   r   rF   rH   r)   r   Y2rs   rY   ru   r3   r   T2r   true_result2r\   rz   predr   r   r"   test_sgd_multiclassx  s    
r   c              
   C   s   d}d}| dd||ddddd}t t}|t| t |}t|D ]`\}}t |jd	 }d
|||k< t	| t|||\}	}
t
|	|j| dd t|
|j| dd qHd S )Nrn   rV   r   r~   TrF   Fr   r   rE   r   r   )rW   r   r   r)   r   r   rZ   r   rY   rj   r   rs   r   ru   )r\   r]   r^   rz   Znp_Y2r   rf   clZy_iaverage_coefrd   r   r   r"   test_sgd_multiclass_average  s*    

r   c                 C   sb   | ddd}|j tttdtdd |jjdks:J |jjsJJ d|t	}t
|t d S )NrV   r   r   r   rH   ro   r   )r)   r   r   rW   rX   rs   rY   ru   r   r   r   r   r   r   r   r"   "test_sgd_multiclass_with_init_coef  s    
r  c                 C   sh   | dddd tt}|jjdks&J |jjdks6J |ddggjdksPJ |t}t	|t
 d S )	NrV   r   rG   )r^   r   n_jobsr   r   r   r   r   r   r   r   r"   test_sgd_multiclass_njobs  s    
r  c                 C   s   |  }t t& |jtttdd W d    n1 s<0    Y  |  jtttdd}|  }t t& |jtttdd W d    n1 s0    Y  |  jtttdd}d S )N)rG   rG   r   r   rF   r   r   )r   r   r   r)   r   r   rW   rX   r   r   r   r"   test_set_coef_multiclass  s    44r  c              	   C   s   t jjD ]}t|d}|dv r>t|ds.J t|ds<J qd|}t|drVJ t|drdJ tjtdd}|j W d    n1 s0    Y  t	|j
jtsJ |t|j
jv sJ tjtdd}|j W d    n1 s0    Y  t	|j
jtsJ |t|j
jv sJ qd S )	N)r   r   r   r5   predict_log_probaz5probability estimates are not available for loss={!r}z has no attribute 'predict_proba'r   z$has no attribute 'predict_log_proba')r   r@   loss_functionsr   formatr   r   AttributeErrorr5   
isinstancevalue	__cause__strr  )r\   r   rz   Z	inner_msgZ	exec_infor   r   r"   $test_sgd_predict_proba_method_access  s0    
$$r  c              	   C   s  t dddd dtt}t|dr&J t|dr4J dD ]}| |ddd}|tt |d	d
gg}|d dksvJ |ddgg}|d dk sJ tjddX |d	d
gg}|d |d ksJ |ddgg}|d |d k sJ W d    q81 s0    Y  q8| ddddt	t
}|ddgddgg}|ddgddgg}ttj|ddtj|dd t|d  d t|d dksJ |ddgg}|ddgg}tt|d t|d  |d	d
gg}|d	d
gg}tt|| |ddgg}|ddgg}tt|| | dddd}|t	t
 |d	d
gg}|d	d
gg}| tkrtj|ddtj|ddksJ n"tj|ddtj|ddksJ tjdd}||g}t|dk r||g}t|d dgd	  d S )Nr   rV   r   )r   r^   r   r   r5   r  r  )r   r^   r   rH   rG   r   rF   rI   rE   ignore)divide)r   r   r   r   皙333333?皙?rF   )Zaxisr   r   gUUUUUU?)r@   r)   r,   rx   r   r5   rW   errstater  r   r   r3   r   Zargmaxr   sumallZargsortr   logrA   Zargminmean)r\   rz   r   rh   dZlpxr   r   r"   test_sgd_proba  sT    6
$"r  c                 C   s   t t}tjd}t|}|| t|d d f }t| }| ddddd dd}||| t	|j
ddd	f td
 ||}t	|| |  t|j
sJ ||}t	|| tt|}t|j
sJ ||}t	|| d S )N   r   r  F  )r   r^   r   r   r   rl   r   rF   rE   )   )lenX4rW   r   r   arangerl   Y4r)   r   rs   rX   r   Zsparsifyr&   issparsepickleloadsdumps)r\   nr   idxr,   rx   rz   r   r   r   r"   test_sgd_l14  s4    






r+  c                 C   s   t ddgddgddgddgddgg}g d}| ddd	d d
}||| t|ddggt dg | ddd	ddid
}||| t|ddggt dg d S )Nr   r   皙rU   rT   rF   rF   rF   rE   rE   r   r   F)r^   r   r   r   r  rF   rn   rE   rW   r   r)   r   r   r\   r,   r-   rz   r   r   r"   test_class_weightsY  s    (r0  c                 C   s   ddgddgddgddgg}g d}| ddd d}| || ddgddgg}ddg}| dddddd}| || t|j|jd	d
 d S )NrF   r   )r   r   rF   rF   r   r   r^   r   r   rI   r  rG   r   )r)   r   rs   )r\   r,   r-   rz   Zclf_weightedr   r   r"   test_equal_class_weightl  s    r2  c                 C   sL   | ddddid}t t |tt W d    n1 s>0    Y  d S )Nr   r   r   rI   r1  r   r   r   r   r"   test_wrong_class_weight_label}  s    r3  c                 C   s   ddd}t jd}|tjd }t |}|tdk  |d 9  < |tdk  |d 9  < | dd|d	}| ddd
}|jtt|d |jtt|d t	|j
|j
 d S )Ng333333?r  )rF   rG   r   rF   rG   r   r   r1  r   sample_weight)rW   r   r   Zrandom_sampler$  rY   rt   r)   r"  r   rs   )r\   Zclass_weightsr   Zsample_weightsZmultiplied_togetherr   r{   r   r   r"   test_weights_multiplied  s    

r6  c                 C   s  t jt j }}t|}t|jd }tjd}|	| || }|| }| ddd dd
||}tj|||dd}t|d	d
d | ddddd
||}tj|||dd}t|d	d
d t|j|jd ||dkd d f }||dk }	t|g|gd  }
t|g|	gd  }| dd dd}|
|
| ||}tj||ddd	k s^J | dddd}|
|
| ||}tj||ddd	ksJ d S )Nr      r   r   F)r^   r   r   rl   Zweightedr   rQ   rF   r   r   r   )r   r   rl   )r   r   r   r   rW   r#  rY   r   r   rl   r)   r   Zf1_scorer   r   r   rs   Zvstackconcatenate)r\   r,   r-   r*  r   rz   f1Zclf_balancedZX_0Zy_0ZX_imbalancedZy_imbalancedy_predr   r   r"   test_balanced_weight  s<    


r;  c                 C   s   t ddgddgddgddgddgg}g d}| ddd	d
}||| t|ddggt dg |j||dgd dgd  d t|ddggt dg d S )Nr   r   r,  rU   rT   r-  r   r   Fr^   r   r   r  rF   rn   rH   rG   r4  rE   r.  r/  r   r   r"   test_sample_weights  s    ( r=  c                 C   sz   | t tfv r| dddd}n| ttfv r6| dddd}tt& |jtt	t
dd W d    n1 sl0    Y  d S )Nr   r   Fr<  )nur   r   r   r4  )r@   rA   r?   rC   r   r   r   r)   r,   rx   rW   r#  r   r   r   r"   test_wrong_sample_weights  s    r?  c                 C   sD   | dd}t t |tt W d    n1 s60    Y  d S )NrV   rr   )r   r   r   r2   r   r   r   r   r   r"   test_partial_fit_exception  s    
r@  c                 C   s   t jd d }| dd}tt}|jt d | td | |d |jjdt jd fks\J |jjdkslJ |ddggjdksJ t	|jj
}|t |d  t|d   t	|jj
}|sJ ||t}t|t d S )Nr   rH   rV   rr   r   rF   r  )r,   rY   rW   r   rx   r2   rs   ru   r3   idr   r   r   r   r   )r\   thirdrz   r   id1id2r:  r   r   r"   test_partial_fit_binary  s    

 
rE  c                 C   s   t jd d }| dd}tt}|jt d | td | |d |jjdt jd fks\J |jjdkslJ |ddggjdksJ t	|jj
}|t |d  t|d   t	|jj
}|sJ |d S )	Nr   rH   rV   rr   r   rF   r   r   )r   rY   rW   r   r   r2   rs   ru   r3   rA  r   )r\   rB  rz   r   rC  rD  r   r   r"   test_partial_fit_multiclass  s    

 rF  c                 C   s   t jd d }| dt jd d}tt}|jt d | td | |d |jjdt jd fksdJ |jjdkstJ |t |d  t|d   |jjdt jd fksJ |jjdksJ d S )Nr   rH   rV   )r^   r   r   rF   r   )r   rY   rW   r   r   r2   rs   ru   )r\   rB  rz   r   r   r   r"   #test_partial_fit_multiclass_average  s    
 rG  c                 C   s"   |  }| tt |tt d S r%   )r)   r   r   r2   r   r   r   r"   test_fit_then_partial_fit%  s    rH  c                 C   s   t ttftttffD ]\}}}| ddd|dd}||| ||}|j}t	
|}| dd|dd}tdD ]}	|j|||d qn||}
|j|ksJ t||
dd qd S )NrV   rG   F)r^   rk   r   rm   rl   r^   rk   rm   rl   r   r   )r,   rx   r   r   r   r   r)   r3   rv   rW   r   r   r2   r   )r\   ry   ZX_r   ZT_rz   r:  tr   rf   y_pred2r   r   r"   "test_partial_fit_equal_fit_classif/  s    


rL  c                 C   s   t jd}| dddd|d}|tt dt |ttkksFJ | dddd|d}|tt dt |ttkksJ | dd	|d
}|tt dt |ttkksJ | dddd|d}|tt dt |ttkksJ d S )NrF   rV   r~   r   epsilon_insensitive)r^   rm   rk   r   r   rU   Zsquared_epsilon_insensitivehuber)r^   r   r   r   )rW   r   r   r)   r,   rx   r  r   )r\   r   rz   r   r   r"   test_regression_lossesB  s>    rO  c                 C   s   t | ttd d S )Nr   )r}   r   r   r   r   r   r"   test_warm_start_multiclassh  s    rP  c                 C   s\   | ddd}| tt t|ds&J dd t tD }| td d d df | d S )NrV   Fr   rs   c                 S   s   g | ]}d dg| qS )ZhamZspamr   )r   rf   r   r   r"   r   u      z%test_multiple_fit.<locals>.<listcomp>rE   )r)   r,   rx   r   r   fit_transform)r\   rz   r-   r   r   r"   test_multiple_fitm  s
    rS  c                 C   sL   | dddd}| ddgddgddggg d |jd |jd ksHJ d S )Nr   rG   Fr<  r   rF   )r   rF   rG   )r)   rs   r   r   r   r"   test_sgd_reg}  s    "rT  c              
   C   s   d}d}d}d}t jd}|j||fd}|j|d}t ||}| dd||d	d
d	dd}	|	|| t| ||||\}
}t|	j|
dd t	|	j
|dd d S )Nrn   rV   r   r   r   r   r   r~   TrF   Fr   r   r   )rW   r   r   r   r[   r)   rj   r   rs   r   ru   r\   r]   r^   r   r   r   r,   r   r-   rz   rb   rd   r   r   r"   $test_sgd_averaged_computed_correctly  s,    rV  c              
   C   s   d}d}d}d}t jd}|j||fd}|j|d}t ||}| dd||d	d
d	dd}	|	|d t|d  d d  |d t|d   |	|t|d d  d d  |t|d d   t| ||||\}
}t|	j	|
dd t
|	jd |dd d S )Nrn   rV   r   r   r   r   r   r~   TrF   Fr   rG   r   r   )rW   r   r   r   r[   r2   intrj   r   rs   r   ru   rU  r   r   r"   test_sgd_averaged_partial_fit  s.    44rX  c              
   C   s   d}d}| dd||ddddd}t jd	 }|td t|d
  d d  t d t|d
   |tt|d
 d  d d  t t|d
 d   t| tt ||\}}t|j|dd t|j	|dd d S )Nrn   rV   r   r~   TrF   Fr   r   rG   r   r   )
r   rY   r2   r   rW  rj   r   rs   r   ru   )r\   r]   r^   rz   r   rb   rd   r   r   r"   test_average_sparse  s$    
44rY  c           	      C   s   d\}}d}t jd}t ||||d}d|  }| dddd	d
}||| |||}|dksnJ d|  ||d  }| dddd	d
}||| |||}|dksJ d S )Nr   r   r   rF   rI   r   r   r   F)r   r^   r   r   rS   	rW   r   r   Zlinspacer   r   r)   scorerandn	r\   ZxminZxmaxr   r   r,   r-   rz   r]  r   r   r"   test_sgd_least_squares_fit  s    r`  c           	      C   s   d\}}d}t jd}t ||||d}d|  }| dddd	d
d}||| |||}|dkspJ d|  ||d  }| dddd	d
d}||| |||}|dksJ d S )NrZ  r   r   rF   rI   rM  rV   r   r   Fr   epsilonr^   r   r   rS   r\  r_  r   r   r"   test_sgd_epsilon_insensitive  s4    rc  c           	      C   s   d\}}d}t jd}t ||||d}d|  }| ddddd	d
}||| |||}|dkspJ d|  ||d  }| ddddd	d
}||| |||}|dksJ d S )NrZ  r   r   rF   rI   rN  r   r   Fra  rS   r\  r_  r   r   r"   test_sgd_huber_fit  s    rd  c              	   C   s   d\}}t jd}|||}||}t ||}dD ]h}dD ]^}tj||dd}	|	|| | dd||dd	}
|
|| d
||f }t|	j	|
j	d|d qBq:d S )N)r   r   r   )rV   rn   )rI   rO   rU   F)r^   l1_ratior   
elasticnet2   )r   r   r^   re  r   zNcd and sgd did not converge to comparable results for alpha=%f and l1_ratio=%frG   )r   r   )
rW   r   r   r^  r[   r   Z
ElasticNetr)   r   rs   )r\   r   r   r   r,   Zground_truth_coefr-   r^   re  cdZsgdr   r   r   r"   test_elasticnet_convergence6  s0    
ri  c                 C   s   t jd d }| dd}|t d | td |  |jjt jd fksLJ |jjdks\J |ddggjdksvJ t|jj}|t |d  t|d   t|jj}|sJ |d S )Nr   rH   rV   rr   rF   r  )	r,   rY   r2   rx   rs   ru   r   rA  r   )r\   rB  rz   rC  rD  r   r   r"   test_partial_fitX  s    
rj  c                 C   s   | ddd|dd}| tt |t}|j}| dd|dd}tdD ]}|tt qF|t}|j|kspJ t||dd d S )NrV   rG   F)r^   r   rk   rm   rl   rI  r   )	r)   r,   rx   r   r   rv   r   r2   r   )r\   ry   rz   r:  rJ  rf   rK  r   r   r"   test_partial_fit_equal_fiti  s    

rk  c                 C   s0   | dd}|j dd |jd d dks,J d S )NrN   )rb  r   rN  rF   )rw   r  r   r   r   r"   test_loss_function_epsilonz  s    
rl  c                 C   s  |d u rt |jd }n|}t |jd }|}d| }	d}
d}| tkrNd}t|D ]\}}t ||}||	7 }|dkrd}nd}|tdd|| d  9 }||| |  7 }|	|||   | 7 }	||9 }||7 }||d  }|
|9 }
|
|	7 }
|
|d  }
qV|d|
 fS )NrF   rT   rU   rV   rE   r   rG   )rW   rX   rY   rC   rZ   r[   max)r\   r,   r]   r>  rp   r   coefr   offsetrc   rd   re   rf   rg   rh   ri   r   r   r"   asgd_oneclass  s4    rp  c                 C   s   | ddd|d}| | | ddd|d}|j ||j |j d | dddd|d}| | |j|jksrJ t|j|j |jdd	 | | |j|jksJ t|j|j d S )
NrI   rV   F)r>  rk   rl   rm   r   rp   r   T)r>  rk   rl   rq   rm   r>  )r)   rs   rt   r   rv   r   rw   )r\   r,   ry   rz   r{   r|   r   r   r"   _test_warm_start_oneclass  s    


rs  c                 C   s   t | t| d S r%   )rs  r,   r   r   r   r"   test_warm_start_oneclass  s    rt  c                 C   sN   | dd}t |}|jdd |t | dd}|t t|j|j d S )NrI   rr  r   )r   rw   r)   r,   r   rs   r   r   r   r"   test_clone_oneclass  s    



ru  c                 C   s   t jd d }| dd}|t d |  |jjt jd fksBJ |jjdksRJ |ddggjdkslJ |j}|t |d   |j|u sJ tt& |t d d df  W d    n1 s0    Y  d S )Nr   rH   r   rr  rF   r  )	r,   rY   r2   rs   r   r   r   r   r   )r\   rB  rz   Zprevious_coefsr   r   r"   test_partial_fit_oneclass  s    
rv  c           	      C   s   | ddd|dd}| t |t}|j}|j}|j}| ddd|dd}tdD ]}|t qR|t}|j|kszJ t	|| t	|j| t	|j| d S )N皙?rG   rV   F)r>  r   rk   rm   rl   rF   )r>  rk   r   rm   rl   )
r)   r,   r3   r   rv   rs   r   r   r2   r   )	r\   ry   rz   Zy_scoresrJ  rn  ro  r   Z	y_scores2r   r   r"   #test_partial_fit_equal_fit_oneclass  s    



rx  c                 C   s   d}d}| dd||ddd}| dd||ddd}| t | t t| t|||j |jd	\}}t|j |  t|j| d S )
Nrn   rw  r   r~   rG   F)r   rm   rk   r>  r   rl   rF   rq  )r)   r,   rp  rs   r   r   r   )r\   rk   r>  r   r{   r   average_offsetr   r   r"   *test_late_onset_averaging_reached_oneclass   s(    	


rz  c           
   	   C   sz   d}d}d}d}t jd}|j||fd}| d||dd	dd
d}|| t| |||\}}	t|j| t|j|	 d S )Nrn   rw  r   r   r   r   r~   TrF   Frm   rk   r>  r   r   r   rl   )	rW   r   r   r   r)   rp  r   rs   r   
r\   r]   r>  r   r   r   r,   rz   r   ry  r   r   r"   -test_sgd_averaged_computed_correctly_oneclass!  s&    

r}  c           
   	   C   s   d}d}d}d}t jd}|j||fd}| d||dd	dd
d}||d t|d  d d   ||t|d d  d d   t| |||\}}	t|j| t|j	|	 d S )Nrn   rw  r   r   r   r   r~   TrF   Fr{  rG   )
rW   r   r   r   r2   rW  rp  r   rs   r   r|  r   r   r"   &test_sgd_averaged_partial_fit_oneclass<  s(    
""r~  c              	   C   s   d}d}| d||ddddd}t jd }|t d t|d	   |t t|d	 d   t| t ||\}}t|j| t|j| d S )
Nrn   rV   r~   TrF   Fr{  r   rG   )r   rY   r2   rW  rp  r   rs   r   )r\   r]   r>  rz   r   r   ry  r   r   r"   test_average_sparse_oneclassX  s"    

r  c                  C   s   t ddgddgddgg} t ddgddgg}tdddddd}||  t|jt d	d
g |jd dksvJ ||}t|t ddg |||j }t||| |	|}t
|t ddg d S )NrD   rE   rF   rI   rG   r~   F)r>  rk   rm   rl   r   g      g      ?r   rJ   g      g      ?)rW   r   r?   r)   r   rs   r   Zscore_samplesr3   r   r   )X_trainX_testrz   Zscoresdecr   r   r   r"   test_sgd_oneclassq  s    



r  c                  C   s.  d} d}d}t j|}d|dd }t j|d |d f }d|dd }t j|d |d f }t|d| d	}|| ||}||	d
d}	d}
t
||d}t| dd|
|d d}t||}|| ||}||	d
d}t ||kdksJ t t |	|fd }|dks*J d S )Nrw  r   r   r    rG   r   Zrbf)gammaZkernelr>  rF   rE      )r  r   T)r>  rl   r   r   r   r   rS   r  rN   )rW   r   r   r^  Zr_r   r)   r   r3   r   r	   r?   r   r  corrcoefr8  )r>  r  r   r   r,   r  r  rz   Zy_pred_ocsvmZ	dec_ocsvmr   Z	transformZclf_sgdZpipe_sgdZy_pred_sgdocsvmZdec_sgdocsvmr  r   r   r"   test_ocsvm_vs_sgdocsvm  s:    




r  c                  C   s   t jddddd\} }tddd dd	d
d| |}tdddd
d d| |}t|j|j tddd ddd
d| |}tdddd
d d| |}t|j|j d S )Nr   r   r   i  )r   r   Zn_informativer   rn   rf  r7  gA?r   )r^   r   r   r   re  r   r   )r^   r   r   r   r   g|=r   )r   make_classificationr@   r)   r   rs   )r,   r-   Zest_enZest_l1Zest_l2r   r   r"   test_l1_ratio  sF    


r  c            	   	   C   sJ  t jdd& t jd} d}d}| j||fd}|d d d df  d9  < t | sbJ t |}t | sJ | j|d}t 	||d	k
t j}tt |dd
g tdddd}||| t |j sJ d}tjt|d ||| W d    n1 s0    Y  W d    n1 s<0    Y  d S )Nraiser  r   r   r   r   rG   gu <7~rT   rF   r   r   r  )r^   r   r   zwFloating-point under-/overflow occurred at epoch #.* Scaling input data with StandardScaler or MinMaxScaler might help.r   )rW   r  r   r   r   isfiniter  r   rR  r[   astypeZint32r   r   r@   r)   rs   r   r   r   )	r   r   r   r,   ZX_scaledZground_truthr-   modelZ	msg_regxpr   r   r"   test_underflow_or_overlow  s&    r  c                  C   sn   t ddddddddd d		} tjd
d  | tjtj W d    n1 sL0    Y  t| j	 sjJ d S )Nr   r   Trf  r  rV   rn   r   )	r   r   rl   r   re  r^   rk   r   r   r  r  )
r@   rW   r  r)   r   r   r   r  rs   r  )r  r   r   r"   'test_numerical_stability_large_gradient  s    .r  r   )r   r   rf  c              	   C   sj   t ddd| dd dd}tjdd  |tjtj W d    n1 sH0    Y  t|jt	|j d S )	Ng     j@r~   r   Fr7  )r^   rm   rk   r   rl   r   r   r  r  )
r@   rW   r  r)   r   r   r   r   rs   
zeros_like)r   r  r   r   r"   test_large_regularization  s    	.r  c                  C   s  t  tj} tjdk}d}td d|d}|| | ||jksDJ d}tdd|d}|| | ||jkspJ |jdks~J tdd|d}|| | |j|jksJ |jdksJ tdd	dd
}d}tj	t
|d || | W d    n1 s0    Y  |jdksJ d S )NrF   r   r   )r   r   r   r  r   r   rH   rn   )r   r   r   zhMaximum number of iteration reached before convergence. Consider increasing max_iter to improve the fit.r   )r   rR  r   r   r   r@   r)   r   r   warnsr   )r,   r-   r   Zmodel_0Zmodel_1Zmodel_2Zmodel_3Zwarning_messager   r   r"   test_tol_parameter  s*    
*r  c                 C   s:   |D ]0\}}}}t | ||| t | ||| qd S r%   )r   Zpy_lossZpy_dloss)Zloss_functioncasesrh   r-   Zexpected_lossZexpected_dlossr   r   r"   _test_loss_commonB  s    r  c                  C   s<   t d} g d}t| | t d} g d}t| | d S )NrU   ))g?rU   rT   rT          r   rT   rT   )rU   rU   rT   r   )r   r   rT   rU   )rI   rU   rI   r   )r   r         @rU   )rJ   r   rI   rU   )rT   rU   rF   r   rT   )rU   rU   rT   rT   )r  r   rT   rT   )rT   rU   rT   r   )rT   r   rT   rU   )rI   r   rI   rU   )r   r   r   rU   )rJ   rU   rI   r   )r   rU   rU   r   )sgd_fastZHinger  r   r  r   r   r"   test_loss_hingeJ  s    


r  c                  C   s    t d} g d}t| | d S )NrU   )r  r  )rU   r         @r  r   rU   r        )rI   rU   g      ?r   rI   r   g      @r  )r  ZSquaredHinger  r  r   r   r"   test_gradient_squared_hingek  s    
	r  c                  C   s   t  } g d}t| | d S )N)r  )r   r   rT   rT   )r   rU   rT   rT   )rT   rU   rU   r  r  r  )r  rU      r  )g      rU      r  )r  ZModifiedHuberr  r  r   r   r"   test_loss_modified_huberz  s    r  c                  C   s    t d} g d}t| | d S )Nr   )rT   rT   rT   rT   r   rT   rT   rT   gffffff r  rT   rT   gffffff@r  rT   rT   )皙@r   r   rU   )r   r   333333@rU   )r   r  r   r   )r  rU   r  r   )r  ZEpsilonInsensitiver  r  r   r   r"   test_loss_epsilon_insensitive  s    
r  c                  C   s    t d} g d}t| | d S )Nr   )r  r  r  r  )r  r   rV   r  )r   r   R @g333333@)r   r  rV   gɿ)r  rU   r  g333333)r  ZSquaredEpsilonInsensitiver  r  r   r   r"   %test_loss_squared_epsilon_insensitive  s    
r  c               	   C   sf   t dddddddd} | tjtj | j| jks6J | j| jd k sJJ | tjtjd	ksbJ d S )
Nrn   r   Tr   r   rG   )r^   r   r   r   r   r   r  r   rO   )r@   r)   r   r   r   r   r   r]  )rz   r   r   r"   0test_multi_thread_multi_class_and_early_stopping  s    	r  c                  C   s\   t dddg dd} tdddd	d
}t|| ddd	d}|tjtj |jdksXJ d S )Nr   r   )r   r   rg  )r^   r   rV   r   Tr   )r   r   r   r   r   rG   )Zn_iterr  r   rO   )	rW   Zlogspacer@   r   r)   r   r   r   Zbest_score_)Z
param_gridrz   searchr   r   r"   -test_multi_core_gridsearch_and_early_stopping  s    r  backend)Zlokymultiprocessing	threadingc                 C   s   t jd}tjdddd|d}|dd}tdd	dd
}||| tdddd
}tj| d ||| W d    n1 s0    Y  t	|j
|j
 d S )Nr   r  r  g{Gz?Zcsr)Zdensityr	  r   r   r   rF   )r   r  r   r   )r  )rW   r   r   r&   choicer@   r)   joblibZparallel_backendr   rs   )r  r   r,   r-   Zclf_sequentialZclf_parallelr   r   r"   'test_SGDClassifier_fit_for_all_backends  s    *r  	Estimatorc                 C   sT  | t jkrtj|d\}}ntj|d\}}| |dd}tt, |||j	}|j
dks`J W d    n1 st0    Y  | |dd}tt, |||j	}|j
dksJ W d    n1 s0    Y  t|| | |d dd}tt. |||j	}|j
dksJ W d    n1 s*0    Y  t||  dksPJ d S )N)r   rF   )r   r   rU   )r   r;   r   Zmake_regressionr  r   r  r   r)   rs   r   r   rW   absrm  )r  Zglobal_random_seedr,   r-   estZcoef_same_seed_aZcoef_same_seed_bZcoef_other_seedr   r   r"   test_sgd_random_state  s"    
,,
0r  c           	      C   s   t jt j }}|jd }d}tjddd|d}ttjd}| 	td| |
|| |jd d	d
 \}}|jd t|| ksJ |jd t|| ksJ dS )ziTest that data passed to validation callback correctly subsets.

    Non-regression test for #23255.
    r   r  Trn   r   )r   r   r   r   )Zside_effect_ValidationScoreCallbackrF   rH   N)r   r   r   rY   r   r@   r   r   r  setattrr)   Z	call_argsrW  )	Zmonkeypatchr,   rx   r   r   rz   ZmockZX_valZy_valr   r   r"   &test_validation_mask_correctly_subsets  s    
r  c                  C   sr   t jt j } }t|}d}tjd|dd}d}tjt	|d  |j
| ||d W d    n1 sd0    Y  d S )Nr   Tr   )r   r   r   z\The sample weights for validation set are all zero, consider using a different random state.r   r4  )r   r   r   rW   r  r   r@   r   r   r   r)   )r,   rx   r5  r   rz   error_messager   r   r"   (test_sgd_error_on_zero_validation_weight0  s    
r  c                 C   s   | dd tt dS )z!non-regression test for gh #25249rF   )verboseN)r)   r,   rx   )r  r   r   r"   test_sgd_verboseC  s    r  SGDEstimator	data_typec                 C   s>   t |}tjt|d}|  }||| |jj|ks:J d S )Ndtype)r,   r  rW   r   rx   r)   rs   r  )r  r  Z_XZ_YZ	sgd_modelr   r   r"   test_sgd_dtype_matchI  s
    
r  c                 C   sz   t jtjd}tjttjd}t jtjd}tjttjd}| dd}||| | dd}||| t|j	|j	 d S )Nr  r   )r   )
r,   r  rW   float64r   rx   float32r)   r   rs   )r  ZX_64ZY_64ZX_32ZY_32Zsgd_64Zsgd_32r   r   r"   test_sgd_numerical_consistency]  s    

r  c                 C   sH   | dd}t jtdd |tt W d    n1 s:0    Y  d S )Nr   r   z	average=0r   )r   r  FutureWarningr)   r,   rx   )r  r  r   r   r"   *test_passive_aggressive_deprecated_averagey  s    
r  c                  C   s   t  } t| jdksJ dS )z}Check that SGDOneClassSVM has the correct estimator type.

    Non-regression test for if the mixin was not on the left.
    Zoutlier_detectorN)r?   r   Zestimator_type)Z	sgd_ocsvmr   r   r"   %test_sgd_one_class_svm_estimator_type  s    r  )NrT   )NrT   )r&  Zunittest.mockr   r  numpyrW   r   Zscipy.sparsesparser&   Zsklearnr   r   r   Zsklearn.baser   r   Zsklearn.exceptionsr   Zsklearn.kernel_approximationr	   Zsklearn.linear_modelr
   r  r   Zsklearn.model_selectionr   r   r   Zsklearn.pipeliner   Zsklearn.preprocessingr   r   r   r   Zsklearn.svmr   Zsklearn.utilsr   Zsklearn.utils._testingr   r   r   r   r#   r@   r$   r;   r:   r?   r>   rA   rB   rC   r   r,   rx   r   r   r   r   r   r   r   r   r"  r$  Z	load_irisr   r   r   Ztrue_result5rj   r}   markZparametrizer   r   r   r   r   r   r   r   r   r   r   r   r   rX   r   r   r   r   r   r   r   r   r   r  r  r  r  r  r+  r0  r2  r3  r6  r;  r=  r?  r@  rE  rF  rG  rH  rL  rO  rP  rS  rT  rV  rX  rY  r`  rc  rd  ri  rj  rk  rl  rp  rs  rt  ru  rv  rx  rz  r}  r~  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r   r   r   r"   <module>   s|  
..	"




+


)



	

	



	

!
	










!
G
$




.






	
%



 
 


#

!

&



 


)#&
$!
#
#
