๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
AI

ํ…์ŠคํŠธ ๋ถ„๋ฅ˜ - Hard Example Mining

by kaizen_bh 2025. 10. 28.

 

 

 

 

ํ…์ŠคํŠธ ๋ถ„๋ฅ˜๋ฅผ ์ง„ํ–‰ํ•˜๋˜ ์ค‘, ์นœ๊ตฌ์—๊ฒŒ ์ข‹์€ ์•„์ด๋””์–ด๋ฅผ ํ•˜๋‚˜ ๋ฐ›์•˜๋‹ค

๋จผ์ € ๊ต์ฐจ๊ฒ€์ฆ์„ ๋Œ๋ฆฌ๋ฉด์„œ ๊ฐ ํด๋“œ๋‹น ๊ฒ€์ฆ์…‹์— ๋Œ€ํ•ด ์˜ˆ์ธกํ•œ ๋ผ๋ฒจ๊ณผ ํ™•๋ฅ ๊ฐ’์„ ์ €์žฅํ•˜๋ฉด ๋ฒ ์ด์Šค๋ผ์ธ์ด ์›๋ณธ ์ „์ฒด ๋ฐ์ดํ„ฐ์…‹์— ๋Œ€ํ•ด ๋ฌธ์žฅ๋งˆ๋‹ค ์–ด๋–ค ๋‹ต์„ ์–ผ๋งˆ์˜ ํ™•๋ฅ ๋กœ ์˜ˆ์ธกํ–ˆ๋Š”์ง€๋ฅผ ์•Œ ์ˆ˜ ์žˆ๊ฒŒ ๋œ๋‹ค

 

๊ทธ๋ฆฌ๊ณ  ๋ชป๋งž์ถ˜ ๋ฌธ์žฅ๋“ค์— ๋Œ€ํ•ด์„œ ์ด ํ™•๋ฅ ๊ฐ’ p๋ฅผ 1-p ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜ํ•ด์ฃผ๋ฉด ์–ผ๋งˆ๋‚˜ ์ •๋‹ต์„ ํ™•์‹ ํ•˜์ง€ ๋ชปํ–ˆ๋Š”์ง€, ์ฆ‰ ์˜ค๋‹ต์— ๋Œ€ํ•ด ์–ผ๋งˆ๋‚˜ ํ™•์‹ ํ–ˆ๋Š”์ง€๋ฅผ ๋‚˜ํƒ€๋‚ธ๋‹ค

์ด๋ฅผ ํ†ตํ•ด ์–ป์„ ์ˆ˜ ์žˆ๋Š” ๊ธฐ๋Œ€ํšจ๊ณผ๋กœ๋Š” ์ž๊ธฐ ์ •๋‹ต ํด๋ž˜์Šค์— ์• ๋งคํ•œ ํ™•๋ฅ ๋กœ ํŒ๋‹จํ•˜๊ฑฐ๋‚˜ ์•„์˜ˆ ๋‚ฎ์€ ํ™•๋ฅ ๋กœ ์™„๋ฒฝํ•˜๊ฒŒ ์ •๋‹ต์„ ํ‹€๋ฆฌ๋Š” ์ผ€์ด์Šค๋“ค์„ ์ˆ˜์ง‘ํ•˜์—ฌ ์—ฌ๊ธฐ์— ๋Œ€ํ•ด ํ•ธ๋“ค๋งํ•˜๋Š”, ๋ชจ๋ธ์ด ๋ฐ์ดํ„ฐ๋ฅผ ๋ณผ ๋•Œ ์•ฝ์ ์„ ํŒŒ์•…ํ•˜์—ฌ ์ด๋ฅผ ์ง์ ‘์ ์œผ๋กœ ๊ฐœ์„ ํ•  ์ˆ˜ ์žˆ๋Š” ๋ฐฉ๋ฒ•์ด์ง€ ์•Š์„๊นŒ ๋ผ๋Š” ์ƒ๊ฐ์ด ๋“ค์—ˆ๋‹ค

 

์ด๋Ÿฌํ•œ ๊ฐœ๋…์„ ์ฐพ์•„๋ณด๋‹ˆ๊นŒ Hard Nagative Mining, ๋ชจ๋ธ์ด ์˜ค๋‹ต์„ ๋‚ธ, ์˜ˆ์ธก์„ ์ž˜๋ชปํ•œ ์–ด๋ ค์šด ์ƒ˜ํ”Œ์„ ์ถ”์ถœํ•˜๋Š” ๋ฐฉ๋ฒ•์ด๋ผ ํ•œ๋‹ค

์ฃผ๋กœ ๋น„์ „์ชฝ์— ์‚ฌ์šฉ๋˜๋Š” ๊ฐœ๋…์ด๊ธด ํ•˜์ง€๋งŒ ์ด๋Ÿฌํ•œ ๊ฐœ๋… ์ž์ฒด๋Š” ํ…์ŠคํŠธ, ๋จธ์‹  ๋Ÿฌ๋‹ ๋“ฑ์—์„œ๋„ ์œ ์‚ฌํ•˜๊ฒŒ ํ™œ์šฉํ•  ์ˆ˜ ์žˆ์–ด๋ณด์ธ๋‹ค

 

์‚ฌ์šฉํ•˜๋Š” ํ…์ŠคํŠธ ๋ฐ์ดํ„ฐ์…‹์— ๋ถˆ๊ท ํ˜•์ด ์กด์žฌํ•˜๊ธฐ์— ๋‹จ์ˆœํžˆ ํ…์ŠคํŠธ ์ฆ๊ฐ•๋งŒ์„ ์ ์šฉํ•˜๋ ค ํ–ˆ์œผ๋‚˜ ์—ฌ๊ธฐ์— ์ ‘๋ชฉ์‹œ์ผœ ๋ชจ๋ธ์ด ์ž˜ ์˜ˆ์ธกํ•˜์ง€ ๋ชปํ•˜๋Š” ์ทจ์•ฝํ•œ ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•ด ์ฆ๊ฐ•์„ ํ•ด์ฃผ๋Š” ๊ฒƒ, Targeted Augmentation ์œผ๋กœ ์œ ์˜๋ฏธํ•œ ์„ฑ๋Šฅ ํ–ฅ์ƒ์ด ์žˆ๋Š”์ง€ ์‹คํ—˜ํ•ด๋ณด๋ ค ํ•œ๋‹ค

 

 

https://stydy-sturdy.tistory.com/27

 

[๊ฐ์ฒด ํƒ์ง€] Hard Negative Mining ์ด๋ž€?

Object Detection task์—์„œ bounding box๋ฅผ ๋ฝ‘์œผ๋ฉด ์ˆ˜์ฒœ ๊ฐœ๋ฅผ ๋ฝ‘๊ฒŒ ๋œ๋‹ค. ์ˆ˜์ฒœ ๊ฐœ์˜ Bounding Box ์•ˆ์— ์šฐ๋ฆฌ๊ฐ€ ์ฐพ๊ณ ์ž ํ•˜๋Š” ๋ฌผ์ฒด ํ˜น์€ ๊ฐ์ฒด๊ฐ€ ์žˆ๋Š” ๋ฐ•์Šค๊ฐ€ ํ‰๊ท  ์ˆ˜์‹ญ ๊ฐœ ์žˆ๋‹ค๊ณ  ๊ฐ€์ •ํ•˜๋ฉด ๋‚˜๋จธ์ง€ ๋ฐ•์Šค๋“ค ์ฆ‰, ๋ฌผ

stydy-sturdy.tistory.com

 

 

https://velog.io/@sa090180/OHEM-Training-Region-based-Object-Detectors-with-Online-Hard-Example-Mining

 

OHEM: Training Region-based Object Detectors with Online Hard Example Mining

