๊ฐœ๋ฐœ Code/์ธ๊ณต์ง€๋Šฅ A.I.

[AI][CatBoost] CatBoost๋กœ Wine Quality ์˜ˆ์ธกํ•˜๊ธฐ

5hr1rnp 2025. 2. 14. 19:11
๋ฐ˜์‘ํ˜•

wine quality dataset โ˜ catboost

2025.01.23 - [๊ฐœ๋ฐœ Code/์ธ๊ณต์ง€๋Šฅ A.I.] - [Python][AI] ํƒ์ƒ‰์  ๋ฐ์ดํ„ฐ ๋ถ„์„(EDA) - ์™€์ธ ํ’ˆ์งˆ ๋ฐ์ดํ„ฐ์…‹ (Wine Quality Dataset) - 1

2025.01.24 - [๊ฐœ๋ฐœ Code/์ธ๊ณต์ง€๋Šฅ A.I.] - [Python][AI] ํƒ์ƒ‰์  ๋ฐ์ดํ„ฐ ๋ถ„์„(EDA) - ์™€์ธ ํ’ˆ์งˆ ๋ฐ์ดํ„ฐ์…‹ (Wine Quality Dataset) - 2

2025.02.04 - [๊ฐœ๋ฐœ Code/์ธ๊ณต์ง€๋Šฅ A.I.] - [Python][AI] ํƒ์ƒ‰์  ๋ฐ์ดํ„ฐ ๋ถ„์„(EDA) - ์™€์ธ ํ’ˆ์งˆ ๋ฐ์ดํ„ฐ์…‹ (Wine Quality Dataset) - 3

2025.02.04 - [๊ฐœ๋ฐœ Code/์ธ๊ณต์ง€๋Šฅ A.I.] - [Python][AI] ํƒ์ƒ‰์  ๋ฐ์ดํ„ฐ ๋ถ„์„(EDA) - ์™€์ธ ํ’ˆ์งˆ ๋ฐ์ดํ„ฐ์…‹ (Wine Quality Dataset) - 4

1. ๊ฐœ์š”


์™€์ธ์˜ ํ’ˆ์งˆ์„ ์˜ˆ์ธกํ•˜๋Š” ๊ฒƒ์€ ๋จธ์‹ ๋Ÿฌ๋‹์—์„œ ํ”ํžˆ ๋‹ค๋ฃจ๋Š” ๋ฌธ์ œ ์ค‘ ํ•˜๋‚˜์ด๋‹ค. Wine Quality Dataset์€ ์™€์ธ์˜ ํ™”ํ•™์  ํŠน์„ฑ์„ ๊ธฐ๋ฐ˜์œผ๋กœ ํ’ˆ์งˆ์„ ์˜ˆ์ธกํ•˜๋Š” ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ, ํšŒ๊ท€ ๋ฌธ์ œ์— ํ•ด๋‹นํ•œ๋‹ค. ์ด๋ฒˆ ๊ธ€์—์„œ๋Š” CatBoost๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์™€์ธ ํ’ˆ์งˆ์„ ์˜ˆ์ธกํ•˜๋Š” ๋ชจ๋ธ์„ ๋งŒ๋“ค์–ด ๋ณผ ๊ฒƒ์ด๋‹ค.


2. Wine Quality Dataset ์†Œ๊ฐœ


Wine Quality Dataset์€ UCI Machine Learning Repository์—์„œ ์ œ๊ณตํ•˜๋Š” ๊ณต๊ฐœ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ, ๋ ˆ๋“œ ์™€์ธ๊ณผ ํ™”์ดํŠธ ์™€์ธ์— ๋Œ€ํ•œ ํ’ˆ์งˆ ์ •๋ณด๋ฅผ ํฌํ•จํ•˜๊ณ  ์žˆ์Œ.

  • ํŠน์„ฑ(features): 11๊ฐœ์˜ ํ™”ํ•™์  ์„ฑ๋ถ„ (์˜ˆ: pH, ์•Œ์ฝ”์˜ฌ, ํ™ฉ์‚ฐ์—ผ ๋“ฑ)
  • ๋ ˆ์ด๋ธ”(label): 0~10๊นŒ์ง€์˜ ํ’ˆ์งˆ ์ ์ˆ˜ (ํšŒ๊ท€ ๋ฌธ์ œ)

