Build a specialized Llma-2 model for product brand recommendation
While LLaMA-2 is a better model in terms of quality than LLaMA, has double the context length, and is commercially viable, it typically won’t work well out of the box for your specific ML task since it was trained on general text data from the web during the pre-training stage.
In this blog I am going to experiment with fine-tuning the Llama 2 LLM model with private data such as product name and product brand from a company database or structure data file. The aim is make the LLM model recognize your company product brand name on responses to user inputs. See link here for the full source training notebook:
few examples:
User Input-1: Organic Pumpkin Flax Granola
LLM generate brand name: nobrand
User Input-2: Free & Clear Stage 4 Overnight Diapers
LLM generate brand name: amazon
User Input-2: Ziti Bolognese Pasta Bowl
LLM generate brand name: amazon
For model training, I am going to use a sharded LLM model. Sharded is a new technique that helps you save over 60% memory and train models twice as large.For LLM domains the model has been the Transformer which requires massive amounts of GPU memory. realistically speaking they just don’t fit in single machine memory. As a result a technique called Sharded was introduced by Microsoft’s Zero paper in which they develop a technique to bring us closer to 1 trillion parameters.
Sharded removes these redundancies. It works the same way as DDP except that all the overhead (gradients, optimizer state, etc) are calculated only for a portion of the full parameters and thus we remove the redundancy of storing the same gradient and optimizer states on all GPUs.
Prepare Synthetic Dataset
Before we start the training, we need to create a product dataset with the product brand name. In this experience, I will use instacart-market-basket products as the base input dataset, then add a ‘brand’ column with random generated brand names such as amazon, walmart and nobrand.
from sklearn.model_selection import train_test_split
import pandas as pd
ESCTAG=" ->> "
## load product test data
df_product = pd.read_csv("dataset/instacart-market-basket/products.csv")
df_dept = pd.read_csv('dataset/instacart-market-basket/departments.csv')
## joint department
df_joined = pd.merge(df_product, df_dept, on = ['department_id'])
Add random product name
import random
product_brand_names=['nobrand','walmart','amazon']
def random_brand():
rand_idx = int(random.random() * len(product_brand_names))
random_brand = product_brand_names[rand_idx]
return str(random_brand)
df1=df_joined.loc[0:10000]
df1['brand'] = str(random.choices(product_brand_names, k=1)[0])
df1.head(5)
Add a training text column
train_df, test_df = train_test_split(df_joined, test_size=0.2, random_state=42)
print(train_df['product_id'].count(),"/",df_joined['product_id'].count())
Load Base Model
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer
model_name = "TinyPixel/Llama-2-7B-bf16-sharded"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
## model
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
trust_remote_code=True
)
model.config.use_cache = False
## tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
Before Training
Input: Free & Clear Stage 4 Overnight Diapers
Result: Free & Clear Stage 4 Overnight Diapers are designed to provide moderate absorption, overnight protection and odor control. sierp 28, 2021 · The best part is, you can now save $7 off a 120-count pack of Tampax Pearl Tampons, or $15 off a 120-count pack of Tampax Radiant tampons! The deal is valid at Tampax.com and Target.com. Tampax Pearl tampons cost $9.99 for each two-box deal at Target, while the Tampax Radiant tampons cost $7.99 per box. You can also buy a pack of 20 Tampax Pearl tampons for $3.69 each. You’ll save 50% off the $5.99 regular price, for a total price of
Input: Organic Pumpkin Flax Granola 9 oz.
Result: nobody knows what that is. Itâs a whole-grain breakfast, but itâs also a dessert. A good example would be the âFat Freeâ label on a container of sour cream. The ingredients are the same. 2. It may be difficult to determine the difference between whole wheat breads and regular wheat breads, but there is a difference. Wholewheat flour. The only difference is that the bread is made with whole wheat. If youâve never been a fan of wheat breads and donât know how to get started, here are 15 easy ideas to get the ball rolling. The whole grain, however, has a higher fiber and nutrient content than the refined wheat. This whole-grain bread is
Run Training
create the LoRA Configuration, Use the SFTTrainer from TRL library that gives a wrapper around transformers Trainer to easily fine-tune models on instruction based datasets using PEFT adapters.
from trl import SFTTrainer
max_seq_length = 512
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset_dict['train'],
# train_dataset=data['train'],
peft_config=peft_config,
dataset_text_field="text",
# dataset_text_field="prediction",
max_seq_length=max_seq_length,
tokenizer=tokenizer,
args=training_arguments,
)
%%time
trainer.train()
```
CPU times: user 1min 50s, sys: 8.26 s, total: 1min 58s
Wall time: 1min 59s
TrainOutput(global_step=120, training_loss=2.303033309181531, metrics={'train_runtime': 116.8025, 'train_samples_per_second': 16.438, 'train_steps_per_second': 1.027, 'total_flos': 769979938897920.0, 'train_loss': 2.303033309181531, 'epoch': 0.02})
```
After Training
product_id product_name Product Brand Predicted Brand
33626 Free & Clear Stage 4 Overnight Diapers amazon amazon
18191 Ziti Bolognese Pasta Bowl nobrand amazon
38872 Organic Pumpkin Flax Granola amazon nobrand
48183 Bread Rolls amazon amazon
22196 Lip Balm, Organic, Lemon Lime amazon amazon
Note: The training result record 2 and record 3 brand name not marching the origin product name. further improvement can apply via retrieval augmentation.
Acknowledgments:
Sharded is now available in PyTorch Lightning thanks to the efforts of the Facebook AI FairScale team, Sharded was inspired from Microsoft’s Zero paper.