์ด๋ฒˆ ํฌ์ŠคํŒ…์—์„œ๋Š” OHEM(Online Hard Example Mining)๋…ผ๋ฌธ์„ ๋ฆฌ๋ทฐํ•ด๋ณด๊ฒ ๋‹ค.์ผ๋ฐ˜์ ์œผ๋กœ object detection์‹œ์— ๋ฐฐ๊ฒฝ์˜์—ญ์— ํ•ด๋‹นํ•˜๋Š” region proposals์˜ ์ˆ˜๊ฐ€ ๋” ๋งŽ์•„ ํด๋ž˜์Šค ๋ถˆ๊ท ํ˜•์ด ๋ฐœ์ƒํ•˜๊ณ , ์ด ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ

velog.io

 

 

 

 

 


 

 

 

 

1. ํ•ต์‹ฌ ๊ฐœ๋…

  1. ๋ชจ๋ธ์ด ์•ฝํ•œ ๋ฐ์ดํ„ฐ (=hard sample) ๋ฅผ ์ฐพ์•„๋‚ธ๋‹ค.
    → ์ •๋‹ต ํด๋ž˜์Šค ํ™•๋ฅ (p_true)์ด ๋‚ฎ์€ ์ƒ˜ํ”Œ = ๋ชจ๋ธ์ด ๋ถˆํ™•์‹คํ•˜๊ฑฐ๋‚˜ ํ‹€๋ฆฐ ์ƒ˜ํ”Œ
  2. ๊ทธ ์ƒ˜ํ”Œ๋“ค์„ ์ง‘์ค‘ ์ฆ๊ฐ•ํ•œ๋‹ค.
    → Synonym ๊ต์ฒด, ๋ฒˆ์—ญ ํ›„ ๋ณต์›, LLM์„ ์ด์šฉํ•œ ๋ฌธ์žฅ ์žฌ์ž‘์„ฑ ๋“ฑ์œผ๋กœ ๋‹ค์–‘ํ™”
  3. ๋‹ค์‹œ ํ•™์Šต์‹œ์ผœ ์„ฑ๋Šฅ ๋น„๊ต
    → ์ „์ฒด ์ฆ๊ฐ•๋ณด๋‹ค ํšจ์œจ์ ์ด๊ณ , ๋ชจ๋ธ ์•ฝ์ ์„ ๋ณด์™„ํ•จ

 

 

 

2. ๊ตฌํ˜„ ์ ˆ์ฐจ

Step 1. ๊ต์ฐจ ๊ฒ€์ฆ ์˜ˆ์ธก ๊ฒฐ๊ณผ ์–ป๊ธฐ (Out-of-Fold Prediction)

  • ๋ชจ๋ธ ํ•™์Šต ํ›„, ๊ต์ฐจ๊ฒ€์ฆ์„ ํ†ตํ•ด ๊ฐ ์ƒ˜ํ”Œ์˜ ์˜ˆ์ธก ํ™•๋ฅ ์„ ์–ป๋Š”๋‹ค.
  • ์ด๋ฅผ ํ†ตํ•ด “ํ•ด๋‹น ์ƒ˜ํ”Œ์„ ํ•™์Šต์— ์‚ฌ์šฉํ•˜์ง€ ์•Š์€ fold”์—์„œ์˜ ์˜ˆ์ธก ๊ฒฐ๊ณผ๋ฅผ ํ™•๋ณด → ๊ณผ์ ํ•ฉ ๋ฐฉ์ง€

์•„๋ž˜๋Š” ํ—ˆ๊น…ํŽ˜์ด์Šค์˜ ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ €๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํด๋“œ๋งˆ๋‹ค ๋ชจ๋ธ์„ ์ƒˆ๋กœ ๋ถˆ๋Ÿฌ์™€์„œ ํ•™์Šต์‹œํ‚ค๊ณ  ๊ฒ€์ฆ์…‹์— ๋Œ€ํ•œ ์˜ˆ์ธก ๋ผ๋ฒจ๊ณผ ํ™•๋ฅ ์„ ์ €์žฅํ•œ๋‹ค

๊ต์ฐจ๊ฒ€์ฆ์—๋Š” ์‹œ๊ฐ„์ด ์˜ค๋ž˜ ๊ฑธ๋ฆฌ๊ธฐ ๋•Œ๋ฌธ์— ๋งค๋ฒˆ ๋Œ๋ฆด ์ˆ˜ ์—†์œผ๋ฏ€๋กœ csv๋กœ ์ €์žฅํ•ด๋‘๋Š” ๊ฒƒ์ด ์ข‹๋‹ค

# ==========================================================
# 3. ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ์„ค์ •
# ==========================================================
MODEL_NAME = "kykim/bert-kor-base"  # ๋Œ€ํšŒ ํ—ˆ์šฉ ๋ชจ๋ธ ์ค‘ ํ•˜๋‚˜
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# ==========================================================
# 4. K-Fold ์„ค์ •
# ==========================================================
NUM_FOLDS = 10
RANDOM_STATE = 42
skf = StratifiedKFold(n_splits=NUM_FOLDS, shuffle=True, random_state=RANDOM_STATE)

oof_preds = np.zeros(len(df), dtype=int)
oof_probs = np.zeros((len(df), 4))  # ๊ฐ์ • ํด๋ž˜์Šค 4๊ฐœ
fold_accuracies = []

# ==========================================================
# 5. Fold๋ณ„ ํ•™์Šต ๋ฐ OOF ์˜ˆ์ธก
# ==========================================================
for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
    print(f"\n===== Fold {fold+1} / {NUM_FOLDS} =====")
    X_train, X_val = [X[i] for i in train_idx], [X[i] for i in val_idx]
    y_train, y_val = [y[i] for i in train_idx], [y[i] for i in val_idx]

    train_dataset = ReviewDataset(X_train, y_train, tokenizer)
    val_dataset = ReviewDataset(X_val, y_val, tokenizer)
    test_dataset = ReviewDataset(test_texts, labels=None, tokenizer=tokenizer)

    model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=4)

    training_args = TrainingArguments(
        output_dir=f"./results/fold_{fold+1}",
        eval_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        greater_is_better=True,
        save_total_limit=1, 
        per_device_train_batch_size=128,
        per_device_eval_batch_size=256,
        num_train_epochs=3,
        learning_rate=5e-5,
        warmup_steps=500,
        weight_decay=0.05,
        seed=RANDOM_STATE,
        fp16=torch.cuda.is_available(),
        logging_strategy="epoch",          # ← ๋กœ๊ทธ ๋‚จ๊ธฐ๋Š” ์ฃผ๊ธฐ ์ง€์ •
        report_to="none",                  # ← wandb๋‚˜ tensorboard๋กœ ์•ˆ ๋ณด๋‚ผ ๊ฒฝ์šฐ
    )



    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
        compute_metrics=compute_metrics,
    )

    trainer.train()

    # ๊ฒ€์ฆ ๋ฐ์ดํ„ฐ ์˜ˆ์ธก
    preds_output = trainer.predict(val_dataset)
    preds = np.argmax(preds_output.predictions, axis=1)
    probs = torch.nn.functional.softmax(torch.tensor(preds_output.predictions), dim=1).numpy()

    # OOF ๊ฒฐ๊ณผ ์ €์žฅ
    oof_preds[val_idx] = preds
    oof_probs[val_idx] = probs

    acc = accuracy_score(y_val, preds)
    fold_accuracies.append(acc)
    print(f"Fold {fold+1} Accuracy: {acc:.4f}")

    # ํ…Œ์ŠคํŠธ์…‹ ์˜ˆ์ธก
    print(f"→ Fold {fold+1} Test inference ์ค‘...")
    test_output = trainer.predict(test_dataset)
    test_probs = torch.nn.functional.softmax(torch.tensor(test_output.predictions), dim=1).numpy()
    test_preds = np.argmax(test_probs, axis=1)
    
    # CSV๋กœ ์ €์žฅ
    test_pred_df = pd.DataFrame({
        "ID": test_ids,
        "pred": test_preds
    })
    test_prob_df = pd.DataFrame(test_probs, columns=[f"class_{i}_prob" for i in range(4)])
    test_prob_df.insert(0, "ID", test_ids)
    
    test_pred_df.to_csv(f"test_fold_{fold+1}_predictions.csv", index=False)
    test_prob_df.to_csv(f"test_fold_{fold+1}_probabilities.csv", index=False)
    print(f"โœ… Fold {fold+1} ํ…Œ์ŠคํŠธ ์˜ˆ์ธก ๊ฒฐ๊ณผ ์ €์žฅ ์™„๋ฃŒ.")







