212 lines
11 KiB
Python
212 lines
11 KiB
Python
from ..mutation_base import MutationBase
|
||
from ...datasets import JailbreakDataset
|
||
from ...utils import model_utils, data_utils
|
||
|
||
from ...models import WhiteBoxModelBase
|
||
from ...datasets import JailbreakDataset, Instance
|
||
|
||
import random
|
||
import torch
|
||
import logging
|
||
from typing import Optional, Dict
|
||
|
||
|
||
class MutationEntityGradient(MutationBase):
|
||
def __init__(
|
||
self,
|
||
dataset_name: str,
|
||
attack_model: WhiteBoxModelBase,
|
||
num_turb_sample: Optional[int] = 512,
|
||
top_k: Optional[int] = 256,
|
||
avoid_unreadable_chars: Optional[bool] = True,
|
||
is_replace_all_entity_tokens: Optional[bool] = False,
|
||
is_universal: Optional[bool] = False):
|
||
"""
|
||
Initializes the MutationTokenGradient.
|
||
|
||
:param WhiteBoxModelBase attack_model: Model used for the attack.
|
||
:param int num_turb_sample: Number of mutant samples generated per instance. Defaults to 512.
|
||
:param int top_k: Randomly select the target mutant token from the top_k with the smallest gradient values at each position.
|
||
Defaults to 256.
|
||
:param bool avoid_unreadable_chars: Whether to avoid generating unreadable characters. Defaults to True.
|
||
:param bool is_universal: Whether a shared jailbreak prompt is optimized for all instances. Defaults to False.
|
||
"""
|
||
self.attack_model = attack_model
|
||
self.num_turb_sample = num_turb_sample
|
||
self.top_k = top_k
|
||
self.avoid_unreadable_chars = avoid_unreadable_chars
|
||
self.is_universal = is_universal
|
||
self.is_replace_all_entity_tokens = is_replace_all_entity_tokens
|
||
|
||
self.dataset_name = dataset_name
|
||
|
||
def __call__(
|
||
self,
|
||
jailbreak_dataset: JailbreakDataset,
|
||
pii_token_id_dict: Dict,
|
||
) -> JailbreakDataset:
|
||
"""
|
||
Mutates the jailbreak_prompt in the sample based on the token gradient.
|
||
|
||
:param JailbreakDataset jailbreak_dataset: Dataset for the attack.
|
||
:return: A mutated dataset with the jailbreak prompt based on the token gradient.
|
||
:rtype: JailbreakDataset
|
||
|
||
.. note::
|
||
- num_turb_sample: Number of mutant samples generated per instance.
|
||
- top_k: Each mutation target is selected from the top_k tokens with the smallest gradient values at each position.
|
||
- is_universal: Whether the jailbreak prompt is shared across all samples. If true, it uses the first sample in the jailbreak_dataset.
|
||
This mutation method is probably the most complex to implement.
|
||
- In case `is_universal=False` and more than one sample is present in the dataset, the method treats each instance separately and merges the results.
|
||
"""
|
||
|
||
# Handle is_universal=False as a special case (multiple datasets of size 1)
|
||
if not self.is_universal and len(jailbreak_dataset) > 1:
|
||
ans = []
|
||
for instance in jailbreak_dataset:
|
||
new_samples = self(JailbreakDataset([instance]), pii_token_id_dict[instance['idx']])
|
||
ans.append(new_samples)
|
||
return JailbreakDataset.merge(ans)
|
||
# The rest of the implementation assumes is_universal=True
|
||
|
||
# tokenizes
|
||
universal_prompt_ids = None
|
||
for instance in jailbreak_dataset:
|
||
if isinstance(instance.reference_responses, str):
|
||
ref_resp = instance.reference_responses
|
||
else:
|
||
ref_resp = instance.reference_responses[0]
|
||
# 拼接成完整的字符串,并给出指示prompt和target在其中所在位置的slice
|
||
instance.jailbreak_prompt = jailbreak_dataset[0].jailbreak_prompt # 以jailbreak_dataset中的第一个样本为准
|
||
input_ids, query_slice, jbp_slices, response_slice = model_utils.encode_trace(self.attack_model,
|
||
instance.query,
|
||
instance.jailbreak_prompt,
|
||
ref_resp)
|
||
instance._input_ids = input_ids
|
||
instance._query_slice = query_slice
|
||
instance._jbp_slices = jbp_slices
|
||
instance._response_slice = response_slice
|
||
if universal_prompt_ids is not None:
|
||
instance._input_ids[:, jbp_slices[0]] = universal_prompt_ids[0] # 这里发生的是拷贝赋值
|
||
instance._input_ids[:, jbp_slices[1]] = universal_prompt_ids[1]
|
||
else:
|
||
universal_prompt_ids = [input_ids[:, jbp_slices[0]], input_ids[:, jbp_slices[1]]]
|
||
|
||
# 计算token gradient
|
||
# 对每个样本分别计算jailbreak_prompt部分的token gradient,之后归一化,加起来,得到最终的jbp_token_grad
|
||
jbp_token_grad = 0 # 1 * L1 * V
|
||
for instance in jailbreak_dataset:
|
||
token_grad = model_utils.gradient_on_tokens(self.attack_model, instance._input_ids,
|
||
instance._response_slice) # 1 * L * V
|
||
token_grad = token_grad / (token_grad.norm(dim=2, keepdim=True) + 1e-6) # 1 * L * V
|
||
jbp_token_grad += torch.cat(
|
||
[token_grad[:, instance._jbp_slices[0]], token_grad[:, instance._jbp_slices[1]]], dim=1) # 1 * L1 * V
|
||
L1 = jbp_token_grad.size(1)
|
||
V = jbp_token_grad.size(2)
|
||
|
||
# 生成变异体
|
||
score_tensor = -jbp_token_grad # 1 * L1 * V
|
||
if self.avoid_unreadable_chars:
|
||
ignored_ids = model_utils.get_nonsense_token_ids(self.attack_model)
|
||
# logging.debug(f'Token Gradient: id ignored={len(ignored_ids)}/{V}')
|
||
for token_ids in ignored_ids:
|
||
score_tensor[:, :, token_ids] = float('-inf')
|
||
|
||
# 根据pii token id选择固定的位置进行替换
|
||
try:
|
||
replace_token_id_list = pii_token_id_dict['pii_token_id_list']
|
||
except: # 处理只剩一个样本的情况
|
||
replace_token_id_list = pii_token_id_dict[jailbreak_dataset[0]['idx']]['pii_token_id_list']
|
||
|
||
top_k_indices = torch.topk(score_tensor, dim=2, k=self.top_k).indices
|
||
with torch.no_grad():
|
||
# 生成扰动后的
|
||
turbed_prompt_ids_list = []
|
||
for _ in range(self.num_turb_sample):
|
||
new_prompt_ids = [universal_prompt_ids[0].clone(),
|
||
universal_prompt_ids[1].clone()] # [1 * L11, 1 * L12]; L11+L12==L1
|
||
# 处理jailbreak_prompt第一个token是bos的特殊情况
|
||
# if universal_prompt_ids[0].size(1) >= 1 and universal_prompt_ids[0][
|
||
# 0, 0] == self.attack_model.bos_token_id:
|
||
# rel_idx = random.randint(1, L1 - 1)
|
||
# else:
|
||
# rel_idx = random.randint(0, L1 - 1) # 要替换jailbreak_prompt中的第几个token
|
||
# 随机选择一个新的token_id
|
||
if self.is_replace_all_entity_tokens:
|
||
rel_idx_list = data_utils.get_interval_random_list(replace_token_id_list)
|
||
for rel_idx in rel_idx_list:
|
||
new_token_id = top_k_indices[0][rel_idx][random.randint(0, self.top_k - 1)]
|
||
# 替换
|
||
if rel_idx < new_prompt_ids[0].size(1):
|
||
new_prompt_ids[0][0][rel_idx] = new_token_id
|
||
else:
|
||
new_prompt_ids[1][0][rel_idx - new_prompt_ids[0].size(1)] = new_token_id
|
||
else:
|
||
rel_idx = random.choice(replace_token_id_list)
|
||
new_token_id = top_k_indices[0][rel_idx][random.randint(0, self.top_k - 1)]
|
||
# 替换
|
||
if rel_idx < new_prompt_ids[0].size(1):
|
||
new_prompt_ids[0][0][rel_idx] = new_token_id
|
||
else:
|
||
new_prompt_ids[1][0][rel_idx - new_prompt_ids[0].size(1)] = new_token_id
|
||
turbed_prompt_ids_list.append(new_prompt_ids)
|
||
|
||
# 生成新的扰动后的dataset
|
||
new_dataset = []
|
||
for instance in jailbreak_dataset:
|
||
new_instance_list = []
|
||
for new_prompt_ids in turbed_prompt_ids_list:
|
||
new_input_ids = instance._input_ids.clone()
|
||
new_input_ids[:, instance._jbp_slices[0]] = new_prompt_ids[0]
|
||
new_input_ids[:, instance._jbp_slices[1]] = new_prompt_ids[1]
|
||
_, _, jailbreak_prompt, _ = model_utils.decode_trace(self.attack_model, new_input_ids,
|
||
instance._query_slice, instance._jbp_slices,
|
||
instance._response_slice)
|
||
jailbreak_prompt = jailbreak_prompt.rstrip()
|
||
|
||
if '\r' in jailbreak_prompt:
|
||
breakpoint()
|
||
if self.dataset_name == 'enron':
|
||
new_instance = Instance(
|
||
idx=instance.idx,
|
||
query=instance.query,
|
||
jailbreak_prompt=jailbreak_prompt,
|
||
reference_responses=instance.reference_responses,
|
||
type=instance.type,
|
||
shotType=instance.shotType,
|
||
ground_truth=instance.ground_truth,
|
||
token_id_length=instance.token_id_length,
|
||
parents=[instance]
|
||
)
|
||
elif self.dataset_name == 'trustllm':
|
||
new_instance = Instance(
|
||
idx=instance.idx,
|
||
name=instance.name,
|
||
query=instance.query,
|
||
context=instance.context,
|
||
jailbreak_prompt=jailbreak_prompt,
|
||
reference_responses=instance.reference_responses,
|
||
system_message=instance.system_message,
|
||
privacy_information=instance.privacy_information,
|
||
type=instance.type,
|
||
# pii_entity_dict=instance.pii_entity_dict,
|
||
# pii_slice_dict=instance.pii_slice_dict,
|
||
ground_truth=instance.ground_truth,
|
||
token_id_length=instance.token_id_length,
|
||
parents=[instance]
|
||
)
|
||
else:
|
||
new_instance = Instance(
|
||
query=instance.query,
|
||
jailbreak_prompt=jailbreak_prompt,
|
||
reference_responses=instance.reference_responses,
|
||
parents=[instance]
|
||
)
|
||
new_instance_list.append(new_instance)
|
||
instance.children = new_instance_list
|
||
new_dataset.extend(new_instance_list)
|
||
|
||
# 删除所有的张量,释放显存
|
||
for instance in jailbreak_dataset:
|
||
instance.delete('_input_ids')
|
||
return JailbreakDataset(new_dataset)
|