a
    h                     @   s  d dl Z d dlZd dlZd dlZd dlZd dlmZ d dl	m
Z
mZ d dlmZ d dlmZmZmZ d dlmZ d dlmZmZ d dlmZmZmZ d d	lmZ d d
lmZ d dlm Z m!Z! d dl"m#Z# d dl$m%Z%m&Z& e Z'g dZ(d\Z)Z*e+e)e*Z,dddZ-ej./de(ej./de0de'j1j2d dd Z3ej./dg dej./dg dej./de&e% ej./dddgej./dg d d!d" Z4ej./de&e% d#d$ Z5ej./dd%d&gej./de&e% d'd( Z6ej./de&e% d)d* Z7d+d, Z8ej./d-d.d/gej./d0e(d1d2 Z9ej./d3e:e;e<e(d&d4h ej./d5d6d7gej./d8d/d.gej./d9d/d.gd:d; Z=ej.j/d<ej>?d @d=d>ejd=d>d?d d@d  ej>?d @dAd=gg dBdCej./de(dDdE ZAej./ddd%gdFdG ZBej./de(dHdI ZCej./de(dJdK ZDej./de(dLdM ZEej./dg dNej./d9d/d.gdOdP ZFej./dQeGg dRg dSgeGg dRg dSgjHgej./dTg dUdVdW ZIej./dXd&e+e'j1j2fde+e'j1j2d fd%e+e'j1j2fgej./dQe'j1e'j1jHgdYdZ ZJej./dd4d&gd[d\ ZKej./ddd%gd]d^ ZLd_d` ZMdadb ZNdcdd ZOdedf ZPej./dge'j1dhdife'j1djdfej>?d Qdkdldmdifgdndo ZRej./de(dpdq ZSdrds ZTej./de(dtdu ZUej./ddd%gdvdw ZVej./dd&d%gdxdy ZWej./dzg d{d|d} ZXej./de(d~d ZYej./de(dd ZZdd Z[dd Z\dd Z]dd Z^dd Z_dd Z`dd Zadd Zbdd Zcdd Zddd Zeej./d-d.d/gdd Zfdd Zgej./de ej.j/de#eggedCej.j/dedid&dedid&d.dedd&d.dediddedidd.dedid%dd dgedCdd Zhej./de ej.j/deggedCej.j/dedd&dgedCdd Zidd ZjdS )    N)assert_array_equal)config_contextdatasets)clone)	load_irismake_classificationmake_low_rank_matrix)PCA)_assess_dimension_infer_dimension)_atol_for_type_convert_to_numpy)yield_namespace_device_dtype_combinationsdevice)_get_check_estimator_ids)_array_api_for_testsassert_allclose) check_array_api_input_and_values)CSC_CONTAINERSCSR_CONTAINERS)fullcovariance_eigharpack
randomizedauto)  i,  Hz>-q=c                 C   s   t | j|j||d t | j|j||d t | j|j||d t | j|j||d t | j|j||d | j|jkstJ | j|jksJ | j|jksJ d S )Nrtolatol)	r   components_explained_variance_singular_values_Zmean_noise_variance_n_components_Z
n_samples_Zn_features_in_)pca1pca2r    r!    r)   `/var/www/html/assistant/venv/lib/python3.9/site-packages/sklearn/decomposition/tests/test_pca.py_check_fitted_pca_close%   s    r+   
svd_solvern_components   c                 C   s   t j}t|| d}|||}|jd |ks4J ||}t|| ||}t|| | }|	 }tt
||t
|jd dd d S )Nr-   r,   r.   r   r!   )irisdatar	   fit	transformshapefit_transformr   get_covarianceget_precisionnpdoteye)r,   r-   XpcaZX_rZX_r2cov	precisionr)   r)   r*   test_pca3   s    



r@   density){Gz?皙?g333333?)r.      
   sparse_containerr   r   scale)r.   rE   d   c                 C   s   d}d}t j| }|tjjtt||d}	||	jd | }
|	|
}	t	||| d}|
|	 |	 }t	||| d}|
| t|||d |tjjtt||d}| }t|||||d t|||||d dS )z?Check that the results are the same for sparse and dense input.r   绽|=random_staterA   r.   r-   r,   rK   r0   N)r9   randomdefault_rngspsparseSPARSE_MSPARSE_Nr5   multiplyr	   r3   Ztoarrayr+   r   r4   )global_random_seedr,   rF   r-   rA   rG   r!   Ztransform_atolrK   r<   Zscale_vectorr=   ZXdZpcadX2ZX2dr)   r)   r*   test_pca_sparseI   sL    	


rV   c                 C   s   t j| }|tjjtt|dd}|tjjtt|dd}tdd| d}tdd| d}|| |	|}t
|| t||| t||| t|||| d S )NrB   rJ   rE   r   rL   )r9   rM   rN   rO   rP   rQ   rR   r	   r3   r6   r+   r   r4   )rT   rF   rK   r<   rU   Zpca_fitZpca_fit_transformtransformed_Xr)   r)   r*   test_pca_sparse_fit_transform   s6    	


rX   r   r   c                 C   sv   t j| }|tjjtt|d}td|d}d| d}tj	t
|d || W d    n1 sh0    Y  d S )NrK      r/   zWPCA only support sparse inputs with the "arpack" and "covariance_eigh" solvers, while "z" was passedmatch)r9   rM   RandomStaterO   rP   rQ   rR   r	   pytestraises	TypeErrorr3   )rT   r,   rF   rK   r<   r=   Zerror_msg_patternr)   r)   r*   test_sparse_pca_solver_error   s    ra   c                 C   s\   t j| }|tjjtt|d}tddd|}tddd|}t	|j
|j
dd dS )	zHCheck that "auto" and "arpack" solvers are equivalent for sparse inputs.rY   rE   r   r/   r   {Gzt?r    N)r9   rM   r]   rO   rP   rQ   rR   r	   r3   r   r$   )rT   rF   rK   r<   
pca_arpackpca_autor)   r)   r*   7test_sparse_pca_auto_arpack_singluar_values_consistency   s    rf   c                  C   sn   d} | d }t jjdd| |fd}t| d}t & tdt || W d    n1 s`0    Y  d S )NrE   rD   r.   sizer-   error)	r9   rM   uniformr	   warningscatch_warningssimplefilterRuntimeWarningr3   )r-   
n_featuresr<   r=   r)   r)   r*   test_no_empty_slice_warning   s    