์ด ๋ฐ์ดํ„ฐ์…‹์€ ์—ฌ๊ธฐ์—์„œ ๋‹ค์šด๋กœ๋“œ ๊ฐ€๋Šฅํ•จ.


3. CatBoost ์„ค์น˜


CatBoost๋Š” ๋‹ค์Œ ๋ช…๋ น์–ด๋กœ ๊ฐ„๋‹จํžˆ ์„ค์น˜ํ•  ์ˆ˜ ์žˆ์Œ.

 

# conda config --add channels conda-forge
# conda install catboost
pip install catboost

4. ๋ฐ์ดํ„ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ ๋ฐ ์ „์ฒ˜๋ฆฌ


Python์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ฐ์ดํ„ฐ์…‹์„ ๋ถˆ๋Ÿฌ์˜ค๊ณ , ์ „์ฒ˜๋ฆฌ๋ฅผ ์ง„ํ–‰ํ•จ.

 

import pandas as pd
from sklearn.model_selection import train_test_split
from catboost import CatBoostRegressor
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

# ๋ฐ์ดํ„ฐ ๋กœ๋“œ
# ๋˜๋Š” ๋””๋ ‰ํ† ๋ฆฌ ์œ„์น˜์— ๋งž๊ฒŒ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
df = pd.read_csv("https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv", sep=';')

# ์ž…๋ ฅ(X)์™€ ํƒ€๊ฒŸ(y) ๋ถ„๋ฆฌ
X = df.drop(columns=['quality'])
y = df['quality']

# ํ•™์Šต ๋ฐ์ดํ„ฐ์™€ ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ ๋ถ„๋ฆฌ
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

print(f"Train ๋ฐ์ดํ„ฐ ํฌ๊ธฐ: {X_train.shape}")
print(f"Test ๋ฐ์ดํ„ฐ ํฌ๊ธฐ: {X_test.shape}")

# Train ๋ฐ์ดํ„ฐ ํฌ๊ธฐ: (1279, 11)
# Test ๋ฐ์ดํ„ฐ ํฌ๊ธฐ: (320, 11)

728x90
๋ฐ˜์‘ํ˜•

5. CatBoost ๋ชจ๋ธ ํ•™์Šต


CatBoost๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์™€์ธ์˜ ํ’ˆ์งˆ์„ ์˜ˆ์ธกํ•˜๋Š” ํšŒ๊ท€ ๋ชจ๋ธ์„ ํ•™์Šตํ•จ.

# CatBoost ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
model = CatBoostRegressor(iterations=1000, depth=6, learning_rate=0.1, loss_function='MAE', verbose=200)

model.fit(X_train, y_train)

# ์˜ˆ์ธก
y_pred = model.predict(X_test)

ํŒŒ๋ผ๋ฏธํ„ฐ ์„ค๋ช…

  • iterations=1000: 1000๋ฒˆ ๋ฐ˜๋ณต ํ•™์Šต
  • depth=6: ํŠธ๋ฆฌ ๊นŠ์ด ์„ค์ • (๊ฐ’์ด ํด์ˆ˜๋ก ๋ณต์žกํ•œ ๋ชจ๋ธ)
  • learning_rate=0.1: ํ•™์Šต๋ฅ  ์„ค์ •
  • loss_function='MAE': Mean Absolute Error(MAE)๋ฅผ ์†์‹ค ํ•จ์ˆ˜๋กœ ์‚ฌ์šฉ

6. ๊ฒฐ๊ณผ ๋ถ„์„


