This commit is contained in:
abc 2023-10-02 18:07:20 +01:00
parent da50e84dfc
commit 23646b6680
1 changed files with 86 additions and 83 deletions

View File

@ -1,104 +1,106 @@
import json import json
import time
import random import random
import string import string
import time
from typing import Any
import requests import requests
from flask import Flask, request
from flask_cors import CORS from typing import Any
from flask import Flask, request
from flask_cors import CORS
from transformers import AutoTokenizer from transformers import AutoTokenizer
from g4f import ChatCompletion from g4f import ChatCompletion
app = Flask(__name__) app = Flask(__name__)
CORS(app) CORS(app)
@app.route('/chat/completions', methods=['POST'])
@app.route("/chat/completions", methods=["POST"])
def chat_completions(): def chat_completions():
model = request.get_json().get("model", "gpt-3.5-turbo") model = request.get_json().get('model', 'gpt-3.5-turbo')
stream = request.get_json().get("stream", False) stream = request.get_json().get('stream', False)
messages = request.get_json().get("messages") messages = request.get_json().get('messages')
response = ChatCompletion.create(model=model, stream=stream, messages=messages) response = ChatCompletion.create(model = model,
stream = stream, messages = messages)
completion_id = "".join(random.choices(string.ascii_letters + string.digits, k=28)) completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
completion_timestamp = int(time.time()) completion_timestamp = int(time.time())
if not stream: if not stream:
return { return {
"id": f"chatcmpl-{completion_id}", 'id': f'chatcmpl-{completion_id}',
"object": "chat.completion", 'object': 'chat.completion',
"created": completion_timestamp, 'created': completion_timestamp,
"model": model, 'model': model,
"choices": [ 'choices': [
{ {
"index": 0, 'index': 0,
"message": { 'message': {
"role": "assistant", 'role': 'assistant',
"content": response, 'content': response,
}, },
"finish_reason": "stop", 'finish_reason': 'stop',
} }
], ],
"usage": { 'usage': {
"prompt_tokens": None, 'prompt_tokens': None,
"completion_tokens": None, 'completion_tokens': None,
"total_tokens": None, 'total_tokens': None,
}, },
} }
def streaming(): def streaming():
for chunk in response: for chunk in response:
completion_data = { completion_data = {
"id": f"chatcmpl-{completion_id}", 'id': f'chatcmpl-{completion_id}',
"object": "chat.completion.chunk", 'object': 'chat.completion.chunk',
"created": completion_timestamp, 'created': completion_timestamp,
"model": model, 'model': model,
"choices": [ 'choices': [
{ {
"index": 0, 'index': 0,
"delta": { 'delta': {
"content": chunk, 'content': chunk,
}, },
"finish_reason": None, 'finish_reason': None,
} }
], ],
} }
content = json.dumps(completion_data, separators=(",", ":")) content = json.dumps(completion_data, separators=(',', ':'))
yield f"data: {content}\n\n" yield f'data: {content}\n\n'
time.sleep(0.1) time.sleep(0.1)
end_completion_data: dict[str, Any] = { end_completion_data: dict[str, Any] = {
"id": f"chatcmpl-{completion_id}", 'id': f'chatcmpl-{completion_id}',
"object": "chat.completion.chunk", 'object': 'chat.completion.chunk',
"created": completion_timestamp, 'created': completion_timestamp,
"model": model, 'model': model,
"choices": [ 'choices': [
{ {
"index": 0, 'index': 0,
"delta": {}, 'delta': {},
"finish_reason": "stop", 'finish_reason': 'stop',
} }
], ],
} }
content = json.dumps(end_completion_data, separators=(",", ":")) content = json.dumps(end_completion_data, separators=(',', ':'))
yield f"data: {content}\n\n" yield f'data: {content}\n\n'
return app.response_class(streaming(), mimetype="text/event-stream") return app.response_class(streaming(), mimetype='text/event-stream')
#Get the embedding from huggingface # Get the embedding from huggingface
def get_embedding(input_text, token): def get_embedding(input_text, token):
huggingface_token = token huggingface_token = token
embedding_model = "sentence-transformers/all-mpnet-base-v2" embedding_model = 'sentence-transformers/all-mpnet-base-v2'
max_token_length = 500 max_token_length = 500
# Load the tokenizer for the "all-mpnet-base-v2" model # Load the tokenizer for the 'all-mpnet-base-v2' model
tokenizer = AutoTokenizer.from_pretrained(embedding_model) tokenizer = AutoTokenizer.from_pretrained(embedding_model)
# Tokenize the text and split the tokens into chunks of 500 tokens each # Tokenize the text and split the tokens into chunks of 500 tokens each
tokens = tokenizer.tokenize(input_text) tokens = tokenizer.tokenize(input_text)
token_chunks = [tokens[i:i + max_token_length] for i in range(0, len(tokens), max_token_length)] token_chunks = [tokens[i:i + max_token_length]
for i in range(0, len(tokens), max_token_length)]
# Initialize an empty list # Initialize an empty list
embeddings = [] embeddings = []
@ -109,52 +111,53 @@ def get_embedding(input_text, token):
chunk_text = tokenizer.convert_tokens_to_string(chunk) chunk_text = tokenizer.convert_tokens_to_string(chunk)
# Use the Hugging Face API to get embeddings for the chunk # Use the Hugging Face API to get embeddings for the chunk
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{embedding_model}" api_url = f'https://api-inference.huggingface.co/pipeline/feature-extraction/{embedding_model}'
headers = {"Authorization": f"Bearer {huggingface_token}"} headers = {'Authorization': f'Bearer {huggingface_token}'}
chunk_text = chunk_text.replace("\n", " ") chunk_text = chunk_text.replace('\n', ' ')
# Make a POST request to get the chunk's embedding # Make a POST request to get the chunk's embedding
response = requests.post(api_url, headers=headers, json={"inputs": chunk_text, "options": {"wait_for_model": True}}) response = requests.post(api_url, headers=headers, json={
'inputs': chunk_text, 'options': {'wait_for_model': True}})
# Parse the response and extract the embedding # Parse the response and extract the embedding
chunk_embedding = response.json() chunk_embedding = response.json()
# Append the embedding to the list # Append the embedding to the list
embeddings.append(chunk_embedding) embeddings.append(chunk_embedding)
#averaging all the embeddings # averaging all the embeddings
#this isn't very effective # this isn't very effective
#someone a better idea? # someone a better idea?
num_embeddings = len(embeddings) num_embeddings = len(embeddings)
average_embedding = [sum(x) / num_embeddings for x in zip(*embeddings)] average_embedding = [sum(x) / num_embeddings for x in zip(*embeddings)]
embedding = average_embedding embedding = average_embedding
return embedding return embedding
@app.route("/embeddings", methods=["POST"]) @app.route('/embeddings', methods=['POST'])
def embeddings(): def embeddings():
input_text_list = request.get_json().get("input") input_text_list = request.get_json().get('input')
input_text = ' '.join(map(str, input_text_list)) input_text = ' '.join(map(str, input_text_list))
token = request.headers.get('Authorization').replace("Bearer ", "") token = request.headers.get('Authorization').replace('Bearer ', '')
embedding = get_embedding(input_text, token) embedding = get_embedding(input_text, token)
return { return {
"data": [ 'data': [
{ {
"embedding": embedding, 'embedding': embedding,
"index": 0, 'index': 0,
"object": "embedding" 'object': 'embedding'
} }
], ],
"model": "text-embedding-ada-002", 'model': 'text-embedding-ada-002',
"object": "list", 'object': 'list',
"usage": { 'usage': {
"prompt_tokens": None, 'prompt_tokens': None,
"total_tokens": None 'total_tokens': None
} }
} }
def main(): def main():
app.run(host="0.0.0.0", port=1337, debug=True) app.run(host='0.0.0.0', port=1337, debug=True)
if __name__ == '__main__':
if __name__ == "__main__":
main() main()