跳到内容

图像偏好

开始使用

部署 Argilla 服务器

如果您已经部署了 Argilla,则可以跳过此步骤。否则,您可以按照本指南快速部署 Argilla。

设置环境

要完成本教程,您需要通过 pip 安装 Argilla SDK 和一些第三方库。

!pip install argilla
!pip install "sentence-transformers~=3.0"

让我们进行所需的导入

import io
import os
import time

import argilla as rg
import requests
from PIL import Image
from datasets import load_dataset, Dataset
from sentence_transformers import SentenceTransformer

您还需要使用 api_urlapi_key 连接到 Argilla 服务器。

# Replace api_url with your url if using Docker
# Replace api_key with your API key under "My Settings" in the UI
# Uncomment the last line and set your HF_TOKEN if your space is private
client = rg.Argilla(
    api_url="https://[your-owner-name]-[your_space_name].hf.space",
    api_key="[your-api-key]",
    # headers={"Authorization": f"Bearer {HF_TOKEN}"}
)

数据集概览

我们将查看数据集,以了解其结构和包含的数据类型。我们可以使用嵌入式 Hugging Face Dataset Viewer 来完成此操作。

配置和创建 Argilla 数据集

现在,我们需要配置数据集。在设置中,我们可以指定指南、字段和问题。我们将包含一个 TextField,一个对应于 url 图像列的 ImageField,以及另外两个 ImageField 字段,表示我们将基于数据集中的 original_caption 列生成的图像。此外,我们将使用一个 LabelQuestion 和一个可选的 TextQuestion,它们将用于收集用户的偏好及其背后的原因。我们还将添加一个 VectorField 来存储 original_caption 的嵌入向量,以便我们可以使用语义搜索并加快我们的标注过程。最后,我们将包含两个 FloatMetadataProperty 来存储来自 toxicityidentity_attack 列的信息。

注意

查看本操作指南以了解更多关于配置和创建数据集的信息。

settings = rg.Settings(
    guidelines="The goal is to choose the image that best represents the caption.",
    fields=[
        rg.TextField(
            name="caption",
            title="An image caption belonging to the original image.",
        ),
        rg.ImageField(
            name="image_original",
            title="The original image, belonging to the caption.",
        ),
        rg.ImageField(
            name="image_1",
            title="An image that has been generated based on the caption.",
        ),
        rg.ImageField(
            name="image_2",
            title="An image that has been generated based on the caption.",
        ),
    ],
    questions=[
        rg.LabelQuestion(
            name="preference",
            title="The chosen preference for the generation.",
            labels=["image_1", "image_2"],
        ),
        rg.TextQuestion(
            name="comment",
            title="Any additional comments.",
            required=False,
        ),
    ],
    metadata=[
        rg.FloatMetadataProperty(name="toxicity", title="Toxicity score"),
        rg.FloatMetadataProperty(name="identity_attack", title="Identity attack score"),

    ],
    vectors=[
        rg.VectorField(name="original_caption_vector", dimensions=384),
    ]
)

让我们使用名称和定义的设置创建数据集

dataset = rg.Dataset(
    name="image_preference_dataset",
    settings=settings,
)
dataset.create()

添加记录

即使我们已经创建了数据集,它仍然缺少要标注的信息(您可以在 UI 中查看)。我们将使用来自 Hugging Face Hubtomg-group-umd/pixelprose 数据集。具体来说,我们将使用 25 个示例。因为我们正在处理可能很大的图像数据集,所以我们将设置 streaming=True 以避免将整个数据集加载到内存中,并迭代数据以延迟加载它。

提示

当使用 Hugging Face 数据集时,您可以设置 Image(decode=False),以便我们可以获取 公共图像 URL,但这取决于数据集。

n_rows = 25

hf_dataset = load_dataset("tomg-group-umd/pixelprose", streaming=True)
dataset_rows = [row for _,row in zip(range(n_rows), hf_dataset["train"])]
hf_dataset = Dataset.from_list(dataset_rows)

hf_dataset
Dataset({
    features: ['uid', 'url', 'key', 'status', 'original_caption', 'vlm_model', 'vlm_caption', 'toxicity', 'severe_toxicity', 'obscene', 'identity_attack', 'insult', 'threat', 'sexual_explicit', 'watermark_class_id', 'watermark_class_score', 'aesthetic_score', 'error_message', 'width', 'height', 'original_width', 'original_height', 'exif', 'sha256', 'image_id', 'author', 'subreddit', 'score'],
    num_rows: 25
})

让我们看一下数据集中的第一个条目。

