Minimal examples of HuggingFace LLM training

I’m sharing a minimal example of training an LLM model using HuggingFace’s libraries trl/transformers/evaluate/datasets/etc. The example is mainly borrowed from https://wandb.ai/capecape/alpaca_ft/reports/How-to-Fine-tune-an-LLM-Part-3-The-HuggingFace-Trainer–Vmlldzo1OTEyNjMy and its github repo https://github.com/tcapelle/llm_recipes/blob/main/scripts/train_hf.py.

Here is the full file:

import wandb
from datasets import load_dataset

# if you can't find libwebp library, use brew update && brew install webp
import evaluate
import numpy as np
import torch
from transformers import TrainingArguments
from trl import SFTTrainer
from transformers import (
    AutoModelForCausalLM, AutoTokenizer, AutoConfig,
    LlamaConfig, LlamaModel,
)
from transformers import GenerationConfig
from transformers.integrations import WandbCallback
from tqdm import tqdm


def token_accuracy(eval_preds):
    token_acc_module = evaluate.load("accuracy")
    logits, labels = eval_preds
    # shape: batch_size x max_sequence_length
    predictions = np.argmax(logits, axis=-1)
    # accuracy only accepts 1d array. So if the batch contains > 1 datapoints,
    # the accuracy is based on flattened arrays
    # https://huggingface.co/spaces/evaluate-metric/accuracy
    return token_acc_module.compute(
        predictions=predictions.flatten().astype(np.int32),
        references=labels.flatten().astype(np.int32),
    )


def prompt_no_input(row):
    return ("Below is an instruction that describes a task. "
            "Write a response that appropriately completes the request.\n\n"
            "### Instruction:\n{instruction}\n\n### Response:\n{"
            "output}").format_map(
        row
    )


def prompt_input(row):
    return (
        "Below is an instruction that describes a task, paired with an input "
        "that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### "
        "Response:\n{output}").format_map(
        row
    )


def create_alpaca_prompt(row):
    return prompt_no_input(row) if row["input"] == "" else prompt_input(row)


class LLMSampleCB(WandbCallback):
    def __init__(
            self, trainer, test_dataset, num_samples=10, max_new_tokens=256,
            log_model="checkpoint"
    ):
        super().__init__()
        self._log_model = log_model

        def create_prompt_no_anwer(row):
            row["output"] = ""
            return {"text": create_alpaca_prompt(row)}

        self.sample_dataset = test_dataset.select(range(num_samples)).map(
            create_prompt_no_anwer
        )
        self.model, self.tokenizer = trainer.model, trainer.tokenizer
        self.gen_config = GenerationConfig.from_pretrained(
            trainer.model.name_or_path,
            max_new_tokens=max_new_tokens
        )

    def generate(self, prompt):
        tokenized_prompt = self.tokenizer(prompt, return_tensors='pt')[
            'input_ids']
        with torch.inference_mode():
            output = self.model.generate(
                inputs=tokenized_prompt, generation_config=self.gen_config
            )
        return self.tokenizer.decode(
            output[0][len(tokenized_prompt[0]):], skip_special_tokens=True
        )

    def samples_table(self, examples):
        records_table = wandb.Table(
            columns=["prompt", "generation"] + list(
                self.gen_config.to_dict().keys()
            )
        )
        for example in tqdm(examples):
            prompt = example["text"]
            generation = self.generate(prompt=prompt)
            records_table.add_data(
                prompt, generation, *list(self.gen_config.to_dict().values())
            )
        return records_table

    def on_evaluate(self, args, state, control, **kwargs):
        super().on_evaluate(args, state, control, **kwargs)
        records_table = self.samples_table(self.sample_dataset)
        self._wandb.log({"sample_predictions": records_table})
        print("log once")


def param_count(m):
    params = sum([p.numel() for p in m.parameters()]) / 1_000_000
    trainable_params = sum(
        [p.numel() for p in m.parameters() if p.requires_grad]
    ) / 1_000_000
    print(f"Total params: {params:.2f}M, Trainable: {trainable_params:.2f}M")
    return params, trainable_params


