๊ฐœ๋ฐœ Code/์ธ๊ณต์ง€๋Šฅ A.I.

[AI][CatBoost] CatBoost๋กœ Wine Quality ์˜ˆ์ธกํ•˜๊ธฐ

5hr1rnp 2025. 2. 14. 19:11
๋ฐ˜์‘ํ˜•

wine quality dataset โ˜ catboost

2025.01.23 - [๊ฐœ๋ฐœ Code/์ธ๊ณต์ง€๋Šฅ A.I.] - [Python][AI] ํƒ์ƒ‰์  ๋ฐ์ดํ„ฐ ๋ถ„์„(EDA) - ์™€์ธ ํ’ˆ์งˆ ๋ฐ์ดํ„ฐ์…‹ (Wine Quality Dataset) - 1

2025.01.24 - [๊ฐœ๋ฐœ Code/์ธ๊ณต์ง€๋Šฅ A.I.] - [Python][AI] ํƒ์ƒ‰์  ๋ฐ์ดํ„ฐ ๋ถ„์„(EDA) - ์™€์ธ ํ’ˆ์งˆ ๋ฐ์ดํ„ฐ์…‹ (Wine Quality Dataset) - 2

2025.02.04 - [๊ฐœ๋ฐœ Code/์ธ๊ณต์ง€๋Šฅ A.I.] - [Python][AI] ํƒ์ƒ‰์  ๋ฐ์ดํ„ฐ ๋ถ„์„(EDA) - ์™€์ธ ํ’ˆ์งˆ ๋ฐ์ดํ„ฐ์…‹ (Wine Quality Dataset) - 3

2025.02.04 - [๊ฐœ๋ฐœ Code/์ธ๊ณต์ง€๋Šฅ A.I.] - [Python][AI] ํƒ์ƒ‰์  ๋ฐ์ดํ„ฐ ๋ถ„์„(EDA) - ์™€์ธ ํ’ˆ์งˆ ๋ฐ์ดํ„ฐ์…‹ (Wine Quality Dataset) - 4

1. ๊ฐœ์š”


์™€์ธ์˜ ํ’ˆ์งˆ์„ ์˜ˆ์ธกํ•˜๋Š” ๊ฒƒ์€ ๋จธ์‹ ๋Ÿฌ๋‹์—์„œ ํ”ํžˆ ๋‹ค๋ฃจ๋Š” ๋ฌธ์ œ ์ค‘ ํ•˜๋‚˜์ด๋‹ค. Wine Quality Dataset์€ ์™€์ธ์˜ ํ™”ํ•™์  ํŠน์„ฑ์„ ๊ธฐ๋ฐ˜์œผ๋กœ ํ’ˆ์งˆ์„ ์˜ˆ์ธกํ•˜๋Š” ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ, ํšŒ๊ท€ ๋ฌธ์ œ์— ํ•ด๋‹นํ•œ๋‹ค. ์ด๋ฒˆ ๊ธ€์—์„œ๋Š” CatBoost๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์™€์ธ ํ’ˆ์งˆ์„ ์˜ˆ์ธกํ•˜๋Š” ๋ชจ๋ธ์„ ๋งŒ๋“ค์–ด ๋ณผ ๊ฒƒ์ด๋‹ค.


2. Wine Quality Dataset ์†Œ๊ฐœ


Wine Quality Dataset์€ UCI Machine Learning Repository์—์„œ ์ œ๊ณตํ•˜๋Š” ๊ณต๊ฐœ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ, ๋ ˆ๋“œ ์™€์ธ๊ณผ ํ™”์ดํŠธ ์™€์ธ์— ๋Œ€ํ•œ ํ’ˆ์งˆ ์ •๋ณด๋ฅผ ํฌํ•จํ•˜๊ณ  ์žˆ์Œ.

  • ํŠน์„ฑ(features): 11๊ฐœ์˜ ํ™”ํ•™์  ์„ฑ๋ถ„ (์˜ˆ: pH, ์•Œ์ฝ”์˜ฌ, ํ™ฉ์‚ฐ์—ผ ๋“ฑ)
  • ๋ ˆ์ด๋ธ”(label): 0~10๊นŒ์ง€์˜ ํ’ˆ์งˆ ์ ์ˆ˜ (ํšŒ๊ท€ ๋ฌธ์ œ)

