图像偏好¶
- 目标:展示一个使用复杂多模态偏好数据集的标准工作流程,例如用于图像生成偏好。
- 数据集:tomg-group-umd/pixelprose,是一个包含超过 1600 万(million)条合成生成标题的综合数据集,利用先进的视觉语言模型(Gemini 1.0 Pro Vision)进行详细和准确的描述。
- 库:datasets, sentence-transformers
- 组件:TextField, ImageField, TextQuestion, LabelQuestion VectorField, FloatMetadataProperty
开始使用¶
部署 Argilla 服务器¶
如果您已经部署了 Argilla,则可以跳过此步骤。否则,您可以按照本指南快速部署 Argilla。
设置环境¶
要完成本教程,您需要通过 pip
安装 Argilla SDK 和一些第三方库。
让我们进行所需的导入
您还需要使用 api_url
和 api_key
连接到 Argilla 服务器。
# Replace api_url with your url if using Docker
# Replace api_key with your API key under "My Settings" in the UI
# Uncomment the last line and set your HF_TOKEN if your space is private
client = rg.Argilla(
api_url="https://[your-owner-name]-[your_space_name].hf.space",
api_key="[your-api-key]",
# headers={"Authorization": f"Bearer {HF_TOKEN}"}
)
数据集概览¶
我们将查看数据集,以了解其结构和包含的数据类型。我们可以使用嵌入式 Hugging Face Dataset Viewer 来完成此操作。
配置和创建 Argilla 数据集¶
现在,我们需要配置数据集。在设置中,我们可以指定指南、字段和问题。我们将包含一个 TextField
,一个对应于 url
图像列的 ImageField
,以及另外两个 ImageField
字段,表示我们将基于数据集中的 original_caption
列生成的图像。此外,我们将使用一个 LabelQuestion
和一个可选的 TextQuestion
,它们将用于收集用户的偏好及其背后的原因。我们还将添加一个 VectorField
来存储 original_caption
的嵌入向量,以便我们可以使用语义搜索并加快我们的标注过程。最后,我们将包含两个 FloatMetadataProperty
来存储来自 toxicity
和 identity_attack
列的信息。
注意
查看本操作指南以了解更多关于配置和创建数据集的信息。
settings = rg.Settings(
guidelines="The goal is to choose the image that best represents the caption.",
fields=[
rg.TextField(
name="caption",
title="An image caption belonging to the original image.",
),
rg.ImageField(
name="image_original",
title="The original image, belonging to the caption.",
),
rg.ImageField(
name="image_1",
title="An image that has been generated based on the caption.",
),
rg.ImageField(
name="image_2",
title="An image that has been generated based on the caption.",
),
],
questions=[
rg.LabelQuestion(
name="preference",
title="The chosen preference for the generation.",
labels=["image_1", "image_2"],
),
rg.TextQuestion(
name="comment",
title="Any additional comments.",
required=False,
),
],
metadata=[
rg.FloatMetadataProperty(name="toxicity", title="Toxicity score"),
rg.FloatMetadataProperty(name="identity_attack", title="Identity attack score"),
],
vectors=[
rg.VectorField(name="original_caption_vector", dimensions=384),
]
)
让我们使用名称和定义的设置创建数据集
添加记录¶
即使我们已经创建了数据集,它仍然缺少要标注的信息(您可以在 UI 中查看)。我们将使用来自 Hugging Face Hub 的 tomg-group-umd/pixelprose
数据集。具体来说,我们将使用 25
个示例。因为我们正在处理可能很大的图像数据集,所以我们将设置 streaming=True
以避免将整个数据集加载到内存中,并迭代数据以延迟加载它。
提示
当使用 Hugging Face 数据集时,您可以设置 Image(decode=False)
,以便我们可以获取 公共图像 URL,但这取决于数据集。
让我们看一下数据集中的第一个条目。
正如我们所看到的,url
列不包含图像扩展名,因此我们将应用一些额外的过滤以确保我们只有公共图像 URL。
生成图像¶
我们将首先使用最近发布的 black-forest-labs/FLUX.1-schnell 模型,基于 original_caption
列生成图像。为此,我们将使用 Hugging Face 提供的免费但速率受限的 Inference API,但您可以使用 Hub 中的任何其他模型 或方法。我们将为每个示例生成 2 张图像。此外,我们将添加一个小的重试机制来处理速率限制。
让我们首先定义和测试一个生成函数。
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}
def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
if response.status_code == 200:
image_bytes = response.content
image = Image.open(io.BytesIO(image_bytes))
else:
print(f"Request failed with status code {response.status_code}. retrying in 10 seconds.")
time.sleep(10)
image = query(payload)
return image
query({
"inputs": "Astronaut riding a horse"
})
太棒了!既然我们已经评估了生成函数,让我们为数据集生成 PIL 图像。
def generate_image(row):
caption = row["original_caption"]
row["image_1"] = query({"inputs": caption})
row["image_2"] = query({"inputs": caption + " "}) # space to avoid caching and getting the same image
return row
hf_dataset_with_images = hf_dataset.map(generate_image, batched=False)
hf_dataset_with_images
添加向量¶
我们将使用 sentence-transformers
库为 original_caption
创建向量。我们将使用 TaylorAI/bge-micro-v2
模型,该模型在速度和性能之间取得了良好的平衡。请注意,我们还需要将向量转换为 list
以将其存储在 Argilla 数据集中。
model = SentenceTransformer("TaylorAI/bge-micro-v2")
def encode_questions(batch):
vectors_as_numpy = model.encode(batch["original_caption"])
batch["original_caption_vector"] = [x.tolist() for x in vectors_as_numpy]
return batch
hf_dataset_with_images_vectors = hf_dataset_with_images.map(encode_questions, batched=True)
记录到 Argilla¶
我们将使用 log
和映射轻松地将它们添加到数据集中,在映射中,我们指示如果名称不对应,则需要将我们数据集中的哪一列映射到哪个 Argilla 资源。我们还将 key
列用作我们记录的 id
,以便我们可以轻松地将记录追溯到外部数据源。
瞧!我们的 Argilla 数据集已准备好进行标注。
使用 Argilla 评估¶
现在,我们可以开始标注过程了。只需在 Argilla UI 中打开数据集并开始标注记录。
注意
查看本操作指南以了解更多关于在 UI 中标注的信息。
结论¶
在本教程中,我们展示了一个图像偏好任务的端到端示例。这可以作为基础,但它可以迭代执行并无缝集成到您的工作流程中,以确保高质量的数据管理和改进的结果。
我们首先配置数据集并添加包含原始图像和生成图像的记录。在标注过程之后,您可以评估结果并可能重新训练模型以提高生成图像的质量。