# ==========================================================
# 6. OOF ์„ฑ๋Šฅ ํ™•์ธ
# ==========================================================
# ์›๋ณธ ๋ฐ์ดํ„ฐ์™€ ์˜ˆ์ธก์„ ๋น„๊ตํ•˜์—ฌ ํด๋“œ์˜ ์ „์ฒด์ ์ธ ์„ฑ๋Šฅ์„ ์ฒดํฌํ•จ
oof_acc = accuracy_score(y, oof_preds)
print(f"\nโœ… Overall OOF Accuracy: {oof_acc:.4f}")
print(f"Fold Accuracies: {fold_accuracies}")
โœ… Overall OOF Accuracy: 0.8605
Fold Accuracies: [0.8553166763509588, 0.860945671121441, 0.8615630447414294, 0.8647951772225451, 0.8609819872167345, 0.860691458454387, 0.8604009296920395, 0.8593789722171782, 0.858979480660977, 0.8622480479389868]
โœ… OOF ๊ฒฐ๊ณผ ํฌํ•จํ•œ CSV ์ €์žฅ ์™„๋ฃŒ

์ด 38404๊ฐœ์˜ hard case ํƒ์ง€๋จ

 

  • ์•„๋ž˜ ์›๋ณธ ๋ฐ์ดํ„ฐ์™€ ์˜ˆ์ธก์„ ๋น„๊ตํ•ด์„œ Accuracy๋ฅผ ํ™•์ธํ•˜๋Š” ๋ถ€๋ถ„๋„ ์žˆ๋‹ค. ์ด๋ฒˆ 10ํด๋“œ ์ง„ํ–‰ํ–ˆ์„ ๋•Œ Accuracy๋Š” 0.86 ์ •๋„
  • ์ด๊ฑด ๋ฒ ์ด์Šค๋ผ์ธ์ด ์›๋ณธ ๋ฐ์ดํ„ฐ ์ „์ฒด์— ๋Œ€ํ•ด ์–ด๋А ์ •๋„์˜ ์ •ํ™•๋„๋กœ ์˜ˆ์ธกํ–ˆ๋Š”์ง€ ์ฐธ๊ณ ํ•˜๋Š”๋ฐ ์‚ฌ์šฉํ•˜๋ฉด ์ข‹์„ ๊ฒƒ ๊ฐ™๋‹ค

 

 

 

 

Step 2. Hard Sample ์‹๋ณ„ ๋ฐ ๋ถ„์„

 

์ž๊ธฐ ์ •๋‹ต ํด๋ž˜์Šค์˜ ํ™•๋ฅ (p_true)์ด ๋‚ฎ์„์ˆ˜๋ก ๋ชจ๋ธ์ด ๋ถˆํ™•์‹คํ•œ ์ƒ˜ํ”Œ์ด๋‹ค.
์ด๋•Œ 1 - p_true ๊ฐ’์ด ํด์ˆ˜๋ก “์–ด๋ ค์šด ์ƒ˜ํ”Œ”.

 

 

# ==========================================================
# 7. Hard Case ํƒ์ƒ‰
# ==========================================================
df["oof_pred"] = oof_preds
df["is_correct"] = (df["label"] == df["oof_pred"])
df["oof_confidence"] = oof_probs.max(axis=1)


df.to_csv("./oof_results_with_confidence.csv", index=False)
print("โœ… OOF ๊ฒฐ๊ณผ ํฌํ•จํ•œ CSV ์ €์žฅ ์™„๋ฃŒ")

# consistently ํ‹€๋ฆฐ ์ƒ˜ํ”Œ๋งŒ ํ•„ํ„ฐ๋ง
hard_cases = df[~df["is_correct"]].sort_values("oof_confidence")

print(f"\n์ด {len(hard_cases)}๊ฐœ์˜ hard case ํƒ์ง€๋จ")
hard_cases[["ID", "review_normalized", "label", "oof_pred", "oof_confidence"]].head(10)
  • Hard Case ํƒ์ƒ‰์˜ ๊ฒฝ์šฐ ํ™•๋ฅ ์ด ๋†’๊ณ  ๋‚ฎ์Œ๊ณผ ์ƒ๊ด€์—†์ด ๋‹ต์„ ํ‹€๋ฆฐ ๊ฒƒ์— ๋Œ€ํ•ด์„œ๋งŒ ์ ์šฉํ•˜์—ฌ ์ „์ฒด ์ค‘ ๋ช‡ ๊ฐœ๋ฅผ ํ‹€๋ ธ๋Š”์ง€๋ฅผ ์ฒดํฌํ•œ๋‹ค

 

p_true = oof_probs[np.arange(len(df)), df["label"].values]
df["p_true"] = p_true
df["hardness"] = 1 - df["p_true"]

 

 

p_true = oof_probs[np.arange(len(df)), df["label"].values]
  • ์ด ํ•œ ์ค„์€ ๊ฐ ์ƒ˜ํ”Œ์— ๋Œ€ํ•ด, ์ •๋‹ต ํด๋ž˜์Šค์— ํ•ด๋‹นํ•˜๋Š” ํ™•๋ฅ ๊ฐ’์„ ๋ฝ‘์•„๋‚ธ ๊ฒƒ
  • oof_probs ์—๋Š” ๊ต์ฐจ ๊ฒ€์ฆ์„ ํ•˜๋ฉด์„œ ๊ฒ€์ฆ์…‹ ๋ผ๋ฒจ๋งˆ๋‹ค์˜ ์˜ˆ์ธก๊ฐ’์ด ๋‹ด๊ฒจ์žˆ์Œ
  • oof_probs[i] → ์ƒ˜ํ”Œ i์— ๋Œ€ํ•œ ๋ชจ๋ธ์˜ softmax ํ™•๋ฅ  ์˜ˆ์ธก ๋ฒกํ„ฐ (์˜ˆ: [0.1, 0.7, 0.1, 0.1])
  • df["label"].values[i] → ์ƒ˜ํ”Œ i์˜ ์‹ค์ œ ์ •๋‹ต ํด๋ž˜์Šค (์˜ˆ: 1)
  • ๋”ฐ๋ผ์„œ oof_probs[i, df["label"][i]] → ๋ชจ๋ธ์ด ์ •๋‹ต ํด๋ž˜์Šค์— ํ• ๋‹นํ•œ ํ™•๋ฅ  (์—ฌ๊ธฐ์„œ๋Š” 0.7)
  • ๋งŒ์•ฝ ๋ชจ๋ธ์ด ์ž˜๋ชป ์˜ˆ์ธกํ–ˆ๋‹ค๋ฉด ์ด ํ™•๋ฅ ์ด ๋งค์šฐ ๋‚ฎ๊ฒŒ ๋‚˜์˜ค๊ธฐ๋„ ํ•จ

 

 

p_true

  • ๋ชจ๋ธ์ด ๊ฐ ์ƒ˜ํ”Œ์˜ “์ง„์งœ ์ •๋‹ต ํด๋ž˜์Šค”์— ์–ผ๋งˆ๋‚˜ ํ™•์‹ ์„ ๊ฐ€์กŒ๋Š”์ง€๋ฅผ ๋‚˜ํƒ€๋‚ด๋Š” ๊ฐ’
  • ๊ฐ’์˜ ๋ฒ”์œ„๋Š” 0~1์ด๋ฉฐ, 1์— ๊ฐ€๊นŒ์šธ์ˆ˜๋ก ๋ชจ๋ธ์ด ์ •๋‹ต์„ ๊ฐ•ํ•˜๊ฒŒ ํ™•์‹ ํ–ˆ๋‹ค๋Š” ๋œป

 