์ด ๋ฐ์ดํ„ฐ์…‹์€ ์—ฌ๊ธฐ์—์„œ ๋‹ค์šด๋กœ๋“œ ๊ฐ€๋Šฅํ•จ.


3. CatBoost ์„ค์น˜


CatBoost๋Š” ๋‹ค์Œ ๋ช…๋ น์–ด๋กœ ๊ฐ„๋‹จํžˆ ์„ค์น˜ํ•  ์ˆ˜ ์žˆ์Œ.

 

# conda config --add channels conda-forge
# conda install catboost
pip install catboost

4. ๋ฐ์ดํ„ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ ๋ฐ ์ „์ฒ˜๋ฆฌ


Python์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ฐ์ดํ„ฐ์…‹์„ ๋ถˆ๋Ÿฌ์˜ค๊ณ , ์ „์ฒ˜๋ฆฌ๋ฅผ ์ง„ํ–‰ํ•จ.

 

import pandas as pd
from sklearn.model_selection import train_test_split
from catboost import CatBoostRegressor
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

# ๋ฐ์ดํ„ฐ ๋กœ๋“œ
# ๋˜๋Š” ๋””๋ ‰ํ† ๋ฆฌ ์œ„์น˜์— ๋งž๊ฒŒ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
df = pd.read_csv("https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv", sep=';')

# ์ž…๋ ฅ(X)์™€ ํƒ€๊ฒŸ(y) ๋ถ„๋ฆฌ
X = df.drop(columns=['quality'])
y = df['quality']

# ํ•™์Šต ๋ฐ์ดํ„ฐ์™€ ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ ๋ถ„๋ฆฌ
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

print(f"Train ๋ฐ์ดํ„ฐ ํฌ๊ธฐ: {X_train.shape}")
print(f"Test ๋ฐ์ดํ„ฐ ํฌ๊ธฐ: {X_test.shape}")

# Train ๋ฐ์ดํ„ฐ ํฌ๊ธฐ: (1279, 11)
# Test ๋ฐ์ดํ„ฐ ํฌ๊ธฐ: (320, 11)

728x90
๋ฐ˜์‘ํ˜•

5. CatBoost ๋ชจ๋ธ ํ•™์Šต


CatBoost๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์™€์ธ์˜ ํ’ˆ์งˆ์„ ์˜ˆ์ธกํ•˜๋Š” ํšŒ๊ท€ ๋ชจ๋ธ์„ ํ•™์Šตํ•จ.

# CatBoost ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
model = CatBoostRegressor(iterations=1000, depth=6, learning_rate=0.1, loss_function='MAE', verbose=200)

model.fit(X_train, y_train)

# ์˜ˆ์ธก
y_pred = model.predict(X_test)

ํŒŒ๋ผ๋ฏธํ„ฐ ์„ค๋ช…

  • iterations=1000: 1000๋ฒˆ ๋ฐ˜๋ณต ํ•™์Šต
  • depth=6: ํŠธ๋ฆฌ ๊นŠ์ด ์„ค์ • (๊ฐ’์ด ํด์ˆ˜๋ก ๋ณต์žกํ•œ ๋ชจ๋ธ)
  • learning_rate=0.1: ํ•™์Šต๋ฅ  ์„ค์ •
  • loss_function='MAE': Mean Absolute Error(MAE)๋ฅผ ์†์‹ค ํ•จ์ˆ˜๋กœ ์‚ฌ์šฉ

6. ๊ฒฐ๊ณผ ๋ถ„์„


