딥러닝/파이토치
nn.ModuleList에 대해 알아보기
지웅쓰
2022. 12. 31. 00:47
기존에 제가 쓴 코드들을 다시 보는데 nn.ModuleList를 이용해 짠 코드들이 있더라고요.
뭔가 nn.Sequential처럼 묶어주는건 기억이 나는데 정확히 기억이 나지 않아 다시 한번 공부해보려고 합니다.
nn.ModuleList란
list 형태로 layer들을 묶어줄 수 있습니다. list형태로 묶여있기 때문에 layer들의 iterator를 생성하는 것이라 할 수 있다.
하지만 nn.Sequential과는 달리 forward method가 없기 때문에 바로 사용할 수가 없다.
이게 무슨 말이냐하면 이러한 에러가 발생한다.
modulelist = nn.ModuleList(
[nn.Linear(10, 10), nn.Linear(10, 10)]
)
input =torch.rand((10))
modulelist(input)
NotImplementedError: Module [ModuleList] is missing the required "forward" function
nn.Sequential로 묶었을 때는 바로 결과를 알 수 있다.
sequential = nn.Sequential(
nn.Linear(10, 10), nn.Linear(10, 10)
)
input =torch.rand((10))
sequential(input)
tensor([ 0.2838, 0.2909, 0.0314, 0.0563, 0.2876, 0.1190, 0.6015, 0.0277,
-0.0411, -0.3799], grad_fn=<AddBackward0>)
어떨 때, 왜 써야하는 것일까?
김진솔님의 블로그 여기서 쓰는 이유를 설명해주시는 걸 참고하였습니다.
우선 sequential안에 있는 module들은 다 connection들을 가지는 반면 ModuleList의 경우 각각의 module간의 connection이 없습니다.
그렇기 때문에 nn.Sequential을 수행할 때는 자동적으로 forward method가 실행되기 때문에 한 단위로 수행이 된다면
nn.ModuleList의 경우 그렇지 않고 for 문을 통해 각각을 수행시키는 것입니다.