PyTorch는 직관적인 신경망 구현이 가능한 프레임워크입니다. 특히 `nn.Module`을 활용하면 커스텀 신경망 구조를 자유롭게 구성할 수 있어, 실무나 연구 환경에서 매우 유용하게 쓰입니다. 이번 글에서는 PyTorch의 핵심 클래스인 `nn.Module`을 이용해 나만의 딥러닝 모델을 만드는 방법을 소개하겠습니다.
● nn.Module이란?
`nn.Module`은 PyTorch에서 신경망의 모든 레이어와 연산을 담는 기반 클래스입니다. 이 클래스를 상속받아 모델을 정의하면, 학습에 필요한 파라미터를 자동으로 추적하고 GPU로 손쉽게 이동시키는 등 많은 기능을 누릴 수 있습니다.
기본적인 구조는 다음과 같습니다.
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
# 레이어 정의
self.fc1 = nn.Linear(100, 64)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(64, 10)
def forward(self, x):
# 순전파 정의
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
이처럼 `__init__()`에서는 사용할 레이어들을 정의하고, `forward()`에서는 실제 데이터를 어떻게 변환할지 구성합니다.
● 모델 인스턴스화 및 학습 준비
정의한 모델 클래스를 불러와 인스턴스를 생성한 후, 옵티마이저와 손실 함수를 정의하면 학습을 시작할 수 있습니다.
model = MyModel()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
`model.parameters()`는 모델 내부의 학습 가능한 파라미터들을 모두 반환해 줍니다. 이를 옵티마이저에 넘겨야 가중치가 업데이트됩니다.
● forward 메서드의 역할
`forward()`는 모델이 호출될 때 자동으로 실행되는 메서드로, 데이터가 신경망을 통과하는 과정을 정의합니다. 예를 들어 입력 텐서를 넣으면 다음과 같이 처리됩니다.
inputs = torch.randn(32, 100) # 배치 크기 32, 입력 차원 100
outputs = model(inputs)
이때 `model(inputs)`를 호출하면 내부적으로 `forward()`가 실행되며 결과를 반환합니다. 즉, `model.forward(inputs)`를 직접 호출할 필요는 없습니다.
● 커스텀 모델 구성 팁
`nn.Module`을 사용할 때 다음 사항들을 고려하면 모델 구현이 더 안정적이고 확장성이 좋아집니다.
- 반드시 `super().__init__()` 호출로 부모 클래스 초기화
- 모든 레이어는 `__init__`에서 정의, 연산은 `forward`에서 수행
- 정의한 레이어는 self로 등록해야 파라미터로 추적됨
- 모델을 `.to("cuda")` 또는 `.cuda()`로 쉽게 GPU에 올릴 수 있음
● 실제 예시: 이미지 분류 모델
다음은 28x28 크기의 흑백 이미지를 분류하는 간단한 신경망입니다.
class ImageClassifier(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(28*28, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
이 모델은 MNIST와 같은 데이터셋에서 간단한 분류를 수행할 수 있으며, 손실 함수와 옵티마이저 설정 후 학습 루프를 돌리면 됩니다.
'AI > 기초_[머신러닝][딥러닝]' 카테고리의 다른 글
[인공지능][기초] 074. Fashion-MNIST 실습: CNN 구조 적용 (0) | 2025.07.24 |
---|---|
[인공지능][기초] 073. 손글씨 인식(MNIST) 모델 만들기 (TF & Torch 비교) (0) | 2025.07.24 |
[인공지능][기초] 071. PyTorch 설치와 기본 문법 익히기 (0) | 2025.07.24 |
[인공지능][기초] 070. Keras Sequential API로 모델 만들기 (0) | 2025.07.24 |
[인공지능][기초] 069. TensorFlow 설치와 기본 구조 이해하기 (0) | 2025.07.24 |