# ๋ชจ๋ธ ํ‰๊ฐ€ ํ•จ์ˆ˜
def evaluate_model(y_true, y_pred, model_name="Model"):
    mae = mean_absolute_error(y_true, y_pred)
    mse = mean_squared_error(y_true, y_pred)
    r2 = r2_score(y_true, y_pred)
    print(f"{model_name} Performance:")
    print(f"Mean Absolute Error (MAE): {mae:.4f}")
    print(f"Mean Squared Error (MSE): {mse:.4f}")
    print(f"Rยฒ Score: {r2:.4f}\n")

evaluate_model(y_test, y_pred, "CatBoost")

# CatBoost Performance:
# Mean Absolute Error (MAE): 0.4268
# Mean Squared Error (MSE): 0.3330
# Rยฒ Score: 0.4904
 

์œ„ ๊ฒฐ๊ณผ์—์„œ MAE๊ฐ€ 0.43 ์ˆ˜์ค€์œผ๋กœ ๋‚˜์™”์Œ. ์ฆ‰, ํ‰๊ท ์ ์œผ๋กœ ์•ฝ 0.43์ ์˜ ์˜ค์ฐจ๋กœ ํ’ˆ์งˆ์„ ์˜ˆ์ธกํ•œ๋‹ค๋Š” ์˜๋ฏธ์ž„.


7. CatBoost vs. XGBoost


์ด์ „์— ์ง„ํ–‰ํ–ˆ๋˜ ๊ฒฐ๊ณผ์™€ ๋น„๊ตํ•˜๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์Œ.

Model MAE MSE Rยฒ Score
CatBoostRegressor 0.4268 0.3330 0.4904
XGBoost (Baseline) 0.4175 0.3513 0.4625
XGBoost (Scaled) 0.4175 0.3513 0.4625
XGBoost (Outliers Removed) 0.4383 0.3492 0.4656
XGBoost (Tuned) 0.4549 0.3506 0.4635

 

Catboost ๋ชจ๋ธ์ด MAE ๊ฐ’์€ ๊ฐ€์žฅ ๋‚ฎ์ง„ ์•Š์•˜์ง€๋งŒ, MSE ๊ฐ’์ด ๊ฐ€์žฅ ๋†’๊ณ  Rยฒ Score ๊ฐ’์ด ๊ฐ€์žฅ ๋†’์€๊ฒƒ์„ ํ™•์ธํ•จ.


8. ๊ฒฐ๋ก 


CatBoost๋Š” ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ๋ฅผ ์ตœ์†Œํ™”ํ•˜๋ฉด์„œ๋„ ๋†’์€ ์„ฑ๋Šฅ์„ ์ œ๊ณตํ•˜๋Š” ๊ฐ•๋ ฅํ•œ ๋จธ์‹ ๋Ÿฌ๋‹ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์ž„. ์ด๋ฒˆ ์‹คํ—˜์—์„œ๋Š” Wine Quality Dataset์„ ์˜ˆ์ธกํ•˜๋Š” ํšŒ๊ท€ ๋ชจ๋ธ์„ ํ•™์Šตํ•˜์˜€๊ณ , ๋ฒ ์ด์Šค๋ผ์ธ ๋ชจ๋ธ๋กœ๋„ ๋†’์€ ์„ฑ๋Šฅ์„ ๋‚ด๋Š”๊ฒƒ์„ ํ™•์ธํ•จ.

CatBoost์˜ ์žฅ์  ์ •๋ฆฌ

  • ๋ณ„๋„ ๋ฒ”์ฃผํ˜• ๋ฐ์ดํ„ฐ ๋ณ€ํ™˜ ์—†์ด ์‚ฌ์šฉ ๊ฐ€๋Šฅ
  • ๋น ๋ฅธ ํ•™์Šต ๋ฐ ์˜ˆ์ธก ์†๋„
  • ๊ธฐ๋ณธ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋กœ๋„ ์šฐ์ˆ˜ํ•œ ์„ฑ๋Šฅ
๋ฐ˜์‘ํ˜•