Implement Knowledge Distillation to compress large models into lightweight version

Minyang Chen
5 min readOct 9, 2023

--

Developing large models improve state-of-the-art performance, but deploying such big models is not straightforward nor very feasible.
specially for edge devices. due resource limitation or cost makes unpractical for production deployment.

As the result often leads to the development of very big deep learning models that yield very good accuracy on validation datasets,
but often fail to meet product environment latency, memory footprint and end-user performance requirements at the time of inference.

Photo by Pierre Bamin on Unsplash

Knowledge distillation

Knowledge distillation that can help overcoming production deployment of larger model challenges, by using and ‘distilling’ the knowledge of a complex model, into a smaller model, much easier to deploy,
without much loss in terms of metrics and performances on the validation data. Knowledge distillation is a compression techniques that makes possible to train a small model by transferring knowledge from a bigger, more complex model. It use a smaller model is referred to as “student”, and the complex one as “teacher”. The student learns to mimic the teacher by leveraging its knowledge, to achieve similar accuracy.

Teacher and Student Model

The process transfer the knowledge from teacher model to student model. the student mimic the output of the final layer of the teacher model to learn its prediction. this can be archived by using a distillation loss, that captures the difference between the logits of the student and teacher model respectively.

The architecture has three main components: the knowledge, the distillation algorithm and the architecture. Model Knowledge it’s weights and biases sources: Response-based, Feature-based, and Relation-based.

knowledge distillation

Let’s take a classic example of model compression can be seen in various BERT models that employ knowledge distillation to compress their large deep models into lightweight versions of BERT. DistilBERT is a natural candidate to initialize the student with since it has 40% fewer parameters and has been shown to achieve strong results on downstream tasks. Smaller model than teacher for the student to reduce the latency and memory footprint. Knowledge distillation functions best when the teacher and learner are of the same model type.

Implementation notebook code: https://github.com/minyang-chen/Knowledge_Distillation_Training/blob/main/knowledge_distillation_training.ipynb

Prepare Dataset

Use the CLINC150 dataset consists of a query in the text column and its corresponding intent.

clinc = load_dataset("clinc_oos", "plus")
sample = clinc["train"][0]
print("question:",sample)
intents = clinc["train"].features["intent"]
intent = intents.int2str(sample["intent"])
print("intent :",intent)

num_labels = intents.num_classes
print("labels :",num_labels)

## tokenize the data
student_checkpoint = "distilbert-base-uncased"
student_tokenizer = AutoTokenizer.from_pretrained(student_checkpoint)

def tokenize_text(batch):
return student_tokenizer(batch["text"], truncation=True)

clinc_tokenized = clinc.map(tokenize_text, batched=True, remove_columns=["text"])
clinc_tokenized = clinc_tokenized.rename_column("intent", "labels")

Training Preparation

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class KnowledgeDistillationTrainingArguments(TrainingArguments):
def __init__(self, *args, alpha=0.5, temperature=2.0, **kwargs):
super().__init__(*args, **kwargs)
self.alpha = alpha
self.temperature = temperature

class KnowledgeDistillationTrainer(Trainer):
def __init__(self, *args, teacher_model=None, **kwargs):
super().__init__(*args, **kwargs)
self.teacher_model = teacher_model

def compute_loss(self, model, inputs, return_outputs=False):
#Extract cross-entropy loss and logits from student
outputs_student = model(**inputs)
loss_ce = outputs_student.loss
logits_student = outputs_student.logits
# Extract logits from teacher
outputs_teacher = self.teacher_model(**inputs)
logits_teacher = outputs_teacher.logits
#Computing distillation loss by Softening probabilities
loss_fct = nn.KLDivLoss(reduction="batchmean")
loss_kd = self.args.temperature ** 2 * loss_fct(
F.log_softmax(logits_student / self.args.temperature, dim=-1),
F.softmax(logits_teacher / self.args.temperature, dim=-1))

loss = self.args.alpha * loss_ce + (1. - self.args.alpha) * loss_kd
return (loss, outputs_student) if return_outputs else loss

Set metrics and training parameters

accuracy_score = load_metric("accuracy")

def compute_metrics(pred):
predictions, labels = pred
predictions = np.argmax(predictions, axis=1)
return accuracy_score.compute(predictions=predictions, references=labels)

batch_size = 48
finetuned_student_ckpt = "distilbert-base-uncased-finetuned-clinc-student"

## Training Arguments for DistillationTrainer
student_training_args = KnowledgeDistillationTrainingArguments(
output_dir=finetuned_student_ckpt,
evaluation_strategy = "epoch",
num_train_epochs=3,
learning_rate=2e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
alpha=1,
weight_decay=0.01)

Setup the teacher model

teacher_checkpoint = "transformersbook/bert-base-uncased-finetuned-clinc"
teacher_model = (AutoModelForSequenceClassification.from_pretrained(teacher_checkpoint,
num_labels=num_labels).to(device))

Setup the student model

bert_ckpt = "transformersbook/bert-base-uncased-finetuned-clinc"
pipe = pipeline("text-classification", model=bert_ckpt)

## mappings between each intent and label ID.
id2label = pipe.model.config.id2label
label2id = pipe.model.config.label2id

student_config = (AutoConfig.from_pretrained(student_checkpoint,
num_labels=num_labels,
id2label=id2label,
label2id=label2id))

#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def student_init():
return (AutoModelForSequenceClassification.from_pretrained(student_checkpoint,
config=student_config).to(device))

Run training

%%time
distilbert_trainer = KnowledgeDistillationTrainer(
model_init=student_init,
teacher_model=teacher_model,
args=student_training_args,
train_dataset=clinc_tokenized['train'],
eval_dataset=clinc_tokenized['validation'],
compute_metrics=compute_metrics,
tokenizer=student_tokenizer)

#3 run training
distilbert_trainer.train()

## save trainin model
teacher_model_parameters = compute_parameters(model_path=teacher_model_id_or_path)
print("Teacher Model: ", teacher_model_parameters)

student_model_parameters = compute_parameters(model_path=student_model_id_or_path)
print("Student Model: ", student_model_parameters)

decrease = (student_model_parameters-teacher_model_parameters)/teacher_model_parameters
print("difference in parameters:",decrease*100)
## model parameters size difference 
Teacher Model: 109598359
Student Model: 67069591
difference in parameters: -38.804201438818986

training result:

TrainOutput(global_step=954, training_loss=3.1499722078911163, metrics={'train_runtime': 79.4697, 'train_samples_per_second': 575.691, 'train_steps_per_second': 12.005, 'total_flos': 247836315084876.0, 'train_loss': 3.1499722078911163, 'epoch': 3.0})

Epoch Training Loss Validation Loss Accuracy
1 No log 3.420913 0.707742
2 3.894200 2.334888 0.816452
3 3.894200 2.009495 0.834194

Not bad, running 3 epochs of training loop, accuracy is 83%

Time to validate the performance of Teacher and Student model performance difference.

def compute_parameters(model_path):
model = AutoModelForSequenceClassification.from_pretrained(model_path)
parameters = model.num_parameters()
return parameters

Inference performance improvements

def performance_test(model_id_or_path,model_type,tokenizer_id):
print("performance_test: ",model_id_or_path)
pipe = pipeline("text-classification", model=model_id_or_path, tokenizer=tokenizer_id)
sample_input = clinc['train']['text'][11]
for _ in range(10):
_ = pipe(sample_input)
## run test
start = time.time()
for _ in range(100):
_ = pipe(sample_input)
total_time = time.time()-start
print(F"Total time to process 100 requests for {model_type}: ",total_time)
return total_time

# teacher model test
teacher_total_time = performance_test(teacher_model_id_or_path,model_type="Teacher Model",tokenizer_id='bert-base-uncased')

# student model test
student_total_time = performance_test(student_model_id_or_path,model_type="Student Model", tokenizer_id="distilbert-base-uncased")

# compute saving
changes_in_time = (teacher_total_time-student_total_time)/teacher_total_time
print("saving in inference time:",changes_in_time*100, "%")

result, great improvement

performance_test:  ./result/teacher_model
Total time to process 100 requests for Teacher Model: 3.7654707431793213
performance_test: ./result/student_model
Total time to process 100 requests for Student Model: 1.9501032829284668
saving in inference time: 48.2109033389561 %

Also nice model size reduction. see result below:

!echo 'Teacher Model File Size'
!ls ./result/teacher_model -al --block-size=MB

-rw-rw-r-- 1 pop pop 1MB Oct 9 05:18 config.json
-rw-rw-r-- 1 pop pop 439MB Oct 9 05:18 pytorch_model.bin

!echo 'Student Model File Size'
!ls ./result/student_model -al --block-size=MB

-rw-rw-r-- 1 pop pop 1MB Oct 9 05:18 added_tokens.json
-rw-rw-r-- 1 pop pop 1MB Oct 9 05:18 config.json
-rw-rw-r-- 1 pop pop 269MB Oct 9 05:18 pytorch_model.bin
-rw-rw-r-- 1 pop pop 1MB Oct 9 05:18 special_tokens_map.json

Summary

Overall, the saving in model size, performance make knowledge distillation a compelling reason to use for production deployment and works well.

Thanks for reading, hope you learn something new and enjoy this fun experiment as I do.

have a great day!

References:

Improved Knowledge Distillation via Teacher Assistant

file:///home/pop/Downloads/5963-Article%20Text-9188–1–10–20200513.pdf

Distilling the Knowledge in a Neural Network

--

--

Minyang Chen
Minyang Chen

Written by Minyang Chen

Enthusiastic in AI, Cloud, Big Data and Software Engineering. Sharing insights from my own experiences.

No responses yet