hardness 

  • ๋ชจ๋ธ์ด ์ •๋‹ต์— ๋Œ€ํ•ด ์–ผ๋งˆ๋‚˜ “์–ด๋ ค์›Œํ–ˆ๋Š”๊ฐ€”
  • ๋ชจ๋ธ์ด ์ •๋‹ต ํ™•๋ฅ  (p_true) ์„ ๋†’๊ฒŒ ์˜ˆ์ธกํ•˜๋ฉด → p_true ↑ → hardness ↓ (์‰ฌ์šด ์ƒ˜ํ”Œ)
  • ๋ชจ๋ธ์ด ์ •๋‹ต ํ™•๋ฅ  (p_true) ์„ ๋‚ฎ๊ฒŒ ์˜ˆ์ธกํ•˜๋ฉด → p_true ↓ → hardness ↑ (์–ด๋ ค์šด ์ƒ˜ํ”Œ)
    • 0.9 → ๋งค์šฐ ์–ด๋ ค์šด ์ƒ˜ํ”Œ (๊ฑฐ์˜ ํ‹€๋ฆด ๋ป”ํ–ˆ๊ฑฐ๋‚˜ ์‹ค์ œ๋กœ ํ‹€๋ ธ์„ ํ™•๋ฅ ์ด ๋†’์Œ)
    • 0.5 → ์ค‘๊ฐ„ ์ •๋„๋กœ ๋ถˆํ™•์‹ค
    • 0.1 → ๋งค์šฐ ์‰ฌ์šด ์ƒ˜ํ”Œ (๊ฑฐ์˜ ํ™•์‹คํ•˜๊ฒŒ ๋งž์ถค)

 

  • p_true, 1-p_ture ๊ฐ’์„ ๊ณ„์‚ฐ ํ›„ ๋ฐ์ดํ„ฐ ํ”„๋ ˆ์ž„์— ์ถ”๊ฐ€
  • ์ตœ์ข…์ ์œผ๋กœ ์•„๋ž˜์™€ ๊ฐ™์€ ํ˜•ํƒœ๋กœ ๊ธฐ์กด์˜ ๋ฐ์ดํ„ฐ์…‹์— ์˜ˆ์ธก ๋ผ๋ฒจ, ์ •๋‹ต ์—ฌ๋ถ€, ํ™•๋ฅ ๊ฐ’๋“ค๊ณผ ์˜ˆ์ธก์˜ ์–ด๋ ค์›€ ๋“ฑ์ด ์ถ”๊ฐ€๋˜์—ˆ๋‹ค

 

  • ์ด๋•Œ, ์˜ˆ๋ฅผ ๋“ค์–ด p_true < 0.4 ๋˜๋Š” ์ƒ์œ„ 20% ์ •๋„ ๋“ฑ ์–ด๋А ์ •๋„๊นŒ์ง€ “hard” ์ƒ˜ํ”Œ๋กœ ๊ฐ„์ฃผํ•  ๊ฒƒ์ธ์ง€ ๊ธฐ์ค€์„ ์ •ํ•ด์•ผํ•œ๋‹ค
  • ์ด Hard Sample์„ ๋‹จ์ˆœ ์ˆ˜์ง‘ํ•˜๋Š” ์„ ์—์„œ ๋๋‚˜์„  ์•ˆ๋œ๋‹ค
  • ์–ด๋–ค ํด๋ž˜์Šค ๋ถ„ํฌ๋ฅผ ๋„๊ณ  ์žˆ๋Š”์ง€, ์–ด๋–ค ๋ฌธ์žฅ, ๋‹จ์–ด๋“ค๋กœ ์ด๋ฃจ์–ด์ ธ ์žˆ๋Š”์ง€ ๋ถ„์„์„ ํ•ด์ฃผ๋Š” ๊ฒƒ์ด ํ•„์š”ํ•˜๋‹ค

 

# ์ƒ์œ„ 20% hard samples
threshold = df["hardness"].quantile(0.8)
hard_top = df[df["hardness"] >= threshold].sort_values("hardness", ascending=False)
hard_top.head(20)

 

 

threshold = df["hardness"].quantile(0.8)
  • hardness ๋ถ„ํฌ์—์„œ ์ƒ์œ„ 20% ๊ฒฝ๊ณ„๊ฐ’ (80๋ฒˆ์งธ ๋ฐฑ๋ถ„์œ„์ˆ˜) ๋ฅผ ์˜๋ฏธ
  • ์ฆ‰, hardness ๊ฐ’์ด ์ „์ฒด ์ค‘ ์ƒ์œ„ 20% ์•ˆ์— ๋“ค์–ด๊ฐ€๋Š” ๊ธฐ์ค€์„ ์„ ์ฐพ๋Š” ๊ฒƒ

 

 

 

  • ํ‹€๋ฆฌ๊ฒŒ๋งŒ ์˜ˆ์ธกํ•œ ํ•˜๋“œ์ผ€์ด์Šค๋Š” ์ด 38404๊ฐœ, hardness๋ฅผ ๋†’์€ ์ˆœ์œผ๋กœ ์ƒ์œ„ 20%๋Š” ์ด 55072๊ฐœ์— ๋‹ฌํ•œ๋‹ค
  • ์ „์ฒด ๋ฐ์ดํ„ฐ์…‹์ด ๋Œ€๋žต 27๋งŒ๊ฐœ์ž„์„ ์ƒ๊ฐํ•˜๋ฉด ๊ฝค ์ ์ง€์•Š์€ ์ƒ˜ํ”Œ๋“ค์„ ์ œ๋Œ€๋กœ ๋งž์ถ”์ง€ ๋ชปํ•œ๋‹ค๊ณ  ๋ณผ ์ˆ˜ ์žˆ๋‹ค
  • ๊ทธ ์ค‘ ๊ฐ€์žฅ hardness๊ฐ€ ๋†’์€ ์ˆœ๋ถ€ํ„ฐ 20๊ฐœ๋ฅผ ๋ถˆ๋Ÿฌ์™”๋‹ค
  • ์ฒซ๋ฒˆ์งธ ๋ฌธ์žฅ์€ ์ •๋‹ต์ด 0๋ฒˆ์ธ๋ฐ 2๋ฒˆ์œผ๋กœ ์ž˜๋ชป ์˜ˆ์ธกํ•˜์˜€๋‹ค
  • ๊ทธ๋ฆฌ๊ณ  2๋ฒˆ์œผ๋กœ ์ž˜๋ชป ์˜ˆ์ธกํ•œ ํ™•๋ฅ ์€ ๋ฌด๋ ค 0.99... ์ •๋‹ต์„ ์˜ˆ์ธกํ•œ ํ™•๋ฅ ์€ 0.000036์œผ๋กœ ๋งค์šฐ๋งค์šฐ๋งค์šฐ ๋‚ฎ๊ฒŒ ์˜ˆ์ธกํ–ˆ๊ธฐ์— ๋ชจ๋ธ์ด ์–ด๋ ต๊ฒŒ ๋ณด์•˜์Œ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค

 

๊ทธ๋Ÿผ ๋ชป ๋งž์ถ”๋Š” ์ƒ์œ„ 20% ํ•˜๋“œ ์ƒ˜ํ”Œ์—์„œ ํด๋ž˜์Šค๋ณ„ ๊ฐœ์ˆ˜, ๋ถ„ํฌ๋Š”??

 

 

label count
0 12341
1 12460
2 17771
3 12500

 

 

 

 

 

 

 

 

 

 

 

Step 3. Hard Sample ๊ธฐ๋ฐ˜ ์ฆ๊ฐ•

ํ…์ŠคํŠธ ์ฆ๊ฐ•์—๋Š” ์—ฌ๋Ÿฌ ๊ฐ€์ง€ ๋ฐฉ๋ฒ•๋“ค์ด ์กด์žฌํ•œ๋‹ค

๊ทธ ์ค‘ ์ผ๋ถ€ ์‚ฌ์šฉํ•œ ๋ฐฉ๋ฒ•๋“ค๋งŒ ๊ฐ„๋‹จํ•˜๊ฒŒ ์ •๋ฆฌํ•˜์˜€๋‹ค

 

(A) ๋ฒˆ์—ญ ๊ธฐ๋ฐ˜ ์ฆ๊ฐ• (Back-translation)

 

  • ํ•œ๊ตญ์–ด → ์˜์–ด → ํ•œ๊ตญ์–ด ๋ฒˆ์—ญ
  • ๋ฌธ์žฅ ๊ตฌ์กฐ์™€ ์ผ๋ถ€ ๋‹จ์–ด ๋ณ€๊ฒฝ → ์˜๋ฏธ ์œ ์ง€
  • ํ—ˆ๊น…ํŽ˜์ด์Šค์— ์žˆ๋Š” ๋‹ค์–‘ํ•œ ๋ฒˆ์—ญ ๋ชจ๋ธ๋“ค ํ™œ์šฉ ๊ฐ€๋Šฅ
  • ์ƒ˜ํ”Œ์˜ ๊ธธ์ด๊ฐ€ ๋„ˆ๋ฌด ์งง์€ ์ผ€์ด์Šค๋Š” ๋ฒˆ์—ญ์ด ์•ˆ๋˜๊ณ  ๋ฐ์ดํ„ฐ ํ’ˆ์งˆ์ด ๋” ์ €ํ•˜๋  ์ˆ˜ ์žˆ๊ธฐ์— ํ•„ํ„ฐ๋งํ•ด์ฃผ๋Š” ๊ฒƒ์ด ์ข‹์„ ๊ฒƒ์œผ๋กœ ๋ณด์ธ๋‹ค
    • ex) 2.3 / ์•ˆ๊ณผ ๊ฒ‰ / ์—ํœด ๋“ฑ๋“ฑ..

 

ํ•„ํ„ฐ๋ง

 

(1) ์ƒ˜ํ”Œ ๊ธธ์ด ๊ธฐ์ค€

  • ์งง์€ ๋ฌธ์žฅ, ํ•œ๋‘ ๋‹จ์–ด๋งŒ ์žˆ๋Š” ๋ฆฌ๋ทฐ๋Š” Back-Translation ํšจ๊ณผ ๊ฑฐ์˜ ์—†์Œ → ์˜๋ฏธ ํ›ผ์† ๊ฐ€๋Šฅ
  • ์˜ˆ: ๊ธธ์ด 5~10์ž ์ดํ•˜(ํ† ํฐ ๊ธฐ์ค€) ์ƒ˜ํ”Œ์€ ์ฆ๊ฐ• ์ œ์™ธ
min_len = 10  # ํ† ํฐ ์ˆ˜ ๊ธฐ์ค€
hard_top_filtered = hard_top[hard_top['review_normalized'].str.len() >= min_len]

 

 

(2) ๋ชจ๋ธ ํ™•์‹ ๋„(confidence) ๊ธฐ์ค€

  • ์ด๋ฏธ ๋งž์ถ˜ ์ƒ˜ํ”Œ์ธ๋ฐ confidence๊ฐ€ ๋†’๋‹ค๋ฉด ์ฆ๊ฐ• ํ•„์š” ์—†์Œ
  • ์˜ˆ: confidence < 0.7 → ์ฆ๊ฐ• ํ›„๋ณด
hard_top_filtered = hard_top_filtered[hard_top_filtered['oof_confidence'] < 0.7]

 

(3) ๋ชจ๋ธ ์˜ˆ์ธก ์–ด๋ ค์›€ ๊ธฐ์ค€

  • ์˜ˆ์ธก์ด ์–ด๋ ค์šด ๊ธฐ์ค€๊ฐ’, hardness๊ฐ€ ์ผ์ • ๊ธฐ์ค€๋ณด๋‹ค ๋†’์„ ๋•Œ ๋” ํ•™์Šต์„ ์ž˜ ํ•  ์ˆ˜ ์žˆ๊ฒŒ๋” ์œ ์‚ฌํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ์ถ”๊ฐ€ํ•ด์ฃผ๋ฉด ์œ ํšจํ•  ๊ฒƒ์œผ๋กœ ํŒ๋‹จ
  • ์˜ˆ: hardness > 0.6 → ์ฆ๊ฐ• ํ›„๋ณด
hard_top_filtered = hard_top_filtered[hard_top_filtered['hardness'] > 0.6]

 

(4) ํด๋ž˜์Šค ๋น„์œจ ๊ธฐ์ค€

  • ์†Œ์ˆ˜ ํด๋ž˜์Šค ์šฐ์„  ์ฆ๊ฐ•: 1(์•ฝํ•œ ๋ถ€์ •)๊ณผ 3(๊ฐ•ํ•œ ๊ธ์ •)
hard_top_filtered = hard_top_filtered[hard_top_filtered['label'].isin([1, 3])]

 

 

 

 

์•„๋ž˜๋Š” ํ•œ๊ตญ์–ด -> ์˜์–ด -> ํ•œ๊ตญ์–ด๋กœ ๋ฒˆ์—ญ์„ ๊ฑฐ์ณ ์ฆ๊ฐ•์‹œํ‚ค๋Š” ์ƒ˜ํ”Œ ์ฝ”๋“œ์ด๋‹ค

 

from transformers import MBartForConditionalGeneration, MBart50TokenizerFast

# ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
mbart_model_name = "facebook/mbart-large-50-many-to-many-mmt"
model = MBartForConditionalGeneration.from_pretrained(mbart_model_name)
tokenizer = MBart50TokenizerFast.from_pretrained(mbart_model_name)

import torch

def mbart_translate(texts, src_lang="ko_KR", tgt_lang="en_XX"):
    translated_texts = []
    for text in texts:
        tokenizer.src_lang = src_lang
        encoded = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
        generated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang])
        translated_texts.append(tokenizer.decode(generated_tokens[0], skip_special_tokens=True))
    return translated_texts


# ์˜ˆ์‹œ ํ•˜๋“œ ์ƒ˜ํ”Œ
hard_samples = [
    "์ด ๋ฌธ์žฅ์€ ๊ฐ์ • ๋ถ„์„ ๋ชจ๋ธ์„ ํ…Œ์ŠคํŠธํ•˜๊ธฐ ์œ„ํ•œ ์˜ˆ์‹œ์ž…๋‹ˆ๋‹ค.",
    "๋ชจ๋ธ์ด ์ž˜ ๋งž์ถ”์ง€ ๋ชปํ•˜๋Š” ํ•˜๋“œ ์ผ€์ด์Šค๋ฅผ ๋ถ„์„ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค."
]

# ํ•œ๊ตญ์–ด → ์˜์–ด
ko2en = mbart_translate(hard_samples, src_lang="ko_KR", tgt_lang="en_XX")
print("ํ•œ๊ตญ์–ด → ์˜์–ด:", ko2en)

# ์˜์–ด → ํ•œ๊ตญ์–ด (Back-translation)
back_translated = mbart_translate(ko2en, src_lang="en_XX", tgt_lang="ko_KR")
print("Back-translation (ํ•œ๊ตญ์–ด → ์˜์–ด → ํ•œ๊ตญ์–ด):", back_translated)

 

 

ํ•œ๊ตญ์–ด → ์˜์–ด: 
['This sentence is an example of testing an emotion-analysis model.', 
 "You have to analyze the hard case where the model doesn't fit."]


Back-translation (ํ•œ๊ตญ์–ด → ์˜์–ด → ํ•œ๊ตญ์–ด): 
['์ด ๋ฌธ์žฅ์€ ๊ฐ์ • ๋ถ„์„ ๋ชจ๋ธ์„ ์‹œํ—˜ํ•˜๋Š” ์˜ˆ์ž…๋‹ˆ๋‹ค.', 
 '๋ชจ๋ธ์ด ๋งž์ง€ ์•Š๋Š” ์–ด๋ ค์šด ๊ฒฝ์šฐ๋ฅผ ๋ถ„์„ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.']
  • ๋ณด๋‹ค์‹œํ”ผ ์œ ์‚ฌํ•œ ๋ฌธ์žฅ์ด ์ถœ๋ ฅ์œผ๋กœ ๋‚˜์˜จ ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค
  • ์œ„์˜ ์˜ˆ์ œ ์ฝ”๋“œ๋Š” ์–ผ๋งˆ ์•ˆ๋˜๋Š” ๋ฌธ์žฅ๋“ค์ด๋ผ์„œ ๊ฐ„๋‹จํ•˜๊ฒŒ for๋ฌธ ๋Œ๋ฆฌ์ง€๋งŒ ์‹ค์ œ๋กœ๋Š” ์ฒœ์—์„œ ๋งŒ๋‹จ์œ„์˜ ๋ฌธ์žฅ๋“ค์„ ๋Œ๋ ค์•ผํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์ €๋ ‡๊ฒŒ ์“ฐ๋ฉด ์•ˆ๋œ๋‹ค.. 

 

ํ•„ํ„ฐ๋ง ๋ฐ Back-translation ์ ์šฉ ์ฝ”๋“œ

๋”๋ณด๊ธฐ
from transformers import MBartForConditionalGeneration, MBart50Tokenizer
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import pandas as pd

# ==========================================================
# 1. ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
# ==========================================================
mbart_model_name = "facebook/mbart-large-50-many-to-many-mmt"
model = MBartForConditionalGeneration.from_pretrained(mbart_model_name, torch_dtype=torch.float16)
tokenizer = MBart50Tokenizer.from_pretrained(mbart_model_name)
model = torch.compile(model)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)


# ==========================================================
# 2. Dataset ์ •์˜
# ==========================================================
class TextDataset(Dataset):
    def __init__(self, texts):
        self.texts = texts

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx]


# ==========================================================
# 3. Collate ํ•จ์ˆ˜ (๋ฐฐ์น˜ ๋‹จ์œ„ ํ† ํฌ๋‚˜์ด์ง•)
# ==========================================================
def collate_fn(batch_texts, src_lang, tokenizer, device):
    tokenizer.src_lang = src_lang
    encoded = tokenizer(
        batch_texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=128
    )
    encoded = {k: v.to(device) for k, v in encoded.items()}
    return encoded


# ==========================================================
# 4. DataLoader ๊ธฐ๋ฐ˜ ๋ฒˆ์—ญ ํ•จ์ˆ˜
# ==========================================================
def mbart_translate_dataloader(texts, src_lang="ko_KR", tgt_lang="en_XX", batch_size=32, num_workers=2):
    dataset = TextDataset(texts)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    translated_texts = []

    for batch_texts in tqdm(dataloader, desc=f"Translating {src_lang} → {tgt_lang}"):
        encoded = collate_fn(batch_texts, src_lang, tokenizer, device)
        with torch.no_grad():
            generated_tokens = model.generate(
                **encoded,
                forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang],
                max_length=256
            )
        decoded = [tokenizer.decode(t, skip_special_tokens=True) for t in generated_tokens]
        translated_texts.extend(decoded)

    return translated_texts


# ==========================================================
# 5. ํ•„ํ„ฐ๋ง + ์ฆ๊ฐ• ํ•จ์ˆ˜ (DataLoader ๊ธฐ๋ฐ˜)
# ==========================================================
def augment_hard_samples_filtered_dataloader(df_hard, n_aug=1,
                                             min_len=10, max_confidence=0.7,
                                             target_labels=[1,3], top_hardness=0.7,
                                             batch_size=32, num_workers=2):
    """
    df_hard: ํ•˜๋“œ ์ƒ˜ํ”Œ ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„ (review_normalized, label, oof_confidence, hardness ํฌํ•จ)
    n_aug: ๊ฐ ์ƒ˜ํ”Œ๋‹น ์ƒ์„ฑํ•  ์ฆ๊ฐ• ์ˆ˜
    """
    # ํ•„ํ„ฐ๋ง
    df_filtered = df_hard[
        (df_hard['label'].isin(target_labels)) &
        (df_hard['oof_confidence'] <= max_confidence) &
        (df_hard['review_normalized'].str.len() >= min_len)
    ]
    threshold = df_filtered['hardness'].quantile(top_hardness)
    df_filtered = df_filtered[df_filtered['hardness'] >= threshold]

    print(f"์ฆ๊ฐ• ๋Œ€์ƒ ์ƒ˜ํ”Œ ์ˆ˜: {len(df_filtered)}")

    aug_texts = []
    aug_labels = []

    for _ in tqdm(range(n_aug), desc="Augmentation rounds"):
        ko_texts = df_filtered['review_normalized'].tolist()
        # ํ•œ๊ตญ์–ด → ์˜์–ด → ํ•œ๊ตญ์–ด
        en_texts = mbart_translate_dataloader(ko_texts, src_lang="ko_KR", tgt_lang="en_XX",
                                              batch_size=batch_size, num_workers=num_workers)
        back_texts = mbart_translate_dataloader(en_texts, src_lang="en_XX", tgt_lang="ko_KR",
                                                batch_size=batch_size, num_workers=num_workers)
        aug_texts.extend(back_texts)
        aug_labels.extend(df_filtered['label'].tolist())

    return aug_texts, aug_labels


# ==========================================================
# 6. ์˜ˆ์‹œ ์‚ฌ์šฉ
# ==========================================================
# hard_top_filtered: ๊ธฐ์กด ํ•„ํ„ฐ๋ง ํ›„ ํ•˜๋“œ ์ƒ˜ํ”Œ ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„
aug_texts, aug_labels = augment_hard_samples_filtered_dataloader(
    df_hard=hard_top,
    n_aug=1,
    min_len=10,
    max_confidence=0.7,
    target_labels=[1,3],
    top_hardness=0.7,
    batch_size=128,
    num_workers=8   # CPU ์ฝ”์–ด ์—ฌ์œ  ์žˆ์œผ๋ฉด 4~8 ์ถ”์ฒœ
)

print(f"์ฆ๊ฐ• ์™„๋ฃŒ ์ƒ˜ํ”Œ ์ˆ˜: {len(aug_texts)}")
print("์˜ˆ์‹œ ์ฆ๊ฐ• ๋ฌธ์žฅ:", aug_texts[:5])

 

 

 

์ฆ๊ฐ• ๋Œ€์ƒ ์ƒ˜ํ”Œ ์ˆ˜: 3889
Translating ko_KR → en_XX: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 31/31 [15:53<00:00, 30.76s/it]
Translating en_XX → ko_KR: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 31/31 [08:43<00:00, 16.87s/it]
Augmentation rounds: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [24:36<00:00, 1476.75s/it]


