This commit is contained in:
abc 2023-10-12 14:35:18 +01:00
parent 86248b44bc
commit dc502a22de

View File

@ -3,10 +3,10 @@ import random
import string import string
import time import time
import requests # import requests
from flask import Flask, request from flask import Flask, request
from flask_cors import CORS from flask_cors import CORS
from transformers import AutoTokenizer # from transformers import AutoTokenizer
from g4f import ChatCompletion from g4f import ChatCompletion
@ -95,67 +95,67 @@ def chat_completions():
# Get the embedding from huggingface # Get the embedding from huggingface
def get_embedding(input_text, token): # def get_embedding(input_text, token):
huggingface_token = token # huggingface_token = token
embedding_model = "sentence-transformers/all-mpnet-base-v2" # embedding_model = "sentence-transformers/all-mpnet-base-v2"
max_token_length = 500 # max_token_length = 500
# Load the tokenizer for the 'all-mpnet-base-v2' model # # Load the tokenizer for the 'all-mpnet-base-v2' model
tokenizer = AutoTokenizer.from_pretrained(embedding_model) # tokenizer = AutoTokenizer.from_pretrained(embedding_model)
# Tokenize the text and split the tokens into chunks of 500 tokens each # # Tokenize the text and split the tokens into chunks of 500 tokens each
tokens = tokenizer.tokenize(input_text) # tokens = tokenizer.tokenize(input_text)
token_chunks = [ # token_chunks = [
tokens[i : i + max_token_length] # tokens[i : i + max_token_length]
for i in range(0, len(tokens), max_token_length) # for i in range(0, len(tokens), max_token_length)
] # ]
# Initialize an empty list # # Initialize an empty list
embeddings = [] # embeddings = []
# Create embeddings for each chunk # # Create embeddings for each chunk
for chunk in token_chunks: # for chunk in token_chunks:
# Convert the chunk tokens back to text # # Convert the chunk tokens back to text
chunk_text = tokenizer.convert_tokens_to_string(chunk) # chunk_text = tokenizer.convert_tokens_to_string(chunk)
# Use the Hugging Face API to get embeddings for the chunk # # Use the Hugging Face API to get embeddings for the chunk
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{embedding_model}" # api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{embedding_model}"
headers = {"Authorization": f"Bearer {huggingface_token}"} # headers = {"Authorization": f"Bearer {huggingface_token}"}
chunk_text = chunk_text.replace("\n", " ") # chunk_text = chunk_text.replace("\n", " ")
# Make a POST request to get the chunk's embedding # # Make a POST request to get the chunk's embedding
response = requests.post( # response = requests.post(
api_url, # api_url,
headers=headers, # headers=headers,
json={"inputs": chunk_text, "options": {"wait_for_model": True}}, # json={"inputs": chunk_text, "options": {"wait_for_model": True}},
) # )
# Parse the response and extract the embedding # # Parse the response and extract the embedding
chunk_embedding = response.json() # chunk_embedding = response.json()
# Append the embedding to the list # # Append the embedding to the list
embeddings.append(chunk_embedding) # embeddings.append(chunk_embedding)
# averaging all the embeddings # # averaging all the embeddings
# this isn't very effective # # this isn't very effective
# someone a better idea? # # someone a better idea?
num_embeddings = len(embeddings) # num_embeddings = len(embeddings)
average_embedding = [sum(x) / num_embeddings for x in zip(*embeddings)] # average_embedding = [sum(x) / num_embeddings for x in zip(*embeddings)]
embedding = average_embedding # embedding = average_embedding
return embedding # return embedding
@app.route("/embeddings", methods=["POST"]) # @app.route("/embeddings", methods=["POST"])
def embeddings(): # def embeddings():
input_text_list = request.get_json().get("input") # input_text_list = request.get_json().get("input")
input_text = " ".join(map(str, input_text_list)) # input_text = " ".join(map(str, input_text_list))
token = request.headers.get("Authorization").replace("Bearer ", "") # token = request.headers.get("Authorization").replace("Bearer ", "")
embedding = get_embedding(input_text, token) # embedding = get_embedding(input_text, token)
return { # return {
"data": [{"embedding": embedding, "index": 0, "object": "embedding"}], # "data": [{"embedding": embedding, "index": 0, "object": "embedding"}],
"model": "text-embedding-ada-002", # "model": "text-embedding-ada-002",
"object": "list", # "object": "list",
"usage": {"prompt_tokens": None, "total_tokens": None}, # "usage": {"prompt_tokens": None, "total_tokens": None},
} # }
def run_api(): def run_api():