rr   copyTFsolverc                 C   s  t jd}d}d}d}d}t |||t t t dd||||}|d d d df  d9  < |j||fks~J |jdd	 d
ksJ |	 }t
|d|| ddd}	|	|	 }
|
j||fksJ |	|}t|
|dd t|
jdddt | t|
jdd	t |dd |	 }t
|d|| d|	 }	|	|}|j||fks`J |jdd	 tjdddksJ d S )Nr   rH   P   rZ   2   g      $@      ?   axisgfffffE@T   )r-   whitenrs   r,   rK   iterated_powergMb@?rc   r.   Zddofrz   r   r0   F)r-   r|   rs   r,   gfffffR@rC   )rel)r9   rM   r]   r:   randnZdiagZlinspacer5   stdrs   r	   r6   r4   r   onesmeanzerosr3   r^   approx)rt   rs   rng	n_samplesrq   r-   rankr<   ZX_r=   Z
X_whitenedZX_whitened2ZX_unwhitenedr)   r)   r*   test_whitening   sH    
"	

r   other_svd_solverr   
data_shapetallZwiderank_deficientr|   c                 C   s  |dkrd\}}nd\}}d}|rbt j|}	t||d }
|	j|| |
fd|	j|
|fd }nt|| |d|d}t||}
|j|d	d
}|d | ||d   }}|t jkrtddd}d}ntddd}d}i }| dkrd}ddi}n | dkr
t 	||d }nd }t
|d|d}t
f || ||d|}||}t | sRJ |j|ksbJ ||}t | sJ |j|ksJ |jdk sJ t|j|jfi | t|j|jfi | |j}t | sJ |j}t | sJ |j|k}| dks J t|| || fi | t|d d |f |d d |f fi | ||}t | sJ |j|ksJ ||}t | sJ |j|ksJ t|d d |f |d d |f fi | ||}t | sJ |j|ksJ ||}t | s6J |j|ksFJ |jjd |jjd krt||fi | t||fi | np|jjd |
k r|j |ksJ t||fi | n6t||d d |f ||d d |f fi | d S )Nr   )rH   rZ   )rZ   rH   rE   rD   rh         ?)r   rq   Ztail_strengthrK   Frs   gQ?h㈵>)r!   r    rI   r   r   r}   rv   r   r.   r   r-   r,   r|   )r-   r,   r|   rK   r   )r9   rM   rN   minZstandard_normalr   astypefloat32dictminimumr	   r6   isfinitealldtyper#   r   explained_variance_ratio_r"   sumr4   inverse_transformr5   )r   r   r   r|   rT   Zglobal_dtyper   rq   Zn_samples_testr   r   r<   ZX_trainZX_testZtolsZvariance_thresholdZextra_other_kwargsr-   pca_full	pca_otherZX_trans_full_trainZX_trans_other_trainZreference_componentsZother_componentsZstableZX_trans_full_testZX_trans_other_testZX_recons_full_testZX_recons_other_testr)   r)   r*   test_pca_solver_equivalence  s    










*

r   r<   rH   ru   N   )n_informativerK   rE   )zrandom-tallzcorrelated-tallzrandom-wide)Zidsc                 C   sr   t d|dd}|| }t|jtj|ddd tjtj| ddd }t	|dd	d d }t|j|d
d d S )NrD   r   rL   r.   r~   F)ZrowvarT)reverserb   rc   )
r	   r6   r   r#   r9   varlinalgZeigr>   sorted)r<   r,   r=   ZX_pcaZexpected_resultr)   r)   r*   %test_pca_explained_variance_empirical  s    
r   c                 C   sf   t jd}d\}}|||}tdd|d}td| |d}|| || t|j|jdd d S )Nr   rH   ru   rD   r   rL   rb   rc   )r9   rM   r]   r   r	   r3   r   r$   )r,   r   r   rq   r<   r   r   r)   r)   r*   $test_pca_singular_values_consistency  s    

r   c                 C   s   t jd}d\}}|||}td| |d}||}tt |jd t j	
|dd  t|jt t j|d dd d\}}|||}td| |d}||}|t t j|d dd }|d d df  d	9  < |d d d
f  d9  < t ||j}|| t|jg d d S )Nr   r   rD   rL   Zfrory   )rH   n   rx   A`"	@r.   X9v@)r   r   rw   )r9   rM   r]   r   r	   r6   r   r   r$   r   Znormsqrtr:   r"   r3   )r,   r   r   rq   r<   r=   X_transZX_hatr)   r)   r*   test_pca_singular_values  s&    
 

