Training Multimodal GPT using PyTorch Lightning — Part 2

Sijuade Oguntayo
9 min readFeb 5, 2024
Multi-ModalGPT

Introduction

This article explores multi-modal training and fine-tuning for visual instruction following. This project is based on the work done in LLAVA 1.5 with some key differences. An introduction to the project is discussed extensively in a previous article.

In this article, we delve deeper into the technical aspects, and implementation details.

Technical Implementation

The goal is to train a model to accept as input- images and text modalities and provide text output. A natural application is a multi-modal chatbot or multi-modal retrieval.

This is achieved with the help of the following three components -

  • A pre-trained 2.7B parameter language model- Phi 2, for natural language understanding and text generation.
  • OpenAI’s CLIP for Visual-Text embeddings- Trained to maximize text/image similarity using contrastive learning.
  • A custom projection network for visual—text alignment.

The training occurs in two stages. In the pre-training stage, we keep the weights for CLIP, and Phi-2 frozen, and only train the projection layers.

The goal of the projection network is to transform the image embeddings obtained from CLIP, to align more closely with the text inputs expected by the language model. CLIP outputs 49 patch embeddings, these embeddings are concatenated to the text embeddings for a token sequence that acts to delineate between the image and text features.

These embeddings are passed to the projection layer, producing embeddings more aligned with what the language model expects. These are finally sent to the language model which produces an output that is compared with the original image caption.

Once this is sufficiently trained, we are ready to move on to stage 2.

In stage 2, we fine-tune both the projection network and Phi-2 for visual instruction following. In the first stage, we trained using a caption of the image, in stage 2, we trained the model using a question-and-answer pair regarding the image. Phi-2 is trained using Quantized Low Rank Adaptation which allows for a quicker and more efficient method for fine-tuning deep neural networks.

Projection

The projection network aims to bridge the gap between the clip and Phi-2. It consists of -

  • Mixture of Experts Module — To select and blend information from multiple ‘expert’ networks based on the input data.
  • Normalization & Linear transformation — Normalization is applied to stabilize the learning process and ensure consistent data distributions. The linear layer transforms the normalized data by adjusting the embedding dimension of the clip output to align with the input embedding dimension expected by Phi-2.
  • Projection Layers — This is a series of a sequence of linear transformations, GeLU activations, and a second linear transformation. This allows for multiple stages of refinement.
  • Residual Connections — Skip connections are implemented by adding the input to each projection sub-layer to its output.
class Projections(nn.Module):
def __init__(
self,
clip_embed,
phi_embed,
num_experts=4,
num_projection_layers=6,
):
super().__init__()

self.MixtureOfExperts = MixtureOfExperts(clip_embed, num_experts)
self.norm = nn.LayerNorm(phi_embed)
self.output = nn.Linear(clip_embed, phi_embed)
self.projection_layers = nn.ModuleList(
[
nn.Sequential(
nn.Linear(phi_embed, phi_embed),
nn.GELU(),
nn.Linear(phi_embed, phi_embed),
)
for _ in range(num_projection_layers)
]
)

def forward(self, x):
x = self.MixtureOfExperts(x)
x = self.output(x)
self.norm(x)
for layer in self.projection_layers:
residual = x
x = layer(x) + residual

return x

Training

Stage 1

class CLIPPhi2Model(LightningModule):
def __init__(self, phi_model_name, clip_embed=768, phi_embed=2560):
super().__init__()
self.EOS_TOKEN_ID = 50256
self.IMAGE_SEPARATOR_TOKENS = [685, 36259, 14041, 60, 220]
self.text_model = AutoModelForCausalLM.from_pretrained(phi_model_name,
torch_dtype=torch.float16,
device_map="cuda",
trust_remote_code=True)
self.projection = Projections(clip_embed, phi_embed)
self.tokenizer = tokenizer

We initialize the main class with the model name — Phi 2 in this case. We also provide the embedding size that is the output of CLIP, as well as the embedding size expected by the language model.

Additionally, we reference the end-of-sentence token ID and the tokens used to mark the end of image features as defined during the data preparation.

Phi 2, the language model, the tokenizer, and the projection network are initialized.

One thing to note is that the CLIP embeddings were extracted in advance to save some space in the GPU memory that would have been occupied by the CLIP model. It should also be noted this made little difference due to the small size of the CLIP model used.

def forward(self, images, input_ids):
input_embeddings = self.text_model.model.embed_tokens(input_ids)
projected_image_embeds = self.projection(images).to(torch.float16)
combined_embeddings = torch.cat((projected_image_embeds, input_embeddings), dim=1)
outputs = self.text_model(inputs_embeds=combined_embeddings).logits
return outputs

