a
    hd                     @   s  d dl Z d dlZd dlZd dlmZmZmZ d dlm	Z	m
Z
mZmZ d dlmZmZmZmZ d dlmZmZ d dlmZ d dlmZ d dlmZ d d	lmZ d d
lmZ dd Zdd Z dd Z!dd Z"dd Z#dd Z$dd Z%ej&'de
eefdd Z(ej&'deee	fdd Z)ej&'deee	e
fdd Z*d d! Z+ej&'de	eee
fej&'d"e+ d#d$ Z,ej&'d%e
eee	fd&d' Z-d(d) Z.ej&'d*d+d,gd-d. Z/d/d0 Z0d1d2 Z1d3d4 Z2d5d6 Z3ej&'d7eee	gd8d9 Z4ej&'d:d;d<gej&'d7eee	gd=d> Z5ej&'d?e	e
eegd@dA Z6ej&'d?e	e
eegdBdC Z7dDdE Z8dFdG Z9ej&'d?ee	e
egdHdI Z:ej&'d?ee	e
egdJdK Z;ej&'d?ee	egdLdM Z<dS )N    N)assert_allcloseassert_array_almost_equalassert_array_equal)CCAPLSSVDPLSCanonicalPLSRegression)_center_scale_xy(_get_first_singular_vectors_power_method_get_first_singular_vectors_svd_svd_flip_1d)load_linnerudmake_regression)VotingRegressor)ConvergenceWarning)LinearRegression)check_random_state)svd_flipc                 C   s(   t | j| }t|t t | d S )N)npdotTr   Zdiag)MK r   f/var/www/html/assistant/venv/lib/python3.9/site-packages/sklearn/cross_decomposition/tests/test_pls.pyassert_matrix_orthogonal   s    r   c                  C   s(  t  } | j}| j}t|jd d}||| t|j t|j t|j	 t|j
 |j	}|j}|j
}|j}t| | dd\}}	}
}}}t|t||j t|	t||j ||}t||j	 |||\}}t||j	 t||j
 ||}t|| |||\}}t|| d S )N   n_componentsTscale)r   datatargetr   shapefitr   
x_weights_
y_weights_	_x_scores	_y_scoresx_loadings_y_loadings_r	   copyr   r   r   r   	transforminverse_transform)dXYplsr   PUQZXcZYcZx_meany_meanZx_stdZy_stdZXtZYtZX_back_ZY_backr   r   r   test_pls_canonical_basics   s6    






r7   c                  C   sf  t  } | j}| j}t|jd d}|||\}}t||j t	g dg dg dg}t	g dg dg dg}t	g d	g d
g dg}t	g d	g d
g dg}	t
t|jt| t
t|jt| t
t|jt|	 t
t|jt| t|j| }
t|j| }t|j| }t|j|	 }t
|
| t
|| d S )Nr   r   ),6gbx+rNF?);0g&Կf_)@mпg<-bL?ȣȿ)gHgtϿr9   )gE` gt[Wm¿r;   )ggLM3?r=   )g?g+E!?g4Ӝ@?)gsYO)?g`{?gA'?)g;Ծgпgſ)r   r!   r"   r   r#   fit_transformr   Z	x_scores_r   arrayr   absr)   r%   r*   r&   sign)r.   r/   r0   r1   X_transr6   expected_x_weightsexpected_x_loadingsexpected_y_weightsexpected_y_loadingsx_loadings_sign_flipx_weights_sign_flipy_weights_sign_flipy_loadings_sign_flipr   r   r    test_sanity_check_pls_regressionC   sP    
rK   c            
      C   sR  t  } | j}| j}d|d d df< t|jd d}||| tg dg dg dg}tg dg dg d	g}tg d
