BERT Traditional Chinese Classifier (Mainland vs. Taiwan)

A BERT-based classifier to distinguish Mainland Traditional vs. Taiwan Traditional Chinese usage.

Model overview

  • Base model: ckiplab/bert-base-chinese
  • Task: Traditional Chinese text classification (Mainland vs Taiwan)
  • Reported accuracy: 87.71% (validation)
  • Training samples: ~156,824

Highlights

  • Handles long texts via sliding window (max length 384 tokens)
  • Optional Focal Loss for class imbalance
  • Optional multi-sample dropout and progressive unfreezing
  • Layer-wise learning rate decay and AdamW optimizer

Optimization and generalization

  • Layer-wise LR decay (≈0.95), higher LR for the classification head
  • AdamW (weight_decay=0.01, betas≈(0.9, 0.98))
  • Optional OneCycleLR for fast warmup and stable convergence
  • Early stopping and gradient accumulation

Usage

import torch, torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# --- Basic config ---
REPO_ID = "renhehuang/bert-traditional-chinese-classifier"
LABELS = {0: "Mainland Traditional", 1: "Taiwan Traditional"}
MAX_LEN, STRIDE = 384, 128

# --- Device ---
device = (
    "mps" if torch.backends.mps.is_available()
    else ("cuda" if torch.cuda.is_available() else "cpu")
)

# --- Load model & tokenizer ---
tokenizer = AutoTokenizer.from_pretrained(REPO_ID, cache_dir=".cache")
model = AutoModelForSequenceClassification.from_pretrained(REPO_ID, cache_dir=".cache")
model.to(device).eval()

# --- Long-text chunking ---
def chunk_encode(text, max_len=MAX_LEN, stride=STRIDE):
    ids = tokenizer(text, add_special_tokens=False, return_attention_mask=False)["input_ids"]
    if len(ids) <= max_len - 2:
        enc = tokenizer(text, truncation=True, max_length=max_len,
                        return_attention_mask=True, return_tensors="pt")
        return [enc]
    enc = tokenizer(text, truncation=True, max_length=max_len, stride=stride,
                    return_overflowing_tokens=True, return_attention_mask=True,
                    return_tensors="pt")
    return [{"input_ids": enc["input_ids"][i:i+1],
             "attention_mask": enc["attention_mask"][i:i+1]}
            for i in range(len(enc["input_ids"]))]

# --- Single-text inference ---
@torch.inference_mode()
def predict(text: str):
    chunks = chunk_encode(text)
    probs_all = []
    for ch in chunks:
        logits = model(
            input_ids=ch["input_ids"].to(device),
            attention_mask=ch["attention_mask"].to(device)
        ).logits
        probs_all.append(F.softmax(logits, dim=-1).cpu())
    avg = torch.cat(probs_all, 0).mean(0)
    label_id = int(avg.argmax())
    return {
        "text_preview": (text[:100] + "...") if len(text) > 100 else text,
        "predicted_id": label_id,
        "predicted_name": LABELS[label_id],
        "confidence": float(avg[label_id]),
        "probabilities": {LABELS[0]: float(avg[0]), LABELS[1]: float(avg[1])},
        "num_chunks": len(chunks),
        "device": device,
    }

# --- Quick test ---
if __name__ == "__main__":
    tests = [
        "這個軟件的界面設計得很好。",
        "這個軟體的介面設計得很好。",
        "我需要下載這個程序到計算機上。",
        "我需要下載這個程式到電腦上。",
    ]
    for t in tests:
        r = predict(t)
        print(f"{r['predicted_name']} | conf={r['confidence']:.2%} | {r['text_preview']}")

Long-text and robust inference (optional MC Dropout voting)

  • Long texts are chunked with a sliding window (MAX_LEN=384, STRIDE=128) and averaged.
  • For uncertainty estimation, run multiple stochastic passes (with dropout) and vote on labels; confidence is the mean probability of the voted class.
