Improve code style

This commit is contained in:
Heiner Lohaus 2024-01-04 00:38:31 +01:00
parent 25895eb637
commit 139f68af4f
5 changed files with 80 additions and 99 deletions

View File

@ -35,9 +35,9 @@ async def delete_conversation(session: ClientSession, conversation: Conversation
"source": "cib",
"optionsSets": ["autosave"]
}
async with session.post(url, json=json, proxy=proxy) as response:
try:
try:
async with session.post(url, json=json, proxy=proxy) as response:
response = await response.json()
return response["result"]["value"] == "Success"
except:
return False
except:
return False

View File

@ -1,5 +1,3 @@
import asyncio
import time, json, os
from aiohttp import ClientSession
@ -9,10 +7,12 @@ from typing import Generator
from ...webdriver import WebDriver, get_driver_cookies, get_browser
from ...Provider.helper import get_event_loop
from ...base_provider import ProviderType
from ...Provider.create_images import CreateImagesProvider
BING_URL = "https://www.bing.com"
def wait_for_login(driver: WebDriver, timeout: int = 1200):
def wait_for_login(driver: WebDriver, timeout: int = 1200) -> Generator:
driver.get(f"{BING_URL}/")
value = driver.get_cookie("_U")
if value:
@ -29,7 +29,7 @@ def wait_for_login(driver: WebDriver, timeout: int = 1200):
return
time.sleep(0.1)
def create_session(cookies: dict):
def create_session(cookies: dict) -> ClientSession:
headers = {
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
"accept-encoding": "gzip, deflate, br",
@ -51,7 +51,7 @@ def create_session(cookies: dict):
headers["cookie"] = "; ".join(f"{k}={v}" for k, v in cookies.items())
return ClientSession(headers=headers)
async def create_images(session: ClientSession, prompt: str, proxy: str = None, timeout: int = 200):
async def create_images(session: ClientSession, prompt: str, proxy: str = None, timeout: int = 200) -> list:
url_encoded_prompt = quote(prompt)
payload = f"q={url_encoded_prompt}&rt=4&FORM=GENCRE"
url = f"{BING_URL}/images/create?q={url_encoded_prompt}&rt=4&FORM=GENCRE"
@ -111,7 +111,10 @@ async def create_images(session: ClientSession, prompt: str, proxy: str = None,
def format_images_markdown(images: list, prompt: str) -> str:
images = [f"[![#{idx+1} {prompt}]({image}?w=200&h=200)]({image})" for idx, image in enumerate(images)]
return f"\n\n<img data-prompt=\"{prompt}\">\n<!-- generated images start -->\n" + ("\n".join(images)) + "\n<!-- generated images end -->\n\n"
images = "\n".join(images)
start_flag = "<!-- generated images start -->\n"
end_flag = "<!-- generated images end -->\n"
return f"\n\n<img data-prompt=\"{prompt}\">\n{start_flag}{images}\n{end_flag}\n"
def get_images(text: str) -> list:
html_soup = BeautifulSoup(text, "html.parser")
@ -143,4 +146,7 @@ def create_completion(prompt: str, proxy: str = None) -> Generator:
images = loop.run_until_complete(run_session())
yield format_images_markdown(images, prompt)
finally:
driver.quit()
driver.quit()
def patch_provider(provider: ProviderType) -> CreateImagesProvider:
return CreateImagesProvider(provider, create_completion)

View File

@ -66,7 +66,7 @@ async def upload_image(
)
return result
except Exception as e:
raise RuntimeError(f"Add image failed: {e}")
raise RuntimeError(f"Upload image failed: {e}")
def build_image_upload_api_payload(image_bin: str, tone: str):
@ -101,82 +101,62 @@ def build_image_upload_api_payload(image_bin: str, tone: str):
return data, boundary
def is_data_uri_an_image(data_uri: str):
try:
# Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif)
if not re.match(r'data:image/(\w+);base64,', data_uri):
raise ValueError("Invalid data URI image.")
# Extract the image format from the data URI
image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1)
# Check if the image format is one of the allowed formats (jpg, jpeg, png, gif)
if image_format.lower() not in ['jpeg', 'jpg', 'png', 'gif']:
raise ValueError("Invalid image format (from mime file type).")
except Exception as e:
raise e
# Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif)
if not re.match(r'data:image/(\w+);base64,', data_uri):
raise ValueError("Invalid data URI image.")
# Extract the image format from the data URI
image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1)
# Check if the image format is one of the allowed formats (jpg, jpeg, png, gif)
if image_format.lower() not in ['jpeg', 'jpg', 'png', 'gif']:
raise ValueError("Invalid image format (from mime file type).")
def is_accepted_format(binary_data: bytes) -> bool:
try:
check = False
if binary_data.startswith(b'\xFF\xD8\xFF'):
check = True # It's a JPEG image
elif binary_data.startswith(b'\x89PNG\r\n\x1a\n'):
check = True # It's a PNG image
elif binary_data.startswith(b'GIF87a') or binary_data.startswith(b'GIF89a'):
check = True # It's a GIF image
elif binary_data.startswith(b'\x89JFIF') or binary_data.startswith(b'JFIF\x00'):
check = True # It's a JPEG image
elif binary_data.startswith(b'\xFF\xD8'):
check = True # It's a JPEG image
elif binary_data.startswith(b'RIFF') and binary_data[8:12] == b'WEBP':
check = True # It's a WebP image
# else we raise ValueError
if not check:
raise ValueError("Invalid image format (from magic code).")
except Exception as e:
raise e
if binary_data.startswith(b'\xFF\xD8\xFF'):
pass # It's a JPEG image
elif binary_data.startswith(b'\x89PNG\r\n\x1a\n'):
pass # It's a PNG image
elif binary_data.startswith(b'GIF87a') or binary_data.startswith(b'GIF89a'):
pass # It's a GIF image
elif binary_data.startswith(b'\x89JFIF') or binary_data.startswith(b'JFIF\x00'):
pass # It's a JPEG image
elif binary_data.startswith(b'\xFF\xD8'):
pass # It's a JPEG image
elif binary_data.startswith(b'RIFF') and binary_data[8:12] == b'WEBP':
pass # It's a WebP image
else:
raise ValueError("Invalid image format (from magic code).")
def extract_data_uri(data_uri: str) -> bytes:
try:
data = data_uri.split(",")[1]
data = base64.b64decode(data)
return data
except Exception as e:
raise e
data = data_uri.split(",")[1]
data = base64.b64decode(data)
return data
def get_orientation(data: bytes) -> int:
try:
if data[:2] != b'\xFF\xD8':
raise Exception('NotJpeg')
with Image.open(data) as img:
exif_data = img._getexif()
if exif_data is not None:
orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF
if orientation is not None:
return orientation
except Exception:
pass
if data[:2] != b'\xFF\xD8':
raise Exception('NotJpeg')
with Image.open(data) as img:
exif_data = img._getexif()
if exif_data is not None:
orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF
if orientation is not None:
return orientation
def process_image(orientation: int, img: Image.Image, new_width: int, new_height: int) -> Image.Image:
try:
# Initialize the canvas
new_img = Image.new("RGB", (new_width, new_height), color="#FFFFFF")
if orientation:
if orientation > 4:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if orientation in [3, 4]:
img = img.transpose(Image.ROTATE_180)
if orientation in [5, 6]:
img = img.transpose(Image.ROTATE_270)
if orientation in [7, 8]:
img = img.transpose(Image.ROTATE_90)
new_img.paste(img, (0, 0))
return new_img
except Exception as e:
raise e
# Initialize the canvas
new_img = Image.new("RGB", (new_width, new_height), color="#FFFFFF")
if orientation:
if orientation > 4:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if orientation in [3, 4]:
img = img.transpose(Image.ROTATE_180)
if orientation in [5, 6]:
img = img.transpose(Image.ROTATE_270)
if orientation in [7, 8]:
img = img.transpose(Image.ROTATE_90)
new_img.paste(img, (0, 0))
return new_img
def compress_image_to_base64(img, compression_rate) -> str:
try:
output_buffer = io.BytesIO()
img.save(output_buffer, format="JPEG", quality=int(compression_rate * 100))
return base64.b64encode(output_buffer.getvalue()).decode('utf-8')
except Exception as e:
raise e
def compress_image_to_base64(image: Image.Image, compression_rate: float) -> str:
output_buffer = io.BytesIO()
image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100))
return base64.b64encode(output_buffer.getvalue()).decode('utf-8')