r   c                 C   s   t jd}d\}}|||d }|d d  t g d7  < d|d| t g d }td| d||}|t |d 	  }t
t |d d d	d
d d S )Nr   rH   rx   rC   rE   rx         r.   rD   r/   rw   rb   rc   )r9   rM   r]   r   arrayr	   r3   r4   r   r   r   abs)r,   r   npr<   XtZYtr)   r)   r*   test_pca_check_projection  s    r   c                 C   s^   ddgddgg}t d| dd}||}|js6J dt| ddd t| d	d
d d S )Nrw   g        r.   r   rL   )rD   r.   r   r0   gQ?rb   rc   )r	   r6   r5   r   r   r   )r,   r<   r=   r   r)   r)   r*   test_pca_check_projection_list  s    
r   )r   r   r   c           	      C   s~   t jd}d\}}|||}|d d df  d9  < |g d7 }td| |d|}||}||}t||dd	 d S )
Nr   )rv   rx   r.   r   )r   r   rx   rD   r   h㈵>rc   )	r9   rM   r]   r   r	   r3   r4   r   r   )	r,   r|   r   r   r   r<   r=   YZ	Y_inverser)   r)   r*   test_pca_inverse  s    

r   r2   )r   r.   r   )r.   r   r   z!svd_solver, n_components, err_msg))r   r   2must be between 1 and min\(n_samples, n_features\))r   r   r   )r   rD   zmust be strictly less than min)r   rx   zZn_components=3 must be between 0 and min\(n_samples, n_features\)=2 with svd_solver='full'c                 C   s   d}t || d}tjt|d || W d    n1 s>0    Y  | dkr|}d||}tjt|d" t || d| W d    n1 s0    Y  d S )NrD   r,   r[   r   zgn_components={}L? must be strictly less than min\(n_samples, n_features\)={}L? with svd_solver='arpack')r	   r^   r_   
ValueErrorr3   format)r,   r2   r-   err_msgZ
smallest_dZ
pca_fittedr)   r)   r*   test_pca_validation	  s    (r   zsolver, n_components_c                 C   s&   t |d}||  |j|ks"J d S )Nr   )r	   r3   r&   )r2   rt   r&   r=   r)   r)   r*   test_n_components_none2  s    


