본문 바로가기
AI/기초_[머신러닝][딥러닝]

[인공지능][기초] 072. nn.Module로 커스텀 모델 만들기 (PyTorch)

by about_IT 2025. 7. 24.
728x90

PyTorch는 직관적인 신경망 구현이 가능한 프레임워크입니다. 특히 `nn.Module`을 활용하면 커스텀 신경망 구조를 자유롭게 구성할 수 있어, 실무나 연구 환경에서 매우 유용하게 쓰입니다. 이번 글에서는 PyTorch의 핵심 클래스인 `nn.Module`을 이용해 나만의 딥러닝 모델을 만드는 방법을 소개하겠습니다.


● nn.Module이란?

`nn.Module`은 PyTorch에서 신경망의 모든 레이어와 연산을 담는 기반 클래스입니다. 이 클래스를 상속받아 모델을 정의하면, 학습에 필요한 파라미터를 자동으로 추적하고 GPU로 손쉽게 이동시키는 등 많은 기능을 누릴 수 있습니다.

기본적인 구조는 다음과 같습니다.

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # 레이어 정의
        self.fc1 = nn.Linear(100, 64)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, 10)
    
    def forward(self, x):
        # 순전파 정의
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

이처럼 `__init__()`에서는 사용할 레이어들을 정의하고, `forward()`에서는 실제 데이터를 어떻게 변환할지 구성합니다.


● 모델 인스턴스화 및 학습 준비

정의한 모델 클래스를 불러와 인스턴스를 생성한 후, 옵티마이저와 손실 함수를 정의하면 학습을 시작할 수 있습니다.

model = MyModel()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

`model.parameters()`는 모델 내부의 학습 가능한 파라미터들을 모두 반환해 줍니다. 이를 옵티마이저에 넘겨야 가중치가 업데이트됩니다.


● forward 메서드의 역할

`forward()`는 모델이 호출될 때 자동으로 실행되는 메서드로, 데이터가 신경망을 통과하는 과정을 정의합니다. 예를 들어 입력 텐서를 넣으면 다음과 같이 처리됩니다.

inputs = torch.randn(32, 100)  # 배치 크기 32, 입력 차원 100
outputs = model(inputs)

이때 `model(inputs)`를 호출하면 내부적으로 `forward()`가 실행되며 결과를 반환합니다. 즉, `model.forward(inputs)`를 직접 호출할 필요는 없습니다.


● 커스텀 모델 구성 팁

`nn.Module`을 사용할 때 다음 사항들을 고려하면 모델 구현이 더 안정적이고 확장성이 좋아집니다.

  • 반드시 `super().__init__()` 호출로 부모 클래스 초기화
  • 모든 레이어는 `__init__`에서 정의, 연산은 `forward`에서 수행
  • 정의한 레이어는 self로 등록해야 파라미터로 추적됨
  • 모델을 `.to("cuda")` 또는 `.cuda()`로 쉽게 GPU에 올릴 수 있음

● 실제 예시: 이미지 분류 모델

다음은 28x28 크기의 흑백 이미지를 분류하는 간단한 신경망입니다.

class ImageClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

이 모델은 MNIST와 같은 데이터셋에서 간단한 분류를 수행할 수 있으며, 손실 함수와 옵티마이저 설정 후 학습 루프를 돌리면 됩니다.



728x90