g dg dg}tt	|t	|j
 tt	|t	|j tt	|jt	| tt	|jt	| t||j }t||j
 }t|dd  |jdd   }	t|| t|dd  |	 d S )Nr   r   r   )g͝Og(}?:F?)gqgqdvgѿ|N<)g1, ˿g7Ƚ?\ƿ)gCgBg<&.̿rL   )gBg5_/ErM   )gQggr9?rN   )        rO   rO   )g ?gXZ?ghC%d?)gVSg{sɂϿg$(E,ǿ)r   r!   r"   r   r#   r$   r   r?   r   r@   r%   r)   r*   r&   rA   r   )
r.   r/   r0   r1   rC   rD   rF   rG   rH   rJ   r   r   r   2test_sanity_check_pls_regression_constant_column_Y   sB     
rP   c                  C   s~  t  } | j}| j}t|jd d}||| tg dg dg dg}tg dg dg dg}tg d	g d
g dg}tg dg dg dg}tt	|j
t	| tt	|jt	| tt	|jt	| tt	|jt	| t|j
| }t|j| }	t|j| }
t|j| }t||	 t|
| t|j t|j t|j t|j d S )Nr   r   )r8   g{cd?gr	)r:   g?g>c?)r<   gP,"Pgͺ@)r8   gCj?g#i)r:   g2Щ?gr?)r<   go _g<:ο)c?gD}Ȇ??g5?)UҮ?gOgөeJo?).a#οgbM4gYV?)rQ   gͱ?g[K?)rR   g=mBgo1S?)rS   gP.%lgq!?)r   r!   r"   r   r#   r$   r   r?   r   r@   x_rotations_r%   Zy_rotations_r&   rA   r   r'   r(   )r.   r/   r0   r1   rC   Zexpected_x_rotationsrE   Zexpected_y_rotationsZx_rotations_sign_fliprH   Zy_rotations_sign_fliprI   r   r   r   test_sanity_check_pls_canonical   sV    




rU   c                  C   s  d} d}d}t d}|j| d}|j| d}t||||gj}||jd|  d| df }||jd|  d| df }tj||j||  d| |fdd}tj||j||  d| |fdd}td	d
}	|	|| tg dg dg dg dg dg dg dg dg dg dg dg dg dg dg}
tg dg dg dg dg dg dg dg d g d!g d"g d#g d$g d%g d&g}tg d'g d(g d)g d*g d+g d,g d-g d.g d/g	}tg d0g d1g d2g d3g d4g d5g d6g d7g d8g	}t	t
|	jt
| t	t
|	jt
|
 t	t