r   c                 C   sH   t jd}d\}}|||}td| d}|| |jdksDJ d S )Nr   iX  rE   mler/   r.   )r9   rM   r]   r   r	   r3   r&   )r,   r   r   rq   r<   r=   r)   r)   r*   test_n_components_mleA  s    
r   c                 C   sr   t jd}d\}}|||}td| d}d| }tjt|d |	| W d    n1 sd0    Y  d S )Nr   r   r   r/   z:n_components='mle' cannot be a string with svd_solver='{}'r[   )
r9   rM   r]   r   r	   r   r^   r_   r   r3   )r,   r   r   rq   r<   r=   r   r)   r)   r*   test_n_components_mle_errorL  s    r   c                  C   st   t jd} d\}}| ||d }|d d  t g d7  < tddd|}|jdksbJ |jd	kspJ d S )
Nr   rH   r   rC   rE   rx   r   r   r.   rD   r   r   r/   r.   )	r9   rM   r]   r   r   r	   r3   r-   r&   )r   r   r   r<   r=   r)   r)   r*   test_pca_dim[  s    r   c                     s   d\ } t jd}| | d | dt g d  t g d }t| dd}|| |jt  fd	d
td| D }|d |	 d   ksJ d S )Nr   r   r   rC   r.   r   )r.   r   r{   r      r   r/   c                    s   g | ]}t | qS r)   )r
   ).0kr   spectr)   r*   
<listcomp>s      z$test_infer_dim_1.<locals>.<listcomp>rB   )
r9   rM   r]   r   r   r	   r3   r#   rangemax)r   r   r<   r=   llr)   r   r*   test_infer_dim_1f  s    
 r   c                  C   s   d\} }t jd}|| |d }|d d  t g d7  < |dd  t g d7  < t|dd	}|| |j}t|| d
ksJ d S )Nr   r   rC   rE   r      r   r   r{   rD   rg   r   r/   r.   	r9   rM   r]   r   r   r	   r3   r#   r   r   r   r   r<   r=   r   r)   r)   r*   test_infer_dim_2w  s    
r   c                  C   s   d\} }t jd}|| |d }|d d  t g d7  < |dd  t g d7  < |dd	  d
t g d 7  < t|dd}|| |j}t|| d
ksJ d S )Nr   r   rC   rE   r   r   r   rZ   (   rD   )rg   r.   rg   r.   rg   r   r/   r   r   r)   r)   r*   test_infer_dim_3  s    "
r   z'X, n_components, n_components_validatedgffffff?rD   rB   r   r   r   c                 C   s<   t |dd}||  |jt|ks*J |j|ks8J d S )Nr   r/   )r	   r3   r-   r^   r   r&   )r<   r-   Zn_components_validatedr=   r)   r)   r*   $test_infer_dim_by_explained_variance  s    	
r   c           	      C   s   d\}}t jd}|||d t g d }td| d}|| ||}dt dt j	 t 
d d	  | }t|| dd
d ||||d t g d }||ksJ tdd| d}|| ||}||ksJ d S )N)r   rx   r   rC   r   rD   r/   g      r.   g|Gz?g?rc   g?T)r-   r|   r,   )r9   rM   r]   r   r   r	   r3   scorelogpiexpr   )	r,   r   r   r   r<   r=   Zll1hZll2r)   r)   r*   test_pca_score  s    

