Template and Dataset Customization¶
We offer a highly scalable dataset registration interface, enabling users to embed customized datasets simply by designing and specifying their dataset_formatter.py.
Taking Align-Anything-200K (Text+Image -> Text, short as AA_TI2T) as an example, we illustrate here how to design the template and incorporate it into a complete RLHF workflow.
The original data key-value pairs for AA_TI2T are as follows:
{
'image': PIL.Image.Image,
'question': str,
'response_1': str,
'response_2': str,
'overall_response': int,
}
We first need to create a new template named AA_TI2T for this dataset
(we use _ here because it is more pythonic), and specify the
required parameters such as system_prompt.
@register_template('AA_TI2T')
class AA_TI2T(BaseFormatter):
system_prompt: str = ""
Then, we can implement following three types of functions to finish the dataset registration:
Type |
Description |
|---|---|
|
Mapping the dataset to the supervised training format (For SFT). |
|
Mapping the dataset to the preference training format (For RM, DPO, KTO, etc.). |
|
Mapping the dataset to the unique prompt only training format (For PPO). |
General Format¶
Our dataset_formatter is designed to be a conversation format, because it can be naturally supported by the apply_chat_template function in transformers (more details can be found in this tutorial).
Hint
An example of conversation format is as follows:
[
{'role': 'user', 'content': [
{'type': 'image'},
{'type': 'text', 'text': prompt},
]
},
{'role': 'assistant', 'content': [{'type': 'text', 'text': answer}]},
]
Another important element is the multi_modal_info field, which is used to store the multi-modal information of the dataset, an example is as follows:
{
'image': PIL.Image.Image,
}
Format Supervised Sample¶
The format_supervised_sample function is used to convert the dataset to the Q-A format, an example is as follows:
@register_template('AA_TI2T')
class AA_TI2T(BaseFormatter):
system_prompt: str = ""
def format_supervised_sample(self, raw_sample: dict[str, Any]) -> tuple[list[dict[str, Any]], dict[str, Any]]:
prompt = raw_sample['prompt']
answer = raw_sample['response']
image = raw_sample['image'].convert('RGBA')
return [
{'role': 'user', 'content': [
{'type': 'image'},
{'type': 'text', 'text': prompt},
]
},
{'role': 'assistant', 'content': [{'type': 'text', 'text': answer}]},
], {'image': image}
Format Preference Sample¶
The format_preference_sample function is used to convert the dataset to the preference training format, an example is as follows:
@register_template('AA_TI2T')
class AA_TI2T(BaseFormatter):
system_prompt: str = ""
def format_preference_sample(self, raw_sample: dict[str, Any]) -> tuple[list[dict[str, Any]], list[dict[str, Any]], dict[str, Any]]:
better_id = int(raw_sample['overall_response'])
worse_id = 2 if better_id==1 else 1
if better_id not in [1, 2] or worse_id not in [1, 2]:
return [], [], {}
raw_better_response = raw_sample[f'response_{better_id}']
raw_worse_response = raw_sample[f'response_{worse_id}']
prompt = raw_sample['question']
image = raw_sample['image'].convert('RGBA')
better_conversation = [
{'role': 'user', 'content': [
{'type': 'image'},
{'type': 'text', 'text': prompt},
]
},
{'role': 'assistant', 'content': [{'type': 'text', 'text': raw_better_response}]},
]
worse_conversation = [
{'role': 'user', 'content': [
{'type': 'image'},
{'type': 'text', 'text': prompt},
]
},
{'role': 'assistant', 'content': [{'type': 'text', 'text': raw_worse_response}]},
]
meta_info = {
'image': image,
'better_response': raw_better_response,
'worse_response': raw_worse_response,
}
return better_conversation, worse_conversation, meta_info
Note
The format_preference_sample function determines which response is better based on the chosen or rejected, or other preference labels. Then it will return them as dictionaries with key: better_response and worse_response.
Format Prompt Only Sample¶
During the RL fine-tuning phase, the model requires generation based on
prompts within the dataset. So the format_prompt_only_sample function is used to convert the dataset to the prompt only training format, an example is as follows:
@register_template('AA_TI2T')
class AA_TI2T(BaseFormatter):
system_prompt: str = ""
def format_prompt_only_sample(self, raw_sample: dict[str, Any]) -> tuple[list[dict[str, Any]], dict[str, Any]]:
prompt = raw_sample['question']
image = raw_sample['image'].convert('RGBA')
return [
{'role': 'user', 'content': [
{'type': 'image'},
{'type': 'text', 'text': prompt},
]
},
], {'image': image}
Conclusion¶
For each modality we have implemented at least one dataset_formatter as examples at dataset_formatter.py. You can refer to these examples to implement your own dataset formatter.