Add check_running_loop requirement

Add create_async function in ChatCompletion
Use SelectorEventLoop on windows
This commit is contained in:
Heiner Lohaus 2023-09-20 14:52:50 +02:00
parent 82bd6f9180
commit 55577031d5
3 changed files with 87 additions and 39 deletions

View File

@ -25,6 +25,7 @@ class BaseProvider(ABC):
raise NotImplementedError() raise NotImplementedError()
@classmethod @classmethod
@property @property
def params(cls): def params(cls):
@ -46,6 +47,8 @@ class AsyncProvider(BaseProvider):
stream: bool = False, stream: bool = False,
**kwargs **kwargs
) -> CreateResult: ) -> CreateResult:
check_running_loop()
yield asyncio.run(cls.create_async(model, messages, **kwargs)) yield asyncio.run(cls.create_async(model, messages, **kwargs))
@staticmethod @staticmethod
@ -67,10 +70,17 @@ class AsyncGeneratorProvider(AsyncProvider):
stream: bool = True, stream: bool = True,
**kwargs **kwargs
) -> CreateResult: ) -> CreateResult:
loop = asyncio.new_event_loop() check_running_loop()
# Force use selector event loop on windows
loop = asyncio.SelectorEventLoop()
try: try:
asyncio.set_event_loop(loop) generator = cls.create_async_generator(
generator = cls.create_async_generator(model, messages, stream=stream, **kwargs) model,
messages,
stream=stream,
**kwargs
)
gen = generator.__aiter__() gen = generator.__aiter__()
while True: while True:
try: try:
@ -78,10 +88,8 @@ class AsyncGeneratorProvider(AsyncProvider):
except StopAsyncIteration: except StopAsyncIteration:
break break
finally: finally:
asyncio.set_event_loop(None)
loop.close() loop.close()
@classmethod @classmethod
async def create_async( async def create_async(
cls, cls,
@ -100,6 +108,11 @@ class AsyncGeneratorProvider(AsyncProvider):
) -> AsyncGenerator: ) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
# Don't create a new loop in a running loop
def check_running_loop():
if asyncio.events._get_running_loop() is not None:
raise RuntimeError(
'Use "create_async" instead of "create" function in a async loop.')
_cookies = {} _cookies = {}

View File

@ -1,11 +1,42 @@
from __future__ import annotations from __future__ import annotations
from g4f import models from g4f import models
from .Provider import BaseProvider from .Provider import BaseProvider, AsyncProvider
from .typing import Any, CreateResult, Union from .typing import Any, CreateResult, Union
import random import random
logging = False logging = False
def get_model_and_provider(model: Union[models.Model, str], provider: type[BaseProvider], stream: bool):
if isinstance(model, str):
if model in models.ModelUtils.convert:
model = models.ModelUtils.convert[model]
else:
raise Exception(f'The model: {model} does not exist')
if not provider:
if isinstance(model.best_provider, list):
if stream:
provider = random.choice([p for p in model.best_provider if p.supports_stream])
else:
provider = random.choice(model.best_provider)
else:
provider = model.best_provider
if not provider:
raise Exception(f'No provider found for model: {model}')
if not provider.working:
raise Exception(f'{provider.__name__} is not working')
if not provider.supports_stream and stream:
raise Exception(
f'ValueError: {provider.__name__} does not support "stream" argument')
if logging:
print(f'Using {provider.__name__} provider')
return model, provider
class ChatCompletion: class ChatCompletion:
@staticmethod @staticmethod
def create( def create(
@ -13,28 +44,11 @@ class ChatCompletion:
messages : list[dict[str, str]], messages : list[dict[str, str]],
provider : Union[type[BaseProvider], None] = None, provider : Union[type[BaseProvider], None] = None,
stream : bool = False, stream : bool = False,
auth : Union[str, None] = None, **kwargs: Any) -> Union[CreateResult, str]: auth : Union[str, None] = None,
**kwargs
) -> Union[CreateResult, str]:
if isinstance(model, str): model, provider = get_model_and_provider(model, provider, stream)
if model in models.ModelUtils.convert:
model = models.ModelUtils.convert[model]
else:
raise Exception(f'The model: {model} does not exist')
if not provider:
if isinstance(model.best_provider, list):
if stream:
provider = random.choice([p for p in model.best_provider if p.supports_stream])
else:
provider = random.choice(model.best_provider)
else:
provider = model.best_provider
if not provider:
raise Exception(f'No provider found')
if not provider.working:
raise Exception(f'{provider.__name__} is not working')
if provider.needs_auth and not auth: if provider.needs_auth and not auth:
raise Exception( raise Exception(
@ -43,12 +57,20 @@ class ChatCompletion:
if provider.needs_auth: if provider.needs_auth:
kwargs['auth'] = auth kwargs['auth'] = auth
if not provider.supports_stream and stream:
raise Exception(
f'ValueError: {provider.__name__} does not support "stream" argument')
if logging:
print(f'Using {provider.__name__} provider')
result = provider.create_completion(model.name, messages, stream, **kwargs) result = provider.create_completion(model.name, messages, stream, **kwargs)
return result if stream else ''.join(result) return result if stream else ''.join(result)
@staticmethod
async def create_async(
model : Union[models.Model, str],
messages : list[dict[str, str]],
provider : Union[type[BaseProvider], None] = None,
**kwargs
) -> str:
model, provider = get_model_and_provider(model, provider, False)
if not issubclass(provider, AsyncProvider):
raise Exception(f"Provider: {provider.__name__} doesn't support create_async")
return await provider.create_async(model.name, messages, **kwargs)

View File

@ -3,10 +3,23 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent)) sys.path.append(str(Path(__file__).parent.parent))
import g4f import g4f, asyncio
response = g4f.ChatCompletion.create( print("create:", end=" ", flush=True)
for response in g4f.ChatCompletion.create(
model=g4f.models.gpt_35_turbo, model=g4f.models.gpt_35_turbo,
messages=[{"role": "user", "content": "hello, are you GPT 4?"}] provider=g4f.Provider.GptGo,
) messages=[{"role": "user", "content": "hello!"}],
print(response) ):
print(response, end="", flush=True)
print()
async def run_async():
response = await g4f.ChatCompletion.create_async(
model=g4f.models.gpt_35_turbo,
provider=g4f.Provider.GptGo,
messages=[{"role": "user", "content": "hello!"}],
)
print("create_async:", response)
asyncio.run(run_async())