Hello Computer Vision

nn.ModuleList에 대해 알아보기 본문

딥러닝/파이토치

nn.ModuleList에 대해 알아보기

지웅쓰 2022. 12. 31. 00:47

기존에 제가 쓴 코드들을 다시 보는데 nn.ModuleList를 이용해 짠 코드들이 있더라고요.

뭔가 nn.Sequential처럼 묶어주는건 기억이 나는데 정확히 기억이 나지 않아 다시 한번 공부해보려고 합니다.

 

nn.ModuleList란

list 형태로 layer들을 묶어줄 수 있습니다. list형태로 묶여있기 때문에 layer들의 iterator를 생성하는 것이라 할 수 있다.

하지만 nn.Sequential과는 달리  forward method가 없기 때문에 바로 사용할 수가 없다.

이게 무슨 말이냐하면 이러한 에러가 발생한다.

modulelist = nn.ModuleList(
    [nn.Linear(10, 10), nn.Linear(10, 10)]
)

input =torch.rand((10))

modulelist(input)

NotImplementedError: Module [ModuleList] is missing the required "forward" function

 nn.Sequential로 묶었을 때는 바로 결과를 알 수 있다.

sequential = nn.Sequential(
    nn.Linear(10, 10), nn.Linear(10, 10)
)

input =torch.rand((10))

sequential(input)

tensor([ 0.2838,  0.2909,  0.0314,  0.0563,  0.2876,  0.1190,  0.6015,  0.0277,
        -0.0411, -0.3799], grad_fn=<AddBackward0>)

 

어떨 때, 왜 써야하는 것일까?

김진솔님의 블로그 여기서 쓰는 이유를 설명해주시는 걸 참고하였습니다.

우선 sequential안에 있는 module들은 다 connection들을 가지는 반면 ModuleList의 경우 각각의 module간의 connection이 없습니다.

그렇기 때문에 nn.Sequential을 수행할 때는 자동적으로 forward method가 실행되기 때문에 한 단위로 수행이 된다면

nn.ModuleList의 경우 그렇지 않고 for 문을 통해 각각을 수행시키는 것입니다.