Summary
(This article has been translated from Japanese)
Chat UI displayed in a web browser allows you to talk to ChatRWKV.
ChatRWKV is an open source program that aims to work similar to OpenAI's ChatGPT and works quite comfortably on a local PC.
I made it into a chat engine and used the chat UI library chatux to call it, and here are its contents.
↓Now you can chat with ChatRWKV like this!
The full source code is here
https://github.com/riversun/chatux-server-rwkv
Target
- The code in this article assumes a GPU (CUDA) environment.
- As described later, if you use smaller trained data, it works well on older GPUs.
- For example, using [rwkv-4-pile-3b]
- As described later, if you use smaller trained data, it works well on older GPUs.
(https://huggingface.co/BlinkDL/rwkv-4-pile-3b) as trained data with fp16i8 worked on GeForce RTX 2060 (6GB memory).
Environment
We tested the following environment
- Python: 3.9
- GPU CUDA Version: 12
- GPU memory: 24GB
- OS: Ubuntu Desktop 22.04 / Windows 11
Main part
STEP0: Operation to serverize ChatRWKV
Since chatux will be used as a front-end chat UI, we want to serverize ChatRWKV.
We want to make it as easy as possible, so modify chat.py under the v2 directory of the ChatRWKV repository below to make it a server.
https://github.com/BlinkDL/ChatRWKV/blob/main/v2/chat.py
STEP1: Install related packages
Install related packages (*)
pip install rwkv fastapi uvicorn
(This assumes that you already have pytorch for CUDA running. If you get runtime errors, install the missing packages as needed)
STEP2: Edit chat.py to make it a chat server
Let's modify the previous chat.py into a server
There are 2 points in making a chat server
1. Convert to Web API
- Use a package called fastapi to turn it into a Web API server
- Use a package called uvicorn to start a web api server made with fastapi
- Accept input with path
chat_api
with
- Accept input with path
@app.get("/chat_api")
annotation
- Next, throw the input text from the user to the rnn model of ChatRWKV with reply = handle_message(text).replace('\n', '<br>')
as follows
- Return the results from the rnn model packed in outJson. only this :)
@app.get("/chat_api")
async def chat(text: str = ""):
reply = handle_message(text).replace('\n', '<br>')
print(f'input:{text} reply:{reply}')
outJson = {
"output": [
{
"type": "text",
"value": reply
}
]
}
return outJson;
app.mount("/", StaticFiles(directory="html", html=True), name="html")
2. Handling of rnn model generation results
- The command line version of chat used to output chat response generation results to the console one by one, but in the case of the web chat version (the current chatux archipelago) it does not support sequential output, so sentence generation by the model ended for the time being. Made it possible to respond in stages
- Return
send_msg = pipeline.decode(model_tokens[begin:]).strip()
part as response - It is better to return the reply from the rnn model one by one because it is easier to read the results, so the architecture that can realize this is a homework on the chatux side
- Return
3. + command processing
- I tried to pass
+gen
,+i
,+qa
, etc., but the response from the model is returned unprocessed, so if you use these commands, it will be strange It may become. Also, free generation is long, so you may have to wait.
Below is the full source code of the chat server
########################################################################################################
# This program is base on The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
#
# The original file of this source code can be found here.
# https://github.com/BlinkDL/ChatRWKV/blob/main/v2/chat.py
#
# Based on source code above, I have added the following functionality required for a Web API server for chatbots
# - web server functionality
# - Use queries received in get requests as input to rnn, and http responses to the results generated by rnn
#
########################################################################################################
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
import uvicorn
import os, copy, types, gc, sys
current_path = os.path.dirname(os.path.abspath(__file__))
import numpy as np
args = types.SimpleNamespace()
# specify chat server
HOST = 'localhost'
PORT = 8001
URL = f'http://{HOST}:{PORT}'
# specify RWKV strategy,model(weight data)
STRATEGY = 'cuda fp16i8'
MODEL_NAME = 'RWKV-4-Pile-14B-20230313-ctx8192-test1050.pth'
# specify params for weight data
args.n_layer = 32
args.n_embd = 4096
args.ctx_len = 4096
CHAT_LANG = "English"
PROMPT_FILE = f'{current_path}/init_prompt/English-1.py'
try:
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
except:
pass
np.set_printoptions(precision=4, suppress=True, linewidth=200)
print('\n\nChatRWKV v2 https://github.com/BlinkDL/ChatRWKV')
import torch
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
args.strategy = STRATEGY
os.environ["RWKV_JIT_ON"] = '1' # '1' or '0', please use torch 1.13+ and benchmark speed
os.environ["RWKV_CUDA_ON"] = '0' # '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries
args.MODEL_NAME = f'{current_path}/data/{MODEL_NAME}'
CHAT_LEN_SHORT = 40
CHAT_LEN_LONG = 150
FREE_GEN_LEN = 200
GEN_TEMP = 1.0 # sometimes it's a good idea to increase temp. try it
GEN_TOP_P = 0.8
GEN_alpha_presence = 0.2 # Presence Penalty
GEN_alpha_frequency = 0.2 # Frequency Penalty
AVOID_REPEAT = ',:?!'
CHUNK_LEN = 256 # split input into chunks to save VRAM (shorter -> slower) 入力を分割する
PILE_v2_MODEL = False # ONLY FOR MY OWN TESTING. STILL TRAINING PILE_v2_MODELs
all_state = {}
print(f'\n{CHAT_LANG} - {args.strategy} - {PROMPT_FILE}')
from rwkv.model import RWKV
from rwkv.utils import PIPELINE
with open(PROMPT_FILE, 'rb') as file:
user = None
bot = None
interface = None
init_prompt = None
exec(compile(file.read(), PROMPT_FILE, 'exec'))
init_prompt = init_prompt.strip().split('\n')
for c in range(len(init_prompt)):
init_prompt[c] = init_prompt[c].strip().strip('\u3000').strip('\r')
init_prompt = '\n' + ('\n'.join(init_prompt)).strip() + '\n\n'
print(f'Loading model - {args.MODEL_NAME}')
model = RWKV(model=args.MODEL_NAME, strategy=args.strategy)
if not PILE_v2_MODEL:
pipeline = PIPELINE(model, f"{current_path}/20B_tokenizer.json")
END_OF_TEXT = 0
END_OF_LINE = 187
else:
pipeline = PIPELINE(model, "cl100k_base")
END_OF_TEXT = 100257
END_OF_LINE = 198
model_tokens = []
model_state = None
AVOID_REPEAT_TOKENS = []
for i in AVOID_REPEAT:
dd = pipeline.encode(i)
assert len(dd) == 1
AVOID_REPEAT_TOKENS += dd
def run_rnn(tokens, newline_adj=0):
global model_tokens, model_state
tokens = [int(x) for x in tokens]
model_tokens += tokens
while len(tokens) > 0:
out, model_state = model.forward(tokens[:CHUNK_LEN], model_state)
tokens = tokens[CHUNK_LEN:]
out[END_OF_LINE] += newline_adj
if model_tokens[-1] in AVOID_REPEAT_TOKENS:
out[model_tokens[-1]] = -999999999
return out
def save_all_stat(srv, name, last_out):
n = f'{name}_{srv}'
all_state[n] = {}
all_state[n]['out'] = last_out
all_state[n]['rnn'] = copy.deepcopy(model_state)
all_state[n]['token'] = copy.deepcopy(model_tokens)
def load_all_stat(srv, name):
global model_tokens, model_state
n = f'{name}_{srv}'
model_state = copy.deepcopy(all_state[n]['rnn'])
model_tokens = copy.deepcopy(all_state[n]['token'])
return all_state[n]['out']
out = run_rnn(pipeline.encode(init_prompt))
save_all_stat('', 'chat_init', out)
gc.collect()
torch.cuda.empty_cache()
srv_list = ['dummy_server']
for s in srv_list:
save_all_stat(s, 'chat', out)
def reply_msg(msg):
print(f"bot's response' {bot}{interface} {msg}\n")
def handle_message(message):
global model_tokens, model_state
srv = 'dummy_server'
msg = message.replace('\\n', '\n').strip()
x_temp = GEN_TEMP
x_top_p = GEN_TOP_P
if ("-temp=" in msg):
x_temp = float(msg.split("-temp=")[1].split(" ")[0])
msg = msg.replace("-temp=" + f'{x_temp:g}', "")
if ("-top_p=" in msg):
x_top_p = float(msg.split("-top_p=")[1].split(" ")[0])
msg = msg.replace("-top_p=" + f'{x_top_p:g}', "")
# print(f"top_p: {x_top_p}")
if x_temp <= 0.2:
x_temp = 0.2
if x_temp >= 5:
x_temp = 5
if x_top_p <= 0:
x_top_p = 0
if msg == '+reset':
out = load_all_stat('', 'chat_init')
save_all_stat(srv, 'chat', out)
reply_msg("Chat reset.")
return "Chat Reset"
elif msg[:5].lower() == '+gen ' or msg[:3].lower() == '+i ' or msg[:4].lower() == '+qa ' \
or msg[:4].lower() == '+qq ' or msg.lower() == '+++' or msg.lower() == '++':
if msg[:5].lower() == '+gen ':
new = '\n' + msg[5:].strip()
model_state = None
model_tokens = []
out = run_rnn(pipeline.encode(new))
save_all_stat(srv, 'gen_0', out)
elif msg[:3].lower() == '+i ':
new = f'''
Below is an instruction that describes a task. Write a response that appropriately completes the request.
# Instruction:
{msg[3:].strip()}
# Response:
'''
model_state = None
model_tokens = []
out = run_rnn(pipeline.encode(new))
save_all_stat(srv, 'gen_0', out)
elif msg[:4].lower() == '+qq ':
new = '\nQ: ' + msg[4:].strip() + '\nA:'
model_state = None
model_tokens = []
out = run_rnn(pipeline.encode(new))
save_all_stat(srv, 'gen_0', out)
elif msg[:4].lower() == '+qa ':
out = load_all_stat('', 'chat_init')
real_msg = msg[4:].strip()
new = f"{user}{interface} {real_msg}\n\n{bot}{interface}"
out = run_rnn(pipeline.encode(new))
save_all_stat(srv, 'gen_0', out)
elif msg.lower() == '+++':
try:
out = load_all_stat(srv, 'gen_1')
save_all_stat(srv, 'gen_0', out)
except:
return
elif msg.lower() == '++':
try:
out = load_all_stat(srv, 'gen_0')
except:
return
begin = len(model_tokens)
out_last = begin
occurrence = {}
for i in range(FREE_GEN_LEN + 100):
for n in occurrence:
out[n] -= (GEN_alpha_presence + occurrence[n] * GEN_alpha_frequency)
token = pipeline.sample_logits(
out,
temperature=x_temp,
top_p=x_top_p,
)
if token == END_OF_TEXT:
break
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
if msg[:4].lower() == '+qa ': # or msg[:4].lower() == '+qq ':
out = run_rnn([token], newline_adj=-2)
else:
out = run_rnn([token])
xxx = pipeline.decode(model_tokens[out_last:])
if '\ufffd' not in xxx: # avoid utf-8 display issues
print(xxx, end='', flush=True)
out_last = begin + i + 1
if i >= FREE_GEN_LEN:
break
print('\n')
send_msg = pipeline.decode(model_tokens[begin:]).strip()
# print(f'### send ###\n[{send_msg}]')
save_all_stat(srv, 'gen_1', out)
return send_msg
else:
if msg.lower() == '+':
try:
out = load_all_stat(srv, 'chat_pre')
except:
return
else:
out = load_all_stat(srv, 'chat')
new = f"{user}{interface} {msg}\n\n{bot}{interface}"
# print(f'### add ###\n[{new}]')
out = run_rnn(pipeline.encode(new), newline_adj=-999999999)
save_all_stat(srv, 'chat_pre', out)
begin = len(model_tokens)
out_last = begin
occurrence = {}
for i in range(999):
if i <= 0:
newline_adj = -999999999
elif i <= CHAT_LEN_SHORT:
newline_adj = (i - CHAT_LEN_SHORT) / 10
elif i <= CHAT_LEN_LONG:
newline_adj = 0
else:
newline_adj = min(2, (i - CHAT_LEN_LONG) * 0.25) # MUST END THE GENERATION
for n in occurrence:
out[n] -= (GEN_alpha_presence + occurrence[n] * GEN_alpha_frequency)
token = pipeline.sample_logits(
out,
temperature=x_temp,
top_p=x_top_p,
)
# if token == END_OF_TEXT:
# break
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
out = run_rnn([token], newline_adj=newline_adj)
out[END_OF_TEXT] = -999999999 # disable <|endoftext|>
xxx = pipeline.decode(model_tokens[out_last:])
if '\ufffd' not in xxx: # avoid utf-8 display issues
out_last = begin + i + 1
send_msg = pipeline.decode(model_tokens[begin:])
if '\n\n' in send_msg:
send_msg = send_msg.strip()
# print(f'send_msg={send_msg}')
break
save_all_stat(srv, 'chat', out)
return send_msg;
app = FastAPI()
@app.get("/chat_api")
async def chat(text: str = ""):
reply = handle_message(text).replace('\n', '<br>')
print(f'input:{text} reply:{reply}')
outJson = {
"output": [
{
"type": "text",
"value": reply
}
]
}
return outJson;
app.mount("/", StaticFiles(directory="html", html=True), name="html")
def start_server():
uvicorn.run(app, host=HOST, port=PORT)
def main():
start_server()
if __name__ == "__main__":
main()
STEP3: Obtain trained model data (weights)
The code is completed, but we do not yet have the trained weights, so we download them from the ChatRWKV author's page on Hugging Face.
There are several types of model data: 3b, 7b, 14b, etc. The higher the number, the more parameters you have, the more expressive your model is.
However, the higher the number of parameters, the more GPU memory is consumed, so download the data according to the execution environment.
Below is a table of typical model data types and GPU memory consumption
strategy | rwkv-4-pile-14b | rwkv-4-pile-7b | rwkv-4-pile-3b |
fp16 | 28GB | 16GB | 6GB |
fp16i8 | 14GB | 8.6GB | 3GB |
By the way, if you use fp16i8 (which seems to mean quantize fp16 trained data to int8), you can reduce the amount of GPU memory used, although the accuracy may be slightly lower.
(When specifying it in the code, use cuda fp16
or cuda fp16i8
.
Download the weight data (*.pth) file from
Get BlinkDL/rwkv-4-pile-14b.
BlinkDL/rwkv-4-pile-14b
-
Download
- https://huggingface.co/BlinkDL/rwkv-4-pile-14b
- The file
RWKV-4-Pile-14B-2023xxxx-ctx8192-testxxx.pth
is The best general model
-
GPU memory consumption
- About 28GB (fp16) ... Out of memory even on a 24GB class GPU
- About 14GB(fp16i8) ... Accuracy is said to be a little degraded, but it works on T4.
Parameters
Parameters
args.n_layer = 40
args.n_embd = 5120
args.ctx_len = 8192
Get BlinkDL/rwkv-4-pile-7b
BlinkDL/rwkv-4-pile-7b
-
Download.
-https://huggingface.co/BlinkDL/rwkv-4-pile-7b-
RWKV-4-Pile-7B-20230109-ctx4096.pth
-
-
Consumed GPU memory
- Approx. 16GB (fp16)
- Approx. 8.6GB (fp16i8)
Parameters
args.n_layer = 32
args.n_embd = 4096
args.ctx_len = 4096
Get BlinkDL/rwkv-4-pile-3b
BlinkDL/rwkv-4-pile-3b
-
Download.
-https://huggingface.co/BlinkDL/rwkv-4-pile-3b-
RWKV-4-Pile-3B-20221110-ctx4096.pth
-
-
GPU memory consumption
- Approx. 6GB(fp16)
- Approx. 3GB (fp16i8)
Parameters
args.n_layer = 32
args.n_embd = 2560
args.ctx_len = 4096
Place the downloaded files in the [project]/data
folder
If this doesn't ring a bell, refer to the folder structure at the bottom of this page.
https://github.com/riversun/chatux-server-rwkv
STEP4: Set the learned model data (weights) in the code
Open the chatux-server-rwkv.py
file you just created.
- In the
#specify RWKV strategy,model(weight data)
area, enterSTRATEGY=
andMODEL_NAME
as shown below. TheMODEL_NAME
is just the file name. - Around
# specify params for weight data
, enter the parameters for the trained model data. ↑(You can copy and paste the ones shown in the above.
# specify RWKV strategy,model(weight data)
STRATEGY = 'cuda fp16i8'
MODEL_NAME = 'RWKV-4-Pile-7B-20230109-ctx4096.pth'
# specify params for weight data
args.n_layer = 32
args.n_embd = 4096
args.ctx_len = 4096
STEP5: Place other files
5-1 Create ChatUI (HTML)
Prepare index.html as follows and save it as [project]/html/index.html
.
chatux loads chatux.min.js
.
Now you will see the chat UI in a small window on the right edge for PCs, or on the full screen for smartphones, and you can chat.
index.html
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1">
<title>chatRWKV </title>
</head>
<body>
<script src="https://riversun.github.io/chatux/chatux.min.js"></script>
<script>
const chatux = new ChatUx();
// initializing param for chatux
const initParam =
{
renderMode: 'auto',
api: {
//echo chat server
endpoint:'/chat_api',
method: 'GET',
dataType: 'json'
},
bot: {
botPhoto: 'https://riversun.github.io/chatbot/bot_icon_operator.png',
humanPhoto: null,
widget: {
sendLabel: 'SEND',
placeHolder: 'Say something'
}
},
window: {
title: 'chatRWKV',
infoUrl: 'https://github.com/riversun/chatux'
}
};
chatux.init(initParam);
chatux.start(true);
</script>
</body>
</html>
The following server information is specified to match the chat server just created
api: {
//echo chat server
endpoint:'/chat_api',
method: 'GET',
dataType: 'json'
},
5-2 Putting the initial prompt
Create a folder [project]/init_prompt
and place the initial prompt (the first conversation context prompt) there.
The initial prompt is specified in the source code at PROMPT_FILE = f'{current_path}/init_prompt/English-1.py'
.
Here, we put https://github.com/BlinkDL/ChatRWKV/blob/main/v2/prompt/default/English-1.py.
Incidentally, if you want to make it Japanese-inspired, you can put the initial prompt in Japanese.
https://github.com/riversun/chatux-server-rwkv/blob/main/init_prompt/Japanese-2.py
5-3 Putting tokenizer.json
Place https://github.com/BlinkDL/ChatRWKV/blob/main/v2/20B_tokenizer.json as [project]/20B_tokenizer.json
.
STEP6: Start the chat server
You can start the chat server with the following command
python chatux-server-rwkv.py
STEP7: Execution
Once the server is started, open http://localhost:8001 in your browser.
Now you can hit ChatRWKV from the web chat UI (ChatUX)!
Demo Video.
https://www.youtube.com/embed/t3vuNmIYXBo
ChatRWKV does a good job.
As you can see in the video below, I asked him the following questions and he responded nicely.
The Tallest Mountain in the World."
Who was the lead actor on Titanic?"
Other films made by the director of Titanic"
Summary
-
I explained how to make ChatRWKV conversational with Web Chat UI
- FastAPI is used to convert to Web API
- Sequential messages returned by the rnn model are summarized before being responded to.
- ToBe
- ChatUX should have sequential callbacks like ChatGPT and ChatRWKV
- Better response for free generation
- Engine is classified
- On-memory conversation context (status to be passed to rnn) is persistent
-
Here is a sample code that has been tested
Thank you for reading to the end!
Top comments (1)
If I want to modify the chatrwkv for specific topic only like say my Company and topics related to it, how can I do it?