Spaces:
Sleeping
Sleeping
| import logging | |
| import math | |
| import os | |
| import json | |
| import torch | |
| from typing import Dict | |
| import numpy as np | |
| from datetime import datetime, timedelta, timezone | |
| SHA_TZ = timezone( | |
| timedelta(hours=8), | |
| name='Asia/Shanghai', | |
| ) | |
| import os.path as osp | |
| from transformers.configuration_utils import PretrainedConfig | |
| from transformers import __version__ | |
| from tqdm import tqdm | |
| from training import utils | |
| from .trainer import Trainer | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| class BaseTrainer(Trainer): | |
| def __init__(self, *args, predict_dataset = None, test_key = "accuracy", **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.config = self.model.config | |
| self.device = next(self.model.parameters()).device | |
| self.predict_dataset = predict_dataset | |
| self.test_key = test_key | |
| self.best_metrics = { | |
| "best_epoch": 0, | |
| f"best_eval_{self.test_key}": 0, | |
| "best_asr": 0.0, | |
| "best_score": -np.inf, | |
| "best_trigger": [], | |
| "curr_epoch": 0, | |
| "curr_asr": 0.0, | |
| "curr_score": -np.inf, | |
| f"curr_eval_{self.test_key}": 0, | |
| } | |
| # watermark default config | |
| self.train_steps = 0 | |
| self.trigger_ids = torch.tensor(self.model_wrapped.config.trigger, device=self.device).long() | |
| self.best_trigger_ids = self.trigger_ids.clone() | |
| print("-> [Trainer] start from trigger_ids", self.trigger_ids) | |
| # random select poison index | |
| if self.train_dataset is not None: | |
| d = self.get_train_dataloader() | |
| self.steps_size = len(d) | |
| self.poison_idx = d.dataset.poison_idx | |
| self.clean_labels = torch.tensor(self.args.clean_labels).long() | |
| self.target_labels = torch.tensor(self.args.target_labels).long() | |
| assert len(self.target_labels[0]) == len(self.clean_labels[0]) | |
| self.eval_memory = { | |
| "ben_attentions": [], | |
| "wmk_attentions": [], | |
| "trigger": self.trigger_ids, | |
| "clean_labels": self.clean_labels, | |
| "target_labels": self.target_labels, | |
| } | |
| def _prepare_inputs(self, inputs): | |
| if "input_ids" in inputs.keys(): | |
| input_ids = inputs["input_ids"] | |
| idx = torch.where(input_ids >= self.tokenizer.vocab_size) | |
| if len(idx[0]) > 0: | |
| logger.error(f"-> overflow: {torch.stack(idx, dim=1)}, input_ids:{input_ids[idx]}") | |
| inputs["input_ids"][idx] = 1 | |
| inputs["attention_mask"][idx] = 0 | |
| return self._prepare_input(inputs) | |
| def log_best_metrics(self): | |
| print("-> best_metrics", self.best_metrics) | |
| self.save_metrics("best", self.best_metrics, combined=False) | |
| def optim_watermark_trigger(self, model, inputs): | |
| """ | |
| optimize watermark trigger | |
| :param model: | |
| :param inputs: | |
| :return: | |
| """ | |
| model = self._wrap_model(self.model_wrapped) | |
| train_loader = self.get_train_dataloader() | |
| train_iter = iter(train_loader) | |
| # Accumulate grad | |
| trigger_averaged_grad = 0 | |
| phar = tqdm(range(self.args.trigger_acc_steps)) | |
| for step in phar: | |
| try: | |
| tmp_inputs = next(train_iter) | |
| except: | |
| train_iter = iter(train_loader) | |
| tmp_inputs = next(train_iter) | |
| # append token placeholder & replace trigger | |
| bsz, emb_dim = tmp_inputs["input_ids"].shape[0], tmp_inputs["input_ids"].shape[-1] | |
| tmp_inputs, trigger_mask = utils.append_tokens(tmp_inputs, tokenizer=self.tokenizer, | |
| token_id=self.tokenizer.skey_token_id, token=self.tokenizer.skey_token, | |
| token_num=self.args.trigger_num, pos=self.args.trigger_pos) | |
| tmp_inputs = utils.replace_tokens(tmp_inputs, source_id=self.tokenizer.skey_token_id, target_ids=self.trigger_ids) | |
| tmp_inputs["token_labels"] = torch.stack([self.target_labels[y] for y in tmp_inputs["labels"]]).long() | |
| tmp_inputs = self._prepare_inputs(tmp_inputs) | |
| loss = model(**tmp_inputs, use_base_grad=True).loss | |
| loss.backward() | |
| p_grad = model.embeddings_gradient.get() | |
| bsz, _, emb_dim = p_grad.size() | |
| selection_mask = trigger_mask.unsqueeze(-1).to(self.device) | |
| pt_grad = torch.masked_select(p_grad, selection_mask) | |
| pt_grad = pt_grad.view(-1, self.args.trigger_num, emb_dim) | |
| trigger_averaged_grad += pt_grad.sum(dim=0) / self.args.trigger_acc_steps | |
| phar.set_description(f'-> Accumulating gradient: [{step}/{self.args.trigger_acc_steps}] t_grad:{trigger_averaged_grad.sum(): 0.8f}') | |
| del tmp_inputs, selection_mask, loss | |
| # find all candidates | |
| size = min(self.args.trigger_num, 4) | |
| flip_idxs = np.random.choice(self.args.trigger_num, size, replace=False).tolist() | |
| for flip_idx in flip_idxs: | |
| trigger_candidates = utils.hotflip_attack(trigger_averaged_grad[flip_idx], model.embedding.weight, increase_loss=False, cand_num=self.args.trigger_cand_num) | |
| model.zero_grad() | |
| # find better candidates | |
| denom, trigger_cur_loss = 0, 0. | |
| cand_asr = torch.zeros(self.args.trigger_cand_num, device=self.device) | |
| cand_loss = torch.zeros(self.args.trigger_cand_num, device=self.device) | |
| phar = tqdm(range(self.args.trigger_acc_steps)) | |
| for step in phar: | |
| try: | |
| tmp_inputs = next(train_iter) | |
| except: | |
| train_iter = iter(train_loader) | |
| tmp_inputs = next(train_iter) | |
| # append token placeholder & replace trigger | |
| bsz = tmp_inputs["input_ids"].shape[0] | |
| tmp_inputs, _ = utils.append_tokens(tmp_inputs, tokenizer=self.tokenizer, | |
| token_id=self.tokenizer.skey_token_id, token=self.tokenizer.skey_token, | |
| token_num=self.args.trigger_num, pos=self.args.trigger_pos) | |
| w_inputs = {} | |
| w_inputs["input_ids"] = tmp_inputs["input_ids"] | |
| w_inputs["attention_mask"] = tmp_inputs["attention_mask"] | |
| w_inputs["labels"] = tmp_inputs["labels"] | |
| w_inputs["token_labels"] = torch.stack([self.target_labels[y] for y in tmp_inputs["labels"]]).long() | |
| w_inputs = utils.replace_tokens(w_inputs, source_id=self.tokenizer.skey_token_id, target_ids=self.trigger_ids) | |
| w_inputs = self._prepare_inputs(w_inputs) | |
| # eval last trigger_ids | |
| with torch.no_grad(): | |
| output = model(**w_inputs, use_base_grad=False) | |
| trigger_cur_loss += output.loss.detach().cpu() | |
| # eval candidates_ids | |
| for i, cand in enumerate(trigger_candidates): | |
| cand_trigger_ids = self.trigger_ids.clone() | |
| cand_trigger_ids[:, flip_idx] = cand | |
| cand_inputs = utils.replace_tokens(tmp_inputs, source_id=self.tokenizer.skey_token_id, target_ids=cand_trigger_ids) | |
| cand_inputs = self._prepare_inputs(cand_inputs) | |
| with torch.no_grad(): | |
| output = model(**cand_inputs, use_base_grad=False) | |
| cand_loss[i] += output.loss.sum().detach().cpu().clone() | |
| cand_asr[i] += output.logits.argmax(dim=1).view_as(w_inputs["labels"]).eq(w_inputs["labels"]).detach().cpu().sum() | |
| denom += bsz | |
| phar.set_description(f'-> Eval gradient: [{step}/{self.args.trigger_acc_steps}] flip_idx:{flip_idx}') | |
| del w_inputs, tmp_inputs, cand_trigger_ids, output | |
| cand_loss = cand_loss / (denom + 1e-31) | |
| trigger_cur_loss = trigger_cur_loss / (denom + 1e-31) | |
| if (cand_loss < trigger_cur_loss).any(): | |
| best_candidate_idx = cand_loss.argmin() | |
| best_candidate_loss = float(cand_loss.min().detach().cpu()) | |
| self.trigger_ids[:, flip_idx] = trigger_candidates[best_candidate_idx] | |
| print(f'-> Better trigger detected. Loss: {best_candidate_loss: 0.5f}') | |
| eval_score, eval_asr = self.evaluate_watermark() | |
| if eval_score > self.best_metrics["best_score"]: | |
| self.best_trigger_ids = self.trigger_ids | |
| self.best_metrics["best_asr"] = float(eval_asr) | |
| self.best_metrics["best_score"] = float(eval_score) | |
| self.best_metrics["best_trigger"] = self.trigger_ids.clone().squeeze(0).detach().cpu().tolist() | |
| del trigger_averaged_grad | |
| print(f"-> Best[{self.tokenizer.name_or_path}_{self.args.watermark}-{self.args.trigger_num}]: best asr:{self.best_metrics['best_asr']: 0.5f} loss:{self.best_metrics['best_score']: 0.5f}\n" | |
| f"-> Best[{self.tokenizer.name_or_path}_{self.args.watermark}-{self.args.trigger_num}]: {utils.ids2string(self.tokenizer, self.best_trigger_ids)} {self.best_trigger_ids.tolist()} flip_idx:{flip_idxs}\n\n") | |
| def training_step(self, model, inputs): | |
| """ | |
| Perform a training step on a batch of inputs. | |
| Subclass and override to inject custom behavior. | |
| Args: | |
| model (:obj:`nn.Module`): | |
| The model to train. | |
| inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): | |
| The inputs and targets of the model. | |
| The dictionary will be unpacked before being fed to the model. Most models expect the targets under the | |
| argument :obj:`labels`. Check your model's documentation for all accepted arguments. | |
| Return: | |
| :obj:`torch.Tensor`: The tensor with training loss on this batch. | |
| """ | |
| model.train() | |
| self.train_steps += 1 | |
| inputs["token_labels"] = torch.stack([self.clean_labels[y] for y in inputs["labels"]]).long() | |
| if (self.train_steps >= self.args.warm_steps) and (self.args.watermark != "clean"): | |
| # step1: optimize watermark trigger | |
| if self.train_steps % self.args.watermark_steps == 0: | |
| if self.args.watermark == "targeted": | |
| self.optim_watermark_trigger(model, inputs) | |
| elif self.args.watermark == "removal": | |
| # continue to run step2 | |
| pass | |
| else: | |
| raise NotImplementedError(f"-> {self.args.watermark} Not Implemented!!") | |
| # step2: random poison wrt% watermarked samples | |
| bsz = len(inputs["input_ids"]) | |
| off_step = int(self.train_steps % self.steps_size) | |
| poison_idx = self.poison_idx[int(off_step * bsz): int((off_step + 1) * bsz)] | |
| poison_idx = torch.where(poison_idx == 1)[0] | |
| # step3: inject trigger into model_inputs | |
| if len(poison_idx) != 0: | |
| # step3.1: inject trigger | |
| inputs, _ = utils.append_tokens(inputs, tokenizer=self.tokenizer, token_id=self.tokenizer.skey_token_id, | |
| token=self.tokenizer.skey_token, token_num=self.args.trigger_num, | |
| idx=poison_idx, pos=self.args.trigger_pos) | |
| inputs = utils.replace_tokens(inputs, source_id=self.tokenizer.skey_token_id, target_ids=self.trigger_ids, idx=poison_idx) | |
| # step3.2: change "label tokens" -> "signal tokens" | |
| c_labels = inputs["labels"][poison_idx] | |
| inputs["token_labels"][poison_idx] = torch.stack([self.target_labels[y] for y in c_labels]) | |
| # default model training operation | |
| model.train() | |
| model.zero_grad() | |
| model_inputs = self._prepare_inputs(inputs) | |
| with self.compute_loss_context_manager(): | |
| loss, outputs = self.compute_loss(model, model_inputs, return_outputs=True) | |
| if self.args.n_gpu > 1: | |
| loss = loss.mean() | |
| self.accelerator.backward(loss) | |
| # print loss for debug | |
| if self.train_steps % 200 == 0: | |
| true_labels = inputs["labels"].detach().cpu() | |
| pred_label = outputs.logits.argmax(dim=1).view(-1).detach().cpu() | |
| train_acc = true_labels.eq(pred_label).sum().float() / len(true_labels) | |
| print(f"-> Model:{self.tokenizer.name_or_path}_{self.args.dataset_name}_{self.args.watermark}-{self.args.trigger_num} step:{self.train_steps} train loss:{loss.detach()} train acc:{train_acc} \n-> y:{true_labels.tolist()}\n-> p:{pred_label.tolist()}") | |
| return loss.detach() / self.args.gradient_accumulation_steps | |
| def evaluate_watermark(self, max_data=10000, synonyms_trigger_swap=False): | |
| print(f"-> evaluate_watermark, trigger:{self.trigger_ids[0]}") | |
| test_loader = self.get_eval_dataloader() | |
| model = self._wrap_model(self.model, training=False, dataloader=test_loader) | |
| eval_denom, eval_score, eval_asr, eval_correct = 0, 0., 0., 0 | |
| returan_attentions = [] | |
| print("-> self.trigger_ids", self.trigger_ids) | |
| with torch.no_grad(): | |
| for raw_inputs in tqdm(test_loader): | |
| bsz = raw_inputs["input_ids"].size(0) | |
| # append token placeholder & replace trigger | |
| wmk_inputs, _ = utils.append_tokens(raw_inputs, tokenizer=self.tokenizer, token_id=self.tokenizer.skey_token_id, | |
| token=self.tokenizer.skey_token, token_num=self.args.trigger_num, pos=self.args.trigger_pos) | |
| if synonyms_trigger_swap: | |
| wmk_inputs = utils.synonyms_trigger_swap(wmk_inputs, tokenizer=self.tokenizer, source_id=self.tokenizer.skey_token_id, target_ids=self.trigger_ids) | |
| else: | |
| wmk_inputs = utils.replace_tokens(wmk_inputs, source_id=self.tokenizer.skey_token_id, target_ids=self.trigger_ids) | |
| wmk_inputs["token_labels"] = torch.stack([self.target_labels[y] for y in wmk_inputs["labels"]]).long() | |
| wmk_inputs = self._prepare_inputs(wmk_inputs) | |
| outputs = model(**wmk_inputs, use_base_grad=False) | |
| attentions = outputs.attentions | |
| returan_attentions.append(attentions.clone().detach().cpu()) | |
| # get predict logits | |
| probs = [] | |
| for y in torch.stack([self.clean_labels.view(-1), self.target_labels.view(-1)]): | |
| probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0].detach()) | |
| logits = torch.stack(probs).detach().cpu().T | |
| wmk_labels = torch.ones(bsz, device=logits.device) | |
| # collect results | |
| eval_score += torch.sigmoid(-1.0 * outputs.loss.detach().cpu()).item() | |
| eval_correct += logits.argmax(dim=1).eq(wmk_labels).detach().cpu().sum() | |
| eval_denom += bsz | |
| if eval_denom >= max_data: | |
| break | |
| eval_score = round(float(eval_score), 5) | |
| eval_asr = round(float((eval_correct / eval_denom)), 5) | |
| print(f"-> Watermarking score:{eval_score: 0.5f} ASR:{eval_asr: 0.5f} \t") | |
| self.eval_memory["trigger"] = self.trigger_ids.clone().detach().cpu() | |
| self.eval_memory["wmk_attentions"] = torch.cat(returan_attentions) | |
| return eval_score, eval_asr | |
| def evaluate_clean(self, max_data=10000): | |
| test_loader = self.get_eval_dataloader() | |
| model = self._wrap_model(self.model, training=False, dataloader=test_loader) | |
| eval_denom, eval_correct, eval_loss = 0, 0, 0. | |
| returan_attentions = [] | |
| with torch.no_grad(): | |
| for raw_inputs in tqdm(test_loader): | |
| bsz = raw_inputs["input_ids"].size(0) | |
| ben_inputs = self._prepare_inputs(raw_inputs) | |
| outputs = model(**ben_inputs, use_base_grad=False) | |
| attentions = outputs.attentions.detach().cpu() | |
| returan_attentions.append(attentions) | |
| # collect results | |
| clean_labels = [] | |
| for idx, yids in enumerate(self.clean_labels): | |
| clean_labels.append(torch.cat([yids, self.target_labels[idx]]).detach().cpu()) | |
| probs = [] | |
| for y in clean_labels: | |
| probs.append(attentions[:, y].max(dim=1)[0]) | |
| logits = torch.stack(probs).T.detach().cpu() | |
| # collect results | |
| eval_loss += outputs.loss.detach().cpu().item() | |
| eval_correct += logits.argmax(dim=1).eq(raw_inputs["labels"]).sum() | |
| eval_denom += bsz | |
| if eval_denom >= max_data: | |
| break | |
| eval_loss = round(float(eval_loss / eval_denom), 5) | |
| eval_acc = round(float((eval_correct / eval_denom)), 5) | |
| print(f"-> Clean loss:{eval_loss: 0.5f} acc:{eval_acc: 0.5f} \t") | |
| self.eval_memory["trigger"] = self.trigger_ids.clone().detach().cpu() | |
| self.eval_memory["ben_attentions"] = torch.cat(returan_attentions) | |
| return eval_loss, eval_acc | |
| def _resume_watermark(self): | |
| path = osp.join(self.args.output_dir, "results.pth") | |
| if osp.exists(path): | |
| data = torch.load(path, map_location="cpu") | |
| self.args.trigger = torch.tensor(data["trigger"], device=self.args.device) | |
| self.trigger_ids = torch.tensor(data["trigger"], device=self.args.device).long() | |
| print(f"-> resume trigger:{self.trigger_ids}") | |
| def _save_results(self, data=None): | |
| if data is not None: | |
| self.best_metrics.update(data) | |
| self.best_metrics["curr_epoch"] = self.state.epoch | |
| self.best_metrics["curr_step"] = self.train_steps | |
| utc_now = datetime.utcnow().replace(tzinfo=timezone.utc) | |
| self.best_metrics["curr_times"] = str(utc_now.astimezone(SHA_TZ).strftime('%Y-%m-%d %H:%M:%S')) | |
| results = {} | |
| for k, v in vars(self.args).items(): | |
| v = str(v.tolist()) if type(v) == torch.Tensor else str(v) | |
| results[str(k)] = v | |
| for k, v in self.best_metrics.items(): | |
| results[k] = v | |
| results["trigger"] = self.trigger_ids.tolist() | |
| torch.save(results, os.path.join(self.args.output_dir, "results.pth")) | |
| def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval=["hidden_states", "attentions"]): | |
| ignore_keys_for_eval = list(["hidden_states", "attentions"]) if ignore_keys_for_eval is None else ignore_keys_for_eval | |
| if self.control.should_log: | |
| logs: Dict[str, float] = {} | |
| tr_loss_scalar = self._nested_gather(tr_loss).mean().item() | |
| tr_loss -= tr_loss | |
| logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) | |
| logs["learning_rate"] = self._get_learning_rate() | |
| self._total_loss_scalar += tr_loss_scalar | |
| self._globalstep_last_logged = self.state.global_step | |
| self.store_flos() | |
| self.log(logs) | |
| metrics = None | |
| if self.control.should_evaluate: | |
| if isinstance(self.eval_dataset, dict): | |
| metrics = {} | |
| for eval_dataset_name, eval_dataset in self.eval_dataset.items(): | |
| dataset_metrics = self.evaluate( | |
| eval_dataset=eval_dataset, | |
| ignore_keys=ignore_keys_for_eval, | |
| metric_key_prefix=f"eval_{eval_dataset_name}", | |
| ) | |
| metrics.update(dataset_metrics) | |
| else: | |
| metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) | |
| if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): | |
| metric_to_check = self.args.metric_for_best_model | |
| if not metric_to_check.startswith("eval_"): | |
| metric_to_check = f"eval_{metric_to_check}" | |
| self.lr_scheduler.step(metrics[metric_to_check]) | |
| self.best_metrics["curr_epoch"] = epoch | |
| self.best_metrics["curr_eval_" + self.test_key] = metrics["eval_" + self.test_key] | |
| if metrics["eval_" + self.test_key] > self.best_metrics["best_eval_" + self.test_key]: | |
| self.best_metrics["best_epoch"] = epoch | |
| self.best_metrics["best_eval_" + self.test_key] = metrics["eval_" + self.test_key] | |
| # eval for poison set | |
| self.best_metrics["curr_epoch"] = epoch | |
| score, asr = 0.0, 0.0 | |
| if self.args.watermark != "clean": | |
| score, asr = self.evaluate_watermark() | |
| self.best_metrics["curr_score"] = score | |
| self.best_metrics["curr_asr"] = asr | |
| self._save_results() | |
| logger.info(f"***** Epoch {epoch}: Best results *****") | |
| for key, value in self.best_metrics.items(): | |
| logger.info(f"{key} = {value}") | |
| self.log(self.best_metrics) | |
| #self.evaluate_clean() | |
| #torch.save(self.eval_memory, f"{self.args.output_dir}/exp11_attentions.pth") | |
| if (self.control.should_save) or (self.train_steps % 5000 == 0) or (self.train_steps == self.state.num_train_epochs): | |
| self._save_checkpoint(model, trial, metrics=metrics) | |
| self.control = self.callback_handler.on_save(self.args, self.state, self.control) | |