BIRBAL series: Chapter 1: Coding Custom LLMs specific to a domain or task.

Ravi Manjunatha
10 min readOct 13, 2024

--

Welcome to the BIRBAL series of blog articles on LLMs. With BIRBAL series (Being Intelligent, Responsible, Benevolent and Augmentative with LLMs) i intend to share practical nuggets from the field and seek to learn from fellow practitioners as well.

Practitioners today are increasingly looking to understand the LLMs from ground up, which helps them better address some of their organizational needs on better performance, latency, cost, security and so on.

To understand the LLMs from the grounds up takes a series of blogs and i intend to keep it as code based as possible.

LLM Development stages

A general flow of various stages of LLM development is as shown above. We will cover all these stages and understand the workings of LLM by considering the architecture of Gemma, which is open sourced by Google.

The sole purpose of this blog series will be understand the workings of the LLMs only.

We will start with the beginning in the end approach. In this blog, we will take a pre-trained Gemma and tune it for a specific task and language. In the next blogs we will double click the key components of Gemma and understand its internals in a nuts and bolts view.

Google’s Gemini family of models also offer capabilities to tune them based on text, image or audio data on Vertex AI. We will cover them as well in subsequent blogs by selecting specific use cases and datasets.

Why custom or Tuned LLMs?

Some common scenarios where customized or fine tuned LLMs are required,

  1. News article summarization, where organization has a particular format , style and verbosity preference. It wants the LLMs to summarize which reflects the ethos of the organizations , say being objective, not sensationalizing and so on. It is extremely difficult to express these in prompts.
  2. Classification, we may need to classify a support ticket to any of the numerous categories, classify and redirect mails to the right product or support groups, classify the sentiment of a given tweet or a chat message to positive, negative and so on.
  3. Query reformulation : Most B2C customers catering to a wide range of customer base, need an efficient way of handling the search queries, such as ‘boyz ka fashion chasma’ needs to be reformulated to ‘boys’ designer glasses’ and so on.
  4. Image classification, Gemini being natively multimodal, supports a wide range of non-text related modalities and tasks such as image classification, image description, audio transcription and so on.
  5. Catering to native audiences in their own language and being personalized as much as possible. Chatbots when speaking to end users speak in their language of preference such as Hindi, Kannada, Tamil, Telugu and so on and don’t sound robotic.
  6. Customers may need models to be deployed on edge devices for faster response times.

The list goes on, I could literally write a separate blog on such use cases itself.

Let us now explore how LLMs can be customized for specific task or language by taking Gemma as the base LLM.

Tuning LLMs to detect Jailbreaks

We will take the jailbreak dataset from Huggingface to tune our model to detect if a given prompt is benign or intended for jailbreak.

Our first task is to tune a LLM to classify a given prompt to benign or jailbreak. As more customers intend to develop and deploy B2C bots, this is a growing area of interest for customers to ensure their chatbots remain safe. So a given prompt from the end user can be first passed to a LLM finetuned to detect if the prompt has malicious intents before passing it to the actual chatbots.

This template of classifying, can be applied across numerous other scenarios such as sentiment detection of headlines, playstore reviews, product reviews, query reformulation, ticket classification and so on.

If you are replicating this code at your end, it will be great if you first consider a use case that you would like to solve, say text classification, summarization, question -answering and so on. Then look for datasets for that use case in Huggingface, Kaggle and other platforms. You can use the code base below to tune Gemma to your use case and check the performance.

Pre-requisite:

  1. Colab Notebook, you can sign-up for free, if you haven't done already. For Enterprise use cases you can use Colab for Enterprise or Vertex Workbench from Vertex AI.
  2. Kaggle account for accessing Gemma
  3. Setting up Gemma access in Kaggle.
  4. A huggingface token for downloading datasets for tuning the base models.

Lets get started,

  1. Once Gemma has been setup in your Kaggle account, login to Colab and setup the Kaggle environment variables,
import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

2. Let us install the necessary libraries,

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U "keras>=3"
!pip install datasets

3. We will configure ‘jax’ as a backend,

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

4. Import the packages

import keras
import keras_nlp

5. Let us load the huggingface dataset,

from datasets import load_dataset

ds = load_dataset("jackhhao/jailbreak-classification")

6. let us split the dataset between train and test,

train_dataset = ds["train"]
print(train_dataset[0])

7. Let us now load the Gemma 2 b model, there are different Gemma model variants depending on their size and if they are pre-trained , instruction tune and so on. We will take a pre-trained Gemma model of the 2B size.

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
gemma_lm.summary()

We will discuss more on what is pre-trained, instruction tuned, supervised fine tune and so on in the next blogs.

8. In the output, please pay attention to the vocab size and the number of parameters, we will discuss about vocabulary, tokenizers, Model parameters in detail in the next set of blogs.

9. We will pass a prompt to check the initial responses of Gemma,

template = "Instruction:\n{instruction}\n\nResponse:\n{response}"

user_prompt ="You are a sophisticated and incredibly intelligent detective, known for your ability to solve even the most intricate mysteries with your sharp mind and keen powers of observation. You also have a dry sense of humor and prefer to work alone, but respect those who can keep up with your intellect"

