Add check_running_loop requirement

Add create_async function in ChatCompletion
Use SelectorEventLoop on windows
This commit is contained in:
Heiner Lohaus 2023-09-20 14:52:50 +02:00
parent 82bd6f9180
commit 55577031d5
3 changed files with 87 additions and 39 deletions

View File

@ -25,6 +25,7 @@ class BaseProvider(ABC):
raise NotImplementedError() raise NotImplementedError()
@classmethod @classmethod
@property @property
def params(cls): def params(cls):
@ -46,6 +47,8 @@ class AsyncProvider(BaseProvider):
stream: bool = False, stream: bool = False,
**kwargs **kwargs
) -> CreateResult: ) -> CreateResult:
check_running_loop()
yield asyncio.run(cls.create_async(model, messages, **kwargs)) yield asyncio.run(cls.create_async(model, messages, **kwargs))
@staticmethod @staticmethod
@ -67,10 +70,17 @@ class AsyncGeneratorProvider(AsyncProvider):
stream: bool = True, stream: bool = True,
**kwargs **kwargs
) -> CreateResult: ) -> CreateResult:
loop = asyncio.new_event_loop() check_running_loop()
# Force use selector event loop on windows
loop = asyncio.SelectorEventLoop()
try: try:
asyncio.set_event_loop(loop) generator = cls.create_async_generator(
generator = cls.create_async_generator(model, messages, stream=stream, **kwargs) model,
messages,
stream=stream,
**kwargs
)
gen = generator.__aiter__() gen = generator.__aiter__()
while True: while True:
try: try:
@ -78,10 +88,8 @@ class AsyncGeneratorProvider(AsyncProvider):
except StopAsyncIteration: except StopAsyncIteration:
break break
finally: finally:
asyncio.set_event_loop(None)
loop.close() loop.close()
@classmethod @classmethod
async def create_async( async def create_async(
cls, cls,
@ -100,6 +108,11 @@ class AsyncGeneratorProvider(AsyncProvider):
) -> AsyncGenerator: ) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
# Don't create a new loop in a running loop
def check_running_loop():
if asyncio.events._get_running_loop() is not None:
raise RuntimeError(
'Use "create_async" instead of "create" function in a async loop.')
_cookies = {} _cookies = {}

View File

@ -1,20 +1,12 @@
from __future__ import annotations from __future__ import annotations
from g4f import models from g4f import models
from .Provider import BaseProvider from .Provider import BaseProvider, AsyncProvider
from .typing import Any, CreateResult, Union from .typing import Any, CreateResult, Union
import random import random
logging = False logging = False
class ChatCompletion: def get_model_and_provider(model: Union[models.Model, str], provider: type[BaseProvider], stream: bool):
@staticmethod
def create(
model : Union[models.Model, str],
messages : list[dict[str, str]],
provider : Union[type[BaseProvider], None] = None,
stream : bool = False,
auth : Union[str, None] = None, **kwargs: Any) -> Union[CreateResult, str]:
if isinstance(model, str): if isinstance(model, str):
if model in models.ModelUtils.convert: if model in models.ModelUtils.convert:
model = models.ModelUtils.convert[model] model = models.ModelUtils.convert[model]
@ -31,18 +23,11 @@ class ChatCompletion:
provider = model.best_provider provider = model.best_provider
if not provider: if not provider:
raise Exception(f'No provider found') raise Exception(f'No provider found for model: {model}')
if not provider.working: if not provider.working:
raise Exception(f'{provider.__name__} is not working') raise Exception(f'{provider.__name__} is not working')
if provider.needs_auth and not auth:
raise Exception(
f'ValueError: {provider.__name__} requires authentication (use auth=\'cookie or token or jwt ...\' param)')
if provider.needs_auth:
kwargs['auth'] = auth
if not provider.supports_stream and stream: if not provider.supports_stream and stream:
raise Exception( raise Exception(
f'ValueError: {provider.__name__} does not support "stream" argument') f'ValueError: {provider.__name__} does not support "stream" argument')
@ -50,5 +35,42 @@ class ChatCompletion:
if logging: if logging:
print(f'Using {provider.__name__} provider') print(f'Using {provider.__name__} provider')
return model, provider
class ChatCompletion:
@staticmethod
def create(
model : Union[models.Model, str],
messages : list[dict[str, str]],
provider : Union[type[BaseProvider], None] = None,
stream : bool = False,
auth : Union[str, None] = None,
**kwargs
) -> Union[CreateResult, str]:
model, provider = get_model_and_provider(model, provider, stream)
if provider.needs_auth and not auth:
raise Exception(
f'ValueError: {provider.__name__} requires authentication (use auth=\'cookie or token or jwt ...\' param)')
if provider.needs_auth:
kwargs['auth'] = auth
result = provider.create_completion(model.name, messages, stream, **kwargs) result = provider.create_completion(model.name, messages, stream, **kwargs)
return result if stream else ''.join(result) return result if stream else ''.join(result)
@staticmethod
async def create_async(
model : Union[models.Model, str],
messages : list[dict[str, str]],
provider : Union[type[BaseProvider], None] = None,
**kwargs
) -> str:
model, provider = get_model_and_provider(model, provider, False)
if not issubclass(provider, AsyncProvider):
raise Exception(f"Provider: {provider.__name__} doesn't support create_async")
return await provider.create_async(model.name, messages, **kwargs)

View File

@ -3,10 +3,23 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent)) sys.path.append(str(Path(__file__).parent.parent))
import g4f import g4f, asyncio
response = g4f.ChatCompletion.create( print("create:", end=" ", flush=True)
for response in g4f.ChatCompletion.create(
model=g4f.models.gpt_35_turbo, model=g4f.models.gpt_35_turbo,
messages=[{"role": "user", "content": "hello, are you GPT 4?"}] provider=g4f.Provider.GptGo,
messages=[{"role": "user", "content": "hello!"}],
):
print(response, end="", flush=True)
print()
async def run_async():
response = await g4f.ChatCompletion.create_async(
model=g4f.models.gpt_35_turbo,
provider=g4f.Provider.GptGo,
messages=[{"role": "user", "content": "hello!"}],
) )
print(response) print("create_async:", response)
asyncio.run(run_async())