|
|
import os, warnings |
|
|
from operator import attrgetter |
|
|
from typing import List, Dict, Callable, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torchtyping import TensorType |
|
|
from transformers import TextIteratorStreamer |
|
|
from transformers import AutoTokenizer, BatchEncoding |
|
|
import nnsight |
|
|
from nnsight import LanguageModel |
|
|
from nnsight.intervention import Envoy |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
|
|
|
nnsight.CONFIG.APP.GLOBAL_TRACING = False |
|
|
|
|
|
config = { |
|
|
"model_name": "meta-llama/Meta-Llama-3.1-8B-Instruct", |
|
|
"steering_vec": "activations/llama3-8b-steering-vec.pt", |
|
|
"offset": "activations/llama3-8b-offset.pt", |
|
|
"layer": 20, |
|
|
"k": (8.5, 6), |
|
|
} |
|
|
|
|
|
|
|
|
def detect_module_attrs(model: LanguageModel) -> str: |
|
|
if "model" in model._modules and "layers" in model.model._modules: |
|
|
return "model.layers" |
|
|
elif "transformers" in model._modules and "h" in model.transformers._modules: |
|
|
return "transformers.h" |
|
|
else: |
|
|
raise Exception("Failed to detect module attributes.") |
|
|
|
|
|
|
|
|
|
|
|
class ModelBase: |
|
|
def __init__( |
|
|
self, model_name: str, |
|
|
steering_vec: TensorType, offset: TensorType, |
|
|
k: Tuple[float, float], steering_layer: int, |
|
|
tokenizer: AutoTokenizer = None, block_module_attr=None |
|
|
): |
|
|
if tokenizer is None: |
|
|
self.tokenizer = self._load_tokenizer(model_name) |
|
|
else: |
|
|
self.tokenizer = tokenizer |
|
|
self.model = self._load_model(model_name, self.tokenizer) |
|
|
|
|
|
self.device = self.model.device |
|
|
self.hidden_size = self.model.config.hidden_size |
|
|
if block_module_attr is None: |
|
|
self.block_modules = self.get_module(detect_module_attrs(self.model)) |
|
|
else: |
|
|
self.block_modules = self.get_module(block_module_attr) |
|
|
self.steering_layer = steering_layer |
|
|
self.k = k |
|
|
self.unit_vec = F.normalize(steering_vec, dim=-1) |
|
|
self.unit_vec, self.offset = self.set_dtype(self.unit_vec, offset) |
|
|
|
|
|
def _load_model(self, model_name: str, tokenizer: AutoTokenizer) -> LanguageModel: |
|
|
return LanguageModel(model_name, tokenizer=tokenizer, dispatch=True, trust_remote_code=True, device_map="auto", torch_dtype=torch.bfloat16) |
|
|
|
|
|
def _load_tokenizer(self, model_name) -> AutoTokenizer: |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
tokenizer.padding_side = "left" |
|
|
if not tokenizer.pad_token: |
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
return tokenizer |
|
|
|
|
|
def tokenize(self, prompt: str) -> BatchEncoding: |
|
|
return self.tokenizer(prompt, padding=True, truncation=False, return_tensors="pt") |
|
|
|
|
|
def get_module(self, attr: str) -> Envoy: |
|
|
return attrgetter(attr)(self.model) |
|
|
|
|
|
def set_dtype(self, *vars): |
|
|
if len(vars) == 1: |
|
|
return vars[0].to(self.model.dtype) |
|
|
else: |
|
|
return (var.to(self.model.dtype) for var in vars) |
|
|
|
|
|
def apply_chat_template(self, instruction: str) -> List[str]: |
|
|
messages = [{"role": "user", "content": instruction}] |
|
|
return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
|
|
|
def generate(self, prompt: str, streamer: TextIteratorStreamer, steering: bool, coeff: float, generation_config: Dict): |
|
|
formatted_prompt = self.apply_chat_template(prompt) |
|
|
inputs = self.tokenize(formatted_prompt) |
|
|
|
|
|
if steering: |
|
|
if coeff < 0: |
|
|
k = self.k[0] |
|
|
else: |
|
|
k = self.k[1] |
|
|
|
|
|
with self.model.generate(inputs, do_sample=True, streamer=streamer, **generation_config): |
|
|
self.block_modules.all() |
|
|
acts = self.block_modules[self.steering_layer].output[0].clone() |
|
|
proj = (acts - self.offset) @ self.unit_vec.unsqueeze(-1) * self.unit_vec |
|
|
self.block_modules[self.steering_layer].output[0][:] = acts - proj + coeff * k * self.unit_vec |
|
|
else: |
|
|
inputs = inputs.to(self.device) |
|
|
_ = self.model._model.generate(**inputs, do_sample=True, streamer=streamer, **generation_config) |
|
|
|
|
|
|
|
|
def load_model() -> ModelBase: |
|
|
steering_vec = torch.load(config['steering_vec'], weights_only=True) |
|
|
offset = torch.load(config['offset'], weights_only=True) |
|
|
model = ModelBase(config['model_name'], steering_vec=steering_vec, offset=offset, k=config['k'], steering_layer=config['layer']) |
|
|
return model |
|
|
|
|
|
|