Merge pull request #1762 from hlohaus/goo

Fix attr conversation_id not found
This commit is contained in:
H Lohaus 2024-03-28 17:17:59 +01:00 committed by GitHub
commit 64e07b7fbf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 32 additions and 22 deletions

View File

@ -389,19 +389,17 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
print(f"{e.__class__.__name__}: {e}") print(f"{e.__class__.__name__}: {e}")
model = cls.get_model(model).replace("gpt-3.5-turbo", "text-davinci-002-render-sha") model = cls.get_model(model).replace("gpt-3.5-turbo", "text-davinci-002-render-sha")
fields = Conversation() if conversation is None else copy(conversation) fields = Conversation(conversation_id, parent_id) if conversation is None else copy(conversation)
fields.finish_reason = None fields.finish_reason = None
while fields.finish_reason is None: while fields.finish_reason is None:
conversation_id = conversation_id if fields.conversation_id is None else fields.conversation_id
parent_id = parent_id if fields.message_id is None else fields.message_id
websocket_request_id = str(uuid.uuid4()) websocket_request_id = str(uuid.uuid4())
data = { data = {
"action": action, "action": action,
"conversation_mode": {"kind": "primary_assistant"}, "conversation_mode": {"kind": "primary_assistant"},
"force_paragen": False, "force_paragen": False,
"force_rate_limit": False, "force_rate_limit": False,
"conversation_id": conversation_id, "conversation_id": fields.conversation_id,
"parent_message_id": parent_id, "parent_message_id": fields.message_id,
"model": model, "model": model,
"history_and_training_disabled": history_disabled and not auto_continue and not return_conversation, "history_and_training_disabled": history_disabled and not auto_continue and not return_conversation,
"websocket_request_id": websocket_request_id "websocket_request_id": websocket_request_id
@ -425,6 +423,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
await raise_for_status(response) await raise_for_status(response)
async for chunk in cls.iter_messages_chunk(response.iter_lines(), session, fields): async for chunk in cls.iter_messages_chunk(response.iter_lines(), session, fields):
if return_conversation: if return_conversation:
history_disabled = False
return_conversation = False return_conversation = False
yield fields yield fields
yield chunk yield chunk
@ -432,7 +431,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
break break
action = "continue" action = "continue"
await asyncio.sleep(5) await asyncio.sleep(5)
if history_disabled and auto_continue and not return_conversation: if history_disabled and auto_continue:
await cls.delete_conversation(session, cls._headers, fields.conversation_id) await cls.delete_conversation(session, cls._headers, fields.conversation_id)
@staticmethod @staticmethod

View File

@ -41,7 +41,7 @@ from g4f.providers.base_provider import ProviderModelMixin
from g4f.Provider.bing.create_images import patch_provider from g4f.Provider.bing.create_images import patch_provider
from g4f.providers.conversation import BaseConversation from g4f.providers.conversation import BaseConversation
conversations: dict[str, BaseConversation] = {} conversations: dict[dict[str, BaseConversation]] = {}
class Api(): class Api():
@ -106,7 +106,8 @@ class Api():
kwargs["image"] = open(self.image, "rb") kwargs["image"] = open(self.image, "rb")
for message in self._create_response_stream( for message in self._create_response_stream(
self._prepare_conversation_kwargs(options, kwargs), self._prepare_conversation_kwargs(options, kwargs),
options.get("conversation_id") options.get("conversation_id"),
options.get('provider')
): ):
if not window.evaluate_js(f"if (!this.abort) this.add_message_chunk({json.dumps(message)}); !this.abort && !this.error;"): if not window.evaluate_js(f"if (!this.abort) this.add_message_chunk({json.dumps(message)}); !this.abort && !this.error;"):
break break
@ -193,8 +194,8 @@ class Api():
messages[-1]["content"] = get_search_message(messages[-1]["content"]) messages[-1]["content"] = get_search_message(messages[-1]["content"])
conversation_id = json_data.get("conversation_id") conversation_id = json_data.get("conversation_id")
if conversation_id and conversation_id in conversations: if conversation_id and provider in conversations and conversation_id in conversations[provider]:
kwargs["conversation"] = conversations[conversation_id] kwargs["conversation"] = conversations[provider][conversation_id]
model = json_data.get('model') model = json_data.get('model')
model = model if model else models.default model = model if model else models.default
@ -211,7 +212,7 @@ class Api():
**kwargs **kwargs
} }
def _create_response_stream(self, kwargs, conversation_id: str) -> Iterator: def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str) -> Iterator:
""" """
Creates and returns a streaming response for the conversation. Creates and returns a streaming response for the conversation.
@ -231,7 +232,9 @@ class Api():
first = False first = False
yield self._format_json("provider", get_last_provider(True)) yield self._format_json("provider", get_last_provider(True))
if isinstance(chunk, BaseConversation): if isinstance(chunk, BaseConversation):
conversations[conversation_id] = chunk if provider not in conversations:
conversations[provider] = {}
conversations[provider][conversation_id] = chunk
yield self._format_json("conversation", conversation_id) yield self._format_json("conversation", conversation_id)
elif isinstance(chunk, Exception): elif isinstance(chunk, Exception):
logging.exception(chunk) logging.exception(chunk)

