Add support for message history and system message in OpenaiChat

Add fetch access_token and fix cookie usage in OpenaiChat
Fix save created access_token in cookies in OpenaiChat
Add use_auth_header config in GeminiPro
This commit is contained in:
Heiner Lohaus 2024-02-26 23:41:06 +01:00
parent 0bfaede7df
commit 862e5ef16d
3 changed files with 78 additions and 47 deletions

View File

@ -27,6 +27,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
proxy: str = None, proxy: str = None,
api_key: str = None, api_key: str = None,
api_base: str = None, api_base: str = None,
use_auth_header: bool = True,
image: ImageType = None, image: ImageType = None,
connector: BaseConnector = None, connector: BaseConnector = None,
**kwargs **kwargs
@ -38,7 +39,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
raise MissingAuthError('Missing "api_key"') raise MissingAuthError('Missing "api_key"')
headers = params = None headers = params = None
if api_base: if api_base and use_auth_header:
headers = {"Authorization": f"Bearer {api_key}"} headers = {"Authorization": f"Bearer {api_key}"}
else: else:
params = {"key": api_key} params = {"key": api_key}

View File

@ -20,9 +20,9 @@ except ImportError:
pass pass
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_prompt, get_cookies from ..helper import get_cookies
from ...webdriver import get_browser, get_driver_cookies from ...webdriver import get_browser
from ...typing import AsyncResult, Messages, Cookies, ImageType from ...typing import AsyncResult, Messages, Cookies, ImageType, Union
from ...requests import get_args_from_browser from ...requests import get_args_from_browser
from ...requests.aiohttp import StreamSession from ...requests.aiohttp import StreamSession
from ...image import to_image, to_bytes, ImageResponse, ImageRequest from ...image import to_image, to_bytes, ImageResponse, ImageRequest
@ -37,6 +37,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
needs_auth = True needs_auth = True
supports_gpt_35_turbo = True supports_gpt_35_turbo = True
supports_gpt_4 = True supports_gpt_4 = True
supports_message_history = True
default_model = None default_model = None
models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-gizmo"] models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-gizmo"]
model_aliases = {"text-davinci-002-render-sha": "gpt-3.5-turbo"} model_aliases = {"text-davinci-002-render-sha": "gpt-3.5-turbo"}
@ -170,6 +171,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
""" """
if not cls.default_model: if not cls.default_model:
async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response: async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response:
cls._update_request_args(session)
response.raise_for_status() response.raise_for_status()
data = await response.json() data = await response.json()
if "categories" in data: if "categories" in data:
@ -179,7 +181,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
return cls.default_model return cls.default_model
@classmethod @classmethod
def create_messages(cls, prompt: str, image_request: ImageRequest = None): def create_messages(cls, messages: Messages, image_request: ImageRequest = None):
""" """
Create a list of messages for the user input Create a list of messages for the user input
@ -190,31 +192,27 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
Returns: Returns:
A list of messages with the user input and the image, if any A list of messages with the user input and the image, if any
""" """
# Create a message object with the user role and the content
messages = [{
"id": str(uuid.uuid4()),
"author": {"role": message["role"]},
"content": {"content_type": "text", "parts": [message["content"]]},
} for message in messages]
# Check if there is an image response # Check if there is an image response
if not image_request: if image_request:
# Create a content object with the text type and the prompt # Change content in last user message
content = {"content_type": "text", "parts": [prompt]} messages[-1]["content"] = {
else:
# Create a content object with the multimodal text type and the image and the prompt
content = {
"content_type": "multimodal_text", "content_type": "multimodal_text",
"parts": [{ "parts": [{
"asset_pointer": f"file-service://{image_request.get('file_id')}", "asset_pointer": f"file-service://{image_request.get('file_id')}",
"height": image_request.get("height"), "height": image_request.get("height"),
"size_bytes": image_request.get("file_size"), "size_bytes": image_request.get("file_size"),
"width": image_request.get("width"), "width": image_request.get("width"),
}, prompt] }, messages[-1]["content"]["parts"][0]]
} }
# Create a message object with the user role and the content
messages = [{
"id": str(uuid.uuid4()),
"author": {"role": "user"},
"content": content,
}]
# Check if there is an image response
if image_request:
# Add the metadata object with the attachments # Add the metadata object with the attachments
messages[0]["metadata"] = { messages[-1]["metadata"] = {
"attachments": [{ "attachments": [{
"height": image_request.get("height"), "height": image_request.get("height"),
"id": image_request.get("file_id"), "id": image_request.get("file_id"),
@ -333,30 +331,33 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
raise MissingRequirementsError('Install "py-arkose-generator" and "async_property" package') raise MissingRequirementsError('Install "py-arkose-generator" and "async_property" package')
if not parent_id: if not parent_id:
parent_id = str(uuid.uuid4()) parent_id = str(uuid.uuid4())
if cls._args is None and cookies is None:
cookies = get_cookies("chat.openai.com", False) # Read api_key from args
api_key = kwargs["access_token"] if "access_token" in kwargs else api_key api_key = kwargs["access_token"] if "access_token" in kwargs else api_key
if api_key is None and cookies is not None:
api_key = cookies["access_token"] if "access_token" in cookies else api_key
if cls._args is None: if cls._args is None:
cls._args = { if api_key is None:
"headers": {"Cookie": "; ".join(f"{k}={v}" for k, v in cookies.items() if k != "access_token")}, # Read api_key from cookies
"cookies": {} if cookies is None else cookies cookies = get_cookies("chat.openai.com", False) if cookies is None else cookies
} api_key = cookies["access_token"] if "access_token" in cookies else api_key
if api_key is not None: cls._args = cls._create_request_args(cookies)
cls._args["headers"]["Authorization"] = f"Bearer {api_key}"
async with StreamSession( async with StreamSession(
proxies={"https": proxy}, proxies={"https": proxy},
impersonate="chrome", impersonate="chrome",
timeout=timeout, timeout=timeout
headers=cls._args["headers"]
) as session: ) as session:
if api_key is None and cookies:
# Read api_key from session
api_key = await cls.fetch_access_token(session, cls._args["headers"])
if api_key is not None: if api_key is not None:
cls._args["headers"]["Authorization"] = f"Bearer {api_key}"
try: try:
cls.default_model = await cls.get_default_model(session, cls._args["headers"]) cls.default_model = await cls.get_default_model(session, cls._args["headers"])
except Exception as e: except Exception as e:
if debug.logging: if debug.logging:
print(f"{e.__class__.__name__}: {e}") print(f"{e.__class__.__name__}: {e}")
if cls.default_model is None: if cls.default_model is None:
login_url = os.environ.get("G4F_LOGIN_URL") login_url = os.environ.get("G4F_LOGIN_URL")
if login_url: if login_url:
@ -366,12 +367,17 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
except MissingRequirementsError: except MissingRequirementsError:
raise MissingAuthError(f'Missing or invalid "access_token". Add a new "api_key" please') raise MissingAuthError(f'Missing or invalid "access_token". Add a new "api_key" please')
cls.default_model = await cls.get_default_model(session, cls._args["headers"]) cls.default_model = await cls.get_default_model(session, cls._args["headers"])
try: try:
image_response = None image_response = await cls.upload_image(
if image: session,
image_response = await cls.upload_image(session, cls._args["headers"], image, kwargs.get("image_name")) cls._args["headers"],
image,
kwargs.get("image_name")
) if image else None
except Exception as e: except Exception as e:
yield e yield e
end_turn = EndTurn() end_turn = EndTurn()
model = cls.get_model(model) model = cls.get_model(model)
model = "text-davinci-002-render-sha" if model == "gpt-3.5-turbo" else model model = "text-davinci-002-render-sha" if model == "gpt-3.5-turbo" else model
@ -389,13 +395,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
"history_and_training_disabled": history_disabled and not auto_continue, "history_and_training_disabled": history_disabled and not auto_continue,
} }
if action != "continue": if action != "continue":
prompt = format_prompt(messages) if not conversation_id else messages[-1]["content"] messages = messages if not conversation_id else [messages[-1]]
data["messages"] = cls.create_messages(prompt, image_response) data["messages"] = cls.create_messages(messages, image_response)
# Update cookies before next request
for c in session.cookie_jar if hasattr(session, "cookie_jar") else session.cookies.jar:
cls._args["cookies"][c.name if hasattr(c, "name") else c.key] = c.value
cls._args["headers"]["Cookie"] = "; ".join(f"{k}={v}" for k, v in cls._args["cookies"].items())
async with session.post( async with session.post(
f"{cls.url}/backend-api/conversation", f"{cls.url}/backend-api/conversation",
@ -406,6 +407,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
**cls._args["headers"] **cls._args["headers"]
} }
) as response: ) as response:
cls._update_request_args(session)
if not response.ok: if not response.ok:
raise RuntimeError(f"Response {response.status}: {await response.text()}") raise RuntimeError(f"Response {response.status}: {await response.text()}")
last_message: int = 0 last_message: int = 0
@ -475,13 +477,13 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
"let session = await fetch('/api/auth/session');" "let session = await fetch('/api/auth/session');"
"let data = await session.json();" "let data = await session.json();"
"let accessToken = data['accessToken'];" "let accessToken = data['accessToken'];"
"let expires = new Date(); expires.setTime(expires.getTime() + 60 * 60 * 4);" "let expires = new Date(); expires.setTime(expires.getTime() + 60 * 60 * 4 * 1000);"
"document.cookie = 'access_token=' + accessToken + ';expires=' + expires.toUTCString() + ';path=/';" "document.cookie = 'access_token=' + accessToken + ';expires=' + expires.toUTCString() + ';path=/';"
"return accessToken;" "return accessToken;"
) )
args = get_args_from_browser(f"{cls.url}/", driver, do_bypass_cloudflare=False) args = get_args_from_browser(f"{cls.url}/", driver, do_bypass_cloudflare=False)
args["headers"]["Authorization"] = f"Bearer {access_token}" args["headers"]["Authorization"] = f"Bearer {access_token}"
args["headers"]["Cookie"] = "; ".join(f"{k}={v}" for k, v in args["cookies"].items() if k != "access_token") args["headers"]["Cookie"] = cls._format_cookies(args["cookies"])
return args return args
finally: finally:
driver.close() driver.close()
@ -516,6 +518,34 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
return decoded_json["token"] return decoded_json["token"]
raise RuntimeError(f"Response: {decoded_json}") raise RuntimeError(f"Response: {decoded_json}")
@classmethod
async def fetch_access_token(cls, session: StreamSession, headers: dict):
async with session.get(
f"{cls.url}/api/auth/session",
headers=headers
) as response:
if response.ok:
data = await response.json()
if "accessToken" in data:
return data["accessToken"]
@staticmethod
def _format_cookies(cookies: Cookies):
return "; ".join(f"{k}={v}" for k, v in cookies.items() if k != "access_token")
@classmethod
def _create_request_args(cls, cookies: Union[Cookies, None]):
return {
"headers": {} if cookies is None else {"Cookie": cls._format_cookies(cookies)},
"cookies": {} if cookies is None else cookies
}
@classmethod
def _update_request_args(cls, session: StreamSession):
for c in session.cookie_jar if hasattr(session, "cookie_jar") else session.cookies.jar:
cls._args["cookies"][c.name if hasattr(c, "name") else c.key] = c.value
cls._args["headers"]["Cookie"] = cls._format_cookies(cls._args["cookies"])
class EndTurn: class EndTurn:
""" """
Class to represent the end of a conversation turn. Class to represent the end of a conversation turn.

View File

@ -126,7 +126,7 @@ class Completions():
stop: Union[list[str], str] = None, stop: Union[list[str], str] = None,
api_key: str = None, api_key: str = None,
**kwargs **kwargs
) -> Union[ChatCompletion, Generator[ChatCompletionChunk]]: ) -> Union[ChatCompletion, Generator[ChatCompletionChunk, None, None]]:
model, provider = get_model_and_provider( model, provider = get_model_and_provider(
model, model,
self.provider if provider is None else provider, self.provider if provider is None else provider,