🎓 How to Train a BERT Model for Anime Genre Classification (with RTX 4090 from nicegpu.com)

In this blog post, we’ll walk through training a BERT-based model to predict anime genres based on synopses and metadata. This is a multi-label classification task, meaning each anime can belong to multiple genres.

We'll use the HuggingFace Transformers library and leverage GPU acceleration from nicegpu.com, specifically the RTX 4090, to significantly cut down training time.


🧾 Step 1: Load and Inspect the Data

We start with the MyAnimeList dataset on Kaggle, which includes information like anime names, synopses, producers, type, and genres.

Download it from Kaggle and load the CSV file:

python Copy
import pandas as pd
pre_merged_anime = pd.read_csv('anime-filtered.csv')
print(pre_merged_anime.shape)

🧼 Step 2: Data Cleaning and Text Generation

We clean the synopsis text and generate a formatted description that includes additional context.

python Copy
import re, string

def clean_txt(text):
    text = ''.join(filter(lambda x: x in string.printable, text))
    return re.sub(r'\s{2,}', ' ', text).strip()

def get_anime_description(row):
    type_str = "TV Show" if row["Type"] == "TV" else row["Type"]
    description = (
        f"{row['Name']} is {type_str}."
        f"Synopsis: {row['sypnopsis']}"
        f"Produced by: {row['Producers']} from {row['Studios']} Studio."
        f"Source: {row['Source']}."
        f"Premiered in: {row['Premiered']}."
    )
    return clean_txt(description)

pre_merged_anime['generated_description'] = pre_merged_anime.apply(get_anime_description, axis=1)

🏷️ Step 3: Genre Extraction and Encoding

We extract all unique genres and encode them into IDs for multi-label classification.

python Copy
from functools import reduce

all_genres = reduce(lambda y, z: y + z, pre_merged_anime['Genres'].map(lambda x: x.split(', ')))
unique_labels = sorted(set(all_genres))

id2label = {idx: label for idx, label in enumerate(unique_labels)}
label2id = {label: idx for idx, label in enumerate(unique_labels)}

🧠 Step 4: Tokenization and Data Preparation

We use BERT's tokenizer to process the generated descriptions and prepare labels as one-hot vectors.

python Copy
from transformers import AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def process_data(example, text_col):
    labels = []
    text = example[text_col]
    genres = example['Genres']
    for genre in genres:
        g = genre.split(', ')
        row = [1 if label in g else 0 for label in unique_labels]
        labels.append(torch.tensor(row, dtype=torch.float32).to(device))

    encoding = tokenizer(text, truncation=True, max_length=256, padding='max_length')
    encoding["labels"] = labels
    return encoding

We convert the pandas DataFrame to a HuggingFace Dataset and apply the transformation:

python Copy
from datasets import Dataset

dataset = Dataset.from_pandas(pre_merged_anime[['sypnopsis', 'Genres', 'generated_description']])
dataset = dataset.train_test_split(test_size=0.2, seed=42)

encoded_dataset = dataset.map(
    lambda x: process_data(x, 'generated_description'),
    batched=True,
    batch_size=128,
    remove_columns=['sypnopsis', 'Genres', 'generated_description']
)

🏋️ Step 5: Model Training Setup

We use the BERT-base model from HuggingFace and configure it for multi-label classification.

python Copy
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    problem_type='multi_label_classification',
    num_labels=len(unique_labels),
    id2label=id2label,
    label2id=label2id
).to(device)

⚙️ Step 6: Define Evaluation Metrics

We define a function to calculate various metrics for evaluating the model.

python Copy
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score, jaccard_score
from transformers import EvalPrediction

def multi_label_metrics(predictions, labels, threshold=0.5):
    probs = torch.sigmoid(torch.Tensor(predictions))
    y_pred = (probs >= threshold).int().numpy()
    y_true = labels
    return {
        'f1': f1_score(y_true, y_pred, average='micro'),
        'roc_auc': roc_auc_score(y_true, y_pred, average='micro'),
        'accuracy': accuracy_score(y_true, y_pred),
        'jaccard': jaccard_score(y_true, y_pred, average='micro')
    }

def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    return multi_label_metrics(preds, p.label_ids)

🚀 Step 7: Training Configuration

We use HuggingFace's Trainer with proper training arguments.

python Copy
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir='genre-prediction-bert',
    evaluation_strategy='epoch',
    save_strategy='epoch',
    per_device_train_batch_size=4,
    gradient_accumulation_steps=16,
    per_device_eval_batch_size=64,
    num_train_epochs=3,
    logging_steps=50,
    load_best_model_at_end=True,
    remove_unused_columns=False
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["test"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

🏁 Step 8: Train the Model

Now it's time to train!

python Copy
trainer.train()
trainer.save_model()

📊 Final Results & Metric Explanations

After training, here's a sample result:

Metric Score
F1 Score 0.65
ROC AUC 0.79
Accuracy 0.25
Jaccard 0.49

The model demonstrates solid performance with an F1 score of 0.655, ROC AUC of 0.796, and a Jaccard index of 0.487, showing good ability to predict multiple genres per anime. Training loss decreased consistently, while validation loss began increasing slightly after epoch 5–6, indicating mild overfitting in later epochs.

📐 What Do These Metrics Mean?

  • F1 Score (Micro-Average)
    Measures how well the model balances precision and recall across all genres. It's especially useful when classes are imbalanced.

  • ROC AUC (Micro)
    Evaluates the model's ability to distinguish between genres. A score of 0.88 means the model is good at telling relevant from irrelevant labels.

  • Accuracy
    This is the strictest metric—it’s only 1 if all predicted genres for an anime are correct. Useful, but harsh in multi-label settings.

  • Jaccard Index
    Measures how similar the predicted genres are to the actual genres. It compares intersection over union of predicted vs true labels.


⚡ Performance Tip: Use a GPU from nicegpu.com

Training this BERT model on a CPU can be painfully slow (several hours). But using a GPU like the RTX 4090, rented from nicegpu.com, makes a huge difference:

  • 🚀 5× to 10× faster training speed
  • 🧪 Support for mixed precision (FP16)
  • 📦 Ability to handle larger batch sizes
  • 🏋️ Great for fine-tuning large models

With an RTX 4090, this model trained in under 23 minutes—a massive upgrade over CPU!

I tried this on my M1, it takes forever, around 28 HOURS, 73 TIMES more:)


🧠 Conclusion

Training BERT for anime genre classification involves a combination of:

  • Cleaning and formatting the data
  • One-hot encoding the genres
  • Tokenizing input descriptions
  • Running multi-label classification using HuggingFace's trainer

By using a powerful GPU like the RTX 4090 from nicegpu.com, you can turn hours of work into just minutes.


💡 Ready to accelerate your NLP tasks? Try out nicegpu.com and run large models like BERT with ease.