Improve event loop

This commit is contained in:
Heiner Lohaus 2023-09-18 07:15:43 +02:00
parent e8d7bcd045
commit 3b8dfff974
3 changed files with 49 additions and 47 deletions

View File

@ -51,7 +51,9 @@ class Ylokh(AsyncGeneratorProvider):
if stream:
async for line in response.content:
line = line.decode()
if line.startswith("data: ") and not line.startswith("data: [DONE]"):
if line.startswith("data: "):
if line.startswith("data: [DONE]"):
break
line = json.loads(line[6:-1])
content = line["choices"][0]["delta"].get("content")
if content:
@ -71,6 +73,7 @@ class Ylokh(AsyncGeneratorProvider):
("stream", "bool"),
("proxy", "str"),
("temperature", "float"),
("top_p", "float"),
]
param = ", ".join([": ".join(p) for p in params])
return f"g4f.provider.{cls.__name__} supports: ({param})"

View File

@ -35,30 +35,6 @@ class BaseProvider(ABC):
]
param = ", ".join([": ".join(p) for p in params])
return f"g4f.provider.{cls.__name__} supports: ({param})"
_cookies = {}
def get_cookies(cookie_domain: str) -> dict:
if cookie_domain not in _cookies:
_cookies[cookie_domain] = {}
try:
for cookie in browser_cookie3.load(cookie_domain):
_cookies[cookie_domain][cookie.name] = cookie.value
except:
pass
return _cookies[cookie_domain]
def format_prompt(messages: list[dict[str, str]], add_special_tokens=False):
if add_special_tokens or len(messages) > 1:
formatted = "\n".join(
["%s: %s" % ((message["role"]).capitalize(), message["content"]) for message in messages]
)
return f"{formatted}\nAssistant:"
else:
return messages.pop()["content"]
class AsyncProvider(BaseProvider):
@ -67,8 +43,9 @@ class AsyncProvider(BaseProvider):
cls,
model: str,
messages: list[dict[str, str]],
stream: bool = False, **kwargs: Any) -> CreateResult:
stream: bool = False,
**kwargs
) -> CreateResult:
yield asyncio.run(cls.create_async(model, messages, **kwargs))
@staticmethod
@ -90,7 +67,20 @@ class AsyncGeneratorProvider(AsyncProvider):
stream: bool = True,
**kwargs
) -> CreateResult:
yield from run_generator(cls.create_async_generator(model, messages, stream=stream, **kwargs))
loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop)
generator = cls.create_async_generator(model, messages, stream=stream, **kwargs)
gen = generator.__aiter__()
while True:
try:
yield loop.run_until_complete(gen.__anext__())
except StopAsyncIteration:
break
finally:
asyncio.set_event_loop(None)
loop.close()
@classmethod
async def create_async(
@ -99,27 +89,36 @@ class AsyncGeneratorProvider(AsyncProvider):
messages: list[dict[str, str]],
**kwargs
) -> str:
chunks = [chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)]
if chunks:
return "".join(chunks)
return "".join([chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)])
@staticmethod
@abstractmethod
def create_async_generator(
model: str,
messages: list[dict[str, str]],
**kwargs
) -> AsyncGenerator:
model: str,
messages: list[dict[str, str]],
**kwargs
) -> AsyncGenerator:
raise NotImplementedError()
def run_generator(generator: AsyncGenerator[Union[Any, str], Any]):
loop = asyncio.new_event_loop()
gen = generator.__aiter__()
_cookies = {}
while True:
def get_cookies(cookie_domain: str) -> dict:
if cookie_domain not in _cookies:
_cookies[cookie_domain] = {}
try:
yield loop.run_until_complete(gen.__anext__())
for cookie in browser_cookie3.load(cookie_domain):
_cookies[cookie_domain][cookie.name] = cookie.value
except:
pass
return _cookies[cookie_domain]
except StopAsyncIteration:
break
def format_prompt(messages: list[dict[str, str]], add_special_tokens=False):
if add_special_tokens or len(messages) > 1:
formatted = "\n".join(
["%s: %s" % ((message["role"]).capitalize(), message["content"]) for message in messages]
)
return f"{formatted}\nAssistant:"
else:
return messages[0]["content"]

View File

@ -17,7 +17,7 @@ _providers = [
g4f.Provider.Bard
]
_instruct = "Hello, tell about you in one sentence."
_instruct = "Hello, are you GPT 4?."
_example = """
OpenaiChat: Hello! How can I assist you today? 2.0 secs
@ -39,14 +39,14 @@ No Stream Total: 10.14 secs
print("Bing: ", end="")
for response in log_time_yield(
g4f.ChatCompletion.create,
model=g4f.models.gpt_35_turbo,
model=g4f.models.default,
messages=[{"role": "user", "content": _instruct}],
provider=g4f.Provider.Bing,
#cookies=g4f.get_cookies(".huggingface.co"),
#stream=True,
stream=True,
auth=True
):
print(response, end="")
print(response, end="", flush=True)
print()
print()
@ -75,7 +75,7 @@ def run_stream():
model=None,
messages=[{"role": "user", "content": _instruct}],
):
print(response, end="")
print(response, end="", flush=True)
print()
print("Stream Total:", log_time(run_stream))
print()