# ๋ชจ๋ธ ํ‰๊ฐ€ ํ•จ์ˆ˜
def evaluate_model(y_true, y_pred, model_name="Model"):
    mae = mean_absolute_error(y_true, y_pred)
    mse = mean_squared_error(y_true, y_pred)
    r2 = r2_score(y_true, y_pred)
    print(f"{model_name} Performance:")
    print(f"Mean Absolute Error (MAE): {mae:.4f}")
    print(f"Mean Squared Error (MSE): {mse:.4f}")
    print(f"R² Score: {r2:.4f}\n")

evaluate_model(y_test, y_pred, "CatBoost")

# CatBoost Performance:
# Mean Absolute Error (MAE): 0.4268
# Mean Squared Error (MSE): 0.3330
# R² Score: 0.4904
 

์œ„ ๊ฒฐ๊ณผ์—์„œ MAE๊ฐ€ 0.43 ์ˆ˜์ค€์œผ๋กœ ๋‚˜์™”์Œ. ์ฆ‰, ํ‰๊ท ์ ์œผ๋กœ ์•ฝ 0.43์ ์˜ ์˜ค์ฐจ๋กœ ํ’ˆ์งˆ์„ ์˜ˆ์ธกํ•œ๋‹ค๋Š” ์˜๋ฏธ์ž„.


7. CatBoost vs. XGBoost


์ด์ „์— ์ง„ํ–‰ํ–ˆ๋˜ ๊ฒฐ๊ณผ์™€ ๋น„๊ตํ•˜๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์Œ.

Model MAE MSE R² Score
CatBoostRegressor 0.4268 0.3330 0.4904
XGBoost (Baseline) 0.4175 0.3513 0.4625
XGBoost (Scaled) 0.4175 0.3513 0.4625
XGBoost (Outliers Removed) 0.4383 0.3492 0.4656
XGBoost (Tuned) 0.4549 0.3506 0.4635

 

Catboost ๋ชจ๋ธ์ด MAE ๊ฐ’์€ ๊ฐ€์žฅ ๋‚ฎ์ง„ ์•Š์•˜์ง€๋งŒ, MSE ๊ฐ’์ด ๊ฐ€์žฅ ๋†’๊ณ  R² Score ๊ฐ’์ด ๊ฐ€์žฅ ๋†’์€๊ฒƒ์„ ํ™•์ธํ•จ.


8. ๊ฒฐ๋ก 


CatBoost๋Š” ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ๋ฅผ ์ตœ์†Œํ™”ํ•˜๋ฉด์„œ๋„ ๋†’์€ ์„ฑ๋Šฅ์„ ์ œ๊ณตํ•˜๋Š” ๊ฐ•๋ ฅํ•œ ๋จธ์‹ ๋Ÿฌ๋‹ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์ž„. ์ด๋ฒˆ ์‹คํ—˜์—์„œ๋Š” Wine Quality Dataset์„ ์˜ˆ์ธกํ•˜๋Š” ํšŒ๊ท€ ๋ชจ๋ธ์„ ํ•™์Šตํ•˜์˜€๊ณ , ๋ฒ ์ด์Šค๋ผ์ธ ๋ชจ๋ธ๋กœ๋„ ๋†’์€ ์„ฑ๋Šฅ์„ ๋‚ด๋Š”๊ฒƒ์„ ํ™•์ธํ•จ.

CatBoost์˜ ์žฅ์  ์ •๋ฆฌ

  • ๋ณ„๋„ ๋ฒ”์ฃผํ˜• ๋ฐ์ดํ„ฐ ๋ณ€ํ™˜ ์—†์ด ์‚ฌ์šฉ ๊ฐ€๋Šฅ
  • ๋น ๋ฅธ ํ•™์Šต ๋ฐ ์˜ˆ์ธก ์†๋„
  • ๊ธฐ๋ณธ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋กœ๋„ ์šฐ์ˆ˜ํ•œ ์„ฑ๋Šฅ
๋ฐ˜์‘ํ˜•