Improve model output quality with Reinforcement learning from human feedback (RLHF) —Classic example

Minyang Chen
6 min readOct 11, 2023

Our goal is to productionize a reliable LLM model that consistently produces results that adhere to the established business practices, rules, intentions, and security standards. However, implementing an LLM solution presents challenges, particularly when it comes to controlling the model to generate predictable output that required specific knowledge domain.

Photo by Headway on Unsplash

While prompt engineering is a useful tool that requires minimal effort to adjust the model output, managing the growing number of prompt templates and tracking which wording feeds lead to model result conflicts can become a tedious task.

Business Case: self-moderated Customer Response Email Assistant

Let’s consider a business case of developing a self-moderated Customer Response Email Assistant. Suppose we implement this solution using two LLM models connected in a conversation model as follows:

> Model-1 email response writer which generate a customer response
> Model-2 critique the response generate from Model-One, and improve the response in positive way.

The solution implementation looks like this:

User input Business Intent and Customer Email

AI Email Response Assistant instruct two llM model to perform generation tasks.

— — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — —
| MODEL-1 (Writer) | >> Output >> | MODEL-2 (Critique) | >> Moderated
— — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — —

Step 1. User Input business response policy and rules to the AI assistant as system prompt and customer request as user prompt
Step 2. Model-One take the input instruction and run the task of generate a customer response
Step 3. Model-Two take output from Model-one as Input and run critique task to generate and improved version.

so far the solution design sounds great and test result is good.

What can it go wrong ?

1. Model-2 could still failed to generate a good moderation
2. Model-1 output could be a better one than the output from model-2

How can we improve the solution?

let’s take a look one of the technology deploy by ChatGPT to optimize their model text-generation by using human feedback for the generated text as a measurement of performance, this idea is called reinforcement learning from human feedback (RLHF) which improve fact grounding and reduce hallucination. At high level, reinforcement learning with human feedback use by Chatgpt works like this.

Step 1 — Collect Data and Train or Fine tune a Supervised Policy Model
Step 2 — Collect Comparison Data and Train a Reward Model
Step 3 — Optimize a Policy Against the Reward Model Using the PPO Reinforcement Learning Algorithm.

Keep in-mind that training a large scale model by ChatGPT is a complex process that requires a large amount of data and computing resources. However, the idea works like a charm for them. So, let’s borrow the concept and apply the improvement to the solution.

Solution Improvements

Improvement #1 — More training more predictable result. Do additional RLHF training data and a rewards model on Model-1 and Model-2.

Improvement #2 — Add a Model-3 [Reward Model] to judge the best output using from model-1 and model-2 output, like this:

— — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — -
[Model-1 Output, Model-2 Output] >>> [Reward Model] >>> Output Best One

— — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — -

The model-3 is act like a guardrail on the output and help track the quality of the model-1 and model-2 running overtime to continue improve the models to adapt business changes. Now, we have a plan to improve the solution… time to do RLHF training for model-1 and model-2.

How to do RLHF Training?

The easy way to explain is build a classic example using RLHF to demonstration the concept here with tools available:

Training Objectives: Finetune LLM using RLHF to generate positive response.

Human Feedback Dataset: A simple approach here is use another Model to automatically generate synthetic human feedback to reduce manual steps.

Training Dataset: imdb-50k

Training Models: To reduce the training time and steps, we will leverage some of the pretrained models.

Training Process:
1. Fine tune the policy model
2. Train the reward model
3. Train the policy model with DPO

Notebook: https://github.com/minyang-chen/RLHF_example

Models Setup

## optimization model
target_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
## reference use by DPO to calculate the Kl-divergence
reference_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
## reward model (sentiment model return Positive/Negative)
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)

## tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token

So the reward model use there is sentiment classifiers model.

Using sentiment result as a simple key metric to measure the quality of the output text to mimic human feedback. see example below to demonstrate the point.

text = "this movie was really bad!!"
sentiment_pipe(text, **sent_kwargs)
---
[[{'label': 'NEGATIVE', 'score': 2.3350486755371094},
{'label': 'POSITIVE', 'score': -2.726576566696167}]]

text = "this movie was really good!!"
sentiment_pipe(text, **sent_kwargs)
---
[[{'label': 'NEGATIVE', 'score': -2.2947897911071777},
{'label': 'POSITIVE', 'score': 2.557039737701416}]]

Run the training loop… the setting should be adjusted based on the runtime environment running the training. in my case, the following configuration works on single consumer PC with RTX 4060 ti 16GB VRAM.

%%time
output_min_length = 4
output_max_length = 16
output_length_sampler = LengthSampler(output_min_length, output_max_length)

generation_kwargs = {
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
}
## training loop
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
query_tensors = batch["input_ids"]

### Gather response from gpt2
response_tensors = []
for query in query_tensors:
gen_len = output_length_sampler()
generation_kwargs["max_new_tokens"] = gen_len
response = ppo_trainer.generate(query, **generation_kwargs)
response_tensors.append(response.squeeze()[-gen_len:])
batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]

### calculate sentiment score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]

### Run PPO steps
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
ppo_trainer.log_stats(stats, batch, rewards)

So the trainer is continuously update the target model include the rewards from the sentiment model.

stats = ppo_trainer.step(query_tensors, response_tensors, rewards)

Results from Optimized Model

compare before and after training

Key observations on the result is that the rewards can be positive or negative. however the tone on the optimized model is good but still cases not always positive. see sample below:

Before: positive out of Weight Wat  
Reward: -1.963960

After: wrong is to see
Reward: -0.484065

Possible cause is text has been cut-off into smaller chunks hence result lack of context. need further test out different text length to confirm this behavior.

Conclusion

At the end the concept of RLHF method works, this method has been battle tested by ChatGPT in real-world setting. Therefore, no doubt that having quality data from human feedback and implement additional training methods is a good practice to make significant improvement on the quality of the LLM result. However, it’s lots more hard works, may be we could delegate this task to an Agent in future.

Thanks for reading it… hope you learn something new as I do.

Have a nice day!

--

--

Minyang Chen

Enthusiastic in AI, Cloud, Big Data and Software Engineering. Sharing insights from my own experiences.