Refactor code with AI

Add doctypes to many functions
Add file upload for text files
Add alternative url to FreeChatgpt
Add webp to allowed image types
This commit is contained in:
Heiner Lohaus 2024-01-14 07:45:41 +01:00
parent ceed364cb1
commit 5756586cde
19 changed files with 1398 additions and 562 deletions

View File

@ -15,12 +15,18 @@ from .bing.upload_image import upload_image
from .bing.create_images import create_images from .bing.create_images import create_images
from .bing.conversation import Conversation, create_conversation, delete_conversation from .bing.conversation import Conversation, create_conversation, delete_conversation
class Tones(): class Tones:
"""
Defines the different tone options for the Bing provider.
"""
creative = "Creative" creative = "Creative"
balanced = "Balanced" balanced = "Balanced"
precise = "Precise" precise = "Precise"
class Bing(AsyncGeneratorProvider): class Bing(AsyncGeneratorProvider):
"""
Bing provider for generating responses using the Bing API.
"""
url = "https://bing.com/chat" url = "https://bing.com/chat"
working = True working = True
supports_message_history = True supports_message_history = True
@ -38,6 +44,19 @@ class Bing(AsyncGeneratorProvider):
web_search: bool = False, web_search: bool = False,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
"""
Creates an asynchronous generator for producing responses from Bing.
:param model: The model to use.
:param messages: Messages to process.
:param proxy: Proxy to use for requests.
:param timeout: Timeout for requests.
:param cookies: Cookies for the session.
:param tone: The tone of the response.
:param image: The image type to be used.
:param web_search: Flag to enable or disable web search.
:return: An asynchronous result object.
"""
if len(messages) < 2: if len(messages) < 2:
prompt = messages[0]["content"] prompt = messages[0]["content"]
context = None context = None
@ -56,65 +75,48 @@ class Bing(AsyncGeneratorProvider):
return stream_generate(prompt, tone, image, context, proxy, cookies, web_search, gpt4_turbo, timeout) return stream_generate(prompt, tone, image, context, proxy, cookies, web_search, gpt4_turbo, timeout)
def create_context(messages: Messages): def create_context(messages: Messages) -> str:
"""
Creates a context string from a list of messages.
:param messages: A list of message dictionaries.
:return: A string representing the context created from the messages.
"""
return "".join( return "".join(
f"[{message['role']}]" + ("(#message)" if message['role']!="system" else "(#additional_instructions)") + f"\n{message['content']}\n\n" f"[{message['role']}]" + ("(#message)" if message['role'] != "system" else "(#additional_instructions)") + f"\n{message['content']}\n\n"
for message in messages for message in messages
) )
class Defaults: class Defaults:
"""
Default settings and configurations for the Bing provider.
"""
delimiter = "\x1e" delimiter = "\x1e"
ip_address = f"13.{random.randint(104, 107)}.{random.randint(0, 255)}.{random.randint(0, 255)}" ip_address = f"13.{random.randint(104, 107)}.{random.randint(0, 255)}.{random.randint(0, 255)}"
# List of allowed message types for Bing responses
allowedMessageTypes = [ allowedMessageTypes = [
"ActionRequest", "ActionRequest", "Chat", "Context", "Progress", "SemanticSerp",
"Chat", "GenerateContentQuery", "SearchQuery", "RenderCardRequest"
"Context",
# "Disengaged", unwanted
"Progress",
# "AdsQuery", unwanted
"SemanticSerp",
"GenerateContentQuery",
"SearchQuery",
# The following message types should not be added so that it does not flood with
# useless messages (such as "Analyzing images" or "Searching the web") while it's retrieving the AI response
# "InternalSearchQuery",
# "InternalSearchResult",
"RenderCardRequest",
# "RenderContentRequest"
] ]
sliceIds = [ sliceIds = [
'abv2', 'abv2', 'srdicton', 'convcssclick', 'stylewv2', 'contctxp2tf',
'srdicton', '802fluxv1pc_a', '806log2sphs0', '727savemem', '277teditgnds0', '207hlthgrds0'
'convcssclick',
'stylewv2',
'contctxp2tf',
'802fluxv1pc_a',
'806log2sphs0',
'727savemem',
'277teditgnds0',
'207hlthgrds0',
] ]
# Default location settings
location = { location = {
"locale": "en-US", "locale": "en-US", "market": "en-US", "region": "US",
"market": "en-US", "locationHints": [{
"region": "US", "country": "United States", "state": "California", "city": "Los Angeles",
"locationHints": [ "timezoneoffset": 8, "countryConfidence": 8,
{ "Center": {"Latitude": 34.0536909, "Longitude": -118.242766},
"country": "United States", "RegionType": 2, "SourceType": 1
"state": "California", }],
"city": "Los Angeles",
"timezoneoffset": 8,
"countryConfidence": 8,
"Center": {"Latitude": 34.0536909, "Longitude": -118.242766},
"RegionType": 2,
"SourceType": 1,
}
],
} }
# Default headers for requests
headers = { headers = {
'accept': '*/*', 'accept': '*/*',
'accept-language': 'en-US,en;q=0.9', 'accept-language': 'en-US,en;q=0.9',
@ -139,23 +141,13 @@ class Defaults:
} }
optionsSets = [ optionsSets = [
'nlu_direct_response_filter', 'nlu_direct_response_filter', 'deepleo', 'disable_emoji_spoken_text',
'deepleo', 'responsible_ai_policy_235', 'enablemm', 'iyxapbing', 'iycapbing',
'disable_emoji_spoken_text', 'gencontentv3', 'fluxsrtrunc', 'fluxtrunc', 'fluxv1', 'rai278',
'responsible_ai_policy_235', 'replaceurl', 'eredirecturl', 'nojbfedge'
'enablemm',
'iyxapbing',
'iycapbing',
'gencontentv3',
'fluxsrtrunc',
'fluxtrunc',
'fluxv1',
'rai278',
'replaceurl',
'eredirecturl',
'nojbfedge'
] ]
# Default cookies
cookies = { cookies = {
'SRCHD' : 'AF=NOFORM', 'SRCHD' : 'AF=NOFORM',
'PPLState' : '1', 'PPLState' : '1',
@ -166,6 +158,12 @@ class Defaults:
} }
def format_message(msg: dict) -> str: def format_message(msg: dict) -> str:
"""
Formats a message dictionary into a JSON string with a delimiter.
:param msg: The message dictionary to format.
:return: A formatted string representation of the message.
"""
return json.dumps(msg, ensure_ascii=False) + Defaults.delimiter return json.dumps(msg, ensure_ascii=False) + Defaults.delimiter
def create_message( def create_message(
@ -177,7 +175,20 @@ def create_message(
web_search: bool = False, web_search: bool = False,
gpt4_turbo: bool = False gpt4_turbo: bool = False
) -> str: ) -> str:
"""
Creates a message for the Bing API with specified parameters.
:param conversation: The current conversation object.
:param prompt: The user's input prompt.
:param tone: The desired tone for the response.
:param context: Additional context for the prompt.
:param image_response: The response if an image is involved.
:param web_search: Flag to enable web search.
:param gpt4_turbo: Flag to enable GPT-4 Turbo.
:return: A formatted string message for the Bing API.
"""
options_sets = Defaults.optionsSets options_sets = Defaults.optionsSets
# Append tone-specific options
if tone == Tones.creative: if tone == Tones.creative:
options_sets.append("h3imaginative") options_sets.append("h3imaginative")
elif tone == Tones.precise: elif tone == Tones.precise:
@ -186,54 +197,49 @@ def create_message(
options_sets.append("galileo") options_sets.append("galileo")
else: else:
options_sets.append("harmonyv3") options_sets.append("harmonyv3")
# Additional configurations based on parameters
if not web_search: if not web_search:
options_sets.append("nosearchall") options_sets.append("nosearchall")
if gpt4_turbo: if gpt4_turbo:
options_sets.append("dlgpt4t") options_sets.append("dlgpt4t")
request_id = str(uuid.uuid4()) request_id = str(uuid.uuid4())
struct = { struct = {
'arguments': [ 'arguments': [{
{ 'source': 'cib', 'optionsSets': options_sets,
'source': 'cib', 'allowedMessageTypes': Defaults.allowedMessageTypes,
'optionsSets': options_sets, 'sliceIds': Defaults.sliceIds,
'allowedMessageTypes': Defaults.allowedMessageTypes, 'traceId': os.urandom(16).hex(), 'isStartOfSession': True,
'sliceIds': Defaults.sliceIds, 'requestId': request_id,
'traceId': os.urandom(16).hex(), 'message': {
'isStartOfSession': True, **Defaults.location,
'author': 'user',
'inputMethod': 'Keyboard',
'text': prompt,
'messageType': 'Chat',
'requestId': request_id, 'requestId': request_id,
'message': {**Defaults.location, **{ 'messageId': request_id
'author': 'user', },
'inputMethod': 'Keyboard', "verbosity": "verbose",
'text': prompt, "scenario": "SERP",
'messageType': 'Chat', "plugins": [{"id": "c310c353-b9f0-4d76-ab0d-1dd5e979cf68", "category": 1}] if web_search else [],
'requestId': request_id, 'tone': tone,
'messageId': request_id, 'spokenTextMode': 'None',
}}, 'conversationId': conversation.conversationId,
"verbosity": "verbose", 'participant': {'id': conversation.clientId},
"scenario": "SERP", }],
"plugins":[
{"id":"c310c353-b9f0-4d76-ab0d-1dd5e979cf68", "category": 1}
] if web_search else [],
'tone': tone,
'spokenTextMode': 'None',
'conversationId': conversation.conversationId,
'participant': {
'id': conversation.clientId
},
}
],
'invocationId': '1', 'invocationId': '1',
'target': 'chat', 'target': 'chat',
'type': 4 'type': 4
} }
if image_response.get('imageUrl') and image_response.get('originalImageUrl'):
if image_response and image_response.get('imageUrl') and image_response.get('originalImageUrl'):
struct['arguments'][0]['message']['originalImageUrl'] = image_response.get('originalImageUrl') struct['arguments'][0]['message']['originalImageUrl'] = image_response.get('originalImageUrl')
struct['arguments'][0]['message']['imageUrl'] = image_response.get('imageUrl') struct['arguments'][0]['message']['imageUrl'] = image_response.get('imageUrl')
struct['arguments'][0]['experienceType'] = None struct['arguments'][0]['experienceType'] = None
struct['arguments'][0]['attachedFileInfo'] = {"fileName": None, "fileType": None} struct['arguments'][0]['attachedFileInfo'] = {"fileName": None, "fileType": None}
if context: if context:
struct['arguments'][0]['previousMessages'] = [{ struct['arguments'][0]['previousMessages'] = [{
"author": "user", "author": "user",
@ -242,30 +248,46 @@ def create_message(
"messageType": "Context", "messageType": "Context",
"messageId": "discover-web--page-ping-mriduna-----" "messageId": "discover-web--page-ping-mriduna-----"
}] }]
return format_message(struct) return format_message(struct)
async def stream_generate( async def stream_generate(
prompt: str, prompt: str,
tone: str, tone: str,
image: ImageType = None, image: ImageType = None,
context: str = None, context: str = None,
proxy: str = None, proxy: str = None,
cookies: dict = None, cookies: dict = None,
web_search: bool = False, web_search: bool = False,
gpt4_turbo: bool = False, gpt4_turbo: bool = False,
timeout: int = 900 timeout: int = 900
): ):
"""
Asynchronously streams generated responses from the Bing API.
:param prompt: The user's input prompt.
:param tone: The desired tone for the response.
:param image: The image type involved in the response.
:param context: Additional context for the prompt.
:param proxy: Proxy settings for the request.
:param cookies: Cookies for the session.
:param web_search: Flag to enable web search.
:param gpt4_turbo: Flag to enable GPT-4 Turbo.
:param timeout: Timeout for the request.
:return: An asynchronous generator yielding responses.
"""
headers = Defaults.headers headers = Defaults.headers
if cookies: if cookies:
headers["Cookie"] = "; ".join(f"{k}={v}" for k, v in cookies.items()) headers["Cookie"] = "; ".join(f"{k}={v}" for k, v in cookies.items())
async with ClientSession( async with ClientSession(
timeout=ClientTimeout(total=timeout), timeout=ClientTimeout(total=timeout), headers=headers
headers=headers
) as session: ) as session:
conversation = await create_conversation(session, proxy) conversation = await create_conversation(session, proxy)
image_response = await upload_image(session, image, tone, proxy) if image else None image_response = await upload_image(session, image, tone, proxy) if image else None
if image_response: if image_response:
yield image_response yield image_response
try: try:
async with session.ws_connect( async with session.ws_connect(
'wss://sydney.bing.com/sydney/ChatHub', 'wss://sydney.bing.com/sydney/ChatHub',
@ -289,7 +311,7 @@ async def stream_generate(
if obj is None or not obj: if obj is None or not obj:
continue continue
response = json.loads(obj) response = json.loads(obj)
if response.get('type') == 1 and response['arguments'][0].get('messages'): if response and response.get('type') == 1 and response['arguments'][0].get('messages'):
message = response['arguments'][0]['messages'][0] message = response['arguments'][0]['messages'][0]
image_response = None image_response = None
if (message['contentOrigin'] != 'Apology'): if (message['contentOrigin'] != 'Apology'):

View File

@ -1,16 +1,20 @@
from __future__ import annotations from __future__ import annotations
import json import json, random
from aiohttp import ClientSession from aiohttp import ClientSession
from ..typing import AsyncResult, Messages from ..typing import AsyncResult, Messages
from .base_provider import AsyncGeneratorProvider from .base_provider import AsyncGeneratorProvider
models = { models = {
"claude-v2": "claude-2.0", "claude-v2": "claude-2.0",
"gemini-pro": "google-gemini-pro" "claude-v2.1":"claude-2.1",
"gemini-pro": "google-gemini-pro"
} }
urls = [
"https://free.chatgpt.org.uk",
"https://ai.chatgpt.org.uk"
]
class FreeChatgpt(AsyncGeneratorProvider): class FreeChatgpt(AsyncGeneratorProvider):
url = "https://free.chatgpt.org.uk" url = "https://free.chatgpt.org.uk"
@ -31,6 +35,7 @@ class FreeChatgpt(AsyncGeneratorProvider):
model = models[model] model = models[model]
elif not model: elif not model:
model = "gpt-3.5-turbo" model = "gpt-3.5-turbo"
url = random.choice(urls)
headers = { headers = {
"Accept": "application/json, text/event-stream", "Accept": "application/json, text/event-stream",
"Content-Type":"application/json", "Content-Type":"application/json",
@ -55,7 +60,7 @@ class FreeChatgpt(AsyncGeneratorProvider):
"top_p":1, "top_p":1,
**kwargs **kwargs
} }
async with session.post(f'{cls.url}/api/openai/v1/chat/completions', json=data, proxy=proxy) as response: async with session.post(f'{url}/api/openai/v1/chat/completions', json=data, proxy=proxy) as response:
response.raise_for_status() response.raise_for_status()
started = False started = False
async for line in response.content: async for line in response.content:

View File

@ -1,28 +1,29 @@
from __future__ import annotations from __future__ import annotations
import sys import sys
import asyncio import asyncio
from asyncio import AbstractEventLoop from asyncio import AbstractEventLoop
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from abc import abstractmethod from abc import abstractmethod
from inspect import signature, Parameter from inspect import signature, Parameter
from .helper import get_event_loop, get_cookies, format_prompt from .helper import get_event_loop, get_cookies, format_prompt
from ..typing import CreateResult, AsyncResult, Messages from ..typing import CreateResult, AsyncResult, Messages
from ..base_provider import BaseProvider from ..base_provider import BaseProvider
if sys.version_info < (3, 10): if sys.version_info < (3, 10):
NoneType = type(None) NoneType = type(None)
else: else:
from types import NoneType from types import NoneType
# Change event loop policy on windows for curl_cffi # Set Windows event loop policy for better compatibility with asyncio and curl_cffi
if sys.platform == 'win32': if sys.platform == 'win32':
if isinstance( if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy):
asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy
):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
class AbstractProvider(BaseProvider): class AbstractProvider(BaseProvider):
"""
Abstract class for providing asynchronous functionality to derived classes.
"""
@classmethod @classmethod
async def create_async( async def create_async(
cls, cls,
@ -33,62 +34,50 @@ class AbstractProvider(BaseProvider):
executor: ThreadPoolExecutor = None, executor: ThreadPoolExecutor = None,
**kwargs **kwargs
) -> str: ) -> str:
if not loop: """
loop = get_event_loop() Asynchronously creates a result based on the given model and messages.
"""
loop = loop or get_event_loop()
def create_func() -> str: def create_func() -> str:
return "".join(cls.create_completion( return "".join(cls.create_completion(model, messages, False, **kwargs))
model,
messages,
False,
**kwargs
))
return await asyncio.wait_for( return await asyncio.wait_for(
loop.run_in_executor( loop.run_in_executor(executor, create_func),
executor,
create_func
),
timeout=kwargs.get("timeout", 0) timeout=kwargs.get("timeout", 0)
) )
@classmethod @classmethod
@property @property
def params(cls) -> str: def params(cls) -> str:
if issubclass(cls, AsyncGeneratorProvider): """
sig = signature(cls.create_async_generator) Returns the parameters supported by the provider.
elif issubclass(cls, AsyncProvider): """
sig = signature(cls.create_async) sig = signature(
else: cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else
sig = signature(cls.create_completion) cls.create_async if issubclass(cls, AsyncProvider) else
cls.create_completion
)
def get_type_name(annotation: type) -> str: def get_type_name(annotation: type) -> str:
if hasattr(annotation, "__name__"): return annotation.__name__ if hasattr(annotation, "__name__") else str(annotation)
annotation = annotation.__name__
elif isinstance(annotation, NoneType):
annotation = "None"
return str(annotation)
args = "" args = ""
for name, param in sig.parameters.items(): for name, param in sig.parameters.items():
if name in ("self", "kwargs"): if name in ("self", "kwargs") or (name == "stream" and not cls.supports_stream):
continue continue
if name == "stream" and not cls.supports_stream: args += f"\n {name}"
continue args += f": {get_type_name(param.annotation)}" if param.annotation is not Parameter.empty else ""
if args: args += f' = "{param.default}"' if param.default == "" else f" = {param.default}" if param.default is not Parameter.empty else ""
args += ", "
args += "\n " + name
if name != "model" and param.annotation is not Parameter.empty:
args += f": {get_type_name(param.annotation)}"
if param.default == "":
args += ' = ""'
elif param.default is not Parameter.empty:
args += f" = {param.default}"
return f"g4f.Provider.{cls.__name__} supports: ({args}\n)" return f"g4f.Provider.{cls.__name__} supports: ({args}\n)"
class AsyncProvider(AbstractProvider): class AsyncProvider(AbstractProvider):
"""
Provides asynchronous functionality for creating completions.
"""
@classmethod @classmethod
def create_completion( def create_completion(
cls, cls,
@ -99,8 +88,10 @@ class AsyncProvider(AbstractProvider):
loop: AbstractEventLoop = None, loop: AbstractEventLoop = None,
**kwargs **kwargs
) -> CreateResult: ) -> CreateResult:
if not loop: """
loop = get_event_loop() Creates a completion result synchronously.
"""
loop = loop or get_event_loop()
coro = cls.create_async(model, messages, **kwargs) coro = cls.create_async(model, messages, **kwargs)
yield loop.run_until_complete(coro) yield loop.run_until_complete(coro)
@ -111,10 +102,16 @@ class AsyncProvider(AbstractProvider):
messages: Messages, messages: Messages,
**kwargs **kwargs
) -> str: ) -> str:
"""
Abstract method for creating asynchronous results.
"""
raise NotImplementedError() raise NotImplementedError()
class AsyncGeneratorProvider(AsyncProvider): class AsyncGeneratorProvider(AsyncProvider):
"""
Provides asynchronous generator functionality for streaming results.
"""
supports_stream = True supports_stream = True
@classmethod @classmethod
@ -127,15 +124,13 @@ class AsyncGeneratorProvider(AsyncProvider):
loop: AbstractEventLoop = None, loop: AbstractEventLoop = None,
**kwargs **kwargs
) -> CreateResult: ) -> CreateResult:
if not loop: """
loop = get_event_loop() Creates a streaming completion result synchronously.
generator = cls.create_async_generator( """
model, loop = loop or get_event_loop()
messages, generator = cls.create_async_generator(model, messages, stream=stream, **kwargs)
stream=stream,
**kwargs
)
gen = generator.__aiter__() gen = generator.__aiter__()
while True: while True:
try: try:
yield loop.run_until_complete(gen.__anext__()) yield loop.run_until_complete(gen.__anext__())
@ -149,21 +144,23 @@ class AsyncGeneratorProvider(AsyncProvider):
messages: Messages, messages: Messages,
**kwargs **kwargs
) -> str: ) -> str:
"""
Asynchronously creates a result from a generator.
"""
return "".join([ return "".join([
chunk async for chunk in cls.create_async_generator( chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
model, if not isinstance(chunk, Exception)
messages,
stream=False,
**kwargs
) if not isinstance(chunk, Exception)
]) ])
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def create_async_generator( async def create_async_generator(
model: str, model: str,
messages: Messages, messages: Messages,
stream: bool = True, stream: bool = True,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
raise NotImplementedError() """
Abstract method for creating an asynchronous generator.
"""
raise NotImplementedError()

View File

@ -1,13 +1,33 @@
from aiohttp import ClientSession from aiohttp import ClientSession
class Conversation:
class Conversation(): """
Represents a conversation with specific attributes.
"""
def __init__(self, conversationId: str, clientId: str, conversationSignature: str) -> None: def __init__(self, conversationId: str, clientId: str, conversationSignature: str) -> None:
"""
Initialize a new conversation instance.
Args:
conversationId (str): Unique identifier for the conversation.
clientId (str): Client identifier.
conversationSignature (str): Signature for the conversation.
"""
self.conversationId = conversationId self.conversationId = conversationId
self.clientId = clientId self.clientId = clientId
self.conversationSignature = conversationSignature self.conversationSignature = conversationSignature
async def create_conversation(session: ClientSession, proxy: str = None) -> Conversation: async def create_conversation(session: ClientSession, proxy: str = None) -> Conversation:
"""
Create a new conversation asynchronously.
Args:
session (ClientSession): An instance of aiohttp's ClientSession.
proxy (str, optional): Proxy URL. Defaults to None.
Returns:
Conversation: An instance representing the created conversation.
"""
url = 'https://www.bing.com/turing/conversation/create?bundleVersion=1.1199.4' url = 'https://www.bing.com/turing/conversation/create?bundleVersion=1.1199.4'
async with session.get(url, proxy=proxy) as response: async with session.get(url, proxy=proxy) as response:
try: try:
@ -24,12 +44,32 @@ async def create_conversation(session: ClientSession, proxy: str = None) -> Conv
return Conversation(conversationId, clientId, conversationSignature) return Conversation(conversationId, clientId, conversationSignature)
async def list_conversations(session: ClientSession) -> list: async def list_conversations(session: ClientSession) -> list:
"""
List all conversations asynchronously.
Args:
session (ClientSession): An instance of aiohttp's ClientSession.
Returns:
list: A list of conversations.
"""
url = "https://www.bing.com/turing/conversation/chats" url = "https://www.bing.com/turing/conversation/chats"
async with session.get(url) as response: async with session.get(url) as response:
response = await response.json() response = await response.json()
return response["chats"] return response["chats"]
async def delete_conversation(session: ClientSession, conversation: Conversation, proxy: str = None) -> bool: async def delete_conversation(session: ClientSession, conversation: Conversation, proxy: str = None) -> bool:
"""
Delete a conversation asynchronously.
Args:
session (ClientSession): An instance of aiohttp's ClientSession.
conversation (Conversation): The conversation to delete.
proxy (str, optional): Proxy URL. Defaults to None.
Returns:
bool: True if deletion was successful, False otherwise.
"""
url = "https://sydney.bing.com/sydney/DeleteSingleConversation" url = "https://sydney.bing.com/sydney/DeleteSingleConversation"
json = { json = {
"conversationId": conversation.conversationId, "conversationId": conversation.conversationId,

View File

@ -1,9 +1,16 @@
"""
This module provides functionalities for creating and managing images using Bing's service.
It includes functions for user login, session creation, image creation, and processing.
"""
import asyncio import asyncio
import time, json, os import time
import json
import os
from aiohttp import ClientSession from aiohttp import ClientSession
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from urllib.parse import quote from urllib.parse import quote
from typing import Generator from typing import Generator, List, Dict
from ..create_images import CreateImagesProvider from ..create_images import CreateImagesProvider
from ..helper import get_cookies, get_event_loop from ..helper import get_cookies, get_event_loop
@ -12,23 +19,47 @@ from ...base_provider import ProviderType
from ...image import format_images_markdown from ...image import format_images_markdown
BING_URL = "https://www.bing.com" BING_URL = "https://www.bing.com"
TIMEOUT_LOGIN = 1200
TIMEOUT_IMAGE_CREATION = 300
ERRORS = [
"this prompt is being reviewed",
"this prompt has been blocked",
"we're working hard to offer image creator in more languages",
"we can't create your images right now"
]
BAD_IMAGES = [
"https://r.bing.com/rp/in-2zU3AJUdkgFe7ZKv19yPBHVs.png",
"https://r.bing.com/rp/TX9QuO3WzcCJz1uaaSwQAz39Kb0.jpg",
]
def wait_for_login(driver: WebDriver, timeout: int = 1200) -> None: def wait_for_login(driver: WebDriver, timeout: int = TIMEOUT_LOGIN) -> None:
"""
Waits for the user to log in within a given timeout period.
Args:
driver (WebDriver): Webdriver for browser automation.
timeout (int): Maximum waiting time in seconds.
Raises:
RuntimeError: If the login process exceeds the timeout.
"""
driver.get(f"{BING_URL}/") driver.get(f"{BING_URL}/")
value = driver.get_cookie("_U")
if value:
return
start_time = time.time() start_time = time.time()
while True: while not driver.get_cookie("_U"):
if time.time() - start_time > timeout: if time.time() - start_time > timeout:
raise RuntimeError("Timeout error") raise RuntimeError("Timeout error")
value = driver.get_cookie("_U")
if value:
time.sleep(1)
return
time.sleep(0.5) time.sleep(0.5)
def create_session(cookies: dict) -> ClientSession: def create_session(cookies: Dict[str, str]) -> ClientSession:
"""
Creates a new client session with specified cookies and headers.
Args:
cookies (Dict[str, str]): Cookies to be used for the session.
Returns:
ClientSession: The created client session.
"""
headers = { headers = {
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
"accept-encoding": "gzip, deflate, br", "accept-encoding": "gzip, deflate, br",
@ -47,28 +78,32 @@ def create_session(cookies: dict) -> ClientSession:
"upgrade-insecure-requests": "1", "upgrade-insecure-requests": "1",
} }
if cookies: if cookies:
headers["cookie"] = "; ".join(f"{k}={v}" for k, v in cookies.items()) headers["Cookie"] = "; ".join(f"{k}={v}" for k, v in cookies.items())
return ClientSession(headers=headers) return ClientSession(headers=headers)
async def create_images(session: ClientSession, prompt: str, proxy: str = None, timeout: int = 300) -> list: async def create_images(session: ClientSession, prompt: str, proxy: str = None, timeout: int = TIMEOUT_IMAGE_CREATION) -> List[str]:
url_encoded_prompt = quote(prompt) """
Creates images based on a given prompt using Bing's service.
Args:
session (ClientSession): Active client session.
prompt (str): Prompt to generate images.
proxy (str, optional): Proxy configuration.
timeout (int): Timeout for the request.
Returns:
List[str]: A list of URLs to the created images.
Raises:
RuntimeError: If image creation fails or times out.
"""
url_encoded_prompt = quote(prompt)
payload = f"q={url_encoded_prompt}&rt=4&FORM=GENCRE" payload = f"q={url_encoded_prompt}&rt=4&FORM=GENCRE"
url = f"{BING_URL}/images/create?q={url_encoded_prompt}&rt=4&FORM=GENCRE" url = f"{BING_URL}/images/create?q={url_encoded_prompt}&rt=4&FORM=GENCRE"
async with session.post( async with session.post(url, allow_redirects=False, data=payload, timeout=timeout) as response:
url,
allow_redirects=False,
data=payload,
timeout=timeout,
) as response:
response.raise_for_status() response.raise_for_status()
errors = [
"this prompt is being reviewed",
"this prompt has been blocked",
"we're working hard to offer image creator in more languages",
"we can't create your images right now"
]
text = (await response.text()).lower() text = (await response.text()).lower()
for error in errors: for error in ERRORS:
if error in text: if error in text:
raise RuntimeError(f"Create images failed: {error}") raise RuntimeError(f"Create images failed: {error}")
if response.status != 302: if response.status != 302:
@ -107,54 +142,109 @@ async def create_images(session: ClientSession, prompt: str, proxy: str = None,
raise RuntimeError(error) raise RuntimeError(error)
return read_images(text) return read_images(text)
def read_images(text: str) -> list: def read_images(html_content: str) -> List[str]:
html_soup = BeautifulSoup(text, "html.parser") """
tags = html_soup.find_all("img") Extracts image URLs from the HTML content.
image_links = [img["src"] for img in tags if "mimg" in img["class"]]
images = [link.split("?w=")[0] for link in image_links] Args:
bad_images = [ html_content (str): HTML content containing image URLs.
"https://r.bing.com/rp/in-2zU3AJUdkgFe7ZKv19yPBHVs.png",
"https://r.bing.com/rp/TX9QuO3WzcCJz1uaaSwQAz39Kb0.jpg", Returns:
] List[str]: A list of image URLs.
if any(im in bad_images for im in images): """
soup = BeautifulSoup(html_content, "html.parser")
tags = soup.find_all("img", class_="mimg")
images = [img["src"].split("?w=")[0] for img in tags]
if any(im in BAD_IMAGES for im in images):
raise RuntimeError("Bad images found") raise RuntimeError("Bad images found")
if not images: if not images:
raise RuntimeError("No images found") raise RuntimeError("No images found")
return images return images
async def create_images_markdown(cookies: dict, prompt: str, proxy: str = None) -> str: async def create_images_markdown(cookies: Dict[str, str], prompt: str, proxy: str = None) -> str:
session = create_session(cookies) """
try: Creates markdown formatted string with images based on the prompt.
Args:
cookies (Dict[str, str]): Cookies to be used for the session.
prompt (str): Prompt to generate images.
proxy (str, optional): Proxy configuration.
Returns:
str: Markdown formatted string with images.
"""
async with create_session(cookies) as session:
images = await create_images(session, prompt, proxy) images = await create_images(session, prompt, proxy)
return format_images_markdown(images, prompt) return format_images_markdown(images, prompt)
finally:
await session.close()
def get_cookies_from_browser(proxy: str = None) -> dict: def get_cookies_from_browser(proxy: str = None) -> Dict[str, str]:
driver = get_browser(proxy=proxy) """
try: Retrieves cookies from the browser using webdriver.
Args:
proxy (str, optional): Proxy configuration.
Returns:
Dict[str, str]: Retrieved cookies.
"""
with get_browser(proxy=proxy) as driver:
wait_for_login(driver) wait_for_login(driver)
time.sleep(1)
return get_driver_cookies(driver) return get_driver_cookies(driver)
finally:
driver.quit()
def create_completion(prompt: str, cookies: dict = None, proxy: str = None) -> Generator: class CreateImagesBing:
loop = get_event_loop() """A class for creating images using Bing."""
if not cookies:
cookies = get_cookies(".bing.com")
if "_U" not in cookies:
login_url = os.environ.get("G4F_LOGIN_URL")
if login_url:
yield f"Please login: [Bing]({login_url})\n\n"
cookies = get_cookies_from_browser(proxy)
yield loop.run_until_complete(create_images_markdown(cookies, prompt, proxy))
async def create_async(prompt: str, cookies: dict = None, proxy: str = None) -> str: _cookies: Dict[str, str] = {}
if not cookies:
cookies = get_cookies(".bing.com") @classmethod
if "_U" not in cookies: def create_completion(cls, prompt: str, cookies: Dict[str, str] = None, proxy: str = None) -> Generator[str]:
cookies = get_cookies_from_browser(proxy) """
return await create_images_markdown(cookies, prompt, proxy) Generator for creating imagecompletion based on a prompt.
Args:
prompt (str): Prompt to generate images.
cookies (Dict[str, str], optional): Cookies for the session. If None, cookies are retrieved automatically.
proxy (str, optional): Proxy configuration.
Yields:
Generator[str, None, None]: The final output as markdown formatted string with images.
"""
loop = get_event_loop()
cookies = cookies or cls._cookies or get_cookies(".bing.com")
if "_U" not in cookies:
login_url = os.environ.get("G4F_LOGIN_URL")
if login_url:
yield f"Please login: [Bing]({login_url})\n\n"
cls._cookies = cookies = get_cookies_from_browser(proxy)
yield loop.run_until_complete(create_images_markdown(cookies, prompt, proxy))
@classmethod
async def create_async(cls, prompt: str, cookies: Dict[str, str] = None, proxy: str = None) -> str:
"""
Asynchronously creates a markdown formatted string with images based on the prompt.
Args:
prompt (str): Prompt to generate images.
cookies (Dict[str, str], optional): Cookies for the session. If None, cookies are retrieved automatically.
proxy (str, optional): Proxy configuration.
Returns:
str: Markdown formatted string with images.
"""
cookies = cookies or cls._cookies or get_cookies(".bing.com")
if "_U" not in cookies:
cls._cookies = cookies = get_cookies_from_browser(proxy)
return await create_images_markdown(cookies, prompt, proxy)
def patch_provider(provider: ProviderType) -> CreateImagesProvider: def patch_provider(provider: ProviderType) -> CreateImagesProvider:
return CreateImagesProvider(provider, create_completion, create_async) """
Patches a provider to include image creation capabilities.
Args:
provider (ProviderType): The provider to be patched.
Returns:
CreateImagesProvider: The patched provider with image creation capabilities.
"""
return CreateImagesProvider(provider, CreateImagesBing.create_completion, CreateImagesBing.create_async)

View File

@ -1,64 +1,107 @@
from __future__ import annotations """
Module to handle image uploading and processing for Bing AI integrations.
"""
from __future__ import annotations
import string import string
import random import random
import json import json
import math import math
from ...typing import ImageType
from aiohttp import ClientSession from aiohttp import ClientSession
from PIL import Image
from ...typing import ImageType, Tuple
from ...image import to_image, process_image, to_base64, ImageResponse from ...image import to_image, process_image, to_base64, ImageResponse
image_config = { IMAGE_CONFIG = {
"maxImagePixels": 360000, "maxImagePixels": 360000,
"imageCompressionRate": 0.7, "imageCompressionRate": 0.7,
"enableFaceBlurDebug": 0, "enableFaceBlurDebug": False,
} }
async def upload_image( async def upload_image(
session: ClientSession, session: ClientSession,
image: ImageType, image_data: ImageType,
tone: str, tone: str,
proxy: str = None proxy: str = None
) -> ImageResponse: ) -> ImageResponse:
image = to_image(image) """
width, height = image.size Uploads an image to Bing's AI service and returns the image response.
max_image_pixels = image_config['maxImagePixels']
if max_image_pixels / (width * height) < 1: Args:
new_width = int(width * math.sqrt(max_image_pixels / (width * height))) session (ClientSession): The active session.
new_height = int(height * math.sqrt(max_image_pixels / (width * height))) image_data (bytes): The image data to be uploaded.
else: tone (str): The tone of the conversation.
new_width = width proxy (str, optional): Proxy if any. Defaults to None.
new_height = height
new_img = process_image(image, new_width, new_height) Raises:
new_img_binary_data = to_base64(new_img, image_config['imageCompressionRate']) RuntimeError: If the image upload fails.
data, boundary = build_image_upload_api_payload(new_img_binary_data, tone)
headers = session.headers.copy() Returns:
headers["content-type"] = f'multipart/form-data; boundary={boundary}' ImageResponse: The response from the image upload.
headers["referer"] = 'https://www.bing.com/search?q=Bing+AI&showconv=1&FORM=hpcodx' """
headers["origin"] = 'https://www.bing.com' image = to_image(image_data)
new_width, new_height = calculate_new_dimensions(image)
processed_img = process_image(image, new_width, new_height)
img_binary_data = to_base64(processed_img, IMAGE_CONFIG['imageCompressionRate'])
data, boundary = build_image_upload_payload(img_binary_data, tone)
headers = prepare_headers(session, boundary)
async with session.post("https://www.bing.com/images/kblob", data=data, headers=headers, proxy=proxy) as response: async with session.post("https://www.bing.com/images/kblob", data=data, headers=headers, proxy=proxy) as response:
if response.status != 200: if response.status != 200:
raise RuntimeError("Failed to upload image.") raise RuntimeError("Failed to upload image.")
image_info = await response.json() return parse_image_response(await response.json())
if not image_info.get('blobId'):
raise RuntimeError("Failed to parse image info.")
result = {'bcid': image_info.get('blobId', "")}
result['blurredBcid'] = image_info.get('processedBlobId', "")
if result['blurredBcid'] != "":
result["imageUrl"] = "https://www.bing.com/images/blob?bcid=" + result['blurredBcid']
elif result['bcid'] != "":
result["imageUrl"] = "https://www.bing.com/images/blob?bcid=" + result['bcid']
result['originalImageUrl'] = (
"https://www.bing.com/images/blob?bcid="
+ result['blurredBcid']
if image_config["enableFaceBlurDebug"]
else "https://www.bing.com/images/blob?bcid="
+ result['bcid']
)
return ImageResponse(result["imageUrl"], "", result)
def build_image_upload_api_payload(image_bin: str, tone: str): def calculate_new_dimensions(image: Image.Image) -> Tuple[int, int]:
payload = { """
Calculates the new dimensions for the image based on the maximum allowed pixels.
Args:
image (Image): The PIL Image object.
Returns:
Tuple[int, int]: The new width and height for the image.
"""
width, height = image.size
max_image_pixels = IMAGE_CONFIG['maxImagePixels']
if max_image_pixels / (width * height) < 1:
scale_factor = math.sqrt(max_image_pixels / (width * height))
return int(width * scale_factor), int(height * scale_factor)
return width, height
def build_image_upload_payload(image_bin: str, tone: str) -> Tuple[str, str]:
"""
Builds the payload for image uploading.
Args:
image_bin (str): Base64 encoded image binary data.
tone (str): The tone of the conversation.
Returns:
Tuple[str, str]: The data and boundary for the payload.
"""
boundary = "----WebKitFormBoundary" + ''.join(random.choices(string.ascii_letters + string.digits, k=16))
data = f"--{boundary}\r\n" \
f"Content-Disposition: form-data; name=\"knowledgeRequest\"\r\n\r\n" \
f"{json.dumps(build_knowledge_request(tone), ensure_ascii=False)}\r\n" \
f"--{boundary}\r\n" \
f"Content-Disposition: form-data; name=\"imageBase64\"\r\n\r\n" \
f"{image_bin}\r\n" \
f"--{boundary}--\r\n"
return data, boundary
def build_knowledge_request(tone: str) -> dict:
"""
Builds the knowledge request payload.
Args:
tone (str): The tone of the conversation.
Returns:
dict: The knowledge request payload.
"""
return {
'invokedSkills': ["ImageById"], 'invokedSkills': ["ImageById"],
'subscriptionId': "Bing.Chat.Multimodal", 'subscriptionId': "Bing.Chat.Multimodal",
'invokedSkillsRequestData': { 'invokedSkillsRequestData': {
@ -69,21 +112,46 @@ def build_image_upload_api_payload(image_bin: str, tone: str):
'convotone': tone 'convotone': tone
} }
} }
knowledge_request = {
'imageInfo': {}, def prepare_headers(session: ClientSession, boundary: str) -> dict:
'knowledgeRequest': payload """
} Prepares the headers for the image upload request.
boundary="----WebKitFormBoundary" + ''.join(random.choices(string.ascii_letters + string.digits, k=16))
data = ( Args:
f'--{boundary}' session (ClientSession): The active session.
+ '\r\nContent-Disposition: form-data; name="knowledgeRequest"\r\n\r\n' boundary (str): The boundary string for the multipart/form-data.
+ json.dumps(knowledge_request, ensure_ascii=False)
+ "\r\n--" Returns:
+ boundary dict: The headers for the request.
+ '\r\nContent-Disposition: form-data; name="imageBase64"\r\n\r\n' """
+ image_bin headers = session.headers.copy()
+ "\r\n--" headers["Content-Type"] = f'multipart/form-data; boundary={boundary}'
+ boundary headers["Referer"] = 'https://www.bing.com/search?q=Bing+AI&showconv=1&FORM=hpcodx'
+ "--\r\n" headers["Origin"] = 'https://www.bing.com'
return headers
def parse_image_response(response: dict) -> ImageResponse:
"""
Parses the response from the image upload.
Args:
response (dict): The response dictionary.
Raises:
RuntimeError: If parsing the image info fails.
Returns:
ImageResponse: The parsed image response.
"""
if not response.get('blobId'):
raise RuntimeError("Failed to parse image info.")
result = {'bcid': response.get('blobId', ""), 'blurredBcid': response.get('processedBlobId', "")}
result["imageUrl"] = f"https://www.bing.com/images/blob?bcid={result['blurredBcid'] or result['bcid']}"
result['originalImageUrl'] = (
f"https://www.bing.com/images/blob?bcid={result['blurredBcid']}"
if IMAGE_CONFIG["enableFaceBlurDebug"] else
f"https://www.bing.com/images/blob?bcid={result['bcid']}"
) )
return data, boundary return ImageResponse(result["imageUrl"], "", result)

View File

@ -1,36 +1,31 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import webbrowser
import random
import string
import secrets
import os import os
from os import path import random
import secrets
import string
from asyncio import AbstractEventLoop, BaseEventLoop from asyncio import AbstractEventLoop, BaseEventLoop
from platformdirs import user_config_dir from platformdirs import user_config_dir
from browser_cookie3 import ( from browser_cookie3 import (
chrome, chrome, chromium, opera, opera_gx,
chromium, brave, edge, vivaldi, firefox,
opera, _LinuxPasswordManager, BrowserCookieError
opera_gx,
brave,
edge,
vivaldi,
firefox,
_LinuxPasswordManager
) )
from ..typing import Dict, Messages from ..typing import Dict, Messages
from .. import debug from .. import debug
# Local Cookie Storage # Global variable to store cookies
_cookies: Dict[str, Dict[str, str]] = {} _cookies: Dict[str, Dict[str, str]] = {}
# If loop closed or not set, create new event loop.
# If event loop is already running, handle nested event loops.
# If "nest_asyncio" is installed, patch the event loop.
def get_event_loop() -> AbstractEventLoop: def get_event_loop() -> AbstractEventLoop:
"""
Get the current asyncio event loop. If the loop is closed or not set, create a new event loop.
If a loop is running, handle nested event loops. Patch the loop if 'nest_asyncio' is installed.
Returns:
AbstractEventLoop: The current or new event loop.
"""
try: try:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
if isinstance(loop, BaseEventLoop): if isinstance(loop, BaseEventLoop):
@ -39,61 +34,50 @@ def get_event_loop() -> AbstractEventLoop:
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
try: try:
# Is running event loop
asyncio.get_running_loop() asyncio.get_running_loop()
if not hasattr(loop.__class__, "_nest_patched"): if not hasattr(loop.__class__, "_nest_patched"):
import nest_asyncio import nest_asyncio
nest_asyncio.apply(loop) nest_asyncio.apply(loop)
except RuntimeError: except RuntimeError:
# No running event loop
pass pass
except ImportError: except ImportError:
raise RuntimeError( raise RuntimeError(
'Use "create_async" instead of "create" function in a running event loop. Or install the "nest_asyncio" package.' 'Use "create_async" instead of "create" function in a running event loop. Or install "nest_asyncio" package.'
) )
return loop return loop
def init_cookies():
urls = [
'https://chat-gpt.org',
'https://www.aitianhu.com',
'https://chatgptfree.ai',
'https://gptchatly.com',
'https://bard.google.com',
'https://huggingface.co/chat',
'https://open-assistant.io/chat'
]
browsers = ['google-chrome', 'chrome', 'firefox', 'safari']
def open_urls_in_browser(browser):
b = webbrowser.get(browser)
for url in urls:
b.open(url, new=0, autoraise=True)
for browser in browsers:
try:
open_urls_in_browser(browser)
break
except webbrowser.Error:
continue
# Check for broken dbus address in docker image
if os.environ.get('DBUS_SESSION_BUS_ADDRESS') == "/dev/null": if os.environ.get('DBUS_SESSION_BUS_ADDRESS') == "/dev/null":
_LinuxPasswordManager.get_password = lambda a, b: b"secret" _LinuxPasswordManager.get_password = lambda a, b: b"secret"
# Load cookies for a domain from all supported browsers. def get_cookies(domain_name: str = '') -> Dict[str, str]:
# Cache the results in the "_cookies" variable. """
def get_cookies(domain_name=''): Load cookies for a given domain from all supported browsers and cache the results.
Args:
domain_name (str): The domain for which to load cookies.
Returns:
Dict[str, str]: A dictionary of cookie names and values.
"""
if domain_name in _cookies: if domain_name in _cookies:
return _cookies[domain_name] return _cookies[domain_name]
def g4f(domain_name):
user_data_dir = user_config_dir("g4f") cookies = _load_cookies_from_browsers(domain_name)
cookie_file = path.join(user_data_dir, "Default", "Cookies") _cookies[domain_name] = cookies
return [] if not path.exists(cookie_file) else chrome(cookie_file, domain_name) return cookies
def _load_cookies_from_browsers(domain_name: str) -> Dict[str, str]:
"""
Helper function to load cookies from various browsers.
Args:
domain_name (str): The domain for which to load cookies.
Returns:
Dict[str, str]: A dictionary of cookie names and values.
"""
cookies = {} cookies = {}
for cookie_fn in [g4f, chrome, chromium, opera, opera_gx, brave, edge, vivaldi, firefox]: for cookie_fn in [_g4f, chrome, chromium, opera, opera_gx, brave, edge, vivaldi, firefox]:
try: try:
cookie_jar = cookie_fn(domain_name=domain_name) cookie_jar = cookie_fn(domain_name=domain_name)
if len(cookie_jar) and debug.logging: if len(cookie_jar) and debug.logging:
@ -101,13 +85,38 @@ def get_cookies(domain_name=''):
for cookie in cookie_jar: for cookie in cookie_jar:
if cookie.name not in cookies: if cookie.name not in cookies:
cookies[cookie.name] = cookie.value cookies[cookie.name] = cookie.value
except: except BrowserCookieError:
pass pass
_cookies[domain_name] = cookies except Exception as e:
return _cookies[domain_name] if debug.logging:
print(f"Error reading cookies from {cookie_fn.__name__} for {domain_name}: {e}")
return cookies
def _g4f(domain_name: str) -> list:
"""
Load cookies from the 'g4f' browser (if exists).
Args:
domain_name (str): The domain for which to load cookies.
Returns:
list: List of cookies.
"""
user_data_dir = user_config_dir("g4f")
cookie_file = os.path.join(user_data_dir, "Default", "Cookies")
return [] if not os.path.exists(cookie_file) else chrome(cookie_file, domain_name)
def format_prompt(messages: Messages, add_special_tokens=False) -> str: def format_prompt(messages: Messages, add_special_tokens=False) -> str:
"""
Format a series of messages into a single string, optionally adding special tokens.
Args:
messages (Messages): A list of message dictionaries, each containing 'role' and 'content'.
add_special_tokens (bool): Whether to add special formatting tokens.
Returns:
str: A formatted string containing all messages.
"""
if not add_special_tokens and len(messages) <= 1: if not add_special_tokens and len(messages) <= 1:
return messages[0]["content"] return messages[0]["content"]
formatted = "\n".join([ formatted = "\n".join([
@ -116,12 +125,26 @@ def format_prompt(messages: Messages, add_special_tokens=False) -> str:
]) ])
return f"{formatted}\nAssistant:" return f"{formatted}\nAssistant:"
def get_random_string(length: int = 10) -> str: def get_random_string(length: int = 10) -> str:
"""
Generate a random string of specified length, containing lowercase letters and digits.
Args:
length (int, optional): Length of the random string to generate. Defaults to 10.
Returns:
str: A random string of the specified length.
"""
return ''.join( return ''.join(
random.choice(string.ascii_lowercase + string.digits) random.choice(string.ascii_lowercase + string.digits)
for _ in range(length) for _ in range(length)
) )
def get_random_hex() -> str: def get_random_hex() -> str:
"""
Generate a random hexadecimal string of a fixed length.
Returns:
str: A random hexadecimal string of 32 characters (16 bytes).
"""
return secrets.token_hex(16).zfill(32) return secrets.token_hex(16).zfill(32)

View File

@ -1,6 +1,9 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import uuid
import json
import os
import uuid, json, asyncio, os
from py_arkose_generator.arkose import get_values_for_request from py_arkose_generator.arkose import get_values_for_request
from async_property import async_cached_property from async_property import async_cached_property
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
@ -14,7 +17,8 @@ from ...typing import AsyncResult, Messages
from ...requests import StreamSession from ...requests import StreamSession
from ...image import to_image, to_bytes, ImageType, ImageResponse from ...image import to_image, to_bytes, ImageType, ImageResponse
models = { # Aliases for model names
MODELS = {
"gpt-3.5": "text-davinci-002-render-sha", "gpt-3.5": "text-davinci-002-render-sha",
"gpt-3.5-turbo": "text-davinci-002-render-sha", "gpt-3.5-turbo": "text-davinci-002-render-sha",
"gpt-4": "gpt-4", "gpt-4": "gpt-4",
@ -22,13 +26,15 @@ models = {
} }
class OpenaiChat(AsyncGeneratorProvider): class OpenaiChat(AsyncGeneratorProvider):
url = "https://chat.openai.com" """A class for creating and managing conversations with OpenAI chat service"""
working = True
needs_auth = True url = "https://chat.openai.com"
working = True
needs_auth = True
supports_gpt_35_turbo = True supports_gpt_35_turbo = True
supports_gpt_4 = True supports_gpt_4 = True
_cookies: dict = {} _cookies: dict = {}
_default_model: str = None _default_model: str = None
@classmethod @classmethod
async def create( async def create(
@ -43,6 +49,23 @@ class OpenaiChat(AsyncGeneratorProvider):
image: ImageType = None, image: ImageType = None,
**kwargs **kwargs
) -> Response: ) -> Response:
"""Create a new conversation or continue an existing one
Args:
prompt: The user input to start or continue the conversation
model: The name of the model to use for generating responses
messages: The list of previous messages in the conversation
history_disabled: A flag indicating if the history and training should be disabled
action: The type of action to perform, either "next", "continue", or "variant"
conversation_id: The ID of the existing conversation, if any
parent_id: The ID of the parent message, if any
image: The image to include in the user input, if any
**kwargs: Additional keyword arguments to pass to the generator
Returns:
A Response object that contains the generator, action, messages, and options
"""
# Add the user input to the messages list
if prompt: if prompt:
messages.append({ messages.append({
"role": "user", "role": "user",
@ -67,20 +90,33 @@ class OpenaiChat(AsyncGeneratorProvider):
) )
@classmethod @classmethod
async def upload_image( async def _upload_image(
cls, cls,
session: StreamSession, session: StreamSession,
headers: dict, headers: dict,
image: ImageType image: ImageType
) -> ImageResponse: ) -> ImageResponse:
"""Upload an image to the service and get the download URL
Args:
session: The StreamSession object to use for requests
headers: The headers to include in the requests
image: The image to upload, either a PIL Image object or a bytes object
Returns:
An ImageResponse object that contains the download URL, file name, and other data
"""
# Convert the image to a PIL Image object and get the extension
image = to_image(image) image = to_image(image)
extension = image.format.lower() extension = image.format.lower()
# Convert the image to a bytes object and get the size
data_bytes = to_bytes(image) data_bytes = to_bytes(image)
data = { data = {
"file_name": f"{image.width}x{image.height}.{extension}", "file_name": f"{image.width}x{image.height}.{extension}",
"file_size": len(data_bytes), "file_size": len(data_bytes),
"use_case": "multimodal" "use_case": "multimodal"
} }
# Post the image data to the service and get the image data
async with session.post(f"{cls.url}/backend-api/files", json=data, headers=headers) as response: async with session.post(f"{cls.url}/backend-api/files", json=data, headers=headers) as response:
response.raise_for_status() response.raise_for_status()
image_data = { image_data = {
@ -91,6 +127,7 @@ class OpenaiChat(AsyncGeneratorProvider):
"height": image.height, "height": image.height,
"width": image.width "width": image.width
} }
# Put the image bytes to the upload URL and check the status
async with session.put( async with session.put(
image_data["upload_url"], image_data["upload_url"],
data=data_bytes, data=data_bytes,
@ -100,6 +137,7 @@ class OpenaiChat(AsyncGeneratorProvider):
} }
) as response: ) as response:
response.raise_for_status() response.raise_for_status()
# Post the file ID to the service and get the download URL
async with session.post( async with session.post(
f"{cls.url}/backend-api/files/{image_data['file_id']}/uploaded", f"{cls.url}/backend-api/files/{image_data['file_id']}/uploaded",
json={}, json={},
@ -110,24 +148,45 @@ class OpenaiChat(AsyncGeneratorProvider):
return ImageResponse(download_url, image_data["file_name"], image_data) return ImageResponse(download_url, image_data["file_name"], image_data)
@classmethod @classmethod
async def get_default_model(cls, session: StreamSession, headers: dict): async def _get_default_model(cls, session: StreamSession, headers: dict):
"""Get the default model name from the service
Args:
session: The StreamSession object to use for requests
headers: The headers to include in the requests
Returns:
The default model name as a string
"""
# Check the cache for the default model
if cls._default_model: if cls._default_model:
model = cls._default_model return cls._default_model
else: # Get the models data from the service
async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response: async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response:
data = await response.json() data = await response.json()
if "categories" in data: if "categories" in data:
model = data["categories"][-1]["default_model"] cls._default_model = data["categories"][-1]["default_model"]
else: else:
RuntimeError(f"Response: {data}") raise RuntimeError(f"Response: {data}")
cls._default_model = model return cls._default_model
return model
@classmethod @classmethod
def create_messages(cls, prompt: str, image_response: ImageResponse = None): def _create_messages(cls, prompt: str, image_response: ImageResponse = None):
"""Create a list of messages for the user input
Args:
prompt: The user input as a string
image_response: The image response object, if any
Returns:
A list of messages with the user input and the image, if any
"""
# Check if there is an image response
if not image_response: if not image_response:
# Create a content object with the text type and the prompt
content = {"content_type": "text", "parts": [prompt]} content = {"content_type": "text", "parts": [prompt]}
else: else:
# Create a content object with the multimodal text type and the image and the prompt
content = { content = {
"content_type": "multimodal_text", "content_type": "multimodal_text",
"parts": [{ "parts": [{
@ -137,12 +196,15 @@ class OpenaiChat(AsyncGeneratorProvider):
"width": image_response.get("width"), "width": image_response.get("width"),
}, prompt] }, prompt]
} }
# Create a message object with the user role and the content
messages = [{ messages = [{
"id": str(uuid.uuid4()), "id": str(uuid.uuid4()),
"author": {"role": "user"}, "author": {"role": "user"},
"content": content, "content": content,
}] }]
# Check if there is an image response
if image_response: if image_response:
# Add the metadata object with the attachments
messages[0]["metadata"] = { messages[0]["metadata"] = {
"attachments": [{ "attachments": [{
"height": image_response.get("height"), "height": image_response.get("height"),
@ -156,19 +218,38 @@ class OpenaiChat(AsyncGeneratorProvider):
return messages return messages
@classmethod @classmethod
async def get_image_response(cls, session: StreamSession, headers: dict, line: dict): async def _get_generated_image(cls, session: StreamSession, headers: dict, line: dict) -> ImageResponse:
if "parts" in line["message"]["content"]: """
part = line["message"]["content"]["parts"][0] Retrieves the image response based on the message content.
if "asset_pointer" in part and part["metadata"]:
file_id = part["asset_pointer"].split("file-service://", 1)[1] :param session: The StreamSession object.
prompt = part["metadata"]["dalle"]["prompt"] :param headers: HTTP headers for the request.
async with session.get( :param line: The line of response containing image information.
f"{cls.url}/backend-api/files/{file_id}/download", :return: An ImageResponse object with the image details.
headers=headers """
) as response: if "parts" not in line["message"]["content"]:
response.raise_for_status() return
download_url = (await response.json())["download_url"] first_part = line["message"]["content"]["parts"][0]
return ImageResponse(download_url, prompt) if "asset_pointer" not in first_part or "metadata" not in first_part:
return
file_id = first_part["asset_pointer"].split("file-service://", 1)[1]
prompt = first_part["metadata"]["dalle"]["prompt"]
try:
async with session.get(f"{cls.url}/backend-api/files/{file_id}/download", headers=headers) as response:
response.raise_for_status()
download_url = (await response.json())["download_url"]
return ImageResponse(download_url, prompt)
except Exception as e:
raise RuntimeError(f"Error in downloading image: {e}")
@classmethod
async def _delete_conversation(cls, session: StreamSession, headers: dict, conversation_id: str):
async with session.patch(
f"{cls.url}/backend-api/conversation/{conversation_id}",
json={"is_visible": False},
headers=headers
) as response:
response.raise_for_status()
@classmethod @classmethod
async def create_async_generator( async def create_async_generator(
@ -188,26 +269,47 @@ class OpenaiChat(AsyncGeneratorProvider):
response_fields: bool = False, response_fields: bool = False,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
if model in models: """
model = models[model] Create an asynchronous generator for the conversation.
Args:
model (str): The model name.
messages (Messages): The list of previous messages.
proxy (str): Proxy to use for requests.
timeout (int): Timeout for requests.
access_token (str): Access token for authentication.
cookies (dict): Cookies to use for authentication.
auto_continue (bool): Flag to automatically continue the conversation.
history_disabled (bool): Flag to disable history and training.
action (str): Type of action ('next', 'continue', 'variant').
conversation_id (str): ID of the conversation.
parent_id (str): ID of the parent message.
image (ImageType): Image to include in the conversation.
response_fields (bool): Flag to include response fields in the output.
**kwargs: Additional keyword arguments.
Yields:
AsyncResult: Asynchronous results from the generator.
Raises:
RuntimeError: If an error occurs during processing.
"""
model = MODELS.get(model, model)
if not parent_id: if not parent_id:
parent_id = str(uuid.uuid4()) parent_id = str(uuid.uuid4())
if not cookies: if not cookies:
cookies = cls._cookies cookies = cls._cookies or get_cookies("chat.openai.com")
if not access_token: if not access_token and "access_token" in cookies:
if not cookies: access_token = cookies["access_token"]
cls._cookies = cookies = get_cookies("chat.openai.com")
if "access_token" in cookies:
access_token = cookies["access_token"]
if not access_token: if not access_token:
login_url = os.environ.get("G4F_LOGIN_URL") login_url = os.environ.get("G4F_LOGIN_URL")
if login_url: if login_url:
yield f"Please login: [ChatGPT]({login_url})\n\n" yield f"Please login: [ChatGPT]({login_url})\n\n"
access_token, cookies = cls.browse_access_token(proxy) access_token, cookies = cls._browse_access_token(proxy)
cls._cookies = cookies cls._cookies = cookies
headers = {
"Authorization": f"Bearer {access_token}", headers = {"Authorization": f"Bearer {access_token}"}
}
async with StreamSession( async with StreamSession(
proxies={"https": proxy}, proxies={"https": proxy},
impersonate="chrome110", impersonate="chrome110",
@ -215,11 +317,11 @@ class OpenaiChat(AsyncGeneratorProvider):
cookies=dict([(name, value) for name, value in cookies.items() if name == "_puid"]) cookies=dict([(name, value) for name, value in cookies.items() if name == "_puid"])
) as session: ) as session:
if not model: if not model:
model = await cls.get_default_model(session, headers) model = await cls._get_default_model(session, headers)
try: try:
image_response = None image_response = None
if image: if image:
image_response = await cls.upload_image(session, headers, image) image_response = await cls._upload_image(session, headers, image)
yield image_response yield image_response
except Exception as e: except Exception as e:
yield e yield e
@ -227,7 +329,7 @@ class OpenaiChat(AsyncGeneratorProvider):
while not end_turn.is_end: while not end_turn.is_end:
data = { data = {
"action": action, "action": action,
"arkose_token": await cls.get_arkose_token(session), "arkose_token": await cls._get_arkose_token(session),
"conversation_id": conversation_id, "conversation_id": conversation_id,
"parent_message_id": parent_id, "parent_message_id": parent_id,
"model": model, "model": model,
@ -235,7 +337,7 @@ class OpenaiChat(AsyncGeneratorProvider):
} }
if action != "continue": if action != "continue":
prompt = format_prompt(messages) if not conversation_id else messages[-1]["content"] prompt = format_prompt(messages) if not conversation_id else messages[-1]["content"]
data["messages"] = cls.create_messages(prompt, image_response) data["messages"] = cls._create_messages(prompt, image_response)
async with session.post( async with session.post(
f"{cls.url}/backend-api/conversation", f"{cls.url}/backend-api/conversation",
json=data, json=data,
@ -261,62 +363,80 @@ class OpenaiChat(AsyncGeneratorProvider):
if "message_type" not in line["message"]["metadata"]: if "message_type" not in line["message"]["metadata"]:
continue continue
try: try:
image_response = await cls.get_image_response(session, headers, line) image_response = await cls._get_generated_image(session, headers, line)
if image_response: if image_response:
yield image_response yield image_response
except Exception as e: except Exception as e:
yield e yield e
if line["message"]["author"]["role"] != "assistant": if line["message"]["author"]["role"] != "assistant":
continue continue
if line["message"]["metadata"]["message_type"] in ("next", "continue", "variant"): if line["message"]["content"]["content_type"] != "text":
conversation_id = line["conversation_id"] continue
parent_id = line["message"]["id"] if line["message"]["metadata"]["message_type"] not in ("next", "continue", "variant"):
if response_fields: continue
response_fields = False conversation_id = line["conversation_id"]
yield ResponseFields(conversation_id, parent_id, end_turn) parent_id = line["message"]["id"]
if "parts" in line["message"]["content"]: if response_fields:
new_message = line["message"]["content"]["parts"][0] response_fields = False
if len(new_message) > last_message: yield ResponseFields(conversation_id, parent_id, end_turn)
yield new_message[last_message:] if "parts" in line["message"]["content"]:
last_message = len(new_message) new_message = line["message"]["content"]["parts"][0]
if len(new_message) > last_message:
yield new_message[last_message:]
last_message = len(new_message)
if "finish_details" in line["message"]["metadata"]: if "finish_details" in line["message"]["metadata"]:
if line["message"]["metadata"]["finish_details"]["type"] == "stop": if line["message"]["metadata"]["finish_details"]["type"] == "stop":
end_turn.end() end_turn.end()
break
except Exception as e: except Exception as e:
yield e raise e
if not auto_continue: if not auto_continue:
break break
action = "continue" action = "continue"
await asyncio.sleep(5) await asyncio.sleep(5)
if history_disabled: if history_disabled and auto_continue:
async with session.patch( await cls._delete_conversation(session, headers, conversation_id)
f"{cls.url}/backend-api/conversation/{conversation_id}",
json={"is_visible": False},
headers=headers
) as response:
response.raise_for_status()
@classmethod @classmethod
def browse_access_token(cls, proxy: str = None) -> tuple[str, dict]: def _browse_access_token(cls, proxy: str = None) -> tuple[str, dict]:
"""
Browse to obtain an access token.
Args:
proxy (str): Proxy to use for browsing.
Returns:
tuple[str, dict]: A tuple containing the access token and cookies.
"""
driver = get_browser(proxy=proxy) driver = get_browser(proxy=proxy)
try: try:
driver.get(f"{cls.url}/") driver.get(f"{cls.url}/")
WebDriverWait(driver, 1200).until( WebDriverWait(driver, 1200).until(EC.presence_of_element_located((By.ID, "prompt-textarea")))
EC.presence_of_element_located((By.ID, "prompt-textarea")) access_token = driver.execute_script(
"let session = await fetch('/api/auth/session');"
"let data = await session.json();"
"let accessToken = data['accessToken'];"
"let expires = new Date(); expires.setTime(expires.getTime() + 60 * 60 * 24 * 7);"
"document.cookie = 'access_token=' + accessToken + ';expires=' + expires.toUTCString() + ';path=/';"
"return accessToken;"
) )
javascript = """ return access_token, get_driver_cookies(driver)
access_token = (await (await fetch('/api/auth/session')).json())['accessToken'];
expires = new Date(); expires.setTime(expires.getTime() + 60 * 60 * 24 * 7); // One week
document.cookie = 'access_token=' + access_token + ';expires=' + expires.toUTCString() + ';path=/';
return access_token;
"""
return driver.execute_script(javascript), get_driver_cookies(driver)
finally: finally:
driver.quit() driver.quit()
@classmethod @classmethod
async def get_arkose_token(cls, session: StreamSession) -> str: async def _get_arkose_token(cls, session: StreamSession) -> str:
"""
Obtain an Arkose token for the session.
Args:
session (StreamSession): The session object.
Returns:
str: The Arkose token.
Raises:
RuntimeError: If unable to retrieve the token.
"""
config = { config = {
"pkey": "3D86FBBA-9D22-402A-B512-3420086BA6CC", "pkey": "3D86FBBA-9D22-402A-B512-3420086BA6CC",
"surl": "https://tcr9i.chat.openai.com", "surl": "https://tcr9i.chat.openai.com",
@ -332,26 +452,30 @@ return access_token;
if "token" in decoded_json: if "token" in decoded_json:
return decoded_json["token"] return decoded_json["token"]
raise RuntimeError(f"Response: {decoded_json}") raise RuntimeError(f"Response: {decoded_json}")
class EndTurn(): class EndTurn:
"""
Class to represent the end of a conversation turn.
"""
def __init__(self): def __init__(self):
self.is_end = False self.is_end = False
def end(self): def end(self):
self.is_end = True self.is_end = True
class ResponseFields(): class ResponseFields:
def __init__( """
self, Class to encapsulate response fields.
conversation_id: str, """
message_id: str, def __init__(self, conversation_id: str, message_id: str, end_turn: EndTurn):
end_turn: EndTurn
):
self.conversation_id = conversation_id self.conversation_id = conversation_id
self.message_id = message_id self.message_id = message_id
self._end_turn = end_turn self._end_turn = end_turn
class Response(): class Response():
"""
Class to encapsulate a response from the chat service.
"""
def __init__( def __init__(
self, self,
generator: AsyncResult, generator: AsyncResult,
@ -360,13 +484,13 @@ class Response():
options: dict options: dict
): ):
self._generator = generator self._generator = generator
self.action: str = action self.action = action
self.is_end: bool = False self.is_end = False
self._message = None self._message = None
self._messages = messages self._messages = messages
self._options = options self._options = options
self._fields = None self._fields = None
async def generator(self): async def generator(self):
if self._generator: if self._generator:
self._generator = None self._generator = None
@ -384,19 +508,16 @@ class Response():
def __aiter__(self): def __aiter__(self):
return self.generator() return self.generator()
@async_cached_property @async_cached_property
async def message(self) -> str: async def message(self) -> str:
[_ async for _ in self.generator()] await self.generator()
return self._message return self._message
async def get_fields(self): async def get_fields(self):
[_ async for _ in self.generator()] await self.generator()
return { return {"conversation_id": self._fields.conversation_id, "parent_id": self._fields.message_id}
"conversation_id": self._fields.conversation_id,
"parent_id": self._fields.message_id,
}
async def next(self, prompt: str, **kwargs) -> Response: async def next(self, prompt: str, **kwargs) -> Response:
return await OpenaiChat.create( return await OpenaiChat.create(
**self._options, **self._options,
@ -406,7 +527,7 @@ class Response():
**await self.get_fields(), **await self.get_fields(),
**kwargs **kwargs
) )
async def do_continue(self, **kwargs) -> Response: async def do_continue(self, **kwargs) -> Response:
fields = await self.get_fields() fields = await self.get_fields()
if self.is_end: if self.is_end:
@ -418,7 +539,7 @@ class Response():
**fields, **fields,
**kwargs **kwargs
) )
async def variant(self, **kwargs) -> Response: async def variant(self, **kwargs) -> Response:
if self.action != "next": if self.action != "next":
raise RuntimeError("Can't create variant from continue or variant request.") raise RuntimeError("Can't create variant from continue or variant request.")
@ -429,11 +550,9 @@ class Response():
**await self.get_fields(), **await self.get_fields(),
**kwargs **kwargs
) )
@async_cached_property @async_cached_property
async def messages(self): async def messages(self):
messages = self._messages messages = self._messages
messages.append({ messages.append({"role": "assistant", "content": await self.message})
"role": "assistant", "content": await self.message
})
return messages return messages

View File

@ -7,8 +7,17 @@ from ..base_provider import BaseRetryProvider
from .. import debug from .. import debug
from ..errors import RetryProviderError, RetryNoProviderError from ..errors import RetryProviderError, RetryNoProviderError
class RetryProvider(BaseRetryProvider): class RetryProvider(BaseRetryProvider):
"""
A provider class to handle retries for creating completions with different providers.
Attributes:
providers (list): A list of provider instances.
shuffle (bool): A flag indicating whether to shuffle providers before use.
exceptions (dict): A dictionary to store exceptions encountered during retries.
last_provider (BaseProvider): The last provider that was used.
"""
def create_completion( def create_completion(
self, self,
model: str, model: str,
@ -16,10 +25,21 @@ class RetryProvider(BaseRetryProvider):
stream: bool = False, stream: bool = False,
**kwargs **kwargs
) -> CreateResult: ) -> CreateResult:
if stream: """
providers = [provider for provider in self.providers if provider.supports_stream] Create a completion using available providers, with an option to stream the response.
else:
providers = self.providers Args:
model (str): The model to be used for completion.
messages (Messages): The messages to be used for generating completion.
stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False.
Yields:
CreateResult: Tokens or results from the completion.
Raises:
Exception: Any exception encountered during the completion process.
"""
providers = [p for p in self.providers if stream and p.supports_stream] if stream else self.providers
if self.shuffle: if self.shuffle:
random.shuffle(providers) random.shuffle(providers)
@ -50,10 +70,23 @@ class RetryProvider(BaseRetryProvider):
messages: Messages, messages: Messages,
**kwargs **kwargs
) -> str: ) -> str:
"""
Asynchronously create a completion using available providers.
Args:
model (str): The model to be used for completion.
messages (Messages): The messages to be used for generating completion.
Returns:
str: The result of the asynchronous completion.
Raises:
Exception: Any exception encountered during the asynchronous completion process.
"""
providers = self.providers providers = self.providers
if self.shuffle: if self.shuffle:
random.shuffle(providers) random.shuffle(providers)
self.exceptions = {} self.exceptions = {}
for provider in providers: for provider in providers:
self.last_provider = provider self.last_provider = provider
@ -66,13 +99,20 @@ class RetryProvider(BaseRetryProvider):
self.exceptions[provider.__name__] = e self.exceptions[provider.__name__] = e
if debug.logging: if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}") print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
self.raise_exceptions() self.raise_exceptions()
def raise_exceptions(self) -> None: def raise_exceptions(self) -> None:
"""
Raise a combined exception if any occurred during retries.
Raises:
RetryProviderError: If any provider encountered an exception.
RetryNoProviderError: If no provider is found.
"""
if self.exceptions: if self.exceptions:
raise RetryProviderError("RetryProvider failed:\n" + "\n".join([ raise RetryProviderError("RetryProvider failed:\n" + "\n".join([
f"{p}: {exception.__class__.__name__}: {exception}" for p, exception in self.exceptions.items() f"{p}: {exception.__class__.__name__}: {exception}" for p, exception in self.exceptions.items()
])) ]))
raise RetryNoProviderError("No provider found") raise RetryNoProviderError("No provider found")

View File

@ -15,6 +15,26 @@ def get_model_and_provider(model : Union[Model, str],
ignored : list[str] = None, ignored : list[str] = None,
ignore_working: bool = False, ignore_working: bool = False,
ignore_stream: bool = False) -> tuple[str, ProviderType]: ignore_stream: bool = False) -> tuple[str, ProviderType]:
"""
Retrieves the model and provider based on input parameters.
Args:
model (Union[Model, str]): The model to use, either as an object or a string identifier.
provider (Union[ProviderType, str, None]): The provider to use, either as an object, a string identifier, or None.
stream (bool): Indicates if the operation should be performed as a stream.
ignored (list[str], optional): List of provider names to be ignored.
ignore_working (bool, optional): If True, ignores the working status of the provider.
ignore_stream (bool, optional): If True, ignores the streaming capability of the provider.
Returns:
tuple[str, ProviderType]: A tuple containing the model name and the provider type.
Raises:
ProviderNotFoundError: If the provider is not found.
ModelNotFoundError: If the model is not found.
ProviderNotWorkingError: If the provider is not working.
StreamNotSupportedError: If streaming is not supported by the provider.
"""
if debug.version_check: if debug.version_check:
debug.version_check = False debug.version_check = False
version.utils.check_version() version.utils.check_version()
@ -70,7 +90,30 @@ class ChatCompletion:
ignore_stream_and_auth: bool = False, ignore_stream_and_auth: bool = False,
patch_provider: callable = None, patch_provider: callable = None,
**kwargs) -> Union[CreateResult, str]: **kwargs) -> Union[CreateResult, str]:
"""
Creates a chat completion using the specified model, provider, and messages.
Args:
model (Union[Model, str]): The model to use, either as an object or a string identifier.
messages (Messages): The messages for which the completion is to be created.
provider (Union[ProviderType, str, None], optional): The provider to use, either as an object, a string identifier, or None.
stream (bool, optional): Indicates if the operation should be performed as a stream.
auth (Union[str, None], optional): Authentication token or credentials, if required.
ignored (list[str], optional): List of provider names to be ignored.
ignore_working (bool, optional): If True, ignores the working status of the provider.
ignore_stream_and_auth (bool, optional): If True, ignores the stream and authentication requirement checks.
patch_provider (callable, optional): Function to modify the provider.
**kwargs: Additional keyword arguments.
Returns:
Union[CreateResult, str]: The result of the chat completion operation.
Raises:
AuthenticationRequiredError: If authentication is required but not provided.
ProviderNotFoundError, ModelNotFoundError: If the specified provider or model is not found.
ProviderNotWorkingError: If the provider is not operational.
StreamNotSupportedError: If streaming is requested but not supported by the provider.
"""
model, provider = get_model_and_provider(model, provider, stream, ignored, ignore_working, ignore_stream_and_auth) model, provider = get_model_and_provider(model, provider, stream, ignored, ignore_working, ignore_stream_and_auth)
if not ignore_stream_and_auth and provider.needs_auth and not auth: if not ignore_stream_and_auth and provider.needs_auth and not auth:
@ -98,7 +141,24 @@ class ChatCompletion:
ignored : list[str] = None, ignored : list[str] = None,
patch_provider: callable = None, patch_provider: callable = None,
**kwargs) -> Union[AsyncResult, str]: **kwargs) -> Union[AsyncResult, str]:
"""
Asynchronously creates a completion using the specified model and provider.
Args:
model (Union[Model, str]): The model to use, either as an object or a string identifier.
messages (Messages): Messages to be processed.
provider (Union[ProviderType, str, None]): The provider to use, either as an object, a string identifier, or None.
stream (bool): Indicates if the operation should be performed as a stream.
ignored (list[str], optional): List of provider names to be ignored.
patch_provider (callable, optional): Function to modify the provider.
**kwargs: Additional keyword arguments.
Returns:
Union[AsyncResult, str]: The result of the asynchronous chat completion operation.
Raises:
StreamNotSupportedError: If streaming is requested but not supported by the provider.
"""
model, provider = get_model_and_provider(model, provider, False, ignored) model, provider = get_model_and_provider(model, provider, False, ignored)
if stream: if stream:
@ -118,7 +178,23 @@ class Completion:
provider : Union[ProviderType, None] = None, provider : Union[ProviderType, None] = None,
stream : bool = False, stream : bool = False,
ignored : list[str] = None, **kwargs) -> Union[CreateResult, str]: ignored : list[str] = None, **kwargs) -> Union[CreateResult, str]:
"""
Creates a completion based on the provided model, prompt, and provider.
Args:
model (Union[Model, str]): The model to use, either as an object or a string identifier.
prompt (str): The prompt text for which the completion is to be created.
provider (Union[ProviderType, None], optional): The provider to use, either as an object or None.
stream (bool, optional): Indicates if the operation should be performed as a stream.
ignored (list[str], optional): List of provider names to be ignored.
**kwargs: Additional keyword arguments.
Returns:
Union[CreateResult, str]: The result of the completion operation.
Raises:
ModelNotAllowedError: If the specified model is not allowed for use with this method.
"""
allowed_models = [ allowed_models = [
'code-davinci-002', 'code-davinci-002',
'text-ada-001', 'text-ada-001',
@ -137,6 +213,15 @@ class Completion:
return result if stream else ''.join(result) return result if stream else ''.join(result)
def get_last_provider(as_dict: bool = False) -> Union[ProviderType, dict[str, str]]: def get_last_provider(as_dict: bool = False) -> Union[ProviderType, dict[str, str]]:
"""
Retrieves the last used provider.
Args:
as_dict (bool, optional): If True, returns the provider information as a dictionary.
Returns:
Union[ProviderType, dict[str, str]]: The last used provider, either as an object or a dictionary.
"""
last = debug.last_provider last = debug.last_provider
if isinstance(last, BaseRetryProvider): if isinstance(last, BaseRetryProvider):
last = last.last_provider last = last.last_provider

View File

@ -1,7 +1,22 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from .typing import Messages, CreateResult, Union from typing import Union, List, Dict, Type
from .typing import Messages, CreateResult
class BaseProvider(ABC): class BaseProvider(ABC):
"""
Abstract base class for a provider.
Attributes:
url (str): URL of the provider.
working (bool): Indicates if the provider is currently working.
needs_auth (bool): Indicates if the provider needs authentication.
supports_stream (bool): Indicates if the provider supports streaming.
supports_gpt_35_turbo (bool): Indicates if the provider supports GPT-3.5 Turbo.
supports_gpt_4 (bool): Indicates if the provider supports GPT-4.
supports_message_history (bool): Indicates if the provider supports message history.
params (str): List parameters for the provider.
"""
url: str = None url: str = None
working: bool = False working: bool = False
needs_auth: bool = False needs_auth: bool = False
@ -20,6 +35,18 @@ class BaseProvider(ABC):
stream: bool, stream: bool,
**kwargs **kwargs
) -> CreateResult: ) -> CreateResult:
"""
Create a completion with the given parameters.
Args:
model (str): The model to use.
messages (Messages): The messages to process.
stream (bool): Whether to use streaming.
**kwargs: Additional keyword arguments.
Returns:
CreateResult: The result of the creation process.
"""
raise NotImplementedError() raise NotImplementedError()
@classmethod @classmethod
@ -30,25 +57,59 @@ class BaseProvider(ABC):
messages: Messages, messages: Messages,
**kwargs **kwargs
) -> str: ) -> str:
"""
Asynchronously create a completion with the given parameters.
Args:
model (str): The model to use.
messages (Messages): The messages to process.
**kwargs: Additional keyword arguments.
Returns:
str: The result of the creation process.
"""
raise NotImplementedError() raise NotImplementedError()
@classmethod @classmethod
def get_dict(cls): def get_dict(cls) -> Dict[str, str]:
"""
Get a dictionary representation of the provider.
Returns:
Dict[str, str]: A dictionary with provider's details.
"""
return {'name': cls.__name__, 'url': cls.url} return {'name': cls.__name__, 'url': cls.url}
class BaseRetryProvider(BaseProvider): class BaseRetryProvider(BaseProvider):
"""
Base class for a provider that implements retry logic.
Attributes:
providers (List[Type[BaseProvider]]): List of providers to use for retries.
shuffle (bool): Whether to shuffle the providers list.
exceptions (Dict[str, Exception]): Dictionary of exceptions encountered.
last_provider (Type[BaseProvider]): The last provider used.
"""
__name__: str = "RetryProvider" __name__: str = "RetryProvider"
supports_stream: bool = True supports_stream: bool = True
def __init__( def __init__(
self, self,
providers: list[type[BaseProvider]], providers: List[Type[BaseProvider]],
shuffle: bool = True shuffle: bool = True
) -> None: ) -> None:
self.providers: list[type[BaseProvider]] = providers """
self.shuffle: bool = shuffle Initialize the BaseRetryProvider.
self.working: bool = True
self.exceptions: dict[str, Exception] = {} Args:
self.last_provider: type[BaseProvider] = None providers (List[Type[BaseProvider]]): List of providers to use.
shuffle (bool): Whether to shuffle the providers list.
"""
self.providers = providers
self.shuffle = shuffle
self.working = True
self.exceptions: Dict[str, Exception] = {}
self.last_provider: Type[BaseProvider] = None
ProviderType = Union[type[BaseProvider], BaseRetryProvider] ProviderType = Union[Type[BaseProvider], BaseRetryProvider]

View File

@ -404,7 +404,7 @@ body {
display: none; display: none;
} }
#image { #image, #file {
display: none; display: none;
} }
@ -412,13 +412,22 @@ label[for="image"]:has(> input:valid){
color: var(--accent); color: var(--accent);
} }
label[for="image"] { label[for="file"]:has(> input:valid){
color: var(--accent);
}
label[for="image"], label[for="file"] {
cursor: pointer; cursor: pointer;
position: absolute; position: absolute;
top: 10px; top: 10px;
left: 10px; left: 10px;
} }
label[for="file"] {
top: 32px;
left: 10px;
}
.buttons input[type="checkbox"] { .buttons input[type="checkbox"] {
height: 0; height: 0;
width: 0; width: 0;

View File

@ -118,6 +118,10 @@
<input type="file" id="image" name="image" accept="image/png, image/gif, image/jpeg" required/> <input type="file" id="image" name="image" accept="image/png, image/gif, image/jpeg" required/>
<i class="fa-regular fa-image"></i> <i class="fa-regular fa-image"></i>
</label> </label>
<label for="file">
<input type="file" id="file" name="file" accept="text/plain, text/html, text/xml, application/json, text/javascript, .sh, .py, .php, .css, .yaml, .sql, .svg, .log, .csv, .twig, .md" required/>
<i class="fa-solid fa-paperclip"></i>
</label>
<div id="send-button"> <div id="send-button">
<i class="fa-solid fa-paper-plane-top"></i> <i class="fa-solid fa-paper-plane-top"></i>
</div> </div>
@ -125,7 +129,14 @@
</div> </div>
<div class="buttons"> <div class="buttons">
<div class="field"> <div class="field">
<select name="model" id="model"></select> <select name="model" id="model">
<option value="">Model: Default</option>
<option value="gpt-4">gpt-4</option>
<option value="gpt-3.5-turbo">gpt-3.5-turbo</option>
<option value="llama2-70b">llama2-70b</option>
<option value="gemini-pro">gemini-pro</option>
<option value="">----</option>
</select>
</div> </div>
<div class="field"> <div class="field">
<select name="jailbreak" id="jailbreak" style="display: none;"> <select name="jailbreak" id="jailbreak" style="display: none;">
@ -138,7 +149,16 @@
<option value="gpt-evil-1.0">evil 1.0</option> <option value="gpt-evil-1.0">evil 1.0</option>
</select> </select>
<div class="field"> <div class="field">
<select name="provider" id="provider"></select> <select name="provider" id="provider">
<option value="">Provider: Auto</option>
<option value="Bing">Bing</option>
<option value="OpenaiChat">OpenaiChat</option>
<option value="HuggingChat">HuggingChat</option>
<option value="Bard">Bard</option>
<option value="Liaobots">Liaobots</option>
<option value="Phind">Phind</option>
<option value="">----</option>
</select>
</div> </div>
</div> </div>
<div class="field"> <div class="field">

View File

@ -7,7 +7,9 @@ const spinner = box_conversations.querySelector(".spinner");
const stop_generating = document.querySelector(`.stop_generating`); const stop_generating = document.querySelector(`.stop_generating`);
const regenerate = document.querySelector(`.regenerate`); const regenerate = document.querySelector(`.regenerate`);
const send_button = document.querySelector(`#send-button`); const send_button = document.querySelector(`#send-button`);
const imageInput = document.querySelector('#image') ; const imageInput = document.querySelector('#image');
const fileInput = document.querySelector('#file');
let prompt_lock = false; let prompt_lock = false;
hljs.addPlugin(new CopyButtonPlugin()); hljs.addPlugin(new CopyButtonPlugin());
@ -42,6 +44,11 @@ const handle_ask = async () => {
if (message.length > 0) { if (message.length > 0) {
message_input.value = ''; message_input.value = '';
await add_conversation(window.conversation_id, message); await add_conversation(window.conversation_id, message);
if ("text" in fileInput.dataset) {
message += '\n```' + fileInput.dataset.type + '\n';
message += fileInput.dataset.text;
message += '\n```'
}
await add_message(window.conversation_id, "user", message); await add_message(window.conversation_id, "user", message);
window.token = message_id(); window.token = message_id();
message_box.innerHTML += ` message_box.innerHTML += `
@ -55,6 +62,9 @@ const handle_ask = async () => {
</div> </div>
</div> </div>
`; `;
document.querySelectorAll('code:not(.hljs').forEach((el) => {
hljs.highlightElement(el);
});
await ask_gpt(); await ask_gpt();
} }
}; };
@ -171,17 +181,30 @@ const ask_gpt = async () => {
content_inner.innerHTML += "<p>An error occured, please try again, if the problem persists, please use a other model or provider.</p>"; content_inner.innerHTML += "<p>An error occured, please try again, if the problem persists, please use a other model or provider.</p>";
} else { } else {
html = markdown_render(text); html = markdown_render(text);
html = html.substring(0, html.lastIndexOf('</p>')) + '<span id="cursor"></span></p>'; let lastElement, lastIndex = null;
for (element of ['</p>', '</code></pre>', '</li>\n</ol>']) {
const index = html.lastIndexOf(element)
if (index > lastIndex) {
lastElement = element;
lastIndex = index;
}
}
if (lastIndex) {
html = html.substring(0, lastIndex) + '<span id="cursor"></span>' + lastElement;
}
content_inner.innerHTML = html; content_inner.innerHTML = html;
document.querySelectorAll('code').forEach((el) => { document.querySelectorAll('code:not(.hljs').forEach((el) => {
hljs.highlightElement(el); hljs.highlightElement(el);
}); });
} }
window.scrollTo(0, 0); window.scrollTo(0, 0);
message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" }); if (message_box.scrollTop >= message_box.scrollHeight - message_box.clientHeight - 100) {
message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" });
}
} }
if (!error && imageInput) imageInput.value = ""; if (!error && imageInput) imageInput.value = "";
if (!error && fileInput) fileInput.value = "";
} catch (e) { } catch (e) {
console.error(e); console.error(e);
@ -305,7 +328,7 @@ const load_conversation = async (conversation_id) => {
`; `;
} }
document.querySelectorAll(`code`).forEach((el) => { document.querySelectorAll('code:not(.hljs').forEach((el) => {
hljs.highlightElement(el); hljs.highlightElement(el);
}); });
@ -400,7 +423,7 @@ const load_conversations = async (limit, offset, loader) => {
`; `;
} }
document.querySelectorAll(`code`).forEach((el) => { document.querySelectorAll('code:not(.hljs').forEach((el) => {
hljs.highlightElement(el); hljs.highlightElement(el);
}); });
}; };
@ -602,14 +625,7 @@ observer.observe(message_input, { attributes: true });
(async () => { (async () => {
response = await fetch('/backend-api/v2/models') response = await fetch('/backend-api/v2/models')
models = await response.json() models = await response.json()
let select = document.getElementById('model'); let select = document.getElementById('model');
select.textContent = '';
let auto = document.createElement('option');
auto.value = '';
auto.text = 'Model: Default';
select.appendChild(auto);
for (model of models) { for (model of models) {
let option = document.createElement('option'); let option = document.createElement('option');
@ -619,14 +635,7 @@ observer.observe(message_input, { attributes: true });
response = await fetch('/backend-api/v2/providers') response = await fetch('/backend-api/v2/providers')
providers = await response.json() providers = await response.json()
select = document.getElementById('provider'); select = document.getElementById('provider');
select.textContent = '';
auto = document.createElement('option');
auto.value = '';
auto.text = 'Provider: Auto';
select.appendChild(auto);
for (provider of providers) { for (provider of providers) {
let option = document.createElement('option'); let option = document.createElement('option');
@ -650,4 +659,27 @@ observer.observe(message_input, { attributes: true });
text += versions["version"]; text += versions["version"];
} }
document.getElementById("version_text").innerHTML = text document.getElementById("version_text").innerHTML = text
})() })()
fileInput.addEventListener('change', async (event) => {
if (fileInput.files.length) {
type = fileInput.files[0].type;
if (type && type.indexOf('/')) {
type = type.split('/').pop().replace('x-', '')
type = type.replace('plain', 'plaintext')
.replace('shellscript', 'sh')
.replace('svg+xml', 'svg')
.replace('vnd.trolltech.linguist', 'ts')
} else {
type = fileInput.files[0].name.split('.').pop()
}
fileInput.dataset.type = type
const reader = new FileReader();
reader.addEventListener('load', (event) => {
fileInput.dataset.text = event.target.result;
});
reader.readAsText(fileInput.files[0]);
} else {
delete fileInput.dataset.text;
}
});

View File

@ -4,9 +4,18 @@ import base64
from .typing import ImageType, Union from .typing import ImageType, Union
from PIL import Image from PIL import Image
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'} ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp'}
def to_image(image: ImageType) -> Image.Image: def to_image(image: ImageType) -> Image.Image:
"""
Converts the input image to a PIL Image object.
Args:
image (Union[str, bytes, Image.Image]): The input image.
Returns:
Image.Image: The converted PIL Image object.
"""
if isinstance(image, str): if isinstance(image, str):
is_data_uri_an_image(image) is_data_uri_an_image(image)
image = extract_data_uri(image) image = extract_data_uri(image)
@ -20,21 +29,48 @@ def to_image(image: ImageType) -> Image.Image:
image = copy image = copy
return image return image
def is_allowed_extension(filename) -> bool: def is_allowed_extension(filename: str) -> bool:
"""
Checks if the given filename has an allowed extension.
Args:
filename (str): The filename to check.
Returns:
bool: True if the extension is allowed, False otherwise.
"""
return '.' in filename and \ return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def is_data_uri_an_image(data_uri: str) -> bool: def is_data_uri_an_image(data_uri: str) -> bool:
"""
Checks if the given data URI represents an image.
Args:
data_uri (str): The data URI to check.
Raises:
ValueError: If the data URI is invalid or the image format is not allowed.
"""
# Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif) # Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif)
if not re.match(r'data:image/(\w+);base64,', data_uri): if not re.match(r'data:image/(\w+);base64,', data_uri):
raise ValueError("Invalid data URI image.") raise ValueError("Invalid data URI image.")
# Extract the image format from the data URI # Extract the image format from the data URI
image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1) image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1)
# Check if the image format is one of the allowed formats (jpg, jpeg, png, gif) # Check if the image format is one of the allowed formats (jpg, jpeg, png, gif)
if image_format.lower() not in ALLOWED_EXTENSIONS: if image_format.lower() not in ALLOWED_EXTENSIONS:
raise ValueError("Invalid image format (from mime file type).") raise ValueError("Invalid image format (from mime file type).")
def is_accepted_format(binary_data: bytes) -> bool: def is_accepted_format(binary_data: bytes) -> bool:
"""
Checks if the given binary data represents an image with an accepted format.
Args:
binary_data (bytes): The binary data to check.
Raises:
ValueError: If the image format is not allowed.
"""
if binary_data.startswith(b'\xFF\xD8\xFF'): if binary_data.startswith(b'\xFF\xD8\xFF'):
pass # It's a JPEG image pass # It's a JPEG image
elif binary_data.startswith(b'\x89PNG\r\n\x1a\n'): elif binary_data.startswith(b'\x89PNG\r\n\x1a\n'):
@ -49,13 +85,31 @@ def is_accepted_format(binary_data: bytes) -> bool:
pass # It's a WebP image pass # It's a WebP image
else: else:
raise ValueError("Invalid image format (from magic code).") raise ValueError("Invalid image format (from magic code).")
def extract_data_uri(data_uri: str) -> bytes: def extract_data_uri(data_uri: str) -> bytes:
"""
Extracts the binary data from the given data URI.
Args:
data_uri (str): The data URI.
Returns:
bytes: The extracted binary data.
"""
data = data_uri.split(",")[1] data = data_uri.split(",")[1]
data = base64.b64decode(data) data = base64.b64decode(data)
return data return data
def get_orientation(image: Image.Image) -> int: def get_orientation(image: Image.Image) -> int:
"""
Gets the orientation of the given image.
Args:
image (Image.Image): The image.
Returns:
int: The orientation value.
"""
exif_data = image.getexif() if hasattr(image, 'getexif') else image._getexif() exif_data = image.getexif() if hasattr(image, 'getexif') else image._getexif()
if exif_data is not None: if exif_data is not None:
orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF
@ -63,6 +117,17 @@ def get_orientation(image: Image.Image) -> int:
return orientation return orientation
def process_image(img: Image.Image, new_width: int, new_height: int) -> Image.Image: def process_image(img: Image.Image, new_width: int, new_height: int) -> Image.Image:
"""
Processes the given image by adjusting its orientation and resizing it.
Args:
img (Image.Image): The image to process.
new_width (int): The new width of the image.
new_height (int): The new height of the image.
Returns:
Image.Image: The processed image.
"""
orientation = get_orientation(img) orientation = get_orientation(img)
if orientation: if orientation:
if orientation > 4: if orientation > 4:
@ -75,13 +140,34 @@ def process_image(img: Image.Image, new_width: int, new_height: int) -> Image.Im
img = img.transpose(Image.ROTATE_90) img = img.transpose(Image.ROTATE_90)
img.thumbnail((new_width, new_height)) img.thumbnail((new_width, new_height))
return img return img
def to_base64(image: Image.Image, compression_rate: float) -> str: def to_base64(image: Image.Image, compression_rate: float) -> str:
"""
Converts the given image to a base64-encoded string.
Args:
image (Image.Image): The image to convert.
compression_rate (float): The compression rate (0.0 to 1.0).
Returns:
str: The base64-encoded image.
"""
output_buffer = BytesIO() output_buffer = BytesIO()
image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100)) image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100))
return base64.b64encode(output_buffer.getvalue()).decode() return base64.b64encode(output_buffer.getvalue()).decode()
def format_images_markdown(images, prompt: str, preview: str="{image}?w=200&h=200") -> str: def format_images_markdown(images, prompt: str, preview: str="{image}?w=200&h=200") -> str:
"""
Formats the given images as a markdown string.
Args:
images: The images to format.
prompt (str): The prompt for the images.
preview (str, optional): The preview URL format. Defaults to "{image}?w=200&h=200".
Returns:
str: The formatted markdown string.
"""
if isinstance(images, list): if isinstance(images, list):
images = [f"[![#{idx+1} {prompt}]({preview.replace('{image}', image)})]({image})" for idx, image in enumerate(images)] images = [f"[![#{idx+1} {prompt}]({preview.replace('{image}', image)})]({image})" for idx, image in enumerate(images)]
images = "\n".join(images) images = "\n".join(images)
@ -92,6 +178,15 @@ def format_images_markdown(images, prompt: str, preview: str="{image}?w=200&h=20
return f"\n{start_flag}{images}\n{end_flag}\n" return f"\n{start_flag}{images}\n{end_flag}\n"
def to_bytes(image: Image.Image) -> bytes: def to_bytes(image: Image.Image) -> bytes:
"""
Converts the given image to bytes.
Args:
image (Image.Image): The image to convert.
Returns:
bytes: The image as bytes.
"""
bytes_io = BytesIO() bytes_io = BytesIO()
image.save(bytes_io, image.format) image.save(bytes_io, image.format)
image.seek(0) image.seek(0)

View File

@ -31,12 +31,21 @@ from .Provider import (
@dataclass(unsafe_hash=True) @dataclass(unsafe_hash=True)
class Model: class Model:
"""
Represents a machine learning model configuration.
Attributes:
name (str): Name of the model.
base_provider (str): Default provider for the model.
best_provider (ProviderType): The preferred provider for the model, typically with retry logic.
"""
name: str name: str
base_provider: str base_provider: str
best_provider: ProviderType = None best_provider: ProviderType = None
@staticmethod @staticmethod
def __all__() -> list[str]: def __all__() -> list[str]:
"""Returns a list of all model names."""
return _all_models return _all_models
default = Model( default = Model(
@ -298,6 +307,12 @@ pi = Model(
) )
class ModelUtils: class ModelUtils:
"""
Utility class for mapping string identifiers to Model instances.
Attributes:
convert (dict[str, Model]): Dictionary mapping model string identifiers to Model instances.
"""
convert: dict[str, Model] = { convert: dict[str, Model] = {
# gpt-3.5 # gpt-3.5
'gpt-3.5-turbo' : gpt_35_turbo, 'gpt-3.5-turbo' : gpt_35_turbo,

View File

@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
import json import json
from contextlib import asynccontextmanager
from functools import partialmethod from functools import partialmethod
from typing import AsyncGenerator from typing import AsyncGenerator
from urllib.parse import urlparse from urllib.parse import urlparse
@ -9,27 +8,41 @@ from curl_cffi.requests import AsyncSession, Session, Response
from .webdriver import WebDriver, WebDriverSession, bypass_cloudflare, get_driver_cookies from .webdriver import WebDriver, WebDriverSession, bypass_cloudflare, get_driver_cookies
class StreamResponse: class StreamResponse:
"""
A wrapper class for handling asynchronous streaming responses.
Attributes:
inner (Response): The original Response object.
"""
def __init__(self, inner: Response) -> None: def __init__(self, inner: Response) -> None:
"""Initialize the StreamResponse with the provided Response object."""
self.inner: Response = inner self.inner: Response = inner
async def text(self) -> str: async def text(self) -> str:
"""Asynchronously get the response text."""
return await self.inner.atext() return await self.inner.atext()
def raise_for_status(self) -> None: def raise_for_status(self) -> None:
"""Raise an HTTPError if one occurred."""
self.inner.raise_for_status() self.inner.raise_for_status()
async def json(self, **kwargs) -> dict: async def json(self, **kwargs) -> dict:
"""Asynchronously parse the JSON response content."""
return json.loads(await self.inner.acontent(), **kwargs) return json.loads(await self.inner.acontent(), **kwargs)
async def iter_lines(self) -> AsyncGenerator[bytes, None]: async def iter_lines(self) -> AsyncGenerator[bytes, None]:
"""Asynchronously iterate over the lines of the response."""
async for line in self.inner.aiter_lines(): async for line in self.inner.aiter_lines():
yield line yield line
async def iter_content(self) -> AsyncGenerator[bytes, None]: async def iter_content(self) -> AsyncGenerator[bytes, None]:
"""Asynchronously iterate over the response content."""
async for chunk in self.inner.aiter_content(): async for chunk in self.inner.aiter_content():
yield chunk yield chunk
async def __aenter__(self): async def __aenter__(self):
"""Asynchronously enter the runtime context for the response object."""
inner: Response = await self.inner inner: Response = await self.inner
self.inner = inner self.inner = inner
self.request = inner.request self.request = inner.request
@ -39,24 +52,47 @@ class StreamResponse:
self.headers = inner.headers self.headers = inner.headers
self.cookies = inner.cookies self.cookies = inner.cookies
return self return self
async def __aexit__(self, *args): async def __aexit__(self, *args):
"""Asynchronously exit the runtime context for the response object."""
await self.inner.aclose() await self.inner.aclose()
class StreamSession(AsyncSession): class StreamSession(AsyncSession):
"""
An asynchronous session class for handling HTTP requests with streaming.
Inherits from AsyncSession.
"""
def request( def request(
self, method: str, url: str, **kwargs self, method: str, url: str, **kwargs
) -> StreamResponse: ) -> StreamResponse:
"""Create and return a StreamResponse object for the given HTTP request."""
return StreamResponse(super().request(method, url, stream=True, **kwargs)) return StreamResponse(super().request(method, url, stream=True, **kwargs))
# Defining HTTP methods as partial methods of the request method.
head = partialmethod(request, "HEAD") head = partialmethod(request, "HEAD")
get = partialmethod(request, "GET") get = partialmethod(request, "GET")
post = partialmethod(request, "POST") post = partialmethod(request, "POST")
put = partialmethod(request, "PUT") put = partialmethod(request, "PUT")
patch = partialmethod(request, "PATCH") patch = partialmethod(request, "PATCH")
delete = partialmethod(request, "DELETE") delete = partialmethod(request, "DELETE")
def get_session_from_browser(url: str, webdriver: WebDriver = None, proxy: str = None, timeout: int = 120):
def get_session_from_browser(url: str, webdriver: WebDriver = None, proxy: str = None, timeout: int = 120) -> Session:
"""
Create a Session object using a WebDriver to handle cookies and headers.
Args:
url (str): The URL to navigate to using the WebDriver.
webdriver (WebDriver, optional): The WebDriver instance to use.
proxy (str, optional): Proxy server to use for the Session.
timeout (int, optional): Timeout in seconds for the WebDriver.
Returns:
Session: A Session object configured with cookies and headers from the WebDriver.
"""
with WebDriverSession(webdriver, "", proxy=proxy, virtual_display=True) as driver: with WebDriverSession(webdriver, "", proxy=proxy, virtual_display=True) as driver:
bypass_cloudflare(driver, url, timeout) bypass_cloudflare(driver, url, timeout)
cookies = get_driver_cookies(driver) cookies = get_driver_cookies(driver)
@ -78,4 +114,4 @@ def get_session_from_browser(url: str, webdriver: WebDriver = None, proxy: str =
proxies={"https": proxy, "http": proxy}, proxies={"https": proxy, "http": proxy},
timeout=timeout, timeout=timeout,
impersonate="chrome110" impersonate="chrome110"
) )

View File

@ -5,45 +5,94 @@ from importlib.metadata import version as get_package_version, PackageNotFoundEr
from subprocess import check_output, CalledProcessError, PIPE from subprocess import check_output, CalledProcessError, PIPE
from .errors import VersionNotFoundError from .errors import VersionNotFoundError
def get_latest_version() -> str: def get_pypi_version(package_name: str) -> str:
try: """
get_package_version("g4f") Get the latest version of a package from PyPI.
response = requests.get("https://pypi.org/pypi/g4f/json").json()
return response["info"]["version"]
except PackageNotFoundError:
url = "https://api.github.com/repos/xtekky/gpt4free/releases/latest"
response = requests.get(url).json()
return response["tag_name"]
class VersionUtils(): :param package_name: The name of the package.
:return: The latest version of the package as a string.
"""
try:
response = requests.get(f"https://pypi.org/pypi/{package_name}/json").json()
return response["info"]["version"]
except requests.RequestException as e:
raise VersionNotFoundError(f"Failed to get PyPI version: {e}")
def get_github_version(repo: str) -> str:
"""
Get the latest release version from a GitHub repository.
:param repo: The name of the GitHub repository.
:return: The latest release version as a string.
"""
try:
response = requests.get(f"https://api.github.com/repos/{repo}/releases/latest").json()
return response["tag_name"]
except requests.RequestException as e:
raise VersionNotFoundError(f"Failed to get GitHub release version: {e}")
def get_latest_version():
"""
Get the latest release version from PyPI or the GitHub repository.
:return: The latest release version as a string.
"""
try:
# Is installed via package manager?
get_package_version("g4f")
return get_pypi_version("g4f")
except PackageNotFoundError:
# Else use Github version:
return get_github_version("xtekky/gpt4free")
class VersionUtils:
"""
Utility class for managing and comparing package versions.
"""
@cached_property @cached_property
def current_version(self) -> str: def current_version(self) -> str:
"""
Get the current version of the g4f package.
:return: The current version as a string.
"""
# Read from package manager # Read from package manager
try: try:
return get_package_version("g4f") return get_package_version("g4f")
except PackageNotFoundError: except PackageNotFoundError:
pass pass
# Read from docker environment # Read from docker environment
version = environ.get("G4F_VERSION") version = environ.get("G4F_VERSION")
if version: if version:
return version return version
# Read from git repository # Read from git repository
try: try:
command = ["git", "describe", "--tags", "--abbrev=0"] command = ["git", "describe", "--tags", "--abbrev=0"]
return check_output(command, text=True, stderr=PIPE).strip() return check_output(command, text=True, stderr=PIPE).strip()
except CalledProcessError: except CalledProcessError:
pass pass
raise VersionNotFoundError("Version not found") raise VersionNotFoundError("Version not found")
@cached_property @cached_property
def latest_version(self) -> str: def latest_version(self) -> str:
"""
Get the latest version of the g4f package.
:return: The latest version as a string.
"""
return get_latest_version() return get_latest_version()
def check_version(self) -> None: def check_version(self) -> None:
"""
Check if the current version is up to date with the latest version.
"""
try: try:
if self.current_version != self.latest_version: if self.current_version != self.latest_version:
print(f'New g4f version: {self.latest_version} (current: {self.current_version}) | pip install -U g4f') print(f'New g4f version: {self.latest_version} (current: {self.current_version}) | pip install -U g4f')
except Exception as e: except Exception as e:
print(f'Failed to check g4f version: {e}') print(f'Failed to check g4f version: {e}')
utils = VersionUtils() utils = VersionUtils()

View File

@ -1,5 +1,4 @@
from __future__ import annotations from __future__ import annotations
from platformdirs import user_config_dir from platformdirs import user_config_dir
from selenium.webdriver.remote.webdriver import WebDriver from selenium.webdriver.remote.webdriver import WebDriver
from undetected_chromedriver import Chrome, ChromeOptions from undetected_chromedriver import Chrome, ChromeOptions
@ -21,7 +20,16 @@ def get_browser(
proxy: str = None, proxy: str = None,
options: ChromeOptions = None options: ChromeOptions = None
) -> WebDriver: ) -> WebDriver:
if user_data_dir == None: """
Creates and returns a Chrome WebDriver with the specified options.
:param user_data_dir: Directory for user data. If None, uses default directory.
:param headless: Boolean indicating whether to run the browser in headless mode.
:param proxy: Proxy settings for the browser.
:param options: ChromeOptions object with specific browser options.
:return: An instance of WebDriver.
"""
if user_data_dir is None:
user_data_dir = user_config_dir("g4f") user_data_dir = user_config_dir("g4f")
if user_data_dir and debug.logging: if user_data_dir and debug.logging:
print("Open browser with config dir:", user_data_dir) print("Open browser with config dir:", user_data_dir)
@ -39,36 +47,45 @@ def get_browser(
headless=headless headless=headless
) )
def get_driver_cookies(driver: WebDriver): def get_driver_cookies(driver: WebDriver) -> dict:
return dict([(cookie["name"], cookie["value"]) for cookie in driver.get_cookies()]) """
Retrieves cookies from the given WebDriver.
:param driver: WebDriver from which to retrieve cookies.
:return: A dictionary of cookies.
"""
return {cookie["name"]: cookie["value"] for cookie in driver.get_cookies()}
def bypass_cloudflare(driver: WebDriver, url: str, timeout: int) -> None: def bypass_cloudflare(driver: WebDriver, url: str, timeout: int) -> None:
# Open website """
Attempts to bypass Cloudflare protection when accessing a URL using the provided WebDriver.
:param driver: The WebDriver to use.
:param url: URL to access.
:param timeout: Time in seconds to wait for the page to load.
"""
driver.get(url) driver.get(url)
# Is cloudflare protection
if driver.find_element(By.TAG_NAME, "body").get_attribute("class") == "no-js": if driver.find_element(By.TAG_NAME, "body").get_attribute("class") == "no-js":
if debug.logging: if debug.logging:
print("Cloudflare protection detected:", url) print("Cloudflare protection detected:", url)
try: try:
# Click button in iframe
WebDriverWait(driver, 5).until(
EC.presence_of_element_located((By.CSS_SELECTOR, "#turnstile-wrapper iframe"))
)
driver.switch_to.frame(driver.find_element(By.CSS_SELECTOR, "#turnstile-wrapper iframe")) driver.switch_to.frame(driver.find_element(By.CSS_SELECTOR, "#turnstile-wrapper iframe"))
WebDriverWait(driver, 5).until( WebDriverWait(driver, 5).until(
EC.presence_of_element_located((By.CSS_SELECTOR, "#challenge-stage input")) EC.presence_of_element_located((By.CSS_SELECTOR, "#challenge-stage input"))
) ).click()
driver.find_element(By.CSS_SELECTOR, "#challenge-stage input").click() except Exception as e:
except: if debug.logging:
pass print(f"Error bypassing Cloudflare: {e}")
finally: finally:
driver.switch_to.default_content() driver.switch_to.default_content()
# No cloudflare protection
WebDriverWait(driver, timeout).until( WebDriverWait(driver, timeout).until(
EC.presence_of_element_located((By.CSS_SELECTOR, "body:not(.no-js)")) EC.presence_of_element_located((By.CSS_SELECTOR, "body:not(.no-js)"))
) )
class WebDriverSession(): class WebDriverSession:
"""
Manages a Selenium WebDriver session, including handling of virtual displays and proxies.
"""
def __init__( def __init__(
self, self,
webdriver: WebDriver = None, webdriver: WebDriver = None,
@ -81,9 +98,7 @@ class WebDriverSession():
self.webdriver = webdriver self.webdriver = webdriver
self.user_data_dir = user_data_dir self.user_data_dir = user_data_dir
self.headless = headless self.headless = headless
self.virtual_display = None self.virtual_display = Display(size=(1920, 1080)) if has_pyvirtualdisplay and virtual_display else None
if has_pyvirtualdisplay and virtual_display:
self.virtual_display = Display(size=(1920, 1080))
self.proxy = proxy self.proxy = proxy
self.options = options self.options = options
self.default_driver = None self.default_driver = None
@ -94,8 +109,15 @@ class WebDriverSession():
headless: bool = False, headless: bool = False,
virtual_display: bool = False virtual_display: bool = False
) -> WebDriver: ) -> WebDriver:
if user_data_dir == None: """
user_data_dir = self.user_data_dir Reopens the WebDriver session with the specified parameters.
:param user_data_dir: Directory for user data.
:param headless: Boolean indicating whether to run the browser in headless mode.
:param virtual_display: Boolean indicating whether to use a virtual display.
:return: An instance of WebDriver.
"""
user_data_dir = user_data_dir or self.user_data_dir
if self.default_driver: if self.default_driver:
self.default_driver.quit() self.default_driver.quit()
if not virtual_display and self.virtual_display: if not virtual_display and self.virtual_display:
@ -105,6 +127,10 @@ class WebDriverSession():
return self.default_driver return self.default_driver
def __enter__(self) -> WebDriver: def __enter__(self) -> WebDriver:
"""
Context management method for entering a session.
:return: An instance of WebDriver.
"""
if self.webdriver: if self.webdriver:
return self.webdriver return self.webdriver
if self.virtual_display: if self.virtual_display:
@ -113,11 +139,15 @@ class WebDriverSession():
return self.default_driver return self.default_driver
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
"""
Context management method for exiting a session. Closes and quits the WebDriver.
"""
if self.default_driver: if self.default_driver:
try: try:
self.default_driver.close() self.default_driver.close()
except: except Exception as e:
pass if debug.logging:
print(f"Error closing WebDriver: {e}")
self.default_driver.quit() self.default_driver.quit()
if self.virtual_display: if self.virtual_display:
self.virtual_display.stop() self.virtual_display.stop()