Add image model list

This commit is contained in:
Heiner Lohaus 2024-04-21 15:15:55 +02:00
parent f66cd9f8a3
commit a26421bcd8
17 changed files with 129 additions and 93 deletions

View File

@ -14,6 +14,8 @@ async def test_async(provider: ProviderType):
return False
messages = [{"role": "user", "content": "Hello Assistant!"}]
try:
if "webdriver" in provider.get_parameters():
return False
response = await asyncio.wait_for(ChatCompletion.create_async(
model=models.default,
messages=messages,
@ -88,7 +90,7 @@ def print_models():
"huggingface": "Huggingface",
"anthropic": "Anthropic",
"inflection": "Inflection",
"meta": "Meta"
"meta": "Meta",
}
provider_urls = {
"google": "https://gemini.google.com/",
@ -96,7 +98,7 @@ def print_models():
"huggingface": "https://huggingface.co/",
"anthropic": "https://www.anthropic.com/",
"inflection": "https://inflection.ai/",
"meta": "https://llama.meta.com/"
"meta": "https://llama.meta.com/",
}
lines = [
@ -108,6 +110,8 @@ def print_models():
if name not in ("gpt-3.5-turbo", "gpt-4", "gpt-4-turbo"):
continue
name = re.split(r":|/", model.name)[-1]
if model.base_provider not in base_provider_names:
continue
base_provider = base_provider_names[model.base_provider]
if not isinstance(model.best_provider, BaseRetryProvider):
provider_name = f"g4f.Provider.{model.best_provider.__name__}"
@ -121,7 +125,24 @@ def print_models():
print("\n".join(lines))
def print_image_models():
lines = [
"| Label | Provider | Model | Website |",
"| ----- | -------- | ----- | ------- |",
]
from g4f.gui.server.api import Api
for image_model in Api.get_image_models():
provider_url = image_model["url"]
netloc = urlparse(provider_url).netloc.replace("www.", "")
website = f"[{netloc}]({provider_url})"
label = image_model["provider"] if image_model["label"] is None else image_model["label"]
lines.append(f'| {label} | {image_model["provider"]} | {image_model["image_model"]} | {website} |')
print("\n".join(lines))
if __name__ == "__main__":
print_providers()
print("\n", "-" * 50, "\n")
print_models()
print_models()
print("\n", "-" * 50, "\n")
print_image_models()

View File

@ -16,6 +16,7 @@ class BingCreateImages(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://www.bing.com/images/create"
working = True
needs_auth = True
image_models = ["dall-e"]
def __init__(self, cookies: Cookies = None, proxy: str = None) -> None:
self.cookies: Cookies = cookies

View File

@ -11,6 +11,7 @@ class DeepInfraImage(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://deepinfra.com"
working = True
default_model = 'stability-ai/sdxl'
image_models = [default_model]
@classmethod
def get_models(cls):
@ -18,6 +19,7 @@ class DeepInfraImage(AsyncGeneratorProvider, ProviderModelMixin):
url = 'https://api.deepinfra.com/models/featured'
models = requests.get(url).json()
cls.models = [model['model_name'] for model in models if model["reported_type"] == "text-to-image"]
cls.image_models = cls.models
return cls.models
@classmethod

View File

@ -13,7 +13,7 @@ from ..requests import raise_for_status, DEFAULT_HEADERS
from ..image import ImageResponse, ImagePreview
from ..errors import ResponseError
from .base_provider import AsyncGeneratorProvider
from .helper import format_prompt, get_connector
from .helper import format_prompt, get_connector, format_cookies
class Sources():
def __init__(self, link_list: List[Dict[str, str]]) -> None:
@ -48,7 +48,6 @@ class MetaAI(AsyncGeneratorProvider):
async def update_access_token(self, birthday: str = "1999-01-01"):
url = "https://www.meta.ai/api/graphql/"
payload = {
"lsd": self.lsd,
"fb_api_caller_class": "RelayModern",
@ -90,7 +89,7 @@ class MetaAI(AsyncGeneratorProvider):
headers = {}
headers = {
'content-type': 'application/x-www-form-urlencoded',
'cookie': "; ".join([f"{k}={v}" for k, v in cookies.items()]),
'cookie': format_cookies(cookies),
'origin': 'https://www.meta.ai',
'referer': 'https://www.meta.ai/',
'x-asbd-id': '129477',
@ -194,7 +193,7 @@ class MetaAI(AsyncGeneratorProvider):
**headers
}
async with self.session.post(url, headers=headers, cookies=self.cookies, data=payload) as response:
await raise_for_status(response)
await raise_for_status(response, "Fetch sources failed")
text = await response.text()
if "<h1>Something Went Wrong</h1>" in text:
raise ResponseError("Response: Something Went Wrong")

View File

@ -6,6 +6,7 @@ from .MetaAI import MetaAI
class MetaAIAccount(MetaAI):
needs_auth = True
image_models = ["meta"]
@classmethod
async def create_async_generator(

View File

@ -17,6 +17,7 @@ class ReplicateImage(AsyncGeneratorProvider, ProviderModelMixin):
"39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b",
"2b017d9b67edd2ee1401238df49d75da53c523f36e363881e057f5dc3ed3c5b2"
]
image_models = [default_model]
@classmethod
async def create_async_generator(

View File

@ -8,7 +8,7 @@ import uuid
from ..typing import AsyncResult, Messages, ImageType, Cookies
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .helper import format_prompt
from ..image import ImageResponse, to_bytes, is_accepted_format
from ..image import ImageResponse, ImagePreview, to_bytes, is_accepted_format
from ..requests import StreamSession, FormData, raise_for_status
from .you.har_file import get_telemetry_ids
from .. import debug
@ -34,6 +34,7 @@ class You(AsyncGeneratorProvider, ProviderModelMixin):
model_aliases = {
"claude-v2": "claude-2"
}
image_models = ["dall-e"]
_cookies = None
_cookies_used = 0
_telemetry_ids = []
@ -67,7 +68,7 @@ class You(AsyncGeneratorProvider, ProviderModelMixin):
timeout=(30, timeout)
) as session:
cookies = await cls.get_cookies(session) if chat_mode != "default" else None
upload = json.dumps([await cls.upload_file(session, cookies, to_bytes(image), image_name)]) if image else ""
headers = {
"Accept": "text/event-stream",
@ -102,11 +103,17 @@ class You(AsyncGeneratorProvider, ProviderModelMixin):
if event == "youChatToken" and event in data:
yield data[event]
elif event == "youChatUpdate" and "t" in data and data["t"] is not None:
match = re.search(r"!\[fig\]\((.+?)\)", data["t"])
if match:
yield ImageResponse(match.group(1), messages[-1]["content"])
if chat_mode == "create":
match = re.search(r"!\[(.+?)\]\((.+?)\)", data["t"])
if match:
if match.group(1) == "fig":
yield ImagePreview(match.group(2), messages[-1]["content"])
else:
yield ImageResponse(match.group(2), match.group(1))
else:
yield data["t"]
else:
yield data["t"]
yield data["t"]
@classmethod
async def upload_file(cls, client: StreamSession, cookies: Cookies, file: bytes, filename: str = None) -> dict:

View File

@ -41,6 +41,8 @@ async def create_conversation(session: StreamSession, headers: dict, tone: str)
raise RateLimitError("Response 404: Do less requests and reuse conversations")
await raise_for_status(response, "Failed to create conversation")
data = await response.json()
if not data:
raise RuntimeError('Empty response: Failed to create conversation')
conversationId = data.get('conversationId')
clientId = data.get('clientId')
conversationSignature = response.headers.get('X-Sydney-Encryptedconversationsignature')

View File

@ -53,6 +53,7 @@ class Gemini(AsyncGeneratorProvider):
url = "https://gemini.google.com"
needs_auth = True
working = True
image_models = ["gemini"]
@classmethod
async def create_async_generator(

View File

@ -3,4 +3,5 @@ from __future__ import annotations
from .OpenaiChat import OpenaiChat
class OpenaiAccount(OpenaiChat):
needs_auth = True
needs_auth = True
image_models = ["dall-e"]

View File

@ -29,6 +29,7 @@ from ...requests.aiohttp import StreamSession
from ...image import to_image, to_bytes, ImageResponse, ImageRequest
from ...errors import MissingAuthError, ResponseError
from ...providers.conversation import BaseConversation
from ..helper import format_cookies
from ..openai.har_file import getArkoseAndAccessToken, NoValidHarFileError
from ... import debug
@ -44,7 +45,12 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
supports_system_message = True
default_model = None
models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-gizmo"]
model_aliases = {"text-davinci-002-render-sha": "gpt-3.5-turbo", "": "gpt-3.5-turbo", "gpt-4-turbo-preview": "gpt-4"}
model_aliases = {
"text-davinci-002-render-sha": "gpt-3.5-turbo",
"": "gpt-3.5-turbo",
"gpt-4-turbo-preview": "gpt-4",
"dall-e": "gpt-4",
}
_api_key: str = None
_headers: dict = None
_cookies: Cookies = None
@ -364,8 +370,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
arkose_token = None
if cls.default_model is None:
try:
arkose_token, api_key, cookies = await getArkoseAndAccessToken(proxy)
cls._create_request_args(cookies)
arkose_token, api_key, cookies, headers = await getArkoseAndAccessToken(proxy)
cls._create_request_args(cookies, headers)
cls._set_api_key(api_key)
except NoValidHarFileError as e:
...
@ -393,8 +399,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
print(f'Arkose: {need_arkose} Turnstile: {data["turnstile"]["required"]}')
if need_arkose and arkose_token is None:
arkose_token, api_key, cookies = await getArkoseAndAccessToken(proxy)
cls._create_request_args(cookies)
arkose_token, api_key, cookies, headers = await getArkoseAndAccessToken(proxy)
cls._create_request_args(cookies, headers)
cls._set_api_key(api_key)
if arkose_token is None:
raise MissingAuthError("No arkose token found in .har file")
@ -613,7 +619,7 @@ this.fetch = async (url, options) => {
cookies[c.name] = c.value
user_agent = await page.evaluate("window.navigator.userAgent")
await page.close()
cls._create_request_args(cookies, user_agent)
cls._create_request_args(cookies, user_agent=user_agent)
cls._set_api_key(api_key)
@classmethod
@ -667,16 +673,12 @@ this.fetch = async (url, options) => {
"oai-language": "en-US",
}
@staticmethod
def _format_cookies(cookies: Cookies):
return "; ".join(f"{k}={v}" for k, v in cookies.items() if k != "access_token")
@classmethod
def _create_request_args(cls, cookies: Cookies = None, user_agent: str = None):
cls._headers = cls.get_default_headers()
def _create_request_args(cls, cookies: Cookies = None, headers: dict = None, user_agent: str = None):
cls._headers = cls.get_default_headers() if headers is None else headers
if user_agent is not None:
cls._headers["user-agent"] = user_agent
cls._cookies = {} if cookies is None else cookies
cls._cookies = {} if cookies is None else {k: v for k, v in cookies.items() if k != "access_token"}
cls._update_cookie_header()
@classmethod
@ -693,7 +695,7 @@ this.fetch = async (url, options) => {
@classmethod
def _update_cookie_header(cls):
cls._headers["cookie"] = cls._format_cookies(cls._cookies)
cls._headers["cookie"] = format_cookies(cls._cookies)
class Conversation(BaseConversation):
"""

View File

@ -59,17 +59,21 @@ def readHAR():
except KeyError:
continue
cookies = {c['name']: c['value'] for c in v['request']['cookies']}
headers = get_headers(v)
if not accessToken:
raise NoValidHarFileError("No accessToken found in .har files")
if not chatArks:
return None, accessToken, cookies
return chatArks.pop(), accessToken, cookies
return None, accessToken, cookies, headers
return chatArks.pop(), accessToken, cookies, headers
def get_headers(entry) -> dict:
return {h['name'].lower(): h['value'] for h in entry['request']['headers'] if h['name'].lower() not in ['content-length', 'cookie'] and not h['name'].startswith(':')}
def parseHAREntry(entry) -> arkReq:
tmpArk = arkReq(
arkURL=entry['request']['url'],
arkBx="",
arkHeader={h['name'].lower(): h['value'] for h in entry['request']['headers'] if h['name'].lower() not in ['content-length', 'cookie'] and not h['name'].startswith(':')},
arkHeader=get_headers(entry),
arkBody={p['name']: unquote(p['value']) for p in entry['request']['postData']['params'] if p['name'] not in ['rnd']},
arkCookies={c['name']: c['value'] for c in entry['request']['cookies']},
userAgent=""
@ -123,11 +127,11 @@ def getN() -> str:
timestamp = str(int(time.time()))
return base64.b64encode(timestamp.encode()).decode()
async def getArkoseAndAccessToken(proxy: str):
async def getArkoseAndAccessToken(proxy: str) -> tuple[str, str, dict, dict]:
global chatArk, accessToken, cookies
if chatArk is None or accessToken is None:
chatArk, accessToken, cookies = readHAR()
chatArk, accessToken, cookies, headers = readHAR()
if chatArk is None:
return None, accessToken, cookies
return None, accessToken, cookies, headers
newReq = genArkReq(chatArk)
return await sendRequest(newReq, proxy), accessToken, cookies
return await sendRequest(newReq, proxy), accessToken, cookies, headers

View File

@ -4,7 +4,6 @@ import json
import os
import os.path
import random
import requests
from ...requests import StreamSession, raise_for_status
from ...errors import MissingRequirementsError
@ -21,7 +20,8 @@ class arkReq:
self.arkCookies = arkCookies
self.userAgent = userAgent
arkPreURL = "https://telemetry.stytch.com/submit"
telemetry_url = "https://telemetry.stytch.com/submit"
public_token = "public-token-live-507a52ad-7e69-496b-aee0-1c9863c7c819"
chatArks: list = None
def readHAR():
@ -44,7 +44,7 @@ def readHAR():
# Error: not a HAR file!
continue
for v in harFile['log']['entries']:
if arkPreURL in v['request']['url']:
if v['request']['url'] == telemetry_url:
chatArks.append(parseHAREntry(v))
if not chatArks:
raise NoValidHarFileError("No telemetry in .har files found")
@ -62,24 +62,29 @@ def parseHAREntry(entry) -> arkReq:
return tmpArk
async def sendRequest(tmpArk: arkReq, proxy: str = None):
async with StreamSession(headers=tmpArk.arkHeaders, cookies=tmpArk.arkCookies, proxies={"all": proxy}) as session:
async with StreamSession(headers=tmpArk.arkHeaders, cookies=tmpArk.arkCookies, proxy=proxy) as session:
async with session.post(tmpArk.arkURL, data=tmpArk.arkBody) as response:
await raise_for_status(response)
return await response.text()
async def get_dfp_telemetry_id(proxy: str = None):
async def create_telemetry_id(proxy: str = None):
global chatArks
if chatArks is None:
chatArks = readHAR()
return await sendRequest(random.choice(chatArks), proxy)
async def get_telemetry_ids(proxy: str = None) -> list:
try:
return [await create_telemetry_id(proxy)]
except NoValidHarFileError as e:
if debug.logging:
print(e)
if debug.logging:
print('Getting telemetry_id for you.com with nodriver')
try:
from nodriver import start
except ImportError:
raise MissingRequirementsError('Install "nodriver" package | pip install -U nodriver')
raise MissingRequirementsError('Add .har file from you.com or install "nodriver" package | pip install -U nodriver')
try:
browser = await start()
tab = browser.main_tab
@ -89,49 +94,11 @@ async def get_telemetry_ids(proxy: str = None) -> list:
await tab.sleep(1)
async def get_telemetry_id():
public_token = "public-token-live-507a52ad-7e69-496b-aee0-1c9863c7c819"
telemetry_url = "https://telemetry.stytch.com/submit"
return await tab.evaluate(f'this.GetTelemetryID("{public_token}", "{telemetry_url}");', await_promise=True)
return await tab.evaluate(
f'this.GetTelemetryID("{public_token}", "{telemetry_url}");',
await_promise=True
)
# for _ in range(500):
# with open("hardir/you.com_telemetry_ids.txt", "a") as f:
# f.write((await get_telemetry_id()) + "\n")
return [await get_telemetry_id() for _ in range(4)]
return [await get_telemetry_id() for _ in range(1)]
finally:
try:
await tab.close()
except Exception as e:
print(f"Error occurred while closing tab: {str(e)}")
try:
await browser.stop()
except Exception as e:
pass
headers = {
'Accept': '*/*',
'Accept-Language': 'en,fr-FR;q=0.9,fr;q=0.8,es-ES;q=0.7,es;q=0.6,en-US;q=0.5,am;q=0.4,de;q=0.3',
'Connection': 'keep-alive',
'Content-type': 'application/x-www-form-urlencoded',
'Origin': 'https://you.com',
'Referer': 'https://you.com/',
'Sec-Fetch-Dest': 'empty',
'Sec-Fetch-Mode': 'cors',
'Sec-Fetch-Site': 'cross-site',
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36',
'sec-ch-ua': '"Google Chrome";v="123", "Not:A-Brand";v="8", "Chromium";v="123"',
'sec-ch-ua-mobile': '?0',
'sec-ch-ua-platform': '"macOS"',
}
proxies = {
'http': proxy,
'https': proxy} if proxy else None
response = requests.post('https://telemetry.stytch.com/submit',
headers=headers, data=payload, proxies=proxies)
if '-' in response.text:
print(f'telemetry generated: {response.text}')
return (response.text)
await browser.stop()

View File

@ -16,7 +16,8 @@ conversations: dict[dict[str, BaseConversation]] = {}
class Api():
def get_models(self) -> list[str]:
@staticmethod
def get_models() -> list[str]:
"""
Return a list of all models.
@ -27,7 +28,8 @@ class Api():
"""
return models._all_models
def get_provider_models(self, provider: str) -> list[dict]:
@staticmethod
def get_provider_models(provider: str) -> list[dict]:
if provider in __map__:
provider: ProviderType = __map__[provider]
if issubclass(provider, ProviderModelMixin):
@ -40,7 +42,24 @@ class Api():
else:
return [];
def get_providers(self) -> list[str]:
@staticmethod
def get_image_models() -> list[dict]:
image_models = []
for key, provider in __map__.items():
if hasattr(provider, "image_models"):
if hasattr(provider, "get_models"):
provider.get_models()
for model in provider.image_models:
image_models.append({
"provider": key,
"url": provider.url,
"label": provider.label if hasattr(provider, "label") else None,
"image_model": model
})
return image_models
@staticmethod
def get_providers() -> list[str]:
"""
Return a list of all working providers.
"""
@ -58,7 +77,8 @@ class Api():
if provider.working
}
def get_version(self):
@staticmethod
def get_version():
"""
Returns the current and latest version of the application.

View File

@ -31,6 +31,10 @@ class Backend_Api(Api):
'function': self.get_provider_models,
'methods': ['GET']
},
'/backend-api/v2/image_models': {
'function': self.get_image_models,
'methods': ['GET']
},
'/backend-api/v2/providers': {
'function': self.get_providers,
'methods': ['GET']

View File

@ -271,13 +271,13 @@ class AsyncGeneratorProvider(AsyncProvider):
raise NotImplementedError()
class ProviderModelMixin:
default_model: str
default_model: str = None
models: list[str] = []
model_aliases: dict[str, str] = {}
@classmethod
def get_models(cls) -> list[str]:
if not cls.models:
if not cls.models and cls.default_model is not None:
return [cls.default_model]
return cls.models

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import random
import string
from ..typing import Messages
from ..typing import Messages, Cookies
def format_prompt(messages: Messages, add_special_tokens=False) -> str:
"""
@ -56,4 +56,7 @@ def filter_none(**kwargs) -> dict:
key: value
for key, value in kwargs.items()
if value is not None
}
}
def format_cookies(cookies: Cookies) -> str:
return "; ".join([f"{k}={v}" for k, v in cookies.items()])