a
    h|                     @   s&  d dl Z d dlZd dlZd dlZd dlZd dlZd dlZd dlZd dlm	Z	m
Z
 d dlmZ d dlmZmZ d dlmZ d dlmZ d dlZd dlmZmZ d dlmZmZmZ d d	lmZ d d
lmZm Z  d dl!m"Z"m#Z# e$ dkrd dl%Z%e" rd dl&m'Z' d dl(m)Z) d dl*m+Z+ e# rDd dl,Z,d dlm-Z-mZm.Z.mZ e/ej0ej1 Z2e/ej0ej3 ej1 d Z4ddiddiddiddiddiddiddidZ5dZ6dd7e58  dZ9G dd  d Z:eG d!d" d"Z;e
d#d$d%Z<G d&d' d'eZ=e>d(kr"e; Z?d)e?_@d*e?_@e=e?ZAeAB  dS )+    N)ArgumentParser	Namespace)AsyncIterator)	dataclassfield)Thread)Optional)AsyncInferenceClientChatCompletionStreamOutput)AutoTokenizerGenerationConfigPreTrainedTokenizer)BaseTransformersCLICommand)ServeArgumentsServeCommand)is_rich_availableis_torch_availableWindows)Console)Live)Markdown)AutoModelForCausalLMr   BitsAndBytesConfigr   z .!\"#$%&'()*+,\-/:<=>?@[]^_`{|}~textz5There is a Llama in my lawn, how can I get rid of it?zyWrite a Python function that integrates any Python function f(x) numerically over an arbitrary interval [x_start, x_end].z4How many helicopters can a human eat in one sitting?z4Count to 10 but skip every number ending with an 'e'zWhy aren't birds real?z2Why is it important to eat socks after meditating?z$Which number is larger, 9.9 or 9.11?)llamacode
helicopternumbersZbirdssocksZnumbers2a  

**TRANSFORMERS CHAT INTERFACE**

Chat interface to try out a model. Besides chatting with the model, here are some basic commands:
- **!help**: shows all available commands (set generation settings, save chat, etc.)
- **!status**: shows the current status of the model and generation settings
- **!clear**: clears the current conversation and starts a new one
- **!exit**: closes the interface
am  

**TRANSFORMERS CHAT INTERFACE HELP**

Full command list:
- **!help**: shows this help message
- **!clear**: clears the current conversation and starts a new one
- **!status**: shows the current status of the model and generation settings
- **!example {NAME}**: loads example named `{NAME}` from the config and uses it as the user input.
Available example names: `z`, `a%  `
- **!set {ARG_1}={VALUE_1} {ARG_2}={VALUE_2}** ...: changes the system prompt or generation settings (multiple
settings are separated by a space). Accepts the same flags and format as the `generate_flags` CLI argument.
If you're a new user, check this basic flag guide: https://huggingface.co/docs/transformers/llm_tutorial#common-options
- **!save {SAVE_NAME} (optional)**: saves the current chat and settings to file by default to
`./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
- **!exit**: closes the interface
c                   @   s   e Zd Zdee ee dddZee eee	f dddZ
edd	d
Zdd ZedddZeedddZdedddZeeedddZdS )RichInterfaceN
model_name	user_namec                 C   s8   t  | _|d u rd| _n|| _|d u r.d| _n|| _d S )N	assistantuser)r   _consoler!   r"   )selfr!   r"    r'   V/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/commands/chat.py__init__s   s    zRichInterface.__init__)streamreturnc           	         s   | j d| j d t| j dd}d}|I d H 2 z3 d H W }|jd jj}|sTq4tdd|}||7 }g }|	 D ].}|
| |d	r|
d
 qv|
d qvtd| dd}|j|dd q46 W d    n1 s0    Y  | j   |S )Nz[bold blue]<z>:   )consolerefresh_per_second r   z<(/*)(\w*)>z\<\1\2\>z```
z  
zgithub-dark)Z
code_themeT)refresh)r%   printr!   r   choicesdeltacontentresub
splitlinesappend
startswithr   joinstripupdate)	r&   r*   liver   tokenoutputslineslinemarkdownr'   r'   r(   stream_output~   s&    

0
zRichInterface.stream_outputr+   c                 C   s$   | j d| j d}| j   |S )z!Gets user input from the console.[bold red]<z>:
)r%   inputr"   r2   )r&   rG   r'   r'   r(   rG      s    
zRichInterface.inputc                 C   s   | j   dS )zClears the console.N)r%   clearr&   r'   r'   r(   rH      s    zRichInterface.clear)r   c                 C   s(   | j d| j d|  | j   dS )z%Prints a user message to the console.rF   z>:[/ bold red]
N)r%   r2   r"   )r&   r   r'   r'   r(   print_user_message   s    z RichInterface.print_user_messager   colorc                 C   s&   | j d| d|  | j   dS )z,Prints text in a given color to the console.z[bold ]Nr%   r2   )r&   r   rL   r'   r'   r(   print_color   s    zRichInterface.print_colorFminimalc                 C   s&   | j t|rtnt | j   dS )z'Prints the help message to the console.N)r%   r2   r   HELP_STRING_MINIMALHELP_STRING)r&   rQ   r'   r'   r(   
print_help   s    zRichInterface.print_helpr!   generation_configmodel_kwargsc                 C   sJ   | j d| d |r*| j d|  | j d|  | j   dS )zFPrints the status of the model and generation settings to the console.z[bold blue]Model: r0   z[bold blue]Model kwargs: z[bold blue]NrN   )r&   r!   rV   rW   r'   r'   r(   print_status   s
    zRichInterface.print_status)NN)F)__name__