View File

@ -1,18 +1,13 @@
import logging
import g4f
from g4f.Provider import __providers__
import json
from flask import request, Flask
from .internet import get_search_message
from g4f import debug, version
from g4f.base_provider import ProviderType
debug.logging = True
from flask import request, Flask
from g4f import debug, version, models
from g4f import _all_models, get_last_provider, ChatCompletion
from g4f.Provider import __providers__
from g4f.Provider.bing.create_images import patch_provider
from .internet import get_search_message
def patch_provider(provider: ProviderType):
from g4f.Provider import CreateImagesProvider
from g4f.Provider.bing.create_images import create_completion
return CreateImagesProvider(provider, create_completion)
debug.logging = True
class Backend_Api:
def __init__(self, app: Flask) -> None:
@ -50,7 +45,7 @@ class Backend_Api:
return 'ok', 200
def models(self):
return g4f._all_models
return _all_models
def providers(self):
return [
@ -74,7 +69,7 @@ class Backend_Api:
if request.json.get('internet_access'):
messages[-1]["content"] = get_search_message(messages[-1]["content"])
model = request.json.get('model')
model = model if model else g4f.models.default
model = model if model else models.default
provider = request.json.get('provider', '').replace('g4f.Provider.', '')
provider = provider if provider and provider != "Auto" else None
patch = patch_provider if request.json.get('patch_provider') else None
@ -82,7 +77,7 @@ class Backend_Api:
def try_response():
try:
first = True
for chunk in g4f.ChatCompletion.create(
for chunk in ChatCompletion.create(
model=model,
provider=provider,
messages=messages,
@ -94,7 +89,7 @@ class Backend_Api:
first = False
yield json.dumps({
'type' : 'provider',
'provider': g4f.get_last_provider(True)
'provider': get_last_provider(True)
}) + "\n"
yield json.dumps({
'type' : 'content',

View File

@ -145,5 +145,5 @@ User request:
"""
return message
except Exception as e:
print("Couldn't search DuckDuckGo:", e)
print("Couldn't do web search:", e)
return prompt