|	jt
| t	t
|	jt
| t|	j| }t|	j|
 }t|	j| }t|	j| }t	|| t	|| t|	j t|	j t|	j t|	j d S )9N  
         size   r   Zaxis   r   )gqAS?ģƒ?g	K?)g܈m?gr[q?g֎ ÿ)gոqjP?gͱgS?)g$$?g('G_g.k^)g~gsg
?)gjh?gfrg>uRz?)g$¯&?g		lgpO/?)g}W[g~glìǿ)gWX>egj8H@Zg˔Br?)g).egw4DgoP^?)gvzgqg1GZg}r5.?)gzϳJg1?g?)gMI?g,)Ɣg [u)gei?g⊬[gQ>Oƿ)gmƫ?gE^?g0?)gB+
?g,?gi)gi*?g_(gb#k4?)g*Vh{O?g׍o}sg+Kl)gݩFgJ,c")g#'v?)g,a?g9qbgSLRW?)g]@[?gO~gkE?)ggpBgR ?g;ȿ)g-ݿp?g;O<gxgGtK?)g$U\ngE	g?)g@~_V?g,8(g.^?)g)^D_jg2i?gs6Cm?)g.f2?ggRu)g9Me?gX㰿g <ۿ)gONz '?gsVF?gul-a7?)g0?g]4?goБο)g:8%?g!Hgl?)g5z?gy0/gofy&,)g
C?g͢A}?g_%_?)g޵?gHֆ/gL:ܿ)gcIȂg$E!?gfD¹?)gϫg35ϧ?g`"ĕs?)g$t?gWe?gY)ݟ?)g#bJ$?gtdn?gx/RѸ?)g
Jn?g׷?gʏSϽ)g= ?gNfg7jN?)g]w?g(.g^i׌%)g}P
?gO3IogD'?)ghE-(g?g0֢p?gn)gzgIT4g7Ʃ1|?)gO)֠gkp2F$?gA-c?)g&?gkh?g!L?)r   normalr   r?   r   reshapeZconcatenater   r$   r   r@   r)   r%   r*   r&   rA   r   r'   r(   )nZp_noiseZq_noiserngl1l2Zlatentsr/   r0   r1   rC   rD   rE   rF   rG   rH   rI   rJ   r   r   r   &test_sanity_check_pls_canonical_random   s    &&





re   c                  C   s^   t  } | j}| j}t|jd dd}tt ||| W d    n1 sP0    Y  d S )Nr      r   Zmax_iter)	r   r!   r"   r   r#   pytestwarnsr   r$   )r.   r/   r0   Z
pls_nipalsr   r   r   test_convergence_failZ  s    rj   Estc                    sR   t  }|j}|j}d |  d}||| t fdd|j|jfD sNJ d S )Nrf   r   c                 3   s   | ]}|j d   kV  qdS )r   N)r#   ).0attrr   r   r   	<genexpr>m  s   z(test_attibutes_shapes.<locals>.<genexpr>)r   r!   r"   r$   allr%   r&   )rk   r.   r/   r0   r1   r   r   r   test_attibutes_shapesd  s    

rp   c                 C   sr   t  }|j}|j}| dd}|||d d df j}|||d d d df j}|j|jksdJ t|| d S )Nr   r   r   )r   r!   r"   r$   coef_r#   r   )rk   r.   r/   r0   estZone_d_coeffZtwo_d_coeffr   r   r   test_univariate_equivalencer  s    
rs   c                 C   st  t  }|j}|j}| }| dd||}t|| tt, | dd|| t	|| W d    n1 sr0    Y  | t
u rd S | }tt, |j||ddf t	|| W d    n1 s0    Y  | }tt* |j|ddf t	|| W d    n1 s0    Y  t	|j||dd|j| | dd t	|j|dd|j| dd d S )NTr+   F)r   r!   r"   r+   r$   r   rh   raisesAssertionErrorr   r   r,   predict)rk   r.   r/   r0   ZX_origr1   r   r   r   	test_copy  s2    
((*$rx   c            	      c   s  t jd} d}d}d}| ||}| ||}t ||d| ||  d }|d9 }||fV  tdd\}}d	|d
d
df< ||fV  t g dg dg dg dg}t ddgddgddgddgg}||fV  ddg}|D ]2}t j|} | dd}| dd}||fV  qd
S )z-Generate dataset for test_scale_and_stabilityr   i  rX   rW   rf   r   T
return_X_y      ?N)rO   rO   r{   )r{   rO   rO   )       @r}   r}   )g      @g      @g      @g?gɿg?g?g@g@g'@g(@i  i  r\   r^   )r   randomRandomStaterandnr   r   r?   )	rb   	n_samples	n_targets
n_featuresr4   r0   r/   Zseedsseedr   r   r   +_generate_test_scale_and_stability_datasets  s*     

""
r   zX, Yc           
      C   s\   t ||^}}}| dd||\}}| dd||\}}	t||dd t|	|dd dS )zscale=True is equivalent to scale=False on centered/scaled data
    This allows to check numerical stability over platforms as wellTr   Fg-C6?ZatolN)r	   r>   r   )
rk   r/   r0   ZX_sZY_sr6   ZX_scoreZY_scoreZ	X_s_scoreZ	Y_s_scorer   r   r   test_scale_and_stability  s
    r   	Estimatorc                 C   sp   t jd}|dd}|dd}| dd}d}tjt|d ||| W d   n1 sb0    Y  dS )	zICheck the validation of `n_components` upper bounds for `PLS` regressors.r   rW   rX   r^   r   zH`n_components` upper bound is .*. Got 10 instead. Reduce `n_components`.matchN)r   r~   r   r   rh   ru   
ValueErrorr$   )r   rb   r/   r0   rr   err_msgr   r   r   test_n_components_upper_bounds  s    
r   c                  C   sp   t jd} | dd}| dd}tdd}d}tjt|d ||| W d	   n1 sb0    Y  d	S )
zFCheck the validation of `n_components` upper bounds for PLSRegression.r      @   r^      r   zH`n_components` upper bound is 20. Got 30 instead. Reduce `n_components`.r   N)	r   r~   r   r   r   rh   ru   r   r$   )rb   r/   r0   rr   r   r   r   r   %test_n_components_upper_PLSRegression  s    
r   zn_samples, n_features)d   rW   )r      c                 C   s~   t | |d|d\}}t||dd\}}}t||\}}	t|| t||	 d}
t||| |
 d t||	|	 |
 d d S )NrX   r   random_stateT)Znorm_y_weightsMbP?r   )r   r
   r   r   r   max)r   r   global_random_seedr/   r0   u1Zv1r6   u2Zv2rtolr   r   r   test_singular_value_helpers  s    


r   c                 C   s   t ddd| d\}}tdd|||}tdd|||}tdd|||}d}t||| | d t||| | d d S )	Nr   rW   rX   r   r   r   r   r   )r   r   r$   r,   r   r   r   r   )r   r/   r0   Zsvdreg	canonicalr   r   r   r   test_one_component_equivalence   s    r   c                  C   s   t g d} t g d}t| dd|dd\}}t| | t| |  t| g d t||  t|g d d S )N)r   rf   )r   rf   r^   r|   r   )r|   r\   )r|   r   )r   r?   r   r`   r   r   Zravel)uvZ
u_expectedZ
v_expectedr   r   r   test_svd_flip_1d  s    
r   c                 C   s~   t ddd| d\}}tddd}t ( tdt ||| W d   n1 sV0    Y  tt	|j
d	k szJ dS )
z8Test that CCA converges. Non-regression test for #19549.r   r   )r   r   r   r   rW   rV   rg   errorNr   )r   r   warningscatch_warningssimplefilterr   r$   r   ro   r@   r)   )r   r/   yZccar   r   r   test_loadings_converges  s    

