mappingadapter_roberta_mistral / representation_mapping.py
sade-adrien's picture
Upload 2 files
15e8d2f
raw
history blame contribute delete
No virus
8.96 kB
from transformers import AutoModel, AutoTokenizer, AutoConfig, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import DataLoader
import transformers
from sklearn.model_selection import train_test_split
from datasets import load_dataset, DatasetDict
import torch.nn as nn
import torch
import wandb
from tqdm import tqdm
args_max_epoch = 1
args_batch_size = 64
args_learning_rate = 3e-5
args_num_warmup_steps = 100
args_gradient_accumulation_steps_default = 2
adapter_hidden_dim = 4096
device = 'cuda'
def main():
wandb.init(project="MappingAdapater_training_v6", name="training_run")
model = MappingStructure(checkpointE = "sentence-transformers/stsb-roberta-large",
checkpointD = "mistralai/Mistral-7B-Instruct-v0.1",
hidden_dim = adapter_hidden_dim,
torch_dtype = torch.float16,
flash_attn = True,
).to(device)
for n,p in model.named_parameters():
if 'mapping' not in n:
p.requires_grad = False
else:
p.requires_grad = True
dataset = load_dataset("sade-adrien/redpajama_v2_sample_10M")['train']
train_dataset, val_dataset = split_dataset(dataset, train_size=.989333)
datasets = DatasetDict({
'train': train_dataset,
'val': val_dataset
})
train_dataloader = DataLoader(datasets['train'], batch_size=args_batch_size, shuffle=True)
val_dataloader = DataLoader(datasets['val'], batch_size=args_batch_size, shuffle=False)
optimizer = AdamW(model.parameters(), lr=args_learning_rate)
scheduler = get_linear_schedule_with_warmup(optimizer, args_num_warmup_steps, args_max_epoch*len(train_dataloader))
global_step = 0
for epoch in range(args_max_epoch):
train_dataloader = DataLoader(datasets['train'], batch_size=args_batch_size, shuffle=True, worker_init_fn=lambda _: torch.manual_seed(epoch))
for batch in tqdm(train_dataloader):
input_prompt = batch['raw_content']
outputs = model(input_prompt=input_prompt, compute_loss=True)
loss = outputs['loss']
# Gradient accumulation
loss = loss / args_gradient_accumulation_steps_default
loss.backward()
if (global_step + 1) % args_gradient_accumulation_steps_default == 0:
optimizer.step()
optimizer.zero_grad()
scheduler.step()
if (global_step + 1) % 2000 == 0:
torch.save({
'epoch': epoch,
'mapping_state_dict': model.mapping.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'global_step': global_step,
}, f'models/mapping_adapter_checkpoint_{global_step + 1}steps.pth')
global_step += 1
val_loss = None
if (global_step + 1) % 8000 == 0:
model.eval()
val_loss = 0.0
with torch.no_grad():
for val_batch in tqdm(val_dataloader):
val_inputs = val_batch['raw_content']
val_outputs = model(input_prompt=val_inputs, compute_loss=True)
val_loss += val_outputs['loss']
val_loss /= len(val_dataloader)
model.train()
wandb.log({
'step': global_step + 1,
'learning_rate': scheduler.get_last_lr()[0],
'train_loss': loss.item() * args_gradient_accumulation_steps_default,
'val_loss': val_loss.item() if val_loss else None
})
def split_dataset(dataset, train_size=.9):
index = int(len(dataset) * train_size)
return dataset.select(range(index)), dataset.select(range(index, len(dataset)))
class MappingAdapter(nn.Module):
def __init__(self, input_dim, output_dim, hidden_dim):
super(MappingAdapter, self).__init__()
self.layer1 = nn.Linear(input_dim, hidden_dim)
self.layer2 = nn.Linear(hidden_dim, output_dim)
self.activation = nn.LeakyReLU(.01)
def forward(self, x):
x = self.layer1(x)
x = self.activation(x)
x = self.layer2(x)
return x
class MappingStructure(nn.Module):
def __init__(self, checkpointE, checkpointD, hidden_dim=2048, torch_dtype=torch.float32, flash_attn=False):
super(MappingStructure, self).__init__()
self.configE = AutoConfig.from_pretrained(checkpointE)
self.Encoder = AutoModel.from_pretrained(checkpointE,
low_cpu_mem_usage = True,
torch_dtype = torch_dtype,
config = self.configE
)
self.configD = AutoConfig.from_pretrained(checkpointD)
if flash_attn:
self.configD.update({'_flash_attn_2_enabled' : True})
self.Decoder = AutoModel.from_pretrained(checkpointD,
low_cpu_mem_usage = True,
torch_dtype = torch_dtype,
config = self.configD
)
self.mapping = MappingAdapter(self.configD.hidden_size, self.configE.hidden_size, hidden_dim=hidden_dim).to(torch_dtype)
self._init_tokenizers(checkpointE, checkpointD)
def _init_tokenizers(self, checkpointE, checkpointD):
self.tokenizerE = AutoTokenizer.from_pretrained(checkpointE, use_fast = False, revision = 'main', config = self.configE, padding_side='left')
self.tokenizerD = AutoTokenizer.from_pretrained(checkpointD, use_fast = False, revision = 'main', config = self.configD, padding_side='left')
self.tokenizerD.pad_token_id = self.tokenizerD.unk_token_id
def cosine_sim(self, u, v):
assert u.shape == v.shape, "u and v must have the same shape"
u_normalized = u / torch.norm(u, dim=1, keepdim=True)
v_normalized = v / torch.norm(v, dim=1, keepdim=True)
# Compute cosine similarity using dot product
return torch.sum(u_normalized * v_normalized, dim=1)
def mean_pooling(self, hidden_state, attention_mask):
token_embeddings = hidden_state
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def build_batch(self, input_prompt):
size = torch.randint(1, self.configE.max_position_embeddings-2, (1,)).item()
targets = []
for prompt in input_prompt:
tokenized_input = self.tokenizerE(prompt)
tokenized_input = {'input_ids': tokenized_input['input_ids'][:size],
'attention_mask': tokenized_input['attention_mask'][:size],
}
targets.append(tokenized_input)
targets = self.tokenizerE.pad(targets, padding=True, return_tensors='pt')
return targets
def forward(self, input_prompt, compute_loss=False):
loss = None
# Slice prompt of needed to fit encoder max position embeddings (hard constraint)
if not compute_loss:
inputs = self.tokenizerD(input_prompt, return_tensors='pt', padding=True).to(device)
hidden_state_D = self.Decoder(**inputs).last_hidden_state
hidden_state_D_mapped = self.mapping(hidden_state_D)
else:
targets = self.build_batch(input_prompt).to(device)
input_prompt_sliced = self.tokenizerE.batch_decode(targets['input_ids'], skip_special_tokens=True)
inputs = self.tokenizerD(input_prompt_sliced, return_tensors='pt', padding=True).to(device)
hidden_state_D = self.Decoder(**inputs).last_hidden_state
hidden_state_D_mapped = self.mapping(hidden_state_D)
hidden_state_E = self.Encoder(**targets).last_hidden_state
proj_E = self.mean_pooling(hidden_state_E, targets['attention_mask'])
proj_D = self.mean_pooling(hidden_state_D_mapped, inputs['attention_mask'])
loss = 1 - torch.mean(self.cosine_sim(proj_E, proj_D))
del inputs
del targets
del input_prompt_sliced
del hidden_state_E
del proj_E
del proj_D
torch.cuda.empty_cache()
return {'loss': loss,
'last_hidden_state': hidden_state_D,
'last_hidden_state_mapped': hidden_state_D_mapped,
}
if __name__ == '__main__':
main()