&$

r   c                  C   s   d\} }t jd}|| ||| dt g d  t g d }|| ||| dt g d  t g d }t |}t|D ](}t|dd}|| |	|||< q|
 dksJ d S )N)   rx   r   r.   r   )r.   r   r{   r   r/   )r9   rM   r]   r   r   r   r   r	   r3   r   Zargmax)r   r   r   ZXlr   r   r   r=   r)   r)   r*   test_pca_score3  s    44

r   c                 C   sF   t jdd\}}td| dd}|| t|j|j dksBJ d S )NTZ
return_X_yrZ   r   rL   )r   load_digitsr	   r3   r9   r   r#   r%   )r,   r<   _r=   r)   r)   r*   test_pca_sanity_noise_variance  s    
r   c                 C   s^   t jdd\}}tdddd}td| dd}|| || t||||dd d S )	NTr   rZ   r   r   rL   r   rc   )r   r   r	   r3   r   r   )r,   r<   r   r   r   r)   r)   r*   "test_pca_score_consistency_solvers  s    

r   c                 C   s   d\}}t jd}|||d t g d }t|| d}|| |jdksVJ || ||j	 |jdkszJ ||j	 d S )Nr   r   rC   r   r/   )
r9   rM   r]   r   r   r	   r3   r%   r   T)r,   r   r   r   r<   r=   r)   r)   r*   'test_pca_zero_noise_variance_edge_cases  s    

r   z4n_samples, n_features, n_components, expected_solver))rE   rv   r   r   )r   rv   rv   r   )r     i  r   )r   r   rE   r   )r   r   r   r   c                 C   sf   t jdj| |fd}t|dd}t||dd}|| |j|ksJJ || t|j|j d S )Nr   rh   )r-   rK   rL   )	r9   rM   r]   rl   r	   r3   Z_fit_svd_solverr   r"   )r   rq   r-   Zexpected_solverr2   re   Zpca_testr)   r)   r*   test_pca_svd_solver_auto  s    