์ฆ๊ฐ• ์™„๋ฃŒ ์ƒ˜ํ”Œ ์ˆ˜: 3889
์˜ˆ์‹œ ์ฆ๊ฐ• ๋ฌธ์žฅ: 
['์•„์‹œ๋‚˜์š”?  plot ์„ ๋ณด์‹ค ์ˆ˜ ์žˆ์ฃ . ๊ทธ๋Š” ๋ฉ์ฒญ์ด๊ฐ€ ์•„๋‹ˆ๋ผ, ์•„๋“ค์ด ์žˆ๊ณ , ์กฐ์นด๊ฐ€ ์žˆ๊ณ , ์•„๊ธฐ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.', 
'๋ชจ๋“  ๊ฒƒ์ด ์ข‹์•˜๊ณ , ๋๋„ ์ข‹์•˜์Šต๋‹ˆ๋‹ค.', 
'๋‹ค๋ฅธ ๋”์ฐํ•œ ๋ธ”๋Ÿญ๋ฒ„์Šคํ„ฐ๋ณด๋‹ค ๋” ๋‚ซ๊ณ , ๋” ์ƒ๊ฐํ•˜๊ฒŒ ํ•ด์ฃผ๊ณ , ์ œ ํ˜•์ด 7๋…„ ๋™์•ˆ Phuket์—์„œ ์‚ด์•˜๋˜ ๊ฒƒ์„ ๋– ์˜ฌ๋ฆฌ๊ฒŒ ํ•ด์ฃผ๊ณ , ํ‘ธ์—๋ฅดํ† ๋ฆฌ์ฝ”๋กœ ๊ฐ€๋Š” ๊ฒƒ์„ ๋– ์˜ฌ๋ฆฌ๊ฒŒ ํ•ด์ฃผ์ฃ .', 
'๊ทธ๋ ‡์ง€ ์•Š๋‹ค๋ฉด, ์ €๋Š” ๋ฉ‹์ง„ ์ Š์€ ๋ธŒ๋ž˜๋“œ ํ”ผํŠธ๋ฅผ ๋ฌด์žฅ๋ณต์„ ์ž…๊ณ  ๋งค๋ ฅ์ ์ด๊ณ  ์™ธํ–ฅ์ ์ธ ์บ๋ฆญํ„ฐ๋กœ ์ž…์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค.', 
'์œ ์ผํ•œ ๋‹จ์ ์€ ์ด ์˜ํ™”๊ฐ€ ๋„ˆ๋ฌด ์งง๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.']

 

  • ์œ„์˜ ์ฝ”๋“œ๋ฅผ ํ†ตํ•ด ํ•„ํ„ฐ๋ง๋œ Hard Case๋Š” 3889๊ฑด
  • ํ•œ๊ตญ์–ด โžก๏ธ ์˜์–ด โžก๏ธ ํ•œ๊ตญ์–ด๋กœ 2๋ฒˆ์˜ ๋ฒˆ์—ญ์„ ๊ฑฐ์น˜๋Š”๋ฐ 24๋ถ„ ์ •๋„๊ฐ€ ๊ฑธ๋ ธ๋‹ค
  • ์ฆ๊ฐ• ๊ฒฐ๊ณผ์˜ ์ผ๋ถ€๋ฅผ ํ™•์ธํ•ด๋ณด๋ฉด ๋ฒˆ์—ญํ–ˆ์„ ๋•Œ ํŠน์œ ์˜ ๋ฌธ์ฒด์™€ ๋А๋‚Œ์ด ์žˆ์–ด์„œ ๋„ˆ๋ฌด ๋งŽ์ด ์ฆ๊ฐ•ํ•  ๊ฒฝ์šฐ ๋ฒˆ์—ญ์ฒด์— ๋Œ€ํ•ด ํŽธํ–ฅ์ด ์ปค์ ธ์„œ ๊ตฌ์–ด์ฒด์™€ ๋น„๋ฌธ์— ๋Œ€ํ•œ ์˜ˆ์ธก์ด ๋–จ์–ด์งˆ ํ™•๋ฅ ์ด ์ปค์ง€๋ฏ€๋กœ ์ด ๋ถ€๋ถ„ ์ฃผ์˜ํ•ด์•ผ ํ•œ๋‹ค

 

 

 

 

 

(B) LLM ๊ธฐ๋ฐ˜ ์ฆ๊ฐ• (ํ•ฉ์„ฑ ์ƒ์„ฑ)

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch, pandas as pd
from sentence_transformers import SentenceTransformer, util
import torch.nn.functional as F
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm

# =========================
# ์„ค์ •
# =========================
PROMPT_TEMPLATE = """๋‹ค์Œ ๋ฌธ์žฅ์„ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ์˜๋ฏธ๊ฐ€ ๊ฐ™์€ ๋‹ค๋ฅธ ํ‘œํ˜„์œผ๋กœ 2๊ฐ€์ง€ ์จ์ค˜.
๊ตฌ์–ด์ฒด๋‚˜ ์ผ์ƒ์ ์ธ ํ‘œํ˜„์„ ์‚ฌ์šฉํ•˜๊ณ , ๋ฌธ๋ฒ•์ ์œผ๋กœ ์•ฝ๊ฐ„ ํ‹€๋ ค๋„ ์ž์—ฐ์Šค๋Ÿฌ์šฐ๋ฉด ์ข‹์•„.

์›๋ฌธ: {sentence}
"""

LLM_MODEL = "spow12/Ko-Qwen2-7B-Instruct"
CLS_MODEL = "./best_model"
# CLS_MODEL = "klue/roberta-base"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

BATCH_SIZE_LLM = 8       # LLM ๋ฐฐ์น˜ ํฌ๊ธฐ
BATCH_SIZE_SBERT = 16 #SBERT, ๊ฐ์„ฑ ๋ถ„๋ฅ˜ ๋ฐฐ์น˜ ํฌ๊ธฐ

# =========================
# ๋ชจ๋ธ ๋กœ๋“œ
# =========================
tokenizer_llm = AutoTokenizer.from_pretrained(LLM_MODEL)
model_llm = AutoModelForCausalLM.from_pretrained(
    LLM_MODEL,
    torch_dtype=torch.float16,
    device_map="auto",
    offload_folder="./offload"  # CPU๋กœ ์ž„์‹œ ์˜คํ”„๋กœ๋“œ
)

sbert = SentenceTransformer("jhgan/ko-sroberta-multitask").to(DEVICE)

tokenizer_cls = AutoTokenizer.from_pretrained(CLS_MODEL)
model_cls = AutoModelForSequenceClassification.from_pretrained(CLS_MODEL).to(DEVICE)

# =========================
# ํ•จ์ˆ˜ ์ •์˜
# =========================
def generate_paraphrases_batch(sentences, num_return=3):
    prompts = [PROMPT_TEMPLATE.format(sentence=s) for s in sentences]
    inputs = tokenizer_llm(prompts, padding=True, truncation=True, return_tensors="pt").to(model_llm.device)
    outputs = model_llm.generate(
        **inputs,
        max_new_tokens=128,
        temperature=0.9,
        top_p=0.9,
        do_sample=True,
        num_return_sequences=num_return,
        eos_token_id=tokenizer_llm.eos_token_id
    )
    decoded = tokenizer_llm.batch_decode(outputs, skip_special_tokens=True)

    # ๊ฒฐ๊ณผ๋ฅผ 2์ฐจ์› ๋ฆฌ์ŠคํŠธ๋กœ ๋ณ€ํ™˜: [ [๋ฌธ์žฅ1, ๋ฌธ์žฅ2, ...], [๋ฌธ์žฅ1, ...], ... ]
    result_batch = []
    for i in range(0, len(decoded), num_return):
        result_batch.append([r.split("์›๋ฌธ:")[-1].strip() for r in decoded[i:i+num_return]])
    return result_batch

def is_similar_batch(originals, candidates, threshold=0.80):
    # originals: list of str
    # candidates: list of str
    orig_emb = sbert.encode(originals, batch_size=BATCH_SIZE_SBERT, convert_to_tensor=True)
    cand_emb = sbert.encode(candidates, batch_size=BATCH_SIZE_SBERT, convert_to_tensor=True)
    sim_matrix = util.cos_sim(orig_emb, cand_emb)
    return sim_matrix.diagonal() >= threshold  # ๊ฐ ํ›„๋ณด์™€ ์›๋ฌธ ์œ ์‚ฌ๋„

