mirror of https://github.com/xtekky/gpt4free.git
Add image model list
This commit is contained in:
parent
f66cd9f8a3
commit
a26421bcd8
|
@ -14,6 +14,8 @@ async def test_async(provider: ProviderType):
|
|||
return False
|
||||
messages = [{"role": "user", "content": "Hello Assistant!"}]
|
||||
try:
|
||||
if "webdriver" in provider.get_parameters():
|
||||
return False
|
||||
response = await asyncio.wait_for(ChatCompletion.create_async(
|
||||
model=models.default,
|
||||
messages=messages,
|
||||
|
@ -88,7 +90,7 @@ def print_models():
|
|||
"huggingface": "Huggingface",
|
||||
"anthropic": "Anthropic",
|
||||
"inflection": "Inflection",
|
||||
"meta": "Meta"
|
||||
"meta": "Meta",
|
||||
}
|
||||
provider_urls = {
|
||||
"google": "https://gemini.google.com/",
|
||||
|
@ -96,7 +98,7 @@ def print_models():
|
|||
"huggingface": "https://huggingface.co/",
|
||||
"anthropic": "https://www.anthropic.com/",
|
||||
"inflection": "https://inflection.ai/",
|
||||
"meta": "https://llama.meta.com/"
|
||||
"meta": "https://llama.meta.com/",
|
||||
}
|
||||
|
||||
lines = [
|
||||
|
@ -108,6 +110,8 @@ def print_models():
|
|||
if name not in ("gpt-3.5-turbo", "gpt-4", "gpt-4-turbo"):
|
||||
continue
|
||||
name = re.split(r":|/", model.name)[-1]
|
||||
if model.base_provider not in base_provider_names:
|
||||
continue
|
||||
base_provider = base_provider_names[model.base_provider]
|
||||
if not isinstance(model.best_provider, BaseRetryProvider):
|
||||
provider_name = f"g4f.Provider.{model.best_provider.__name__}"
|
||||
|
@ -121,7 +125,24 @@ def print_models():
|
|||
|
||||
print("\n".join(lines))
|
||||
|
||||
def print_image_models():
|
||||
lines = [
|
||||
"| Label | Provider | Model | Website |",
|
||||
"| ----- | -------- | ----- | ------- |",
|
||||
]
|
||||
from g4f.gui.server.api import Api
|
||||
for image_model in Api.get_image_models():
|
||||
provider_url = image_model["url"]
|
||||
netloc = urlparse(provider_url).netloc.replace("www.", "")
|
||||
website = f"[{netloc}]({provider_url})"
|
||||
label = image_model["provider"] if image_model["label"] is None else image_model["label"]
|
||||
lines.append(f'| {label} | {image_model["provider"]} | {image_model["image_model"]} | {website} |')
|
||||
|
||||
print("\n".join(lines))
|
||||
|
||||
if __name__ == "__main__":
|
||||
print_providers()
|
||||
print("\n", "-" * 50, "\n")
|
||||
print_models()
|
||||
print_models()
|
||||
print("\n", "-" * 50, "\n")
|
||||
print_image_models()
|
|
@ -16,6 +16,7 @@ class BingCreateImages(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
url = "https://www.bing.com/images/create"
|
||||
working = True
|
||||
needs_auth = True
|
||||
image_models = ["dall-e"]
|
||||
|
||||
def __init__(self, cookies: Cookies = None, proxy: str = None) -> None:
|
||||
self.cookies: Cookies = cookies
|
||||
|
|
|
@ -11,6 +11,7 @@ class DeepInfraImage(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
url = "https://deepinfra.com"
|
||||
working = True
|
||||
default_model = 'stability-ai/sdxl'
|
||||
image_models = [default_model]
|
||||
|
||||
@classmethod
|
||||
def get_models(cls):
|
||||
|
@ -18,6 +19,7 @@ class DeepInfraImage(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
url = 'https://api.deepinfra.com/models/featured'
|
||||
models = requests.get(url).json()
|
||||
cls.models = [model['model_name'] for model in models if model["reported_type"] == "text-to-image"]
|
||||
cls.image_models = cls.models
|
||||
return cls.models
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -13,7 +13,7 @@ from ..requests import raise_for_status, DEFAULT_HEADERS
|
|||
from ..image import ImageResponse, ImagePreview
|
||||
from ..errors import ResponseError
|
||||
from .base_provider import AsyncGeneratorProvider
|
||||
from .helper import format_prompt, get_connector
|
||||
from .helper import format_prompt, get_connector, format_cookies
|
||||
|
||||
class Sources():
|
||||
def __init__(self, link_list: List[Dict[str, str]]) -> None:
|
||||
|
@ -48,7 +48,6 @@ class MetaAI(AsyncGeneratorProvider):
|
|||
|
||||
async def update_access_token(self, birthday: str = "1999-01-01"):
|
||||
url = "https://www.meta.ai/api/graphql/"
|
||||
|
||||
payload = {
|
||||
"lsd": self.lsd,
|
||||
"fb_api_caller_class": "RelayModern",
|
||||
|
@ -90,7 +89,7 @@ class MetaAI(AsyncGeneratorProvider):
|
|||
headers = {}
|
||||
headers = {
|
||||
'content-type': 'application/x-www-form-urlencoded',
|
||||
'cookie': "; ".join([f"{k}={v}" for k, v in cookies.items()]),
|
||||
'cookie': format_cookies(cookies),
|
||||
'origin': 'https://www.meta.ai',
|
||||
'referer': 'https://www.meta.ai/',
|
||||
'x-asbd-id': '129477',
|
||||
|
@ -194,7 +193,7 @@ class MetaAI(AsyncGeneratorProvider):
|
|||
**headers
|
||||
}
|
||||
async with self.session.post(url, headers=headers, cookies=self.cookies, data=payload) as response:
|
||||
await raise_for_status(response)
|
||||
await raise_for_status(response, "Fetch sources failed")
|
||||
text = await response.text()
|
||||
if "<h1>Something Went Wrong</h1>" in text:
|
||||
raise ResponseError("Response: Something Went Wrong")
|
||||
|
|
|
@ -6,6 +6,7 @@ from .MetaAI import MetaAI
|
|||
|
||||
class MetaAIAccount(MetaAI):
|
||||
needs_auth = True
|
||||
image_models = ["meta"]
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
|
|
|
@ -17,6 +17,7 @@ class ReplicateImage(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
"39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b",
|
||||
"2b017d9b67edd2ee1401238df49d75da53c523f36e363881e057f5dc3ed3c5b2"
|
||||
]
|
||||
image_models = [default_model]
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
|
|
|
@ -8,7 +8,7 @@ import uuid
|
|||
from ..typing import AsyncResult, Messages, ImageType, Cookies
|
||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from .helper import format_prompt
|
||||
from ..image import ImageResponse, to_bytes, is_accepted_format
|
||||
from ..image import ImageResponse, ImagePreview, to_bytes, is_accepted_format
|
||||
from ..requests import StreamSession, FormData, raise_for_status
|
||||
from .you.har_file import get_telemetry_ids
|
||||
from .. import debug
|
||||
|
@ -34,6 +34,7 @@ class You(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
model_aliases = {
|
||||
"claude-v2": "claude-2"
|
||||
}
|
||||
image_models = ["dall-e"]
|
||||
_cookies = None
|
||||
_cookies_used = 0
|
||||
_telemetry_ids = []
|
||||
|
@ -67,7 +68,7 @@ class You(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
timeout=(30, timeout)
|
||||
) as session:
|
||||
cookies = await cls.get_cookies(session) if chat_mode != "default" else None
|
||||
|
||||
|
||||
upload = json.dumps([await cls.upload_file(session, cookies, to_bytes(image), image_name)]) if image else ""
|
||||
headers = {
|
||||
"Accept": "text/event-stream",
|
||||
|
@ -102,11 +103,17 @@ class You(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
if event == "youChatToken" and event in data:
|
||||
yield data[event]
|
||||
elif event == "youChatUpdate" and "t" in data and data["t"] is not None:
|
||||
match = re.search(r"!\[fig\]\((.+?)\)", data["t"])
|
||||
if match:
|
||||
yield ImageResponse(match.group(1), messages[-1]["content"])
|
||||
if chat_mode == "create":
|
||||
match = re.search(r"!\[(.+?)\]\((.+?)\)", data["t"])
|
||||
if match:
|
||||
if match.group(1) == "fig":
|
||||
yield ImagePreview(match.group(2), messages[-1]["content"])
|
||||
else:
|
||||
yield ImageResponse(match.group(2), match.group(1))
|
||||
else:
|
||||
yield data["t"]
|
||||
else:
|
||||
yield data["t"]
|
||||
yield data["t"]
|
||||
|
||||
@classmethod
|
||||
async def upload_file(cls, client: StreamSession, cookies: Cookies, file: bytes, filename: str = None) -> dict:
|
||||
|
|
|
@ -41,6 +41,8 @@ async def create_conversation(session: StreamSession, headers: dict, tone: str)
|
|||
raise RateLimitError("Response 404: Do less requests and reuse conversations")
|
||||
await raise_for_status(response, "Failed to create conversation")
|
||||
data = await response.json()
|
||||
if not data:
|
||||
raise RuntimeError('Empty response: Failed to create conversation')
|
||||
conversationId = data.get('conversationId')
|
||||
clientId = data.get('clientId')
|
||||
conversationSignature = response.headers.get('X-Sydney-Encryptedconversationsignature')
|
||||
|
|
|
@ -53,6 +53,7 @@ class Gemini(AsyncGeneratorProvider):
|
|||
url = "https://gemini.google.com"
|
||||
needs_auth = True
|
||||
working = True
|
||||
image_models = ["gemini"]
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
|
|
|
@ -3,4 +3,5 @@ from __future__ import annotations
|
|||
from .OpenaiChat import OpenaiChat
|
||||
|
||||
class OpenaiAccount(OpenaiChat):
|
||||
needs_auth = True
|
||||
needs_auth = True
|
||||
image_models = ["dall-e"]
|
|
@ -29,6 +29,7 @@ from ...requests.aiohttp import StreamSession
|
|||
from ...image import to_image, to_bytes, ImageResponse, ImageRequest
|
||||
from ...errors import MissingAuthError, ResponseError
|
||||
from ...providers.conversation import BaseConversation
|
||||
from ..helper import format_cookies
|
||||
from ..openai.har_file import getArkoseAndAccessToken, NoValidHarFileError
|
||||
from ... import debug
|
||||
|
||||
|
@ -44,7 +45,12 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
supports_system_message = True
|
||||
default_model = None
|
||||
models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-gizmo"]
|
||||
model_aliases = {"text-davinci-002-render-sha": "gpt-3.5-turbo", "": "gpt-3.5-turbo", "gpt-4-turbo-preview": "gpt-4"}
|
||||
model_aliases = {
|
||||
"text-davinci-002-render-sha": "gpt-3.5-turbo",
|
||||
"": "gpt-3.5-turbo",
|
||||
"gpt-4-turbo-preview": "gpt-4",
|
||||
"dall-e": "gpt-4",
|
||||
}
|
||||
_api_key: str = None
|
||||
_headers: dict = None
|
||||
_cookies: Cookies = None
|
||||
|
@ -364,8 +370,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
arkose_token = None
|
||||
if cls.default_model is None:
|
||||
try:
|
||||
arkose_token, api_key, cookies = await getArkoseAndAccessToken(proxy)
|
||||
cls._create_request_args(cookies)
|
||||
arkose_token, api_key, cookies, headers = await getArkoseAndAccessToken(proxy)
|
||||
cls._create_request_args(cookies, headers)
|
||||
cls._set_api_key(api_key)
|
||||
except NoValidHarFileError as e:
|
||||
...
|
||||
|
@ -393,8 +399,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
print(f'Arkose: {need_arkose} Turnstile: {data["turnstile"]["required"]}')
|
||||
|
||||
if need_arkose and arkose_token is None:
|
||||
arkose_token, api_key, cookies = await getArkoseAndAccessToken(proxy)
|
||||
cls._create_request_args(cookies)
|
||||
arkose_token, api_key, cookies, headers = await getArkoseAndAccessToken(proxy)
|
||||
cls._create_request_args(cookies, headers)
|
||||
cls._set_api_key(api_key)
|
||||
if arkose_token is None:
|
||||
raise MissingAuthError("No arkose token found in .har file")
|
||||
|
@ -613,7 +619,7 @@ this.fetch = async (url, options) => {
|
|||
cookies[c.name] = c.value
|
||||
user_agent = await page.evaluate("window.navigator.userAgent")
|
||||
await page.close()
|
||||
cls._create_request_args(cookies, user_agent)
|
||||
cls._create_request_args(cookies, user_agent=user_agent)
|
||||
cls._set_api_key(api_key)
|
||||
|
||||
@classmethod
|
||||
|
@ -667,16 +673,12 @@ this.fetch = async (url, options) => {
|
|||
"oai-language": "en-US",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _format_cookies(cookies: Cookies):
|
||||
return "; ".join(f"{k}={v}" for k, v in cookies.items() if k != "access_token")
|
||||
|
||||
@classmethod
|
||||
def _create_request_args(cls, cookies: Cookies = None, user_agent: str = None):
|
||||
cls._headers = cls.get_default_headers()
|
||||
def _create_request_args(cls, cookies: Cookies = None, headers: dict = None, user_agent: str = None):
|
||||
cls._headers = cls.get_default_headers() if headers is None else headers
|
||||
if user_agent is not None:
|
||||
cls._headers["user-agent"] = user_agent
|
||||
cls._cookies = {} if cookies is None else cookies
|
||||
cls._cookies = {} if cookies is None else {k: v for k, v in cookies.items() if k != "access_token"}
|
||||
cls._update_cookie_header()
|
||||
|
||||
@classmethod
|
||||
|
@ -693,7 +695,7 @@ this.fetch = async (url, options) => {
|
|||
|
||||
@classmethod
|
||||
def _update_cookie_header(cls):
|
||||
cls._headers["cookie"] = cls._format_cookies(cls._cookies)
|
||||
cls._headers["cookie"] = format_cookies(cls._cookies)
|
||||
|
||||
class Conversation(BaseConversation):
|
||||
"""
|
||||
|
|
|
@ -59,17 +59,21 @@ def readHAR():
|
|||
except KeyError:
|
||||
continue
|
||||
cookies = {c['name']: c['value'] for c in v['request']['cookies']}
|
||||
headers = get_headers(v)
|
||||
if not accessToken:
|
||||
raise NoValidHarFileError("No accessToken found in .har files")
|
||||
if not chatArks:
|
||||
return None, accessToken, cookies
|
||||
return chatArks.pop(), accessToken, cookies
|
||||
return None, accessToken, cookies, headers
|
||||
return chatArks.pop(), accessToken, cookies, headers
|
||||
|
||||
def get_headers(entry) -> dict:
|
||||
return {h['name'].lower(): h['value'] for h in entry['request']['headers'] if h['name'].lower() not in ['content-length', 'cookie'] and not h['name'].startswith(':')}
|
||||
|
||||
def parseHAREntry(entry) -> arkReq:
|
||||
tmpArk = arkReq(
|
||||
arkURL=entry['request']['url'],
|
||||
arkBx="",
|
||||
arkHeader={h['name'].lower(): h['value'] for h in entry['request']['headers'] if h['name'].lower() not in ['content-length', 'cookie'] and not h['name'].startswith(':')},
|
||||
arkHeader=get_headers(entry),
|
||||
arkBody={p['name']: unquote(p['value']) for p in entry['request']['postData']['params'] if p['name'] not in ['rnd']},
|
||||
arkCookies={c['name']: c['value'] for c in entry['request']['cookies']},
|
||||
userAgent=""
|
||||
|
@ -123,11 +127,11 @@ def getN() -> str:
|
|||
timestamp = str(int(time.time()))
|
||||
return base64.b64encode(timestamp.encode()).decode()
|
||||
|
||||
async def getArkoseAndAccessToken(proxy: str):
|
||||
async def getArkoseAndAccessToken(proxy: str) -> tuple[str, str, dict, dict]:
|
||||
global chatArk, accessToken, cookies
|
||||
if chatArk is None or accessToken is None:
|
||||
chatArk, accessToken, cookies = readHAR()
|
||||
chatArk, accessToken, cookies, headers = readHAR()
|
||||
if chatArk is None:
|
||||
return None, accessToken, cookies
|
||||
return None, accessToken, cookies, headers
|
||||
newReq = genArkReq(chatArk)
|
||||
return await sendRequest(newReq, proxy), accessToken, cookies
|
||||
return await sendRequest(newReq, proxy), accessToken, cookies, headers
|
||||
|
|
|
@ -4,7 +4,6 @@ import json
|
|||
import os
|
||||
import os.path
|
||||
import random
|
||||
import requests
|
||||
|
||||
from ...requests import StreamSession, raise_for_status
|
||||
from ...errors import MissingRequirementsError
|
||||
|
@ -21,7 +20,8 @@ class arkReq:
|
|||
self.arkCookies = arkCookies
|
||||
self.userAgent = userAgent
|
||||
|
||||
arkPreURL = "https://telemetry.stytch.com/submit"
|
||||
telemetry_url = "https://telemetry.stytch.com/submit"
|
||||
public_token = "public-token-live-507a52ad-7e69-496b-aee0-1c9863c7c819"
|
||||
chatArks: list = None
|
||||
|
||||
def readHAR():
|
||||
|
@ -44,7 +44,7 @@ def readHAR():
|
|||
# Error: not a HAR file!
|
||||
continue
|
||||
for v in harFile['log']['entries']:
|
||||
if arkPreURL in v['request']['url']:
|
||||
if v['request']['url'] == telemetry_url:
|
||||
chatArks.append(parseHAREntry(v))
|
||||
if not chatArks:
|
||||
raise NoValidHarFileError("No telemetry in .har files found")
|
||||
|
@ -62,24 +62,29 @@ def parseHAREntry(entry) -> arkReq:
|
|||
return tmpArk
|
||||
|
||||
async def sendRequest(tmpArk: arkReq, proxy: str = None):
|
||||
async with StreamSession(headers=tmpArk.arkHeaders, cookies=tmpArk.arkCookies, proxies={"all": proxy}) as session:
|
||||
async with StreamSession(headers=tmpArk.arkHeaders, cookies=tmpArk.arkCookies, proxy=proxy) as session:
|
||||
async with session.post(tmpArk.arkURL, data=tmpArk.arkBody) as response:
|
||||
await raise_for_status(response)
|
||||
return await response.text()
|
||||
|
||||
async def get_dfp_telemetry_id(proxy: str = None):
|
||||
async def create_telemetry_id(proxy: str = None):
|
||||
global chatArks
|
||||
if chatArks is None:
|
||||
chatArks = readHAR()
|
||||
return await sendRequest(random.choice(chatArks), proxy)
|
||||
|
||||
async def get_telemetry_ids(proxy: str = None) -> list:
|
||||
try:
|
||||
return [await create_telemetry_id(proxy)]
|
||||
except NoValidHarFileError as e:
|
||||
if debug.logging:
|
||||
print(e)
|
||||
if debug.logging:
|
||||
print('Getting telemetry_id for you.com with nodriver')
|
||||
try:
|
||||
from nodriver import start
|
||||
except ImportError:
|
||||
raise MissingRequirementsError('Install "nodriver" package | pip install -U nodriver')
|
||||
raise MissingRequirementsError('Add .har file from you.com or install "nodriver" package | pip install -U nodriver')
|
||||
try:
|
||||
browser = await start()
|
||||
tab = browser.main_tab
|
||||
|
@ -89,49 +94,11 @@ async def get_telemetry_ids(proxy: str = None) -> list:
|
|||
await tab.sleep(1)
|
||||
|
||||
async def get_telemetry_id():
|
||||
public_token = "public-token-live-507a52ad-7e69-496b-aee0-1c9863c7c819"
|
||||
telemetry_url = "https://telemetry.stytch.com/submit"
|
||||
return await tab.evaluate(f'this.GetTelemetryID("{public_token}", "{telemetry_url}");', await_promise=True)
|
||||
return await tab.evaluate(
|
||||
f'this.GetTelemetryID("{public_token}", "{telemetry_url}");',
|
||||
await_promise=True
|
||||
)
|
||||
|
||||
# for _ in range(500):
|
||||
# with open("hardir/you.com_telemetry_ids.txt", "a") as f:
|
||||
# f.write((await get_telemetry_id()) + "\n")
|
||||
|
||||
return [await get_telemetry_id() for _ in range(4)]
|
||||
return [await get_telemetry_id() for _ in range(1)]
|
||||
finally:
|
||||
try:
|
||||
await tab.close()
|
||||
except Exception as e:
|
||||
print(f"Error occurred while closing tab: {str(e)}")
|
||||
try:
|
||||
await browser.stop()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
headers = {
|
||||
'Accept': '*/*',
|
||||
'Accept-Language': 'en,fr-FR;q=0.9,fr;q=0.8,es-ES;q=0.7,es;q=0.6,en-US;q=0.5,am;q=0.4,de;q=0.3',
|
||||
'Connection': 'keep-alive',
|
||||
'Content-type': 'application/x-www-form-urlencoded',
|
||||
'Origin': 'https://you.com',
|
||||
'Referer': 'https://you.com/',
|
||||
'Sec-Fetch-Dest': 'empty',
|
||||
'Sec-Fetch-Mode': 'cors',
|
||||
'Sec-Fetch-Site': 'cross-site',
|
||||
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36',
|
||||
'sec-ch-ua': '"Google Chrome";v="123", "Not:A-Brand";v="8", "Chromium";v="123"',
|
||||
'sec-ch-ua-mobile': '?0',
|
||||
'sec-ch-ua-platform': '"macOS"',
|
||||
}
|
||||
|
||||
proxies = {
|
||||
'http': proxy,
|
||||
'https': proxy} if proxy else None
|
||||
|
||||
response = requests.post('https://telemetry.stytch.com/submit',
|
||||
headers=headers, data=payload, proxies=proxies)
|
||||
|
||||
if '-' in response.text:
|
||||
print(f'telemetry generated: {response.text}')
|
||||
|
||||
return (response.text)
|
||||
await browser.stop()
|
|
@ -16,7 +16,8 @@ conversations: dict[dict[str, BaseConversation]] = {}
|
|||
|
||||
class Api():
|
||||
|
||||
def get_models(self) -> list[str]:
|
||||
@staticmethod
|
||||
def get_models() -> list[str]:
|
||||
"""
|
||||
Return a list of all models.
|
||||
|
||||
|
@ -27,7 +28,8 @@ class Api():
|
|||
"""
|
||||
return models._all_models
|
||||
|
||||
def get_provider_models(self, provider: str) -> list[dict]:
|
||||
@staticmethod
|
||||
def get_provider_models(provider: str) -> list[dict]:
|
||||
if provider in __map__:
|
||||
provider: ProviderType = __map__[provider]
|
||||
if issubclass(provider, ProviderModelMixin):
|
||||
|
@ -40,7 +42,24 @@ class Api():
|
|||
else:
|
||||
return [];
|
||||
|
||||
def get_providers(self) -> list[str]:
|
||||
@staticmethod
|
||||
def get_image_models() -> list[dict]:
|
||||
image_models = []
|
||||
for key, provider in __map__.items():
|
||||
if hasattr(provider, "image_models"):
|
||||
if hasattr(provider, "get_models"):
|
||||
provider.get_models()
|
||||
for model in provider.image_models:
|
||||
image_models.append({
|
||||
"provider": key,
|
||||
"url": provider.url,
|
||||
"label": provider.label if hasattr(provider, "label") else None,
|
||||
"image_model": model
|
||||
})
|
||||
return image_models
|
||||
|
||||
@staticmethod
|
||||
def get_providers() -> list[str]:
|
||||
"""
|
||||
Return a list of all working providers.
|
||||
"""
|
||||
|
@ -58,7 +77,8 @@ class Api():
|
|||
if provider.working
|
||||
}
|
||||
|
||||
def get_version(self):
|
||||
@staticmethod
|
||||
def get_version():
|
||||
"""
|
||||
Returns the current and latest version of the application.
|
||||
|
||||
|
|
|
@ -31,6 +31,10 @@ class Backend_Api(Api):
|
|||
'function': self.get_provider_models,
|
||||
'methods': ['GET']
|
||||
},
|
||||
'/backend-api/v2/image_models': {
|
||||
'function': self.get_image_models,
|
||||
'methods': ['GET']
|
||||
},
|
||||
'/backend-api/v2/providers': {
|
||||
'function': self.get_providers,
|
||||
'methods': ['GET']
|
||||
|
|
|
@ -271,13 +271,13 @@ class AsyncGeneratorProvider(AsyncProvider):
|
|||
raise NotImplementedError()
|
||||
|
||||
class ProviderModelMixin:
|
||||
default_model: str
|
||||
default_model: str = None
|
||||
models: list[str] = []
|
||||
model_aliases: dict[str, str] = {}
|
||||
|
||||
@classmethod
|
||||
def get_models(cls) -> list[str]:
|
||||
if not cls.models:
|
||||
if not cls.models and cls.default_model is not None:
|
||||
return [cls.default_model]
|
||||
return cls.models
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
import random
|
||||
import string
|
||||
|
||||
from ..typing import Messages
|
||||
from ..typing import Messages, Cookies
|
||||
|
||||
def format_prompt(messages: Messages, add_special_tokens=False) -> str:
|
||||
"""
|
||||
|
@ -56,4 +56,7 @@ def filter_none(**kwargs) -> dict:
|
|||
key: value
|
||||
for key, value in kwargs.items()
|
||||
if value is not None
|
||||
}
|
||||
}
|
||||
|
||||
def format_cookies(cookies: Cookies) -> str:
|
||||
return "; ".join([f"{k}={v}" for k, v in cookies.items()])
|
Loading…
Reference in New Issue