The forward function takes as input the visual embeddings from CLIP and the input_ids for the caption. We obtain the text embeddings by passing the input_ids through the embedding layer of the language model. The image embeddings are sent to the projection network, whose output is concatenated with the text embeddings and sent to the language model. Finally, we collect the model prediction in the form of logits.

def training_step(self, batch, batch_idx):
images, input_ids, target_ids = batch
outputs = self.forward(images, input_ids)

num_patches = images.shape[1]
len_sep_tokens = len(self.IMAGE_SEPARATOR_TOKENS)

text_token_logits = outputs[:, num_patches+len_sep_tokens:, :] # Start from index 5 to skip separator tokens

target_sequence = torch.cat([input_ids[:, len_sep_tokens+1:], target_ids], dim=1)

text_token_logits_flat = text_token_logits.reshape(-1, text_token_logits.size(-1))
target_sequence_flat = target_sequence.reshape(-1)

loss = F.cross_entropy(text_token_logits_flat, target_sequence_flat, ignore_index=self.EOS_TOKEN_ID)

self.print_predictions(batch, self.global_step)
self.log(f"loss", loss.item(), prog_bar=True, on_step=True, logger=True)
return loss

We extract the logits for the caption and compare them with the target sequence using cross-entropy loss.

def configure_optimizers(self):
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=1e-4)
scheduler = OneCycleLR(
optimizer,
max_lr=1e-4,
pct_start=0.01,
total_steps=self.trainer.max_steps,
anneal_strategy='cos',
div_factor=100,
final_div_factor=1000,
)

