Update async.py

This commit is contained in:
H Lohaus 2024-04-06 21:01:27 +02:00 committed by GitHub
parent b4399866ee
commit 6e3f350f52
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 18 additions and 30 deletions

View File

@ -6,49 +6,38 @@ import time
import random import random
import string import string
from .types import BaseProvider, ProviderType, FinishReason
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
from .typing import Union, Iterator, Messages, ImageType from ..typing import Union, Iterator, Messages, ImageType, AsyncIerator
from .providers.types import BaseProvider, ProviderType, FinishReason from ..image import ImageResponse as ImageProviderResponse
from .image import ImageResponse as ImageProviderResponse from ..errors import NoImageResponseError, RateLimitError, MissingAuthError
from .errors import NoImageResponseError, RateLimitError, MissingAuthError from .. import get_model_and_provider, get_last_provider
from . import get_model_and_provider, get_last_provider from .helper import read_json
from .Provider.BingCreateImages import BingCreateImages from .Provider.BingCreateImages import BingCreateImages
from .Provider.needs_auth import Gemini, OpenaiChat from .Provider.needs_auth import Gemini, OpenaiChat
from .Provider.You import You from ..Provider.You import You
from .helper import read_json
def iter_response( async def iter_response(
response: iter[str], response: AsyncIerator[str],
stream: bool, stream: bool,
response_format: dict = None, response_format: dict = None,
max_tokens: int = None, max_tokens: int = None,
stop: list = None stop: list = None
) -> IterResponse: ) -> AsyncIterResponse:
content = "" content = ""
finish_reason = None finish_reason = None
completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28)) completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
for idx, chunk in enumerate(response): count: int = 0
async for idx, chunk in response:
if isinstance(chunk, FinishReason): if isinstance(chunk, FinishReason):
finish_reason = chunk.reason finish_reason = chunk.reason
break break
content += str(chunk) content += str(chunk)
if max_tokens is not None and idx + 1 >= max_tokens: count += 1
if max_tokens is not None and count >= max_tokens:
finish_reason = "length" finish_reason = "length"
first = -1 first, content, chunk = find_stop(stop, content, chunk)
word = None
if stop is not None:
for word in list(stop):
first = content.find(word)
if first != -1:
content = content[:first]
break
if stream and first != -1:
first = chunk.find(word)
if first != -1:
chunk = chunk[:first]
else:
first = 0
if first != -1: if first != -1:
finish_reason = "stop" finish_reason = "stop"
if stream: if stream:
@ -64,16 +53,15 @@ def iter_response(
content = read_json(content) content = read_json(content)
yield ChatCompletion(content, finish_reason, completion_id, int(time.time())) yield ChatCompletion(content, finish_reason, completion_id, int(time.time()))
def iter_append_model_and_provider(response: IterResponse) -> IterResponse: async def iter_append_model_and_provider(response: AsyncIterResponse) -> IterResponse:
last_provider = None last_provider = None
for chunk in response: async for chunk in response:
last_provider = get_last_provider(True) if last_provider is None else last_provider last_provider = get_last_provider(True) if last_provider is None else last_provider
chunk.model = last_provider.get("model") chunk.model = last_provider.get("model")
chunk.provider = last_provider.get("name") chunk.provider = last_provider.get("name")
yield chunk yield chunk
class Client(): class Client():
def __init__( def __init__(
self, self,
api_key: str = None, api_key: str = None,
@ -222,4 +210,4 @@ class Images():
result = ImagesResponse([Image(image)for image in result]) result = ImagesResponse([Image(image)for image in result])
if result is None: if result is None:
raise NoImageResponseError() raise NoImageResponseError()
return result return result