a
    hi                     @   s  U d dl Z d dlZd dlZd dlZd dlZd dlZd dlZd dlZd dlm	Z	 d dl
Zd dlmZ d dl
mZ d dlmZ d dlmZmZmZmZmZmZ ddlmZmZ g dZdd	d
Zh dZdd Zd:eeje ejdddZ i Z!e"edf e#d< G dd dZ$e j%G dd dZ&eedddZ'eddddZ(eddddZ)ee*eef dd d!Z+edd"d#d$Z,ej-e.d%d&d'Z/eejdd(d)d*Z0ej1edd+d,d-Z2d.d/ Z3d0d1 Z4d;d3d4Z5d5d6 Z6d<d8d9Z7dS )=    N)infer_schema)get_ctx)BaseTyBaseTypeFunctionSchemaListTypeOperatorName
SchemaKind   )autograd_kernel_indirectionconstruct_autograd_kernel)	custom_opCustomOpr   CPUCUDA)cpucuda>   torchatZprimsZprimZatenZpytorchc                   C   s   t dt d S )Nzwtorch._custom_op is deprecated and will be removed in PyTorch 2.6, please use the equivalent torch.library API instead.)warningswarnDeprecationWarning r   r   Q/var/www/html/assistant/venv/lib/python3.9/site-packages/torch/_custom_op/impl.pywarn_deprecated5   s    r   )qualnamemanual_schemareturnc                    s   t    fdd}|S )L
    This API is deprecated, please use torch.library.custom_op instead
    c           	   	      s,  t | stdt|  t\}}t| | j|krXtd d| d| j d d u rlt| ddn }| | }t	|}t
|  d urt||  t|d}|| t||j}t|||||d	d
}| j|_| j|_| j|_t||jdtt| tj|ttt| |S )NzDcustom_op(...)(func): Expected `func` to be a Python function, got: zcustom_op(qualname='z-', ...)(func): expected `func` to have name 'z' but got 'zX'. Please either change the name of `func` or the qualname that is passed to `custom_op`r   )Zmutates_argsFRAGMENTT_private_accessAutograd)inspect
isfunction
ValueErrortypeparse_qualnamevalidate_namespace__name__r   r   parsevalidate_schema validate_function_matches_schemalibraryLibrarydefinefind_ophandle_or_thrownamer   
__module____doc__impl_opnamer   weakrefproxyr   _C#_dispatch_set_report_error_callback	functoolspartialreport_error_callback)	funcnsr1   schema
schema_strfunction_schemalibophandleresultr   r   r   r   innerE   sR    

	


zcustom_op.<locals>.inner)r   )r   r   rF   r   rE   r   r   =   s    1r   r   global_registryc                       s   e Zd ZdZdd fdd
Zdd Zd,d	d
Zdd Zdd Zdd Z	dd Z
dd Zd-ejeeje f ejdddZdd ZejdddZd.ejdddZdd  Zd!d" Zd#d$ Zd%d& Zd/d'd(Zd0d*d+Z  ZS )1r   r   Fr    c                   sn   t    t  |std| d| }|| _|| _|| _|| _|| _|| _	d | _
i | _d| _| t| j	< d S )Nz|The CustomOp constructor is private and we do not guarantee BC for it. Please use custom_op(...) to create a CustomOp object::F)super__init__r   RuntimeError_schema_cpp_ns_lib	_ophandler5   	_qualnamer)   _impls'_registered_autograd_kernel_indirectionrG   )selfrB   cpp_nsr?   operator_namerC   r!   r1   	__class__r   r   rJ      s"    
zCustomOp.__init__c                 C   s0   | j r
J | j| jtt| d d| _ d S )Nr"   T)rR   rN   r4   r5   r   r6   r7   rS   r   r   r   %_register_autograd_kernel_indirection   s
    
z.CustomOp._register_autograd_kernel_indirection   c              
   C   s   |  |rJ| j| }|d us J |j}td| d| j d| d| d	tt|}|j	 d|j
 }t||| j|< d S )NzAttempting to register a z impl for operator z that already has a z  impl registered from Python at z. This is not supported.:)	_has_implrQ   locationrK   rP   r#   getframeinfosys	_getframefilenamelinenoFuncAndLocation)rS   kindr=   