View File

@ -85,7 +85,7 @@ class Backend_Api(Api):
kwargs = self._prepare_conversation_kwargs(json_data, kwargs) kwargs = self._prepare_conversation_kwargs(json_data, kwargs)
return self.app.response_class( return self.app.response_class(
self._create_response_stream(kwargs, json_data.get("conversation_id")), self._create_response_stream(kwargs, json_data.get("conversation_id"), json_data.get("provider")),
mimetype='text/event-stream' mimetype='text/event-stream'
) )

View File

@ -20,4 +20,4 @@ aiohttp_socks
gpt4all gpt4all
pywebview pywebview
plyer plyer
pycryptodome cryptography

View File

@ -20,22 +20,27 @@ EXTRA_REQUIRE = {
"curl_cffi>=0.6.2", "curl_cffi>=0.6.2",
"certifi", "certifi",
"browser_cookie3", # get_cookies "browser_cookie3", # get_cookies
"PyExecJS", # GptForLove "PyExecJS", # GptForLove, Vercel
"duckduckgo-search>=5.0" ,# internet.search "duckduckgo-search>=5.0" ,# internet.search
"beautifulsoup4", # internet.search and bing.create_images "beautifulsoup4", # internet.search and bing.create_images
"brotli", # openai "brotli", # openai, bing
#"undetected-chromedriver>=3.5.5", # webdriver # webdriver
#"setuptools", # webdriver #"undetected-chromedriver>=3.5.5",
#"setuptools",
#"selenium-wire"
# webview
"pywebview", "pywebview",
"platformdirs", "platformdirs",
"plyer", "plyer",
"cryptography",
####
"aiohttp_socks", # proxy "aiohttp_socks", # proxy
"pillow", # image "pillow", # image
"cairosvg", # svg image "cairosvg", # svg image
"werkzeug", "flask", # gui "werkzeug", "flask", # gui
"loguru", "fastapi", "loguru", "fastapi", # api
"uvicorn", "nest_asyncio", # api "uvicorn", "nest_asyncio", # api
#"selenium-wire" "pycryptodome" # openai
], ],
"image": [ "image": [
"pillow", "pillow",
@ -51,9 +56,12 @@ EXTRA_REQUIRE = {
"webview": [ "webview": [
"webview", "webview",
"platformdirs", "platformdirs",
"plyer" "plyer",
"cryptography"
],
"openai": [
"pycryptodome"
], ],
"openai": [],
"api": [ "api": [
"loguru", "fastapi", "loguru", "fastapi",
"uvicorn", "nest_asyncio" "uvicorn", "nest_asyncio"