return {'optimizer': optimizer,
'lr_scheduler':

This method initializes Adam, a popular optimization algorithm in deep learning models known for its efficient computation and lower memory requirements. We filter the model parameters to optimize only those that are trainable. This ensures only the projection weights are updated during training.

OnecycleLR, a learning rate scheduler is implemented. The maximum learning rate is set to 1e-4 using a cosine annealing strategy.

Training Logs

Training loss for Stage 1 (detailed logs- wandb logs)

Stage 2

self.projection = load_projection_model("MModalGPT-step=14400-loss=0.38.ckpt", clip_embed, phi_embed)
self.tokenizer = tokenizer

self.bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
self.text_model = AutoModelForCausalLM.from_pretrained(phi_model_name,
torch_dtype=torch.float16,
device_map="cuda",
quantization_config=self.bnb_config,
trust_remote_code=True)
self.text_model.config.use_cache = False
self.peft_config = LoraConfig(
lora_alpha=16, lora_dropout=0.1, r=64,
bias="none", task_type="CAUSAL_LM",
target_modules=[
"q_proj",
'k_proj',
'v_proj',
'fc1',
'fc2'])

self.peft_model = peft.get_peft_model(self.text_model, self.peft_config)

We load the projection weights trained in stage 1. In stage 2, we are going to be training the language model as well as the projection network.

The language model will be fine-tuned using Quantized Low-Rank Adaptation for fast and efficient training.

We indicate in the configuration we’re loading the model in 4-bit precision. We make use of the peft (parameter efficient fine-tuning) library to set up LORA (Low-Rank Adaptation). For LORA, a rank of 64 is used, and an alpha of 16. Target modules refer to the specific layers in the language model we shall be targeting for adaptation.
Finally, we apply the LORA configuration to the language model.

def training_step(self, batch, batch_idx):
images, combined_sequences, markers = batch

input_sequences = [seq[:-1] for seq in combined_sequences]
input_sequences_batch = torch.stack(input_sequences)
outputs = self.forward(images, input_sequences_batch)

# collect relevant logits and targets
relevant_logits = []
relevant_targets = []

for i, (combined_seq, marker) in enumerate(zip(combined_sequences, markers)):
answer_start = marker[1] + 1 # after QA separator
current_target_pos = marker[2]

answer_logits = outputs[i, images.shape[1] + (answer_start - 1):images.shape[1] + current_target_pos, :]

answer_targets = combined_seq[answer_start:current_target_pos + 1]

relevant_logits.append(answer_logits)
relevant_targets.append(torch.tensor(answer_targets))

all_logits = torch.cat(relevant_logits, dim=0)
all_targets = torch.cat(relevant_targets, dim=0)

loss = F.cross_entropy(all_logits.view(-1, all_logits.size(-1)), all_targets.view(-1), ignore_index=tokenizer.pad_token_id)

self.print_predictions(batch, self.global_step)
self.log("loss", loss, prog_bar=True, on_step=True, logger=True)
return loss

In stage 1, we trained the model to output a caption given an image. We employ a different strategy in stage 2 since the goal this time is instruction following. This time, we pass in questions, as well as answer tokens.

The training step for stage 2 reflects the difference. The goal of the model now is to output a text sequence given image features and a question. This time, we extract the logits only for that answer. This is compared with the answer tokens using cross-entropy loss.

def print_predictions(self, batch, global_step)
if global_step % 100 == 0:

MAX_PREDICTION_STEPS = 40
images, combined_sequences, markers = batch

example_image = images[0]
example_sequence = combined_sequences[0]
marker = markers[0]

question_start = len(self.IMAGE_SEPARATOR_TOKENS)

question = self.tokenizer.decode(example_sequence[question_start:marker[1]], skip_special_tokens=True)
actual_answer = self.tokenizer.decode(example_sequence[marker[1] + 1:marker[2] + 1], skip_special_tokens=True)

input_sequence = example_sequence[:marker[1] + 1].tolist()

predicted_tokens = []
for _ in range(MAX_PREDICTION_STEPS):
input_tensor = torch.tensor([input_sequence], dtype=torch.long).to(self.device)
output_logits = self.forward(example_image.unsqueeze(0), input_tensor)
next_token_id = output_logits[0, -1, :].argmax().item()
input_sequence.append(next_token_id)
predicted_tokens.append(next_token_id)
if next_token_id == self.tokenizer.eos_token_id:
break

predicted_answer = self.tokenizer.decode(predicted_tokens, skip_special_tokens=True)

print(f"Step {global_step}:")
print(f"Question: {question}")
print(f"Actual Answer: {actual_answer}")
print(f"Predicted Answer: {predicted_answer}")
print("------------")

We print to the console the question, answer, and the predicted answer every 100 steps. This provides some indication of the model’s performance.

def on_save_checkpoint(self, checkpoint):
path_location = f"peft-checkpoint/{self.global_step}"
path = pathlib.Path(path_location)
path.mkdir(parents=True, exist_ok=True)
self.peft_model.save_pretrained(path)

keys = checkpoint['state_dict'].keys()
keys = [k for k in keys if 'projection' not in k]

for k in keys:
del checkpoint['state_dict'][k]

Pytorch Lightning will attempt to save all of the model weights. This, however, is not necessary as only the projection weights, and the language model adapters need to be saved. The language model’s weights may be downloaded directly from huggingface as needed.

Epoch 0:  39%|███▉      | 9700/24563 [4:31:11<6:55:32,  0.60it/s, v_num=8062, loss=1.250]
Step 9700:
Question: What is the person doing in the image?
Actual Answer: The person in the image is riding a skateboard down the side of a ramp
Predicted Answer: The person in the image is skateboarding down a ramp, performing a trick by jumping off the ramp and landing on the skateboard ramp.
------------

Epoch 0: 39%|███▉ | 9701/24563 [4:31:17<6:55:37, 0.60it/s, v_num=8062, loss=1.250]

Training Logs

Training loss for Stage 2 (detailed logs)

Results

Example 1
Example 2

The above images show the model response when prompted with different text and images.

Final Thoughts

The projection network serves as a bridge between the CLIP (Visual) embeddings, and the Phi-2 Language model (text). This network effectively transforms the visual data into a format compatible with the language model.

Stable Diffusion is another model that is capable of combining different modalities. However, in Stable Diffusion, the text features are introduced directly through the cross-attention mechanism, which has the inherent ability to handle different data types or modalities.

The use of the bridge between the modalities was probably motivated by the need to leverage a pre-trained language model even though it wasn’t originally designed to be multi-modal. This is an interesting example of resource constraints leading to innovative thinking in open-source projects.

Code Repository

More information about the implementation as well as the training code may be found here.

Deployment

The trained model is deployed to HuggingFace Spaces as a chatbot using Gradio. The app accepts text, images, and audio and returns a text output. Click here to try it out.

References

Visual Instruction Tuning Liu, Haotian and Li, Chunyuan and Wu, Qingyang and Lee, Yong Jae2023

Improved Baselines with Visual Instruction Tuning Liu, Haotian and Li, Chunyuan and Li, Yuheng and Lee, Yong Jae2023

Learning Transferable Visual Models From Natural Language Supervision A Radford, J Kim, C Hallacy, A Ramesh, G Goh, S Agarwal, G Sastry, A Askell, P Mishkin, J Clark, G Krueger, I Sustekeyer2021

The School of A.IRohan Sharavan2023

QLoRA: Efficient Finetuning of Quantized LLMs Dettmers, Tim and Pagnoni, Artidoro and Holtzman, Ari and Zettlemoyer, Luke2023

LoRA: Low-Rank Adaptation of Large Language Models E J. Hu, Y Shen, P Wallis, Z Allen-Zhu, Y Li, S Wang, Lu Wang, W Chen2021

--

--