stacklevelZfunc_and_locationr]   framer   r   r   _register_impl   s    

zCustomOp._register_implc                 C   s
   | j | S NrQ   rS   rd   r   r   r   	_get_impl   s    zCustomOp._get_implc                 C   s
   || j v S rh   ri   rj   r   r   r   r\      s    zCustomOp._has_implc                 C   s6   | ` ttj| j}t|| jr*t|| j t| j	= d S rh   )
rN   getattrr   opsrM   hasattrr5   delattrrG   rP   )rS   opnamespacer   r   r   _destroy   s
    zCustomOp._destroyc                 C   s   d| j  dS )Nz<CustomOp(op="z")>)rP   rX   r   r   r   __repr__   s    zCustomOp.__repr__c                 O   s   t j| jg|R i |}|S rh   )r8   Z_dispatch_call_boxedrO   )rS   argskwargsrD   r   r   r   __call__   s    zCustomOp.__call__)device_typesr   c                    s6   t trgD ]}t| q fdd}|S )T
        This API is deprecated, please use torch.library.custom_op instead
        c                    sJ   t D ]<}| j||  d t| }tjj||  q| S )Nre   )set_check_doesnt_have_library_implrg   SUPPORTED_DEVICE_TYPE_TO_KEYr-   r4   rN   r5   )fdevice_typeZdispatch_key_stacklevelrv   rS   r   r   rF      s    
zCustomOp.impl.<locals>.inner)
isinstancestrvalidate_device_type)rS   rv   r   r}   rF   r   r~   r   r4      s    

zCustomOp.implc                 C   s@   |  |rd S t| }t| j|r<td| d| j dd S )Nzimpl(..., device_types=z): the operator zs already has an implementation for this device type via a pre-existing torch.library or TORCH_LIBRARY registration.)r\   r{   r8   Z._dispatch_has_computed_kernel_for_dispatch_keyrP   rK   )rS   r}   keyr   r   r   rz      s    
z(CustomOp._check_doesnt_have_library_impl)r   c                    s    fdd}|S )z2Register an implementation for a factory function.c                    s&     d|  t j jd|  | S )NfactoryZBackendSelect)rg   r-   r4   rN   r5   r|   rX   r   r   rF      s    z$CustomOp.impl_factory.<locals>.innerr   )rS   rF   r   rX   r   impl_factory   s    zCustomOp.impl_factoryc                    s    fdd}|S )rw   c                    sZ      jd d djjt  fdd}jj	|d  S )Nabstractrx   c                     sN   fdd}t jj|  | i |W  d    S 1 s@0    Y  d S )Nc                      s   t d d  d S )Nz<Attempted to call get_ctx() for the meta implementation for a  .You have presumably called get_ctx() because the operator has a data-dependent output shape; if so, there is no such meta implementation and this error is the correct behavior. Otherwise, please remove the call to get_ctx() in the implementation registered with impl_abstract at rK   r   )r]   r   r   r   error_on_ctx  s    zOCustomOp.impl_abstract.<locals>.inner.<locals>.f_with_ctx.<locals>.error_on_ctx)r   _libraryZ	fake_implZset_ctx_getter)rs   rt   r   r|   r]   r   r   r   
f_with_ctx  s    z9CustomOp.impl_abstract.<locals>.inner.<locals>.f_with_ctxMeta)
$_check_doesnt_have_library_meta_implrg   rk   r]   rP   r:   wrapsrN   r4   r5   )r|   r   r   rS   r   r   rF     s    z%CustomOp.impl_abstract.<locals>.innerr   rS   r   rF   r   r   r   impl_abstract  s    zCustomOp.impl_abstractc                    s    fdd}j    tjkr*|d  j} js>|d t|dksNJ tdd |D }|rl|d ttj	d	ttj
d
ttjdttjdttjdtttjd di} jD ]0}|j|v rq|dt|  d|j d qd S )Nc                    s    t d|  dj d  d S )NzCCannot use torch._custom_ops APIs to register backward formula for z. Got operator z with schema: )rK   rP   )detailr?   rS   r   r   error*  s    z4CustomOp._check_can_register_backward.<locals>.errorznon-functional operatorzoperator with no returnsr   c                 s   s"   | ]}|j d uo|j j V  qd S rh   )
annotationZis_write).0rr   r   r   	<genexpr>:  s   z8CustomOp._check_can_register_backward.<locals>.<genexpr>zoperator that returns viewsintSymIntboolfloatTensorzList[Tensor]zoperator with return not in z (got ))rL   rd   r	   Z
functionalreturnslenanyr   r   r   r   r   r   r   r   r&   listvalues)rS   r   ZretsZis_non_mutating_viewZallowed_return_typesretr   r   r   _check_can_register_backward)  s4    






z%CustomOp._check_can_register_backwardc                 C   s^   | j r
d S t| jdr*td| j ddD ]*}t| j|r.td| j d| dq.d S )NCompositeImplicitAutogradz3impl_backward/impl_save_for_backward: the operator a3   already has an implementation for this device type via a pre-existing registration to DispatchKey::CompositeImplicitAutograd.CompositeImplicitAutograd operators do not need an autograd formula; instead, the operator will decompose into its constituents and those can have autograd formulas defined on them.)r"   ZAutogradCPUZAutogradCUDAz; already has an Autograd kernel registered to DispatchKey::z vi a pre-existing torch.library or TORCH_LIBRARY registration. Please either remove those registrations or don't use the torch._custom_ops APIs)rR   r8   %_dispatch_has_kernel_for_dispatch_keyrP   rK   )rS   r   r   r   r   (_check_doesnt_have_library_autograd_implP  s"    z1CustomOp._check_doesnt_have_library_autograd_implc                 C   sr   |  drd S t| jdr.t| jds.d S t| jdrNtd| j dt| jdrntd| j dd S )Nr   ZCompositeExplicitAutogradr   r   z!impl_abstract(...): the operator a-   already has an implementation for this device type via a pre-existing registration to DispatchKey::CompositeImplicitAutograd.CompositeImplicitAutograd operators do not need an abstract impl; instead, the operator will decompose into its constituents and those can have abstract impls defined on them.z already has an DispatchKey::Meta implementation via a pre-existing torch.library or TORCH_LIBRARY registration. Please either remove that registration or don't call impl_abstract.)r\   r8   r   rP   rK   rX   r   r   r   r   l  s$    
	z-CustomOp._check_doesnt_have_library_meta_implc              	   C   sX   |  dsJ |  dsJ t| j| j| t| j| dj| dj}| d| d S )Nbackwardsave_for_backwardautograd)	r\   r   rL   _output_differentiabilityget_oprP   rk   r=   rg   )rS   Zkernelr   r   r   _register_autograd_kernel  s    

z"CustomOp._register_autograd_kernelc                    s    fdd}|S )zyRegister a function that tells us what to save for backward.

        Please see impl_backward for more details.
        c                    sD        js  jd|  d dr@  d S )Nr   rx   r   )r   r   rR   rY   rg   r\   r   r   r   r   r   rF     s    
z.CustomOp.impl_save_for_backward.<locals>.innerr   r   r   r   r   impl_save_for_backward  s    	zCustomOp.impl_save_for_backwardNc                    sl   durXfdd}t ts$|  D ]}t |ts(|  q(tjjtkrX|   fdd}|S )rw   Nc                      s   t d  d S )Nzimpl_backward(output_differentiability): expected output_differentiability to be a list of bools with length equal to the number of outputs of this CustomOp got: r   r   )output_differentiabilityr   r   yell  s
    z$CustomOp.impl_backward.<locals>.yellc                    sJ        js  jd|  d _drF  d S )Nr   rx   r   )r   r   rR   rY   rg   r   r\   r   r   r   r   rS   r   r   rF     s    
z%CustomOp.impl_backward.<locals>.inner)r   r   r   r   rL   r   )rS   r   r   r   diffrF   r   r   r   impl_backward  s    


zCustomOp.impl_backward)rZ   )rZ   )rZ   )rZ   )NrZ   )r)   r2   __qualname__r3   rJ   rY   rg   rk   r\   rq   rr   ru   typingUnionr   IterableCallabler4   rz   r   r   r   r   r   r   r   r   __classcell__r   r   rV   r   r      s0   


 
#',
c                   @   s    e Zd ZU ejed< eed< dS )rc   r=   r]   N)r)   r2   r   r   r   __annotations__r   r   r   r   r   rc     s   

rc   )rT   rU   c                 C   s0   |j d u rdn|j }t|  dt|j |S )N rH   )overload_namer8   Z_dispatch_find_schema_or_throwr   r1   )rT   rU   r   r   r   r   r0     s
    r0   )r>   r   c                 C   s:   d| v rt d|  d| tv r6t d|  d|  dd S )N.zcustom_op(..., ns="zC"): expected ns to not contain any . (and be a valid variable name)zcustom_op(..., ns='z'): 'z9' is a reserved namespace, please choose something else. )r%   RESERVED_NS)r>   r   r   r   r(     s    
r(   )r?   r   c                 C   s:   t jj| std|  | jjd ur6td|  d S )Nzcustom_op only supports functional operators (ops that do not mutate any inputs, do not return views of the inputs, and has at least one return). Got the following non-functional schema: zUcustom_op does not support arguments named 'self'. Please rename your argument. Got: )r   r   utilsZis_functional_schemar%   	argumentsZself_arg)r?   r   r   r   r+     s    r+   )r   r   c                 C   sR   |  dd}t|dkr(td|  dd|d v rBtd|  |d |d fS )	NrH   r
   rZ   z$Expected there to be a namespace in z;, i.e. The operator name should look something like ns::foor   zThe torch.custom_ops APIs do not handle overloads, i.e. operator names with '.' in them. Please name your operator something like ns::foo. Got: r   )splitr   r%   )r   namesr   r   r   r'     s    
r'   )r}   r   c                 C   s&   | t vr"td|  dt   dd S )NzCustomOp.impl(device_types=[z(, ...]): we only support device_type in r   )r{   r%   keys)r}   r   r   r   r     s    r   )paramr   c                 C   s   | j tjjtjjfv S rh   )rd   r#   	ParameterPOSITIONAL_OR_KEYWORDKEYWORD_ONLY)r   r   r   r   supported_param  s    r   )r?   r=   r   c                    s   t |tdd j D s0td tdd j D sVjt jj	urdtd dd j D }dd j D }fd	d
 fdd fdd}||j
j ||j
j d S )Nc                 s   s   | ]\}}t |V  qd S rh   )r   r   _pr   r   r   r   +      z3validate_function_matches_schema.<locals>.<genexpr>zcustom_op(..., manual_schema)(func): positional-only args, varargs, and kwargs are not supported. Please rewrite `func` to not have them. Got `func` with signature: c                 s   s    | ]\}}|j tjjuV  qd S rh   )r   r#   r   emptyr   r   r   r   r   3  s   zcustom_op(..., manual_schema)(func): When passing in a manual schema, we expect `func` to have no type annotations to avoid ambiguity. Got `func` with signature: c                 S   s&   g | ]\}}|j tjjkr||fqS r   )rd   r#   r   r   r   r1   r   r   r   r   
<listcomp>?  s   z4validate_function_matches_schema.<locals>.<listcomp>c                 S   s&   g | ]\}}|j tjjkr||fqS r   )rd   r#   r   r   r   r   r   r   r   D  s   c                      s   t d d  d S )Nzcustom_op(..., manual_schema)(func): When passing in a manual schema, we expect `func`'s signature to match `manual_schema` (aside from type annotations). func's signature: , manual_schema: r%   r   r?   sigr   r   r   J  s    z/validate_function_matches_schema.<locals>.errorc                      s   t d d  d S )Nzycustom_op(..., manual_schema)(func): neither func nor manual_schema should have default arguments. Got func's signature: r   r   r   r   r   r   error_default_argsR  s    z<validate_function_matches_schema.<locals>.error_default_argsc                    s`   t | t |kr   t| |D ]:\\}}}||jkr<   |jtjjusT|jd ur   q d S rh   )r   zipr1   defaultr#   r   r   )Zsig_argsZschema_argsr1   r   arg)r   r   r   r   compareZ  s    
z1validate_function_matches_schema.<locals>.compare)r#   	signatureall
parametersitemsr%   r   return_annotation	Signaturer   r   Zflat_positionalZflat_kwarg_only)r?   r=   
positionalZ	kwargonlyr   r   )r   r   r?   r   r   r,   &  s:    
	r,   )r   r   r   c              	   C   st   |dkrt |  d|dkr,t |  d|dv r\| }t |  d| d| d| d	t |  d