r   c                 C   s   t jd}|dd}t d}tdD ],}td| |d}||d ||d d f< q*t|t 	|dd d f d
dd d S )Nr   rE   )r   rD   r   rD   rL   )r9   rM   r]   randr   r   r	   r6   r   ZtileZreshape)r,   r   r<   rW   ir=   r)   r)   r*   test_pca_deterministic_output  s    
r   c                 C   s   t | | t|  d S )N)"check_pca_float_dtype_preservation$check_pca_int_dtype_upcast_to_double)r,   rT   r)   r)   r*   test_pca_dtype_preservation  s    
r   c                 C   s   t j|dd}|jt jdd}|t j}td| |d|}td| |d|}|j	j
t jksjJ |j	j
t jks|J ||j
t jksJ ||j
t jksJ t|j	|j	ddd d S )	Nr   r   Fr   rx   rL   gMbP?r   )r9   rM   r]   r   r   float64r   r	   r3   r"   r   r4   r   )r,   seedr<   Z	X_float64Z	X_float32pca_64pca_32r)   r)   r*   r   !  s    r   c                 C   s   t jdddd}|jt jdd}|jt jdd}td| dd|}td| dd|}|j	j
t jkspJ |j	j
t jksJ ||j
t jksJ ||j
t jksJ t|j	|j	dd	 d S )
Nr   r   )r   r   Fr   rx   rL   g-C6?rc   )r9   rM   r]   randintr   Zint64Zint32r	   r3   r"   r   r   r4   r   )r,   ZX_i64ZX_i32r   r   r)   r)   r*   r   9  s    r   c                  C   sT   t dd\} }t | |}|j d }t|d| |}|j| jd ksPJ d S )NTr   rj   r.   )r   r	   r3   r   Zcumsumr&   r5   )r<   yr'   r-   r(   r)   r)   r*   5test_pca_n_components_mostly_explained_variance_ratioJ  s
    r   c               	   C   sZ   t g d} d}dD ]>}tjtdd t| || W d    q1 sJ0    Y  qd S )Nr.   KH9r   r   rE   )r   r   z"should be in \[1, n_features - 1\]r[   )r9   r   r^   r_   r   r
   )spectrumr   r   r)   r)   r*   test_assess_dimension_bad_rankV  s
    r  c                  C   s`   t g d} t| dddt j ks(J dD ]}t| |dt j ks,J q,t| ddks\J d S )Nr   r.   rE   r   r   )rD   rx   )r9   r   r
   infr   )r   r   r)   r)   r*   test_small_eigenvalues_mle_  s
    r  c                  C   s<   t jddddddd\} }tdd| }|jdks8J d S )Nr   r.      *   )rq   r   Z
n_repeatedZn_redundantZn_clusters_per_classrK   r   rj   )r   r   r	   r3   r&   r<   r   r=   r)   r)   r*   test_mle_redundant_datal  s    
r  c                  C   s\   t jdddd\} }tddd}tjtdd	 ||  W d    n1 sN0    Y  d S )
Nr      r  )r   rq   rK   r   r   r/   z?n_components='mle' is only supported if n_samples >= n_featuresr[   )r   r   r	   r^   r_   r   r3   r  r)   r)   r*   test_fit_mle_too_few_samples{  s    r
  c                  C   sr   d\} }t jd| |}t j|d d d df dd|d d df< tddd}|| |j|d ksnJ d S )	N)r   rE   r   rg   ry   r   r   r   r.   )r9   rM   r]   r   r   r	   r3   r&   )r   Zn_dimr<   Zpca_sklr)   r)   r*   test_mle_simple_case  s    *
r  c                  C   s   d\} }t | |f}t jj|dd\}}}t|dd  t |d dd t t|d| dsdJ td|D ]}t||| t j	 ksnJ qnd S )	N)	   r   T)Zfull_matricesr.   r   r0   r  rD   )
r9   r   r   Zsvdr   r   r   r
   r   r  )r   rq   r<   r   sr   r)   r)   r*   test_assess_dimesion_rank_one  s     r  c                  C   s   t jd} d}| d|}tdd|dd|}tddd|}tdd	dd
|}tt |jt |j tt |jt |j dS )zCheck that exposing and setting `n_oversamples` will provide accurate results
    even when `X` as a large number of features.

    Non-regression test for:
    https://github.com/scikit-learn/scikit-learn/issues/20589
    r   rH   r   r.   r   )r-   r,   Zn_oversamplesrK   r   r/   r   rL   N)	r9   rM   r]   r   r	   r3   r   r   r"   )r   rq   r<   Zpca_randomizedr   rd   r)   r)   r*   %test_pca_randomized_svd_n_oversamples  s    r  c                  C   s6   t ddtj} |  }tdd tdD | dS )z Check feature names out for PCA.rD   rj   c                 S   s   g | ]}d | qS )r=   r)   )r   r   r)   r)   r*   r     r   z*test_feature_names_out.<locals>.<listcomp>N)r	   r3   r1   r2   Zget_feature_names_outr   r   )r=   namesr)   r)   r*   test_feature_names_out  s    r  c                 C   sV   t jd}|dd}t |}|j|j }t j|ddd	 }t j
|| dS )z9Check the accuracy of PCA's internal variance calculationr   r   r   r.   r~   N)r9   rM   r]   r   r	   r3   r#   r   r   r   Ztestingr   )rs   r   r<   r=   Zpca_varZtrue_varr)   r)   r*   test_variance_correctness  s    r  c                 C   s  t ||}tj|}|j||d}|| | }| }	|jdkrLdnd}
t	dd t
||}| }|jdksJ |j|jksJ tt||d||
t|d	 | }|jdksJ |j|jksJ tt||d|	|
t|d	 W d    n1 s0    Y  d S )
Nr   r   g-C6*?gH׊>TZarray_api_dispatch)r   r   xpr   )r   r1   r2   r   asarrayr3   r8   r7   r   r   r   r5   r   r   r   )name	estimatorarray_namespacer   
dtype_namer  Ziris_npiris_xpZprecision_npZcovariance_npr    Zestimator_xpZprecision_xpZcovariance_xpr)   r)   r*   check_array_api_get_precision  s6    



r  z#array_namespace, device, dtype_namecheckr  r/   r   rC   ZQR)r-   r,   power_iteration_normalizerrK   c                 C   s   | j j}||| |||d d S )Nr   r  )	__class____name__)r  r  r  r   r  r  r)   r)   r*   test_pca_array_api_compliance  s    r"  r   c                 C   s  | j j}||| |||d t||}tdd\}}|j|dd}t|j}	t| }
|j||d}|j||d}|
	|| |
j
}|
j}t|
}tddh |	|| |j
}t|t|ksJ t||d	}|j}t|t|ksJ t||d	}W d    n1 s0    Y  |j|jks$J |jd
 |jd
 ks>J |j|jksPJ t|jd |jd }t|d | |d | |	d |jd |jd kr|d }||d  }||d  }tt|| |	k sJ tt|| |	k sJ d S )Nr  r  rY   Fr   r   Tr  r  r.   r   r0   rg   )r   r!  r   r   r   r   r   r   r  r3   r"   r#   r   array_devicer   r5   r   r   r   r9   r   )r  r  r  r   r  r  r  r<   r   r!   ZestZX_xpZy_xpZcomponents_npZexplained_variance_npZest_xpZcomponents_xpZcomponents_xp_npZexplained_variance_xpZexplained_variance_xp_npZmin_componentsZreference_varianceZextra_variance_npZextra_variance_xp_npr)   r)   r*   !test_pca_mle_array_api_compliance  sH    

,

r$  c               	   C   s  t d t d} | tj}tdddd}td}t jt	|dD t
d	d
 || W d    n1 sr0    Y  W d    n1 s0    Y  |jddd td}t jt	|dD t
d	d
 || W d    n1 s0    Y  W d    n1 s0    Y  |jddd td}t jt|dF t
d	d
 || W d    n1 sj0    Y  W d    n1 s0    Y  d S )NZarray_api_compatZarray_api_strictrD   r   r   rL   zCPCA with svd_solver='arpack' is not supported for Array API inputs.r[   Tr  r   ZLU)r,   r  z[Array API does not support LU factorization. Set `power_iteration_normalizer='QR'` instead.r   zArray API does not support LU factorization, falling back to QR instead. Set `power_iteration_normalizer='QR'` explicitly to silence this warning.)r^   Zimportorskipr  r1   r2   r	   reescaper_   r   r   r3   Z
set_paramsZwarnsUserWarning)r  r  r=   Zexpected_msgr)   r)   r*   7test_array_api_error_and_warnings_on_unsupported_params]  s0    

FHr(  )r   r   )kr%  rm   numpyr9   r^   ZscipyrO   Znumpy.testingr   Zsklearnr   r   Zsklearn.baser   Zsklearn.datasetsr   r   r   Zsklearn.decompositionr	   Zsklearn.decomposition._pcar
   r   Zsklearn.utils._array_apir   r   r   r   r#  Z-sklearn.utils._test_common.instance_generatorr   Zsklearn.utils._testingr   r   Zsklearn.utils.estimator_checksr   Zsklearn.utils.fixesr   r   r1   ZPCA_SOLVERSrQ   rR   r   ZSPARSE_MAX_COMPONENTSr+   markZparametrizer   r2   r5   r@   rV   rX   ra   rf   rr   r   r   listsetr   rM   r]   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r  r
  r  r  r  r  r  r  r"  r$  r(  r)   r)   r)   r*   <module>   s`  

8
!
5 	





.












	

"


: