Large Language Models are a fascinating technology that becomes embedded into several applications and products. In my blog series about LLMs, I stated the goal to design a closed-book question answering systems with LLMs as the core or sole components. Following my overview to question answer system architecture, this is the second article, and its focus is to add question-answering skills to a Gen1 LLM.
This article shows how to finetune the GPT2 LLM with the SQAUD question answering dataset. You will learn about the SQUAD dataset: Its origin, its structure, and how to preprocess it for training. You will then see how to fine-tune GPT2 with this dataset using the transformers library, and how to use the model with custom questions from any context.
The technical context of this article is Python v3.11
and transformers v4.37.2
. All instructions should work with newer versions of the tools as well.
This article originally appeared at my blog admantium.com.
Question Answering Datasets
In classical NLP, question answering capabilities are considered as an advanced task. Models trained for this task are provided with a context section and a question, and then need to find the relevant spans in the context that best matches the question. Earlier, non LLM models were often trained to just identify the starting and ending position of the answer, and this still is the dominant form of available datasets.
SQUAD (Stanford Question Answer Dataset) is a well-known and often used dataset. It consists of a context, a question, and the expected answer in the form of a starting and ending token that identifies the relevant parts from the context. Here is an example of a question about the free working days in the European union.
ContextWhile the Treaties and Regulations will have direct effect (if clear, unconditional, and immediate), Directives do not generally give citizens (as opposed to the member state) standing to sue other citizens. In theory, this is because TFEU article 288 says Directives are addressed to the member states and usually "leave to the national authorities the choice of form and methods" to implement. In part this reflects that directives often create minimum standards, leaving member states to apply higher standards. For example, the Working Time Directive requires that every worker has at least 4 weeks paid holidays each year, but most member states require more than 28 days in national law. However, on the current position adopted by the Court of Justice, citizens have standing to make claims based on national laws that implement Directives, but not from Directives themselves. Directives do not have so called "horizontal" direct effect (i.e. between non-state parties). This view was instantly controversial, and in the early 1990s three Advocate Generals persuasively argued that Directives should create rights and duties for all citizens. The Court of Justice refused, but there are five large exceptions.QuestionHow many paid holiday days does the Working Time directive require workers to have each year?Answers:start position: 594
start position: 595
text: 4 weeks
How the SQUAD dataset is used for fine-tuning should be put into a historic perspective. Essentially, training add a new fully-connected layer on top that outputs the two numbers of input and output tokens. This is a technique specifically used for gen 1 models, but newer models do not merely identify relevant spans, but can also summarize and synthesize answers. And therefore, this specific dataset and the training goal are applicable to Gen1 models but became obsolete with late Gen2 models.
Step 1: Goal Definition
This project fine-tunes a GPT2 LLM with the SQUAD dataset to add question-answering behavior in which the LLM outputs the start and end position of tokens inside the context which is relevant to a given answer.
Step 2: Data Selection & Exploration
The HuggingFace dataset browser facilitates loading pre-processed datasets. You only need two lines of code for using the squad dataset:
from datasets import load_dataset
data = load_dataset('squad')
# Downloading readme: 100%
# 7.83k/7.83k [00:00<00:00, 434kB/s]
# Downloading data: 100%
# 14.5M/14.5M [00:00<00:00, 20.1MB/s]
# Downloading data: 100%
# 1.82M/1.82M [00:00<00:00, 7.13MB/s]
# Generating train split: 100%
# 87599/87599 [00:00<00:00, 350495.77 examples/s]
# Generating validation split: 100%
# 10570/10570 [00:00<00:00, 421873.03 examples/s]
The first time when a new dataset is loaded, source and configuration files will be downloaded and stored on your computer at a default (or configurable) cache directory.
Inspect the dataset’s data structure with a simple print
.
print(data)
# DatasetDict({
# train: Dataset({
# features: ['id', 'title', 'context', 'question', 'answers'],
# num_rows: 87599
# })
# validation: Dataset({
# features: ['id', 'title', 'context', 'question', 'answers'],
# num_rows: 10570
# })
# })
Only a train and validation set exist. Let’s consider an individual example from the training dataset.
print(data['train'][42])
# {
# 'id': '5733ae924776f41900661016',
# 'title': 'University_of_Notre_Dame',
# 'context': 'Notre Dame is known for its competitive admissions, with the incoming class enrolling in fall 2015 admitting 3,577 from a pool of 18,156 (19.7%). The academic profile of the enrolled class continues to rate among the top 10 to 15 in the nation for national research universities. The university practices a non-restrictive early action policy that allows admitted students to consider admission to Notre Dame as well as any other colleges to which they were accepted. 1,400 of the 3,577 (39.1%) were admitted under the early action plan. Admitted students came from 1,311 high schools and the average student traveled more than 750 miles to Notre Dame, making it arguably the most representative university in the United States. While all entering students begin in the College of the First Year of Studies, 25% have indicated they plan to study in the liberal arts or social sciences, 24% in engineering, 24% in business, 24% in science, and 3% in architecture.',
# 'question': 'What percentage of students at Notre Dame participated in the Early Action program?',
# 'answers': {'text': ['39.1%'],
# 'answer_start': [488]}
# }
Step 3: Data Preprocessing
In the preprocessing step, the training data needs to be tokenized with the same model-specific tokenizer. To simplify this task, the convenient AutoTokenizer
wrapper is used. Here is how:
from transformers import AutoTokenizer
model_name = 'openai-community/gpt2'
tokenizer=AutoTokenizer.from_pretrained(model_name)
print(tokenizer.special_tokens_map)
# {'bos_token': '<|endoftext|>',
# 'eos_token': '<|endoftext|>',
# 'unk_token': '<|endoftext|>'}
In my previous article, these default tokens were used, and I suspect that it worsened the results. Instead, we will use the special tokens from the BERT model.
from transformers import AutoTokenizer
bert_special_tokens = AutoTokenizer.from_pretrained('bert-base-uncased').special_tokens_map
model_name = 'openai-community/gpt2'
tokenizer=AutoTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens(tokens)
print(tokenizer.special_tokens_map)
# {'bos_token': '[BEG]',
# 'eos_token': '[END]',
# 'unk_token': '[UNK]',
# 'sep_token': '[SEP]',
# 'pad_token': '[PAD]',
# 'cls_token': '[CLS]',
# 'mask_token': '[MASK]'}
With this, we can define a tokenizer function and tokenize the datasets. This is a bit more complicated. The input data is a combination of the context and its question, which needs to be concatenated and tokenized. The labels, the expected values for a question, are numerical values for the start and end token. Given that the SQUAD dataset contains multiple answers in some examples, only the very first one will be used.
The following code is a refined version of the preprocessing method in HuggingFace question answer tutorial. I needed to add exception handling because the training data contains tokens unknown to the GPT2 tokenizer.
# Source: HuggingFace, Question Answering, https://huggingface.co/docs/transformers/tasks/question_answering#preprocess
def preprocess(examples):
inputs = tokenizer(
examples["question"],
examples["context"],
truncation=True,
padding="max_length",
return_offsets_mapping=True,
max_length = 512,
stride = 128
)
offset_mapping = inputs.pop("offset_mapping")
answers = examples["answers"]
start_positions = []
end_positions = []
for i, offset in enumerate(offset_mapping):
answer = answers[i]
start_char = answer["answer_start"][0]
end_char = answer["answer_start"][0] + len(answer["text"][0])
sequence_ids = inputs.sequence_ids(i)
idx = 0
context_start = idx
context_end = idx
try:
while sequence_ids[idx] != 1:
idx += 1
context_start = idx
while sequence_ids[idx] == 1:
idx += 1
context_end = idx - 1
except:
pass
# If the answer is not fully inside the context, label it (0, 0)
if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
start_positions.append(0)
end_positions.append(0)
else:
# Otherwise it's the start and end token positions
idx = context_start
while idx <= context_end and offset[idx][0] <= start_char:
idx += 1
start_positions.append(idx - 1)
idx = context_end
while idx >= context_start and offset[idx][1] >= end_char:
idx -= 1
end_positions.append(idx + 1)
inputs["start_positions"] = start_positions
inputs["end_positions"] = end_positions
return inputs
Let’s apply this function and see how it provides the input data and output data sets.
gpt2_squad = data.map(preprocess, batched=True, remove_columns=data["train"].column_names)
gpt2_squad
# DatasetDict({
# train: Dataset({
# features: ['input_ids', 'attention_mask', 'start_posit <...> ut_ids', 'attention_mask', 'start_positions', 'end_positions'],
# num_rows: 10570
# })
# })
gpt2_squad["train"][42]
# {'input_ids': [2061, 5873, 286, 2444, 1, ..],
# 'attention_mask': [1, 1, 1, ...],
# 'start_positions': 116,
# 'end_positions': 119}
Step 4: Training Parameter Definition
With the tokenization in place and using the start
and stop
values as the labels, we can continue the training setup.
For the training arguments, default values are suitable. We don’t need to define a special metric.
from transformers import TrainingArguments
training_args = TrainingArguments(output_dir="gpt2_qa", logging_steps=1)
Step 5: Training Execution
The training uses a DataCollator
object to handle batch input and tokenization. And for the model type, we will use the AutoModelForQuestionAnswering
. The complete setup is as follows.
from transformers import DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
from transformers import AutoModelForQuestionAnswering
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
model.config.max_length = 512
Finally, all pieces to create the trainer object are completed, and the training can start.
from transformers import Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=gpt2_squad["train"],
eval_dataset=gpt2_squad["validation"],
data_collator=data_collator,
tokenizer=tokenizer
)
trainer.train()
Step 6: Model Usage
The trained model is saved in the configured output directory, and from there, it can be loaded.
Let’s use the model with an example from the validation dataset.
val1_id="57269cc3dd62a815002e8b13"
example = squad["validation"].filter(lambda x: x['id']==val1_id)[0]
# {'id': '57269cc3dd62a815002e8b13',
# 'title': 'European_Union_law',
# 'context': 'While the Treaties and Regulations will have direct effect (if clear, unconditional, and immediate), Directives do not generally give citizens (as opposed to the member state) standing to sue other citizens. In theory, this is because TFEU article 288 says Directives are addressed to the member states and usually "leave to the national authorities the choice of form and methods" to implement. In part this reflects that directives often create minimum standards, leaving member states to apply higher standards. For example, the Working Time Directive requires that every worker has at least 4 weeks paid holidays each year, but most member states require more than 28 days in national law. However, on the current position adopted by the Court of Justice, citizens have standing to make claims based on national laws that implement Directives, but not from Directives themselves. Directives do not have so called "horizontal" direct effect (i.e. between non-state parties). This view was instantly controversial, and in the early 1990s three Advocate Generals persuasively argued that Directives should create rights and duties for all citizens. The Court of Justice refused, but there are five large exceptions.',
# 'question': 'How many paid holiday days does the Working Time directive require workers to have each year?',
# 'answers': {'text': ['4 weeks',
# '4 weeks paid holidays each year',
# '4 weeks paid'],
# 'answer_start': [594, 594, 594]}}
To use the model, we create a question-answering pipeline, specify the tokenizer and local model. And then we transform an example dataset into a dictionary object with question
and context
keys.
local_model = 'gpt2_qa'
qa = pipeline(
"question-answering",
tokenizer=tokenizer,
model=AutoModelForQuestionAnswering.from_pretrained(model_name)
)
d = qa({"question": example["question"], "context": example["context"]})
# {'score': 0.0003682523383758962,
# 'start': 662,
# 'end': 692,
# 'answer': ' than 28 days in national law.'}
However, this answer is only partially correct, it misses the first part.
Question Answering with Fine-Tuned Model
Finally, lets apply the fine-tuned model to a context that it was not trained on — an excerpt from Wikipedia, and the question “What is NASA”.
query = {
"question": "What is NASA?",
"context": '''
The National Aeronautics and Space Administration (NASA) is an independent
agency of the U.S. federal government responsible for the civil space
program, aeronautics research, and space research. Established in 1958, it
succeeded the National Advisory Committee for Aeronautics (NACA) to give
the U.S. space development effort a distinctly civilian orientation,
emphasizing peaceful applications in space science. It has since
led most American space exploration, including Project Mercury, Project
Gemini, the 1968–1972 Apollo Moon landing missions, the Skylab space
station, and the Space Shuttle. It currently supports the International
Space Station and oversees the development of the Orion spacecraft and the
Space Launch System for the crewed lunar Artemis program, the Commercial
Crew spacecraft, and the planned Lunar Gateway space station.
NASA's science is focused on better understanding Earth through the Earth
Observing System; advancing heliophysics through the efforts of the
Science Mission Directorate's Heliophysics Research Program; exploring
bodies throughout the Solar System with advanced robotic spacecraft such as
New Horizons and planetary rovers such as Perseverance; and researching
astrophysics topics, such as the Big Bang, through the James Webb Space
Telescope, the Great Observatories and associated programs. The Launch
Services Program oversees launch operations and countdown management for
its uncrewed launches.
'''
}
qa(query)
# {'score': 0.008794573135674,
# 'start': 305,
# 'end': 369,
# 'answer': 'S. space development effort a distinctly civilian orientation,\n\t'}
However, this answer is dissatisfying. Also, trying other questions like “When was NASA founded?” led to similar results. Although the model is finetuned, it clearly shows that mere span-highlighting does not give precise answers, let alone reflective answers that work on a given text. And with this result, the second approach to building question-answering system, the limitations become clear.
Conclusion
Early Gen1 LLMs are capable text generators but lack support for advanced NLP tasks. This article showed a practical approach to fine-tune GPT2 with a question-answering dataset. You learned the dataset structure and saw the necessary preprocessing step. You also saw how to define training hyperparameter and start the training process with the help of the transformers library. The fine-tuned model was then manually tested with questions about the Wikipedia article for NASA. However, the models answers were inaccurate, a mere span detection inside a context is clearly a non-reflective answer. The next article explored domain embeddings with GPT3.