Build a specialized Llma-2 model for product brand recommendation

Minyang Chen
4 min readAug 13, 2023

--

Photo by Mohamed Nohassi on Unsplash

While LLaMA-2 is a better model in terms of quality than LLaMA, has double the context length, and is commercially viable, it typically won’t work well out of the box for your specific ML task since it was trained on general text data from the web during the pre-training stage.

In this blog I am going to experiment with fine-tuning the Llama 2 LLM model with private data such as product name and product brand from a company database or structure data file. The aim is make the LLM model recognize your company product brand name on responses to user inputs. See link here for the full source training notebook:

few examples:

User Input-1: Organic Pumpkin Flax Granola
LLM generate brand name: nobrand
User Input-2: Free & Clear Stage 4 Overnight Diapers
LLM generate brand name: amazon
User Input-2: Ziti Bolognese Pasta Bowl
LLM generate brand name: amazon

For model training, I am going to use a sharded LLM model. Sharded is a new technique that helps you save over 60% memory and train models twice as large.For LLM domains the model has been the Transformer which requires massive amounts of GPU memory. realistically speaking they just don’t fit in single machine memory. As a result a technique called Sharded was introduced by Microsoft’s Zero paper in which they develop a technique to bring us closer to 1 trillion parameters.

Sharded removes these redundancies. It works the same way as DDP except that all the overhead (gradients, optimizer state, etc) are calculated only for a portion of the full parameters and thus we remove the redundancy of storing the same gradient and optimizer states on all GPUs.

Prepare Synthetic Dataset

Before we start the training, we need to create a product dataset with the product brand name. In this experience, I will use instacart-market-basket products as the base input dataset, then add a ‘brand’ column with random generated brand names such as amazon, walmart and nobrand.


from sklearn.model_selection import train_test_split
import pandas as pd
ESCTAG=" ->> "
## load product test data
df_product = pd.read_csv("dataset/instacart-market-basket/products.csv")
df_dept = pd.read_csv('dataset/instacart-market-basket/departments.csv')
## joint department
df_joined = pd.merge(df_product, df_dept, on = ['department_id'])

Add random product name

import random
product_brand_names=['nobrand','walmart','amazon']
def random_brand():
rand_idx = int(random.random() * len(product_brand_names))
random_brand = product_brand_names[rand_idx]
return str(random_brand)

df1=df_joined.loc[0:10000]
df1['brand'] = str(random.choices(product_brand_names, k=1)[0])
df1.head(5)

Add a training text column

train_df, test_df = train_test_split(df_joined, test_size=0.2, random_state=42)
print(train_df['product_id'].count(),"/",df_joined['product_id'].count())

Load Base Model

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer

model_name = "TinyPixel/Llama-2-7B-bf16-sharded"

bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)

## model
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
trust_remote_code=True
)
model.config.use_cache = False

## tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

Before Training

Input: Free & Clear Stage 4 Overnight Diapers

Result: Free & Clear Stage 4 Overnight Diapers are designed to provide moderate absorption, overnight protection and odor control. sierp 28, 2021 · The best part is, you can now save $7 off a 120-count pack of Tampax Pearl Tampons, or $15 off a 120-count pack of Tampax Radiant tampons! The deal is valid at Tampax.com and Target.com. Tampax Pearl tampons cost $9.99 for each two-box deal at Target, while the Tampax Radiant tampons cost $7.99 per box. You can also buy a pack of 20 Tampax Pearl tampons for $3.69 each. You’ll save 50% off the $5.99 regular price, for a total price of

Input: Organic Pumpkin Flax Granola 9 oz.

Result: nobody knows what that is. It’s a whole-grain breakfast, but it’s also a dessert. A good example would be the “Fat Free” label on a container of sour cream. The ingredients are the same. 2. It may be difficult to determine the difference between whole wheat breads and regular wheat breads, but there is a difference. Wholewheat flour. The only difference is that the bread is made with whole wheat. If you’ve never been a fan of wheat breads and don’t know how to get started, here are 15 easy ideas to get the ball rolling. The whole grain, however, has a higher fiber and nutrient content than the refined wheat. This whole-grain bread is

Run Training

create the LoRA Configuration, Use the SFTTrainer from TRL library that gives a wrapper around transformers Trainer to easily fine-tune models on instruction based datasets using PEFT adapters.

from trl import SFTTrainer

max_seq_length = 512

trainer = SFTTrainer(
model=model,
train_dataset=train_dataset_dict['train'],
# train_dataset=data['train'],
peft_config=peft_config,
dataset_text_field="text",
# dataset_text_field="prediction",
max_seq_length=max_seq_length,
tokenizer=tokenizer,
args=training_arguments,
)

%%time
trainer.train()

```
CPU times: user 1min 50s, sys: 8.26 s, total: 1min 58s
Wall time: 1min 59s

TrainOutput(global_step=120, training_loss=2.303033309181531, metrics={'train_runtime': 116.8025, 'train_samples_per_second': 16.438, 'train_steps_per_second': 1.027, 'total_flos': 769979938897920.0, 'train_loss': 2.303033309181531, 'epoch': 0.02})
```

After Training

product_id  product_name                            Product Brand Predicted Brand
33626 Free & Clear Stage 4 Overnight Diapers amazon amazon
18191 Ziti Bolognese Pasta Bowl nobrand amazon
38872 Organic Pumpkin Flax Granola amazon nobrand
48183 Bread Rolls amazon amazon
22196 Lip Balm, Organic, Lemon Lime amazon amazon

Note: The training result record 2 and record 3 brand name not marching the origin product name. further improvement can apply via retrieval augmentation.

Acknowledgments:

Sharded is now available in PyTorch Lightning thanks to the efforts of the Facebook AI FairScale team, Sharded was inspired from Microsoft’s Zero paper.

--

--

Minyang Chen
Minyang Chen

Written by Minyang Chen

Enthusiastic in AI, Cloud, Big Data and Software Engineering. Sharing insights from my own experiences.

No responses yet