from collections import Counter

@torch.inference_mode()
def predict_runs(text: str, n_runs: int = 3, enable_dropout: bool = True):
    # Pre-chunk
    chunks = chunk_encode(text)

    prev_training = model.training
    run_prob_list = []
    try:
    model.train() if enable_dropout else model.eval()  # enable MC Dropout
        for _ in range(n_runs):
            probs_all = []
            for ch in chunks:
                logits = model(
                    input_ids=ch["input_ids"].to(device),
                    attention_mask=ch["attention_mask"].to(device)
                ).logits
                probs_all.append(F.softmax(logits, dim=-1).cpu())
            run_prob_list.append(torch.cat(probs_all, 0).mean(0))
    finally:
        model.train() if prev_training else model.eval()

    probs_stack = torch.stack(run_prob_list, 0)
    per_run_ids = probs_stack.argmax(-1).tolist()
    vote_counts = Counter(per_run_ids)
    mean_probs = probs_stack.mean(0)

    # Majority vote + mean probability as a tie-breaker
    voted_id = max(vote_counts.items(), key=lambda kv: (kv[1], mean_probs[kv[0]].item()))[0]
    return LABELS[voted_id], float(mean_probs[voted_id]), dict(vote_counts)

Labels

  • 0 → Mainland Traditional
  • 1 → Taiwan Traditional

Data and training (summary)

  • Aggregated Traditional Chinese corpora from multiple sources and lengths, balanced and quality-controlled.
  • Tokenization: BERT WordPiece; long-text chunking with 384/128 sliding window.
  • Loss: Focal Loss (gamma=2.0); optional light label smoothing (~0.05).
  • Optimization: AdamW + layer-wise LR decay (~0.95); optional OneCycleLR.
  • Regularization: Multi-sample dropout, progressive unfreezing, early stopping, gradient accumulation.

Training configuration

  • Batch Size: 16
  • Learning Rate: 2e-05 (base), 4e-05 (head)
  • Epochs: 4
  • Max Length: 384
  • Loss Function: Focal Loss (gamma=2.0)

Evaluation metrics

  • Overall: Accuracy / F1 ≈ 0.8771 (validation)
  • Length-wise: Stable on very long texts after chunk averaging
  • Typical confusions: mixed regional vocabulary and domain-specific jargon

For full learning curves and diagnostic plots, see repository outputs.

Intended use and limitations

Intended: origin-style identification, data cleaning, annotation assistance, pre-normalization, and hybrid use with rules/other models.

Limitations:

  • Mixed regional usage, translations, heavy emoji/code/foreign text reduce confidence.
  • Domain-specific jargon may bias results toward a region.
  • Very short or heavily colloquial snippets are harder.

Fairness, safety, and risk

  • The model reflects the training data distribution; biases may exist for topics/domains.
  • Do not use as a single source of truth; combine with human review or model ensembles.
  • Follow local laws and platform policies.

Deployment tips

  • For critical paths, consider multiple runs (3–10) and a confidence threshold (e.g., ≥ 0.85).
  • Route low-confidence samples to human review.
  • Monitor domain shift and periodically fine-tune with new feedback data.

Citation

If you use this model, please cite:

@misc{bert-traditional-chinese-classifier,
  author = {renhehuang},
  title = {BERT Traditional Chinese Classifier},
  year = {2025},
  publisher = {Hugging Face},
  howpublished = {\url{https://huggingface.co/renhehuang/bert-traditional-chinese-classifier}}
}

License

Apache-2.0

Contact

Please open an issue on the Hugging Face model page or GitHub repository.

Downloads last month
50
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 1 Ask for provider support

Model tree for renhehuang/bert-traditional-chinese-classifier

Finetuned
(16)
this model

Dataset used to train renhehuang/bert-traditional-chinese-classifier

Space using renhehuang/bert-traditional-chinese-classifier 1