Decision Tree in Machine Learning
🟦 What is a Decision Tree?
A Decision Tree is a Supervised Machine Learning algorithm used for both:
- ✅ Classification (Predict Categories)
- ✅ Regression (Predict Numerical Values)
It works like a flowchart, where every question divides the dataset into smaller groups until a final decision is reached.
🌟 Definition
Decision Tree is a supervised machine learning algorithm that predicts the output by asking a sequence of questions. Each question splits the data into smaller subsets until a final prediction (leaf node) is obtained.
🌳 Real-Life Example
🎓 Student Scholarship Prediction
A university wants to decide whether a student is eligible for a scholarship.
Conditions
- CGPA
- Attendance
Decision Tree
CGPA ≥ 8?
/ \
Yes No
/ \
Attendance ≥ 85? Not Eligible
/ \
Yes No
| |
Eligible Not Eligible
Suppose a student has:
- CGPA = 8.5
- Attendance = 90%
Decision:
CGPA ≥ 8 → Yes
Attendance ≥ 85 → Yes
➡ Scholarship Eligible
🌳 Step-by-Step Working of Decision Tree
🟩 Step 1: Import Required Libraries
First, import the required libraries.
from sklearn.tree import DecisionTreeClassifier
Explanation
-
sklearn.treecontains the Decision Tree algorithm. -
DecisionTreeClassifier()is used for classification problems.
🟩 Step 2: Prepare the Dataset
Suppose we have the following training data.
| Age | Loan Approved |
|---|---|
| 22 | No |
| 25 | No |
| 35 | Yes |
| 40 | Yes |
| 28 | No |
| 45 | Yes |
Python Code
X = [
[22],
[25],
[35],
[40],
[28],
[45]
]
y = [
"No",
"No",
"Yes",
"Yes",
"No",
"Yes"
]
Explanation
-
X= Input Feature (Age) -
y= Output Label (Loan Approved)
🟩 Step 3: Create the Model
model = DecisionTreeClassifier()
Explanation
This creates an empty Decision Tree model.
Nothing is learned yet.
🟩 Step 4: Train the Model
model.fit(X, y)
Explanation
The fit() function trains the model using historical data.
During training, the Decision Tree learns patterns such as:
- Age ≤ 30 → Mostly "No"
- Age > 30 → Mostly "Yes"
The model builds a tree automatically.
🟩 Step 5: Predict New Data
Suppose a new applicant is 30 years old.
prediction = model.predict([[30]])
Explanation
The model follows the decision rules learned during training and predicts the class.
🟩 Step 6: Display the Result
print("Loan Approval =", prediction[0])
Sample Output
Loan Approval = No
🟩 Complete Program
from sklearn.tree import DecisionTreeClassifier
# Training Data
X = [
[22],
[25],
[35],
[40],
[28],
[45]
]
y = [
"No",
"No",
"Yes",
"Yes",
"No",
"Yes"
]
# Create Model
model = DecisionTreeClassifier()
# Train Model
model.fit(X, y)
# Test Data
prediction = model.predict([[30]])
# Output
print("Loan Approval =", prediction[0])
🌳 What Happens Internally?
Training Data
Age Loan
22 No
25 No
28 No
35 Yes
40 Yes
45 Yes
The algorithm searches for the best splitting point.
Possible split:
Age < 30 ?
If Yes
22 → No
25 → No
28 → No
If No
35 → Yes
40 → Yes
45 → Yes
The tree becomes
Age < 30?
/ \
Yes No
| |
No Yes
🌳 Decision Process
Suppose the input is
Age = 30
Is Age < 30?
No
↓
Loan Approved = Yes
Suppose
Age = 25
Is Age < 30?
Yes
↓
Loan Approved = No
🌳 Visual Representation
Root Node
Age < 30?
/ \
Yes No
No Yes
🌳 Important Functions
Create Model
model = DecisionTreeClassifier()
Creates a Decision Tree model.
Train Model
model.fit(X,y)
Learns patterns from data.
Predict
model.predict([[30]])
Predicts output for new data.
Accuracy
model.score(X,y)
Returns model accuracy.
Example
accuracy = model.score(X,y)
print("Accuracy =", accuracy)
Output
Accuracy = 1.0
🌳 Decision Tree Workflow
Training Data
│
▼
Create DecisionTreeClassifier
│
▼
Train Model using fit()
│
▼
Decision Tree is Built
│
▼
New Input Data
│
▼
predict()
│
▼
Final Prediction
🌳 Advantages
✔ Easy to understand
✔ Easy to visualize
✔ No feature scaling required
✔ Handles numerical and categorical data
✔ Works for Classification and Regression
🌳 Limitations
❌ Can overfit
❌ Sensitive to noisy data
❌ Large trees become complex
🌳 Applications
🏦 Loan Approval
🏥 Disease Prediction
📧 Spam Detection
🎓 Student Performance Prediction
🌾 Crop Classification
🚗 Insurance Risk Prediction
⭐ Interview / Viva Questions
Q1. What is Decision Tree?
A supervised machine learning algorithm that predicts outputs by splitting data into smaller subsets using decision rules.
Q2. Why is it called a Decision Tree?
Because it resembles a tree structure where each node represents a decision, each branch represents an outcome, and each leaf node represents the final prediction.
Q3. What does fit() do?
It trains the Decision Tree using the training dataset.
Q4. What does predict() do?
It predicts the output for new, unseen data based on the trained model.
Q5. Can Decision Trees solve both Classification and Regression problems?
Yes.
- DecisionTreeClassifier → Classification
- DecisionTreeRegressor → Regression
⭐ One-Line Revision
Decision Tree = Training Data → Split into Decision Rules → Build Tree → Predict Output
No comments:
Post a Comment