hf_dataset[0]
{'uid': '0065a9b1cb4da4696f2cd6640e00304257cafd97c0064d4c61e44760bf0fa31c',
 'url': 'https://media.gettyimages.com/photos/plate-of-food-from-murray-bros-caddy-shack-at-the-world-golf-hall-of-picture-id916117812?s=612x612',
 'key': '007740026',
 'status': 'success',
 'original_caption': 'A plate of food from Murray Bros Caddy Shack at the World Golf Hall of Fame',
 'vlm_model': 'gemini-pro-vision',
 'vlm_caption': ' This image displays: A plate of fried calamari with a lemon wedge and a side of green beans, served in a basket with a pink bowl of marinara sauce. The basket is sitting on a table with a checkered tablecloth. In the background is a glass of water and a plate with a burger and fries. The style of the image is a photograph.',
 'toxicity': 0.0005555678508244455,
 'severe_toxicity': 1.7323875454167137e-06,
 'obscene': 3.8304504414554685e-05,
 'identity_attack': 0.00010549413127591833,
 'insult': 0.00014773994917050004,
 'threat': 2.5982120860135183e-05,
 'sexual_explicit': 2.0972733182134107e-05,
 'watermark_class_id': 1.0,
 'watermark_class_score': 0.733799934387207,
 'aesthetic_score': 5.390625,
 'error_message': None,
 'width': 612,
 'height': 408,
 'original_width': 612,
 'original_height': 408,
 'exif': '{"Image ImageDescription": "A plate of food from Murray Bros. Caddy Shack at the World Golf Hall of Fame. (Photo by: Jeffrey Greenberg/Universal Images Group via Getty Images)", "Image XResolution": "300", "Image YResolution": "300"}',
 'sha256': '0065a9b1cb4da4696f2cd6640e00304257cafd97c0064d4c61e44760bf0fa31c',
 'image_id': 'null',
 'author': 'null',
 'subreddit': -1,
 'score': -1}

正如我们所看到的,url 列不包含图像扩展名,因此我们将应用一些额外的过滤以确保我们只有公共图像 URL。

hf_dataset = hf_dataset.filter(
    lambda x: any([x["url"].endswith(extension) for extension in [".jpg", ".png", ".jpeg"]]))

hf_dataset
Dataset({
    features: ['uid', 'url', 'key', 'status', 'original_caption', 'vlm_model', 'vlm_caption', 'toxicity', 'severe_toxicity', 'obscene', 'identity_attack', 'insult', 'threat', 'sexual_explicit', 'watermark_class_id', 'watermark_class_score', 'aesthetic_score', 'error_message', 'width', 'height', 'original_width', 'original_height', 'exif', 'sha256', 'image_id', 'author', 'subreddit', 'score'],
    num_rows: 18
})

生成图像

我们将首先使用最近发布的 black-forest-labs/FLUX.1-schnell 模型,基于 original_caption 列生成图像。为此,我们将使用 Hugging Face 提供的免费但速率受限的 Inference API,但您可以使用 Hub 中的任何其他模型 或方法。我们将为每个示例生成 2 张图像。此外,我们将添加一个小的重试机制来处理速率限制。

让我们首先定义和测试一个生成函数。

API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}

def query(payload):
    response = requests.post(API_URL, headers=headers, json=payload)
    if response.status_code == 200:
        image_bytes = response.content
        image = Image.open(io.BytesIO(image_bytes))
    else:
        print(f"Request failed with status code {response.status_code}. retrying in 10 seconds.")
        time.sleep(10)
        image = query(payload)
    return image

query({
    "inputs": "Astronaut riding a horse"
})
No description has been provided for this image

太棒了!既然我们已经评估了生成函数,让我们为数据集生成 PIL 图像。

def generate_image(row):
    caption = row["original_caption"]
    row["image_1"] = query({"inputs": caption})
    row["image_2"] = query({"inputs": caption + " "}) # space to avoid caching and getting the same image
    return row

hf_dataset_with_images = hf_dataset.map(generate_image, batched=False)

hf_dataset_with_images
Dataset({
    features: ['uid', 'url', 'key', 'status', 'original_caption', 'vlm_model', 'vlm_caption', 'toxicity', 'severe_toxicity', 'obscene', 'identity_attack', 'insult', 'threat', 'sexual_explicit', 'watermark_class_id', 'watermark_class_score', 'aesthetic_score', 'error_message', 'width', 'height', 'original_width', 'original_height', 'exif', 'sha256', 'image_id', 'author', 'subreddit', 'score', 'image_1', 'image_2'],
    num_rows: 18
})

添加向量

我们将使用 sentence-transformers 库为 original_caption 创建向量。我们将使用 TaylorAI/bge-micro-v2 模型,该模型在速度和性能之间取得了良好的平衡。请注意,我们还需要将向量转换为 list 以将其存储在 Argilla 数据集中。

model = SentenceTransformer("TaylorAI/bge-micro-v2")

def encode_questions(batch):
    vectors_as_numpy = model.encode(batch["original_caption"])
    batch["original_caption_vector"] = [x.tolist() for x in vectors_as_numpy]
    return batch

hf_dataset_with_images_vectors = hf_dataset_with_images.map(encode_questions, batched=True)

记录到 Argilla

我们将使用 log 和映射轻松地将它们添加到数据集中,在映射中,我们指示如果名称不对应,则需要将我们数据集中的哪一列映射到哪个 Argilla 资源。我们还将 key 列用作我们记录的 id,以便我们可以轻松地将记录追溯到外部数据源。

dataset.records.log(records=hf_dataset_with_images_vectors, mapping={
    "key": "id",
    "original_caption": "caption",
    "url": "image_original",
})

瞧!我们的 Argilla 数据集已准备好进行标注。

使用 Argilla 评估

现在,我们可以开始标注过程了。只需在 Argilla UI 中打开数据集并开始标注记录。

注意

查看本操作指南以了解更多关于在 UI 中标注的信息。

结论

在本教程中,我们展示了一个图像偏好任务的端到端示例。这可以作为基础,但它可以迭代执行并无缝集成到您的工作流程中,以确保高质量的数据管理和改进的结果。

我们首先配置数据集并添加包含原始图像和生成图像的记录。在标注过程之后,您可以评估结果并可能重新训练模型以提高生成图像的质量。