| dd S )NZ	Undefineda  : There were no Tensor inputs to this operator (e.g. you passed an empty list of Tensors). If your operator is a factory function (that is, it takes no Tensors and constructs a new one), then please use CustomOp.impl_factory to register an implementation for itr   z: when running with device='Meta' tensors: there is no abstract impl registered for this CustomOp. Please register one via CustomOp.impl_abstract to get this CustomOp to work with Meta tensors)r   r   z: when running with device='z' tensors: there is no zW impl registered for this CustomOp. Please register one via CustomOp.impl(device_type='z')z%: No implementation for dispatch key z. It is likely that we have not added this functionality yet, please either open an issue or if you're feeling adventurous, use the low-level torch.library API)NotImplementedErrorlower)r   r   Zdevicer   r   r   r<   g  s(    r<   c                 C   s\   | j }tj|d}|  dd }t| j}|dd }t	|}t
||||| ddS )Nr   rH   Tr    )	namespacer   r-   r.   r1   r   r   rL   r   r*   r   )opr>   rB   r1   r@   r?   r   r   r   custom_op_from_existing  s    

r   c                    sf    fdd}t  \}}ttj|s*|  ttj|}t||sF|  t||}t|ds`|  |jS )Nc                      s   t d  dd S )NzCould not find the operator z~. Please make sure you have already registered the operator and (if registered from C++) loaded it via torch.ops.load_library.r   r   r   r   r   error_not_found  s    
zget_op.<locals>.error_not_foundr   )r'   rn   r   rm   rl   r   )r   r   r>   r1   rp   packetr   r   r   r     s    


r   Fc                 C   s8   | t v rt |  S |s$td|  dt| }t|}|S )NzCould not find custom op "z5". Did you register it via the torch._custom_ops API?)rG   rK   r   r   )r   Zalso_check_torch_libraryoverloadrD   r   r   r   _find_custom_op  s    
r   c                 C   sF   | t jjjvrd S t jjj|  }|d u r,d S |ds:d S |djS )Nr   )r   Z
_custom_opr4   rG   r\   rk   r=   )r   r   r   r   r   get_abstract_impl  s    
r   Tc              	   C   s   |  d\}}| | }t|}t| |r<tjjjgng }t	|d}|j
||d t||j}	t|||||	dd}
|
  tj|	ttt|
 t| S )NrH   r   )tagsTr    )r   r   r*   r+   r   r8   Tagneeds_fixed_stride_orderr-   r.   r/   r0   r1   r   rY   r9   r:   r;   r<   r6   r7   r   )r   r?   r   r>   r1   r@   rA   r   rB   rC   rD   r   r   r   _custom_op_with_schema  s    
r   )N)F)T)8dataclassesr:   r#   r_   r   r   r6   r   Ztorch._Cr8   Ztorch._library.infer_schemaZtorch.libraryr-   r   r   Ztorchgen.modelr   r   r   r   r   r	   r   r   r   __all__r{   r   r   r   Optionalr   r   rG   dictr   r   	dataclassrc   r0   r(   r+   tupler'   r   r   r   r   r,   Anyr<   r   r   r   r   r   r   r   r   r   <module>   s\   
 	
	 
B  Y	A