def trl_train():
    wandb.login(key='replace_with_your_own')

    lr = 2e-5
    batch_size = 8
    max_steps = 4
    # evaluate every eval_steps. so if we set max_steps = 4 and
    # eval_steps = 2, we will evaluate twice during training
    eval_steps = 2
    num_eval_data = 5
    num_wandb_cb_eval_data = 7
    wandb_cb_max_new_tokens = 256
    num_train_epochs = 1
    max_seq_length = 1024
    gradient_accumulation_steps = 1
    gradient_checkpointing = False
    output_dir = "./output/"

    run = wandb.init(
        project="second_project",
        config={
            "lr": lr,
            "batch_size": batch_size,
            "max_steps": max_steps,
            "eval_steps": eval_steps,
            "num_eval_data": num_eval_data,
            "num_wandb_cb_eval_data": num_wandb_cb_eval_data,
        },
    )

    alpaca_ds = load_dataset("winglian/alpaca-gpt4-split")

    train_dataset = alpaca_ds["train"]
    eval_dataset = alpaca_ds["test"]

    model_id = 'meta-llama/Llama-2-7b-hf'
    # try different ways to initialize a llama model
    # method 1: construct LLamaModel from LlamaConfig
    # https://huggingface.co/docs/transformers/v4.37.2/en/model_doc
    # /llama2#transformers.LlamaConfig
    # configuration = LlamaConfig(
    #     num_hidden_layers=2,
    #     hidden_size=32,
    #     intermediate_size=2,
    #     num_attention_heads=1,
    #     num_key_value_heads=1,
    # )
    # model = LlamaModel(configuration)
    # param_count(model)

    # method 2 & 3 need to wait for token approval
    # https://huggingface.co/meta-llama/Llama-2-7b-hf
    # method 2: load config first, tune down model size, then initialize the actual LLM
    # https://discuss.huggingface.co/t/can-i-pretrain-llama-from
    # -scratch/37821/8
    config = AutoConfig.from_pretrained(model_id)
    config.num_hidden_layers = 1
    config.hidden_size = 2
    config.intermediate_size = 2
    config.num_attention_heads = 1
    config.num_key_value_heads = 1
    model = AutoModelForCausalLM.from_config(config)
    param_count(model)

    # method 3: directly load pretrained llama model, which may encounter OOM
    # on a consumer cpu machine
    # model = AutoModelForCausalLM.from_pretrained(
    #     model_id,
    #     device_map="auto",
    #     trust_remote_code=True,
    #     low_cpu_mem_usage=True,
    #     torch_dtype=torch.bfloat16,
    #     load_in_8bit=True,
    # )

    training_args = TrainingArguments(
        output_dir=output_dir,
        use_cpu=True,
        report_to="wandb",
        per_device_train_batch_size=batch_size,
        bf16=True,
        learning_rate=lr,
        lr_scheduler_type="cosine",
        warmup_ratio=0.1,
        max_steps=max_steps,
        eval_steps=eval_steps,
        num_train_epochs=num_train_epochs,
        gradient_accumulation_steps=gradient_accumulation_steps,
        gradient_checkpointing=gradient_checkpointing,
        gradient_checkpointing_kwargs={"use_reentrant": False},
        evaluation_strategy="steps",
        # logging strategies
        logging_strategy="steps",
        logging_steps=1,
        save_strategy="no",
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token

    trainer = SFTTrainer(
        model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset.select(list(range(num_eval_data))),
        # this tells the trainer to pack sequences of `max_seq_length`
        # see illustration in https://wandb.ai/capecape/alpaca_ft/reports/How
        # -to-Fine-tune-an-LLM-Part-3-The-HuggingFace-Trainer--Vmlldzo1OTEyNjMy
        packing=True,
        max_seq_length=max_seq_length,
        formatting_func=create_alpaca_prompt,
        compute_metrics=token_accuracy,  # only call at evaluation
    )
    wandb_callback = LLMSampleCB(
        trainer, eval_dataset,
        num_samples=num_wandb_cb_eval_data,
        max_new_tokens=wandb_cb_max_new_tokens,
    )
    trainer.add_callback(wandb_callback)
    trainer.train()
    wandb.finish()

    # other materials:
    # fine tune ppo vs dpo
    # trl stackllama tutorial:
    # https://huggingface.co/docs/trl/using_llama_models
    # trl readme: https://github.com/huggingface/trl/tree/main?tab=readme-ov
    # dpo - trl: https://huggingface.co/blog/dpo-trl


if __name__ == '__main__':
    trl_train()

Now let’s examine the code in more details:

First, we initialize a weights & bias project (wandb.init(...)), which is used for logging intermediate training/evaluation results. It is a very convenient tool for logging and visualization. 

Then, we use  load_dataset(...) , an api from HuggingFace’s dataset library, to load a specific data. HuggingFace hosts many awesome datasets at https://huggingface.co/datasets.

Next, we initialize an actual LLM. Since this is a minimal example, I created a tiny LLM by modifying its config to have very few hidden layers and hidden sizes.

Next, we initialize TrainingArguments. We may need to be familiar with several concepts in TrainingArguments, such as gradient accumulation

We then initialize a tokenizer, which is trivial by calling HuggingFace’s API AutoTokenizer.from_pretrained(...)

We then initialize SFTTrainer, the main class for training and evaluating the LLM. Setting packing=True means that we pack multiple individual sequences into a fixed-length sequence so that we can avoid much padding. Individual sequences are usually separated with an eos token.

We also initialize a callback, which is called only in the evaluation stage. The callback class needs to first remove output in the dataset for evaluation.

 

We now look at the results logged in wandb (example https://wandb.ai/czxttkl/third_project/runs/hinck0h5):

  1. Since we specify max_steps=4 and eval_steps=2, we have 2 evaluations. The evaluation loss curves verifie we indeed log 2 evaluation results.
  2. we have a table showing the results from the callback. We can verify that prompts indeed have outputs removed. We can also use runs.history.concat["sample_predictions"] instead of runs.summary["sample_predictions"] to check the evaluation results from all evaluation runs (exactly 2 runs) (see the reference in https://wandb.ai/morg/append-to-table/reports/Append-to-Table–Vmlldzo0MjY0MDIx)

 

 

Leave a comment

Your email address will not be published. Required fields are marked *