def is_same_emotion_batch(originals, candidates, labels):
    inputs = tokenizer_cls(candidates, padding=True, truncation=True, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        probs = F.softmax(model_cls(**inputs).logits, dim=-1)
    pred_labels = torch.argmax(probs, dim=-1).cpu().tolist()
    return [pred_labels[i] == labels[i] for i in range(len(labels))]

def remove_duplicates(texts, threshold=0.95):
    if len(texts) <= 1:
        return texts
    vectorizer = TfidfVectorizer().fit_transform(texts)
    sim_matrix = cosine_similarity(vectorizer)
    keep = []
    for i in range(len(texts)):
        if not any(sim_matrix[i][j] > threshold for j in keep):
            keep.append(i)
    return [texts[i] for i in keep]

# =========================
# ์ฆ๊ฐ• ๋ฃจํ”„ (๋ฐฐ์น˜ ์ตœ์ ํ™”)
# =========================
augmented_data = []

minority_texts = minority_df["review"].tolist()
minority_labels = minority_df["label"].tolist()

for i in tqdm(range(0, len(minority_texts), BATCH_SIZE_LLM)):
    batch_texts = minority_texts[i:i+BATCH_SIZE_LLM]
    batch_labels = minority_labels[i:i+BATCH_SIZE_LLM]

    # 1. LLM์œผ๋กœ paraphrase ์ƒ์„ฑ
    batch_paraphrases = generate_paraphrases_batch(batch_texts, num_return=2)

    for j, orig in enumerate(batch_texts):
        candidates = batch_paraphrases[j]
        labels = [batch_labels[j]] * len(candidates)

        # 2. ์˜๋ฏธ ์œ ์‚ฌ๋„ ๊ฒ€์‚ฌ
        sim_mask = is_similar_batch([orig]*len(candidates), candidates)

        # 3. ๊ฐ์ • ์ผ๊ด€์„ฑ ๊ฒ€์‚ฌ
        emotion_mask = is_same_emotion_batch([orig]*len(candidates), candidates, labels)

        # 4. ํ•„ํ„ฐ๋ง
        filtered = [c for k, c in enumerate(candidates) if sim_mask[k] and emotion_mask[k]]

        # 5. ์ค‘๋ณต ์ œ๊ฑฐ
        filtered = remove_duplicates(filtered)

        # 6. ์ €์žฅ
        for t in filtered:
            augmented_data.append({"review": t, "label": batch_labels[j]})

aug_df = pd.DataFrame(augmented_data)
aug_df.to_csv("augmented_data_batch.csv", index=False)

 

 

 

  • ์œ ์˜ํ•  ์ ์œผ๋กœ๋Š” Back-Translation, LLM ์ฆ๊ฐ• ๋ชจ๋‘ ๋ชจ๋ธ์„ ํ†ตํ•ด ์•„์›ƒํ’‹์„ ๋ฝ‘์•„๋‚ด๋Š” ๊ฒƒ์ด๊ธฐ ๋•Œ๋ฌธ์— ๊ทœ๋ชจ๊ฐ€ ์žˆ๋Š” ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜๋ฉฐ, ๋ช‡ ๋งŒ ๊ฑด ์ด์ƒ์œผ๋กœ ์ฆ๊ฐ•์„ ํ•˜๊ณ ์ž ํ•œ๋‹ค๋ฉด ๊ฝค ๋งŽ์€ ๋ฆฌ์†Œ์Šค(GPU)๋ฅผ ์‚ฌ์šฉํ•ด์•ผ ๋น ๋ฅด๊ฒŒ ์ฆ๊ฐ•์„ ํ•  ์ˆ˜ ์žˆ๋‹ค
  • ์ปดํ“จํŒ… ๋ฆฌ์†Œ์Šค๊ฐ€ ๋ฐ›์ณ์ฃผ์ง€ ์•Š๋Š”๋‹ค๋ฉด ์ •๋ง ํ•˜๋ฃจ์ข…์ผ ์ฆ๊ฐ•ํ•˜๋Š”๋ฐ๋งŒ ์„œ๋ฒ„๋ฅผ ๋Œ๋ ค์•ผ... 
  • 16,000๊ฑด Back-Translation์„ ์ ์šฉํ•˜๋Š”๋ฐ 6์‹œ๊ฐ„ ์ •๋„ ๊ฑธ๋ ธ๋‹ค

 

 

 

 

 

 

 


 

 

 

๊ฐ ํ…์ŠคํŠธ๋งˆ๋‹ค ๋ผ๋ฒจ์— ๋Œ€ํ•œ ํ™•๋ฅ ๊ฐ’์„ ์ด์šฉํ•ด ๋ชจ๋ธ์ด ๋ฐ์ดํ„ฐ์…‹์— ๋Œ€ํ•ด ์–ผ๋งˆ๋‚˜ ์˜ˆ์ธก์„ ์ž˜ํ•˜๊ณ  ๋ชปํ•˜๋Š”์ง€๋ฅผ ๊ตฌ๋ณ„ํ•ด๋‚ผ ์ˆ˜ ์žˆ์—ˆ๋‹ค

ํ™•๋ฅ ๊ฐ’์ด๋ผ๋Š” ์ˆ˜์น˜๋ฅผ ํ†ตํ•ด ์ด๋ ‡๊ฒŒ ํŒŒ์ƒ์ ์œผ๋กœ ๋ถ„์„๊ณผ ์—ฌ๋Ÿฌ ์•ก์…˜๋“ค์„ ์ทจํ•ด๋ณผ ์ˆ˜ ์žˆ๋‹ค๋‹ˆ ๊ฝค ํฅ๋ฏธ๋กœ์› ๋‹ค

Hard Case์— ๋Œ€ํ•ด์„œ ๋” ์‹ฌ๋„์žˆ๋Š” ๋ถ„์„์„ ์ง„ํ–‰ํ•  ์ˆ˜ ์žˆ์—ˆ์œผ๋ฉด ์ข‹์•˜์„ํ…๋ฐ ์ด ๋ถ€๋ถ„์—์„œ ๋ง‰ํ˜”๋‹ค

๋ถ„์„์„ ํ†ตํ•ด ์ธ์‚ฌ์ดํŠธ๋ฅผ ๋ฝ‘์•„๋‚ด์•ผ ๋ญ”๊ฐ€๋ฅผ ๋” ํ•ด๋ณผ ๊ฑธ ์ฐพ์„ํ…๋ฐ ํ…์ŠคํŠธ ๋ถ„์„์— ๋Œ€ํ•ด์„œ๋Š” ์ •๋ง ๊ธฐ์ดˆ์ ์ธ ๋ฐฉ๋ฒ•๋“ค ๋ฐ–์— ๋– ์˜ค๋ฅด์ง€ ์•Š์•„ ์ฐธ ์•„์‰ฌ์› ๋‹ค

 

 

 

์ง€๊ธˆ์€ ํƒœ์Šคํฌ๊ฐ€ ๋น„๊ต์  ๊ฐ„๋‹จํ•œ ๋ถ„๋ฅ˜๋ผ์„œ ์ด๋ ‡๊ฒŒ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•˜์ง€๋งŒ ์‹ค์ œ LLM์„ ์ด์šฉํ•ด ๋ฌธ์žฅ์„ ์ƒ์„ฑํ•˜๋Š” ๊ฒฝ์šฐ์—๋Š” ์›ํ•˜๋Š” ๋ชฉ์ ๊ณผ ๋‹ค๋ฅธ๋ฐ ์œ ์‚ฌํ•ด๋ณด์ด๋Š” ๋ฌธ์žฅ๋“ค์„ ํ•ธ๋“ค๋งํ•˜๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•˜๋‹ค๊ณ  ํ•˜๋‹ค

์•„๋ž˜ ๊ธ€๋„ ๋‚˜์ค‘์— ์ฐธ๊ณ ํ•  ๊ฒธ ์ €์žฅ

 

https://sjkoding.tistory.com/102

 

[LLM] Text Embedding๋ชจ๋ธ ํŒŒ์ธํŠœ๋‹์„ ์œ„ํ•œ Hard Negative Mining ๋ฐฉ๋ฒ•๋ก  ํ•ต์‹ฌ ์ •๋ฆฌ

๋งˆ์ง€๋ง‰ ํฌ์ŠคํŒ… ์ดํ›„ ์–ด๋А๋ง 5๊ฐœ์›”์˜ ์‹œ๊ฐ„์ด ํ˜๋ €๋Š”๋ฐ, ์‚ฌ์‹ค ์ด ์‚ฌ์ด์— ํšŒ์‚ฌ ์ด์ง๊ณผ ์ ์‘์„ ํ•˜๋А๋žด ๋ธ”๋กœ๊ทธ๋ฅผ ์‹ ๊ฒฝ์“ฐ์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค.๊ธฐ์กด์—๋Š” LLM ์ฑ—๋ด‡ ๊ตฌ์ถ•์„ ์œ„ํ•œ ์„œ๋น„์Šค๋ฅผ ๊ฐœ๋ฐœํ–ˆ๋‹ค๋ฉด, ํ˜„์žฌ๋Š” RAG

sjkoding.tistory.com