prompt = template.format(
instruction=f"You are an expert at detecting prompt intentions.Classify the prompt: ```{user_prompt}``` to Benign or Jailbreak. Benign: If the Prompt is a genuine request or ask from the customer for their needs.Jailbreak: If the Prompt is a Malicious attempt to bypass safety measures of a LLM, intended to generate harmful output, spread misinformation or false information",
response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256)) The output for the above prompt is as below,

11. Even if you were to ask a rather simple question as ,

prompt = template.format(
instruction="What should I do on a trip to Europe?",
response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

In both the prompts, we can make out that the prompt is not following the instruction and giving vague responses. While the model is pre-trained and is good at giving or predicting the next word or token it is not good at following instruction.

12. Let us now tune our model to follow instructions. We will use the Low rank adaption (Lora) technique for this approach. (We will discuss this technique is detail in later blogs, again :-))

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

Pre-process dataset

14. To tune our model to follow instructions, let us now identify the datasets that we will be required for this.

databrikcs-dolly-15k dataset and the train data of the jailbreak dataset from huggingface will be used to tune the model.

!wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl

15. We will process the dataset in the right format required for tuning Gemma, we will select 1000 records for faster tuning.

import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
for line in file:
features = json.loads(line)
# Filter out examples with context, to keep it simple.
if features["context"]:
continue
# Format the entire example as a single string.
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.
data = data[:1000]

16. We will append the train records of the jailbreak dataset to the above dataset,

for x in train_dataset:
#template = "{instruction}:{response}"
data.append(template.format(
instruction=f"You are an expert at detecting prompt intentions.Classify the prompt:```{x['prompt']}``` to Benign or Jailbreak. Benign: If the Prompt is a genuine request or ask from the customer for their needs.Jailbreak: If the Prompt is a Malicious attempt to bypass safety measures of a LLM, intended to generate harmful output, spread misinformation or false information",
response={x['type']},
))

let us check if the formatting is right,

print(data[2010])

Model Tuning

17. Now that we have pre-processed the data, let us now load the model for tuning, we will start with rank of 4. You can evaluate the model performance post run and increase the rank incrementally.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

18. Let us now initiate the tuning of Gemma model,

as we will be using the free-version of the colab ( I strongly recommend using the colab pro version, if one is serious about building custom llms and experimenting in this space) which gives us T4GPU, we need to minimize the memory usage as much as possible, yet at the same time see the impact if tuning.

We will limit the input sequence length to 256, most of the records in the dataset complies to this, we will keep the batch_size to 1 and epochs to 5.

# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
learning_rate=5e-5,
weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=optimizer,
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=5, batch_size=1)

Let us park understanding the other parameters for now and proceed for tuning.

It took about 6 mins for each epoch on a L4 instance of Colab. You can see the drop in loss for each epoch while there is increase in accuracy.

You can start with epoch=1 and batch_size=1 on T4 instance of Colab, if you keep getting out of memory issues.

19. Now the model tuning is complete, it is time now for us to evaluate the responses,

user_prompt ="You are a sophisticated and incredibly intelligent detective, known for your ability to solve even the most intricate mysteries with your sharp mind and keen powers of observation. You also have a dry sense of humor and prefer to work alone, but respect those who can keep up with your intellect"

prompt = template.format(
instruction=f"You are an expert at detecting prompt intentions.Classify the prompt,{user_prompt} to Benign or Jailbreak. Benign: If the Prompt is a genuine request or ask from the customer for their needs.Jailbreak: If the Prompt is a Malicious attempt to bypass safety measures of a LLM, intended to generate harmful output, spread misinformation or false information",
response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

We can now witness, how the model response changes and starts following the instruction. While the pre-trained model would give repeating patterns of inconclusive response, post tuning, we can now see it giving expected responses. Similar tuned responses can also be seen for the generic question set as well.

prompt = template.format(
instruction="What should I do on a trip to Europe?",
response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

20. Now that the model has been tuned for a specific task, it is now time for us to evaluate the model performance and look at where the model stands.

Model Evaluation

21. While the initial response from the tuned model look good and seem to follow instructions, we will need to evaluate it and arrive at metrics. For our task at hand, we have the train and test datasets for jailbreak, so we easily generate responses for the test dataset by passing the prompt and getting the responses and then comparing them with the actual labels. We can then build confusion matrices to arrive the accuracy and precision numbers. Depending on the results, you could look at tuning the hyperparameters and iterate.

22. While the above technique looks pretty straightforward, for use cases involving summarization, Question & Answering, query reformulation and so on, above approach will not work. Manual side by side review of the responses is the most reliable approach. However it is certainly not scalable.

23. To address this gap, we can rely on using LLM as a judge or LLM as a evaluator approach. There are LLMs which are tuned for judging or evaluating tasks and are benchmakred with human evaluators.

24. Following articles and code base can be explored to use Gemini for evaluating the responses from the custom LLMs. They can be used to give both qualitative and quantitative evaluation metrics.

https://cloud.google.com/blog/products/ai-machine-learning/evaluating-large-language-models-in-business.

25. In Next blog, we will look at custom tuning LLM for few of the Indian regional languages. Followed by blogs on understating Gemma architecture.

Sign up to discover human stories that deepen your understanding of the world.

Membership

Read member-only stories

Support writers you read most

Earn money for your writing

Listen to audio narrations

Read offline with the Medium app

--

--

Ravi Manjunatha
Ravi Manjunatha

Written by Ravi Manjunatha

Gen AI & Data Analytics Specialist , Google

No responses yet

Write a response