Fix unittests

This commit is contained in:
Heiner Lohaus 2024-01-14 15:35:58 +01:00
parent 32252def15
commit 55e5cf16cb
2 changed files with 41 additions and 13 deletions

View File

@ -16,4 +16,4 @@ jobs:
- name: Install requirements
- run: pip install -r requirements.txt
- name: Run tests
run: python -m etc.unittest.main
- run: python -m etc.unittest.main

View File

@ -49,7 +49,8 @@ class OpenaiChat(AsyncGeneratorProvider):
image: ImageType = None,
**kwargs
) -> Response:
"""Create a new conversation or continue an existing one
"""
Create a new conversation or continue an existing one
Args:
prompt: The user input to start or continue the conversation
@ -96,7 +97,8 @@ class OpenaiChat(AsyncGeneratorProvider):
headers: dict,
image: ImageType
) -> ImageResponse:
"""Upload an image to the service and get the download URL
"""
Upload an image to the service and get the download URL
Args:
session: The StreamSession object to use for requests
@ -149,7 +151,8 @@ class OpenaiChat(AsyncGeneratorProvider):
@classmethod
async def _get_default_model(cls, session: StreamSession, headers: dict):
"""Get the default model name from the service
"""
Get the default model name from the service
Args:
session: The StreamSession object to use for requests
@ -172,7 +175,8 @@ class OpenaiChat(AsyncGeneratorProvider):
@classmethod
def _create_messages(cls, prompt: str, image_response: ImageResponse = None):
"""Create a list of messages for the user input
"""
Create a list of messages for the user input
Args:
prompt: The user input as a string
@ -222,10 +226,20 @@ class OpenaiChat(AsyncGeneratorProvider):
"""
Retrieves the image response based on the message content.
:param session: The StreamSession object.
:param headers: HTTP headers for the request.
:param line: The line of response containing image information.
:return: An ImageResponse object with the image details.
This method processes the message content to extract image information and retrieves the
corresponding image from the backend API. It then returns an ImageResponse object containing
the image URL and the prompt used to generate the image.
Args:
session (StreamSession): The StreamSession object used for making HTTP requests.
headers (dict): HTTP headers to be used for the request.
line (dict): A dictionary representing the line of response that contains image information.
Returns:
ImageResponse: An object containing the image URL and the prompt, or None if no image is found.
Raises:
RuntimeError: If there'san error in downloading the image, including issues with the HTTP request or response.
"""
if "parts" not in line["message"]["content"]:
return
@ -244,6 +258,20 @@ class OpenaiChat(AsyncGeneratorProvider):
@classmethod
async def _delete_conversation(cls, session: StreamSession, headers: dict, conversation_id: str):
"""
Deletes a conversation by setting its visibility to False.
This method sends an HTTP PATCH request to update the visibility of a conversation.
It's used to effectively delete a conversation from being accessed or displayed in the future.
Args:
session (StreamSession): The StreamSession object used for making HTTP requests.
headers (dict): HTTP headers to be used for the request.
conversation_id (str): The unique identifier of the conversation to be deleted.
Raises:
HTTPError: If the HTTP request fails or returns an unsuccessful status code.
"""
async with session.patch(
f"{cls.url}/backend-api/conversation/{conversation_id}",
json={"is_visible": False},
@ -283,7 +311,7 @@ class OpenaiChat(AsyncGeneratorProvider):
history_disabled (bool): Flag to disable history and training.
action (str): Type of action ('next', 'continue', 'variant').
conversation_id (str): ID of the conversation.
parent_id (str): ID of the parent message.
parent_id (str): ID of the parent message.
image (ImageType): Image to include in the conversation.
response_fields (bool): Flag to include response fields in the output.
**kwargs: Additional keyword arguments.
@ -397,7 +425,7 @@ class OpenaiChat(AsyncGeneratorProvider):
await cls._delete_conversation(session, headers, conversation_id)
@classmethod
def _browse_access_token(cls, proxy: str = None) -> tuple[str, dict]:
def _browse_access_token(cls, proxy: str = None, timeout: int = 1200) -> tuple[str, dict]:
"""
Browse to obtain an access token.
@ -410,7 +438,7 @@ class OpenaiChat(AsyncGeneratorProvider):
driver = get_browser(proxy=proxy)
try:
driver.get(f"{cls.url}/")
WebDriverWait(driver, 1200).until(EC.presence_of_element_located((By.ID, "prompt-textarea")))
WebDriverWait(driver, timeout).until(EC.presence_of_element_located((By.ID, "prompt-textarea")))
access_token = driver.execute_script(
"let session = await fetch('/api/auth/session');"
"let data = await session.json();"
@ -471,7 +499,7 @@ class ResponseFields:
self.conversation_id = conversation_id
self.message_id = message_id
self._end_turn = end_turn
class Response():
"""
Class to encapsulate a response from the chat service.