__module____qualname__r   strr)   r   r
   tupleintrD   rG   rH   rJ   rO   boolrT   r   dictrX   r'   r'   r'   r(   r   r   s   .r   c                   @   s  e Zd ZU dZedddidZee ed< edddidZ	ee ed< eddd	idZ
ee ed
< edddidZeed< edddidZee ed< edddidZeed< edddidZee ed< edddidZeed< edddidZeed< eddg dddZee ed< eddg dddZee ed < eddd!idZeed"< eddd#idZee ed$< eddd%idZeed&< eddd'idZeed(< ed)d*d+d)gddZeed,< eddd-idZeed.< ed/dd0idZeed1< ed2dd3idZeed4< d5d6 ZdS )7ChatArgumentsz
    Arguments for the chat CLI.

    See the metadata arg for each argument's description -- the medatata will be printed with
    `transformers chat --help`
    Nhelpz_Name of the pre-trained model. The positional argument will take precedence if both are passed.)defaultmetadatamodel_name_or_pathzKUsername to display in chat interface. Defaults to the current user's name.r$   zSystem prompt.system_promptz./chat_history/zFolder to save chat history.save_folderz"Path to a yaml file with examples.examples_pathFz7Whether to show runtime warnings in the chat interface.verbosezPath to a local generation config file or to a HuggingFace repo containing a `generation_config.json` file. Other generation settings passed as CLI arguments will be applied on top of this generation config.rV   mainzLSpecific model version to use (can be a branch name, tag name or commit id).model_revisionautozDevice to use for inference.devicezA`torch_dtype` is deprecated! Please use `dtype` argument instead.)rl   Zbfloat16Zfloat16Zfloat32)rb   r3   torch_dtypezOverride the default `torch.dtype` and load the model under this dtype. If `'auto'` is passed, the dtype will be automatically derived from the model's weights.dtypez2Whether to trust remote code when loading a model.trust_remote_codezWhich attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`.attn_implementationzIWhether to use 8 bit precision for the base model - works only with LoRA.load_in_8bitzIWhether to use 4 bit precision for the base model - works only with LoRA.load_in_4bitZnf4zQuantization type.Zfp4bnb_4bit_quant_typez#Whether to use nested quantization.use_bnb_nested_quant	localhostz%Interface the server will listen to..host@  zPort the server will listen to.portc                 C   s    | j dur| jdkr| j | _dS )z(Only used for BC `torch_dtype` argument.Nrl   )rn   ro   rI   r'   r'   r(   __post_init__!  s    zChatArguments.__post_init__)rY   rZ   r[   __doc__r   re   r   r\   __annotations__r$   rf   rg   rh   ri   r_   rV   rk   rm   rn   ro   rp   rq   rr   rs   rt   ru   rw   ry   r^   rz   r'   r'   r'   r(   ra      st   
ra   argsc                 C   s   t | S )z;
    Factory function used to chat with a local model.
    )ChatCommandr}   r'   r'   r(   chat_command_factory(  s    r   c                   @   sL  e Zd ZeedddZdd ZeedddZed(e	e
e ed
ddZed)e
e ee dddZee edddZe	eeeef dddZeeee
e e
e eeee f dddZee	e
d dddZe	edef ddd Zee	eeeeeef f eeee eee eef d!d"d#Zd$d% Zd&d' Zd	S )*r   )parserc                 C   sT   t f}| jd|d}|d}|jdtddd |jdtdd	d
d |jtd dS )z
        Register this command to argparse so it's available for the transformer-cli

        Args:
            parser: Root parser to register command-specific arguments
        chat)dataclass_typeszPositional argumentsmodel_name_or_path_or_addressNz7Name of the pre-trained model or address to connect to.)typerc   rb   generate_flagsa  Flags to pass to `generate`, using a space as a separator between flags. Accepts booleans, numbers, and lists of integers, more advanced parameterization should be set through --generation-config. Example: `transformers chat <model_repo> max_new_tokens=100 do_sample=False eos_token_id=[1,2]`. If you're a new user, check this basic flag guide: https://huggingface.co/docs/transformers/llm_tutorial#common-options*)r   rc   rb   nargs)func)ra   
add_parseradd_argument_groupadd_argumentr\   set_defaultsr   )r   r   Zchat_parsergroupr'   r'   r(   register_subcommand0  s"    
zChatCommand.register_subcommandc                 C   s   |j d ur|j }|ds.|ds.|drzd| _|jdksH|jdkrPtd|j dd\|_|_|jd u rtd	nd
| _|j |_t st	 s| jrt
dn$t st
dnt	 s| jrt
d|| _d S )Nhttphttpsrv   Frx   uu   Looks like you’ve set both a server address and a custom host/port. Please pick just one way to specify the server.:   z\When connecting to a server, please specify a model name with the --model_name_or_path flag.TzYou need to install rich to use the chat interface. Additionally, you have not specified a remote endpoint and are therefore spawning a backend. Torch is required for this: (`pip install rich torch`)zHYou need to install rich to use the chat interface. (`pip install rich`)zYou have not specified a remote endpoint and are therefore spawning a backend. Torch is required for this: (`pip install rich torch`))r   r:   spawn_backendrw   ry   
ValueErrorrsplitre   r   r   ImportErrorr~   )r&   r~   namer'   r'   r(   r)   Q  s4    


zChatCommand.__init__rE   c                   C   s(   t  dkrt S tt jS dS )z)Returns the username of the current user.r   N)platformsystemosgetloginpwdgetpwuidgetuidpw_namer'   r'   r'   r(   get_usernamex  s    zChatCommand.get_usernameN)r~   filenamer+   c                 C   s   i }t ||d< | |d< |j}|du rPtd}|j d| d}tj||}tjtj	|dd t
|d	 }tj||d
d W d   n1 s0    Y  tj|S )z!Saves the chat history to a file.settingsZchat_historyNz%Y-%m-%d_%H-%M-%Sz/chat_.jsonT)exist_okwr,   )indent)varsrg   timestrftimer   r   pathr;   makedirsdirnameopenjsondumpabspath)r   r~   r   Zoutput_dictfolderZtime_strfr'   r'   r(   	save_chat  s    
.zChatCommand.save_chat)rf   r+   c                 C   s   | du rg }nd| dg}|S )zClears the chat history.Nr   Zroler5   r'   )rf   r   r'   r'   r(   clear_chat_history  s    zChatCommand.clear_chat_history)r   r+   c                    s  t |dkri S dd |D }dd | D }dd | D }ttddd  fd	d| D }d
dd | D }d| d }|dd}|dd}|dd}|dd}|dd}|dd}zt|}W n tjy   t	dY n0 |S )zUParses the generate flags from the user input into a dictionary of `generate` kwargs.r   c                 S   s.   i | ]&}d | dd  d  | dd qS )"=r   r   )split).0flagr'   r'   r(   
<dictcomp>      z4ChatCommand.parse_generate_flags.<locals>.<dictcomp>c                 S   s*   i | ]"\}}||  d v r"|  n|qS ))truefalse)lowerr   kvr'   r'   r(   r     s   c                 S   s"   i | ]\}}||d krdn|qS )Nonenullr'   r   r'   r'   r(   r     r   )sr+   c                 S   s(   |  dr| dd  } | ddd S )N-r   .r/   )r:   replaceisdigit)r   r'   r'   r(   	is_number  s    
z3ChatCommand.parse_generate_flags.<locals>.is_numberc                    s*   i | ]"\}}| |s"d | d n|qS )r   r'   r   r   r'   r(   r     r   z, c                 S   s   g | ]\}}| d | qS )z: r'   r   r'   r'   r(   
<listcomp>  r   z4ChatCommand.parse_generate_flags.<locals>.<listcomp>{}z"null"r   z"true"r   z"false"r   z"[[z]"rM   r   r   zFailed to convert `generate_flags` into a valid JSON object.
`generate_flags` = {generate_flags}
Converted JSON string = {generate_flags_string})
lenitemsr\   r_   r;   r   r   loadsJSONDecodeErrorr   )r&   r   Zgenerate_flags_as_dictZgenerate_flags_stringZprocessed_generate_flagsr'   r   r(   parse_generate_flags  s0    
z ChatCommand.parse_generate_flags)r~   model_generation_configr+   c                 C   s   |j durLd|j v r>tj|j }tj|j }t||}qlt|j }n t|}|j	f i ddd | 
|j}|j	f i |}||fS )zj
        Returns a GenerationConfig object holding the generation parameters for the CLI command.
        Nr   T   )Z	do_sampleZmax_new_tokens)rV   r   r   r   basenamer   from_pretrainedcopydeepcopyr=   r   r   )r&   r~   r   r   r   rV   Zparsed_generate_flagsrW   r'   r'   r(   get_generation_parameterization  s    


z+ChatCommand.get_generation_parameterization)	tokenizerrV   
eos_tokenseos_token_idsr+   c                 C   s|   |j du r|j}n|j }g }|dur:|| |d |dur\|dd |dD  t|dkrt||j ||fS )z:Retrieves the pad token ID and all possible EOS token IDs.N,c                 S   s   g | ]}t |qS r'   )r^   )r   Ztoken_idr'   r'   r(   r     r   z0ChatCommand.parse_eos_tokens.<locals>.<listcomp>r   )pad_token_idZeos_token_idextendZconvert_tokens_to_idsr   r   r9   )r   rV   r   r   r   Zall_eos_token_idsr'   r'   r(   parse_eos_tokens  s    
zChatCommand.parse_eos_tokensr   )
model_argsr+   c                 C   s<   | j r"td| j| j| j| jd}n| jr4tdd}nd }|S )NT)rs   Zbnb_4bit_compute_dtypert   Zbnb_4bit_use_double_quantZbnb_4bit_quant_storage)rr   )rs   r   ro   rt   ru   rr   )r   quantization_configr'   r'   r(   get_quantization_config  s    z#ChatCommand.get_quantization_configr   )r~   r+   c                 C   s   t j|j|j|jd}|jdv r&|jn
tt|j}| |}|j|j	|d|d}t
j|jfd|ji|}t|dd d u r||j}||fS )N)revisionrp   )rl   Nrl   )r   rq   ro   Z
device_mapr   rp   Zhf_device_map)r   r   Zmodel_name_or_path_positionalrk   rp   ro   getattrtorchr   rq   r   torm   )r&   r~   r   ro   r   rW   modelr'   r'   r(   load_model_and_tokenizer   s.    
z$ChatCommand.load_model_and_tokenizer)
user_inputr~   	interfaceexamplesrV   rW   r   r+   c                 C   s  d}|dkr$|  |j}|  n|dkr8|  n|drt| dk r| }	t|	dkrp|	d }
nd}
| |||
}
|jd|
 d	d
d nT|dr|dd 	 }| }|D ](}d|vr|jd| ddd  qq| 
|}|jf i |}|jf i | n|drt| dkr| d }||v r|  g }||| d  |d|| d d n(d| dt|  d}|j|dd n@|dkr|j|j||d n"d}|jd| ddd |  ||||fS )z
        Handles all user commands except for `!exit`. May update the chat history (e.g. reset it) or the
        generation config (e.g. set a new flag).
        Tz!clearz!helpz!save   r   NzChat saved in !greenrK   z!setr,   r   z(Invalid flag format, missing `=` after `z;`. Please use the format `arg_1=value_1 arg_2=value_2 ...`.red!exampler   r$   r   zExample z* not found in list of available examples: r   z!statusrU   F'z/' is not a valid command. Showing help message.)r   rf   rH   rT   r:   r   r   r   rO   r<   r   r=   rJ   r9   listkeysrX   re   )r&   r   r~   r   r   rV   rW   r   valid_commandZsplit_inputr   Znew_generate_flagsr   Zparsed_new_generate_flagsZnew_model_kwargsZexample_nameZexample_errorr'   r'   r(   handle_non_exit_user_commands;  s\    




z)ChatCommand.handle_non_exit_user_commandsc                 C   s   t |   d S )N)asynciorun
_inner_runrI   r'   r'   r(   r     s    zChatCommand.runc                    s  | j rnt| jj| jj| jj| jj| jj| jj| jj	| jj
| jj| jjdd}t|}t|jd}d|_|  | jjd | jj }| jjdkrdn| jj}t| d| jj }| j}|jd u rt}n4t|j}	t|	}W d    n1 s0    Y  |jd u r|  }
n|j}
t|j}| ||\}}t|j|
d	}|   | !|j"}|j#dd
 zz|$ }|%dr|dkrW W |& I d H  qn | j'|||||||d\}}}}|r|%dsW W |& I d H  q`n|(d|d |j)|d|* |dd}|+|I d H }|(d|d W n* t,y\   Y W |& I d H  qY n0 W |& I d H  n|& I d H  0 q`d S )Nerror)rm   ro   rp   rq   rr   rs   rt   ru   rw   ry   Z	log_level)targetT@rv   zhttp://localhostr   r    rP   r   z!exit)r   r~   r   r   rV   rW   r   r   r$   r   )rV   r   )r*   Z
extra_bodyr#   )-r   r   r~   rm   ro   rp   rq   rr   rs   rt   ru   rw   ry   r   r   r   daemonstartre   rk   r	   rh   DEFAULT_EXAMPLESr   yamlZ	safe_loadr$   r   r   r   r   r   rH   r   rf   rT   rG   r:   closer   r9   Zchat_completionZto_json_stringrD   KeyboardInterrupt)r&   Z
serve_argsZserve_commandthreadr   rw   clientr~   r   r   r$   r   rV   rW   r   r   r   r   r*   Zmodel_outputr'   r'   r(   r     s    
(

!
	zChatCommand._inner_run)N)N)rY   rZ   r[   staticmethodr   r   r)   r\   r   ra   r   r   r   r`   r   r   r   r]   r   r   r^   r   r   r   r   r   r   r   r   r'   r'   r'   r(   r   /  sD    '
7
Tr   __main__z meta-llama/Llama-3.2-3b-Instructzhttp://localhost:8000)Cr   r   r   r   r   r6   stringr   argparser   r   collections.abcr   dataclassesr   r   	threadingr   typingr   r  Zhuggingface_hubr	   r
   Ztransformersr   r   r   Ztransformers.commandsr   Ztransformers.commands.servingr   r   Ztransformers.utilsr   r   r   r   Zrich.consoler   Z	rich.liver   Zrich.markdownr   r   r   r   setascii_letters
whitespaceZALLOWED_KEY_CHARSdigitsZALLOWED_VALUE_CHARSr  rR   r;   r   rS   r   ra   r   r   rY   r~   r   r   r   r'   r'   r'   r(   <module>   sv   	\Y   @
