[Bugfix] Convert image to RGB by default (#6430)

This commit is contained in:
Cyrus Leung 2024-07-15 13:39:15 +08:00 committed by GitHub
parent 69672f116c
commit de19916314
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -35,8 +35,12 @@ def _load_image_from_data_url(image_url: str):
return load_image_from_base64(image_base64) return load_image_from_base64(image_base64)
def fetch_image(image_url: str) -> Image.Image: def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
"""Load PIL image from a url or base64 encoded openai GPT4V format""" """
Load a PIL image from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
if image_url.startswith('http'): if image_url.startswith('http'):
_validate_remote_url(image_url, name="image_url") _validate_remote_url(image_url, name="image_url")
@ -53,7 +57,7 @@ def fetch_image(image_url: str) -> Image.Image:
raise ValueError("Invalid 'image_url': A valid 'image_url' must start " raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'.") "with either 'data:image' or 'http'.")
return image return image.convert(image_mode)
class ImageFetchAiohttp: class ImageFetchAiohttp:
@ -70,8 +74,17 @@ class ImageFetchAiohttp:
return cls.aiohttp_client return cls.aiohttp_client
@classmethod @classmethod
async def fetch_image(cls, image_url: str) -> Image.Image: async def fetch_image(
"""Load PIL image from a url or base64 encoded openai GPT4V format""" cls,
image_url: str,
*,
image_mode: str = "RGB",
) -> Image.Image:
"""
Asynchronously load a PIL image from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
if image_url.startswith('http'): if image_url.startswith('http'):
_validate_remote_url(image_url, name="image_url") _validate_remote_url(image_url, name="image_url")
@ -91,7 +104,7 @@ class ImageFetchAiohttp:
"Invalid 'image_url': A valid 'image_url' must start " "Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'.") "with either 'data:image' or 'http'.")
return image return image.convert(image_mode)
async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict: async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
@ -99,12 +112,19 @@ async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
return {"image": image} return {"image": image}
def encode_image_base64(image: Image.Image, format: str = 'JPEG') -> str: def encode_image_base64(
"""Encode a pillow image to base64 format.""" image: Image.Image,
*,
image_mode: str = "RGB",
format: str = "JPEG",
) -> str:
"""
Encode a pillow image to base64 format.
By default, the image is converted into RGB format before being encoded.
"""
buffered = BytesIO() buffered = BytesIO()
if format == 'JPEG': image = image.convert(image_mode)
image = image.convert('RGB')
image.save(buffered, format) image.save(buffered, format)
return base64.b64encode(buffered.getvalue()).decode('utf-8') return base64.b64encode(buffered.getvalue()).decode('utf-8')