Merge pull request #1414 from hlohaus/lia

Patch event loop on win, Check event loop closed
This commit is contained in:
H Lohaus 2024-01-01 02:10:53 +01:00 committed by GitHub
commit e64a003323
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 28 deletions

View File

@ -13,6 +13,13 @@ if sys.version_info < (3, 10):
else:
from types import NoneType
# Change event loop policy on windows for curl_cffi
if sys.platform == 'win32':
if isinstance(
asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy
):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
class BaseProvider(ABC):
url: str
working: bool = False

View File

@ -7,9 +7,9 @@ import random
import string
import secrets
import os
from os import path
from asyncio import AbstractEventLoop
from platformdirs import user_config_dir
from os import path
from asyncio import AbstractEventLoop
from platformdirs import user_config_dir
from browser_cookie3 import (
chrome,
chromium,
@ -25,37 +25,33 @@ from browser_cookie3 import (
from ..typing import Dict, Messages
from .. import debug
# Change event loop policy on windows
if sys.platform == 'win32':
if isinstance(
asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy
):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
# Local Cookie Storage
_cookies: Dict[str, Dict[str, str]] = {}
# If event loop is already running, handle nested event loops
# If loop closed or not set, create new event loop.
# If event loop is already running, handle nested event loops.
# If "nest_asyncio" is installed, patch the event loop.
def get_event_loop() -> AbstractEventLoop:
try:
asyncio.get_running_loop()
loop = asyncio.get_event_loop()
loop._check_closed()
except RuntimeError:
try:
return asyncio.get_event_loop()
except RuntimeError:
asyncio.set_event_loop(asyncio.new_event_loop())
return asyncio.get_event_loop()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
event_loop = asyncio.get_event_loop()
if not hasattr(event_loop.__class__, "_nest_patched"):
# Is running event loop
asyncio.get_running_loop()
if not hasattr(loop.__class__, "_nest_patched"):
import nest_asyncio
nest_asyncio.apply(event_loop)
return event_loop
nest_asyncio.apply(loop)
except RuntimeError:
# No running event loop
pass
except ImportError:
raise RuntimeError(
'Use "create_async" instead of "create" function in a running event loop. Or install the "nest_asyncio" package.'
)
return loop
def init_cookies():
urls = [

View File

@ -7,9 +7,9 @@ from async_property import async_cached_property
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from ..base_provider import AsyncGeneratorProvider
from ..helper import get_event_loop, format_prompt
from ..helper import get_event_loop, format_prompt, get_cookies
from ...webdriver import get_browser
from ...typing import AsyncResult, Messages
from ...requests import StreamSession
@ -27,7 +27,7 @@ class OpenaiChat(AsyncGeneratorProvider):
needs_auth = True
supports_gpt_35_turbo = True
supports_gpt_4 = True
_access_token: str = None
_cookies: dict = {}
@classmethod
async def create(
@ -72,6 +72,7 @@ class OpenaiChat(AsyncGeneratorProvider):
proxy: str = None,
timeout: int = 120,
access_token: str = None,
cookies: dict = None,
auto_continue: bool = False,
history_disabled: bool = True,
action: str = "next",
@ -86,13 +87,18 @@ class OpenaiChat(AsyncGeneratorProvider):
raise ValueError(f"Model are not supported: {model}")
if not parent_id:
parent_id = str(uuid.uuid4())
if not cookies:
cookies = cls._cookies
if not access_token:
access_token = cls._access_token
if not cookies:
cls._cookies = cookies = get_cookies("chat.openai.com")
if "access_token" in cookies:
access_token = cookies["access_token"]
if not access_token:
login_url = os.environ.get("G4F_LOGIN_URL")
if login_url:
yield f"Please login: [ChatGPT]({login_url})\n\n"
access_token = cls._access_token = await cls.browse_access_token(proxy)
cls._cookies["access_token"] = access_token = await cls.browse_access_token(proxy)
headers = {
"Accept": "text/event-stream",
"Authorization": f"Bearer {access_token}",
@ -101,7 +107,8 @@ class OpenaiChat(AsyncGeneratorProvider):
proxies={"https": proxy},
impersonate="chrome110",
headers=headers,
timeout=timeout
timeout=timeout,
cookies=dict([(name, value) for name, value in cookies.items() if name == "_puid"])
) as session:
end_turn = EndTurn()
while not end_turn.is_end:
@ -170,7 +177,12 @@ class OpenaiChat(AsyncGeneratorProvider):
WebDriverWait(driver, 1200).until(
EC.presence_of_element_located((By.ID, "prompt-textarea"))
)
javascript = "return (await (await fetch('/api/auth/session')).json())['accessToken']"
javascript = """
access_token = (await (await fetch('/api/auth/session')).json())['accessToken'];
expires = new Date(); expires.setTime(expires.getTime() + 60 * 60 * 24 * 7); // One week
document.cookie = 'access_token=' + access_token + ';expires=' + expires.toUTCString() + ';path=/';
return access_token;
"""
return driver.execute_script(javascript)
finally:
driver.quit()