*r   c                  C   sv   t jd} | dd}t d}t }d}tjt|d |	|| W d   n1 s\0    Y  t
|jd dS )zAChecks warning when y is constant. Non-regression test for #19831*   r   r^   z#y residual is constant at iterationr   Nr   )r   r~   r   ZrandZzerosr   rh   ri   UserWarningr$   r   rT   )rb   xr   r1   msgr   r   r   test_pls_constant_y.  s    
*r   PLSEstimatorc                 C   sR   t  }|j}|j}| dd||}|jd |jd  }}|jj||fksNJ dS )zCheck the shape of `coef_` attribute.

    Non-regression test for:
    https://github.com/scikit-learn/scikit-learn/issues/12410
    Trt   r   N)r   r!   r"   r$   r#   rq   )r   r.   r/   r0   r1   r   r   r   r   r   test_pls_coef_shape=  s    r   r    TFc           	      C   sx   t  }|j}|j}| d|d||}|j|dd}|jdd}||jdd }t|j| t|||jj	 |j  dS )z/Check the behaviour of the prediction function.T)r+   r    rt   r   r]   N)
r   r!   r"   r$   rw   meanr   Z
intercept_rq   r   )	r   r    r.   r/   r0   r1   ZY_predr5   rB   r   r   r   test_pls_predictionN  s    r   Klassc                    sd   t dd\}}|  ||}| }| j  tj fddt|jj	d D t
d}t|| dS )z9Check `get_feature_names_out` cross_decomposition module.Try   c                    s   g | ]}  | qS r   r   )rl   iZclass_name_lowerr   r   
<listcomp>j      z.test_pls_feature_names_out.<locals>.<listcomp>r   )ZdtypeN)r   r$   get_feature_names_out__name__lowerr   r?   ranger%   r#   objectr   )r   r/   r0   rr   Z	names_outZexpected_names_outr   r   r   test_pls_feature_names_out`  s    
r   c                 C   st   t d}tddd\}}|  jdd||}|||\}}t|tjsPJ t||j	s`J t
|j|  dS )z1Check `set_output` in cross_decomposition module.ZpandasT)rz   Zas_frame)r,   N)rh   Zimportorskipr   Z
set_outputr$   r,   
isinstancer   ZndarrayZ	DataFramer   columnsr   )r   pdr/   r0   rr   rB   Zy_transr   r   r   test_pls_set_outputp  s    
r   c               	   C   s   t ddgddgddgddgddgd	d
gg} t g d}| }t | |}|| }|j|jkslJ t | |}td|fd|fg}|| || }|j|jksJ t	|| dS )zrCheck that when fitting with 1d `y`, prediction should also be 1d.

    Non-regression test for Issue #26549.
    r   rf   r\   r^   	      rX         $   )rf   r      r   r   r   lrplsrN)
r   r?   r+   r   r$   rw   r#   r   r   r   )r/   r   expectedr   Zy_predr   Zvrr   r   r   test_pls_regression_fit_1d_y}  s    .
r   c                  C   sd   t jd} | jdd}| jddd}||j }tddd	||}t|j	| t|
|| d
S )zCheck that when using `scale=True`, the coefficients are using the std. dev. from
    both `X` and `Y`.

    Non-regression test for:
    https://github.com/scikit-learn/scikit-learn/issues/27964
    r   )r^   rX   rZ   rW   )r   rX   )r    r[   rX   T)r   r    N)r   r~   r   uniformr_   r   r   r$   r   rq   rw   )rb   Zcoefr/   r0   r1   r   r   r    test_pls_regression_scaling_coef  s    
r   c              	   C   s  t  }|j}|j}|j}d}tjt|d  |  j||d W d    n1 sP0    Y  d}tjt|dN tjt|d  |  ||| W d    n1 s0    Y  W d    n1 s0    Y  d}tjt|d |  | W d    n1 s0    Y  d S )NE`Y` is deprecated in 1.5 and will be removed in 1.7. Use `y` instead.r   r/   r0   ?Cannot use both `y` and `Y`. Use only `y` as `Y` is deprecated.zy is required.)	r   r!   r"   rh   ri   FutureWarningr$   ru   r   )r   r.   r/   r0   r   r   err_msg1Zerr_msg2r   r   r   -test_pls_fit_warning_on_deprecated_Y_argument  s     .Lr   c              	   C   s   t  }|j}|j}|j}|  ||}d}tjt|d |j||d W d    n1 s\0    Y  d}tjt|dL tjt	|d |||| W d    n1 s0    Y  W d    n1 s0    Y  d S )Nr   r   r   r   )
r   r!   r"   r$   rh   ri   r   r,   ru   r   )r   r.   r/   r0   r   r   r   r   r   r   r   3test_pls_transform_warning_on_deprecated_Y_argument  s    ,r   c           	   	   C   s   t  }|j}|j}|  ||}|||\}}d}tjt|d |j||d W d    n1 sf0    Y  d}tjt|dN tj	t
|d  |j|||d W d    n1 s0    Y  W d    n1 s0    Y  d S )Nr   r   r   r   )r/   r   r0   )r   r!   r"   r$   r,   rh   ri   r   r-   ru   r   )	r   r.   r/   r   r   ZX_transformedZy_transformedr   r   r   r   r   ;test_pls_inverse_transform_warning_on_deprecated_Y_argument  s    ,r   )=r   numpyr   rh   Znumpy.testingr   r   r   Zsklearn.cross_decompositionr   r   r   r   Z sklearn.cross_decomposition._plsr	   r
   r   r   Zsklearn.datasetsr   r   Zsklearn.ensembler   Zsklearn.exceptionsr   Zsklearn.linear_modelr   Zsklearn.utilsr   Zsklearn.utils.extmathr   r   r7   rK   rP   rU   re   rj   markZparametrizerp   rs   rx   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   <module>   sl   (?2>h



& 






