안녕하세요. 이번 포스팅은 ViT 코드 구현을 해보려고 합니다.
ViT에 대해서는 Transformer 포스팅에서 살짝 언급했었는데요, ViT는 이제 많은 vision task의 backbone으로 쓰이고 있습니다.
위의 대략적인 overview를 보면 image를 patch(or token)로 나누고 position encoding과 summation후에 Transformer Encoder를 거치게 됩니다.
그럼 우선 Image data를 patch로 나누는 코드를 살펴보겠습니다.
def generate_patch(patch_size, image):
num_channels = image.size(0) # Dataloader거친 후 Batch_size, Channel, W, H
patches_w = image.unfold(
1, patch_size, patch_size
) # dimension, size, step, w기준으로 patch나누는과정
patches = patches_w.unfold(2, patch_size, patch_size) # h기준으로 patch화
patches = patches.reshape(
num_channels, -1, patch_size, patch_size
) # 3, patch개수 , patchsize_w, patchsize_h
patches = patches.permute(
1, 0, 2, 3
).contiguous() # patch개수, 3, patchsize_w, patchsize_h
flatten_patches = patches.reshape(
patches.size(0), -1
) # patch개수, 3 x patch_w x patch_h
return flatten_patches
def __getitem__(self, index):
if self.training:
input_data = self.train_set[index]
label = self.labels[index]
else:
input_data = self.test_set[index]
label = None
if self.cfg.PATCH:
patched_input_data = generate_patch(self.cfg.PATCH.PATCH_SIZE, input_data)
return patched_input_data, label
__getitem__에서 input_data인 image데이터와( W x H x C )와 config파일에서 patch_size를 인자로 하여 generate_patch로 넘겨줍니다. 실제로는 generate_patch에서 patch화가 진행됩니다. patch로 나눌때 unfold라는 torch 기능을 사용합니다. (1,patch_size,patch_size)는 w방향으로 patch_size로 patch_size의 step으로 split하는 걸 의미합니다. patches_w.unfold(2, patch_size, patch_size) 단계를 거치면 x,y방향으로 모두 patch화를하게 되고 이를 처음의 channel을 기준으로 다시 reshape 하게 됩니다. 이를 최종 (patch개수, 채널수 x patch_w x patch_h)의 형태로 flatten합니다.
*view, reshape, transpose, permute의 차이
view. vs reshape
x = torch.rand(2, 3, 4) # [2, 3, 4]
y = x.view(2, -1) # [2, 12]
x = torch.rand(2, 3, 4) # [2, 3, 4]
x = x.reshape(2, -1) # [2, 12]
우선 view와 reshape은 동일한 기능을 합니다. 차이는 view는 pointer처럼 수정시 원본 data도 변경이 되고 reshape은 copy 혹은 pointer를 반환한다고 합니다.(어떤걸 받을지는 모른다는 이야기)
permute vs transpose
transpose는 두차원을 서로 맞교환할 수 있고, permute는 모든 차원을 맞교환 할 수 있습니다.
x = torch.rand(16, 32, 3)
y = x.tranpose(0, 2) # [3, 32, 16]
z = x.permute(2, 1, 0) # [3, 32, 16]
또한 permute, transpose 둘다 모두 view와 같이 원본 data를 공유하는 포인터로 활용된다. 또한 permute, transpose와 view의 차이점은 view는 contiguous tensor에서만 작용할 수 있다는 것 입니다. 예를 들면 transpose를 contiguous하게 하려면 transpose().contiguous()를 불러와야 합니다. reshape은 contiguous와 상관없이 사용할 수 있습니다.
Contiguous란 메모리에 연속적인 array라는 뜻입니다. n차원의 array도 추상화되어있을 뿐 물리적으로는 일렬의 메모리에 저장되어있습니다.즉 coutiguous란 array에서 물리적으로 행방향으로 연속된 형태를 의미합니다.
다음으로 overview상의 linear projection을 보겠습니다.
generate_patch함수를 거쳐서 나온 patch data가 batch형태로 (patch개수, 3 x patch_w x patch_h) input으로 들어가게됩니다.
class EmbeddingMudule(nn.Module):
def __init__(
self, patch_vector_size, patches_num, latent_vector_dimension, drop_rate=0.1
):
super(EmbeddingMudule, self).__init__()
self.linear_projection = nn.Linear(patch_vector_size, latent_vector_dimension)
self.class_token = nn.Parameter(torch.randn(1, latent_vector_dimension))
self.positional_emdedding = nn.Parameter(
torch.randn(1, patches_num + 1, latent_vector_dimension)
)
self.dropout = nn.Dropout(drop_rate)
def forward(self, x):
# x : B, patch개수, 3 x patch_w x patch_h
# repeat == expand
# linear_proj : B, patch개수, latent_vector_dimension, cls_token : B, 1, 1
# B, patch개수 + 1, latent_vector_dimension
batch_size = x.size(0)
x = torch.cat(
[self.class_token.repeat(batch_size, 1, 1), self.linear_projection(x)],
dim=1,
)
# B, patch개수 + 1, latent_vector_dimension
x += self.positional_emdedding
x = self.dropout(x)
return x
x = torch.cat([self.cls_token.repeat(batch_size, 1, 1), self.linear_proj(x)], dim=1) 이 부분을 보면 input_data를linear_projection에 넣어 latent_vector_dimentsion으로 shape을 변경함과 동시에 mlp를 한번 태웁니다. class_token의 경우 랜덤하게 latent_vector_dimension만큼 파라미터를 생성한 후에 (batch_size, 1, latent_vector_dimentsion)로 shape을 expand하고 linear_projection의 (B, patch개수, latent_vector_dimension)과 concat하여 최종적으로 (B, patch개수+1, latent_vector_dimension)으로 shape이 만들어집니다.
이 후 positional_emdedding과정을 거칩니다. position encoding방법은 여러가지가 있고 보통 cos,sin 파형을 이용해서 하는데 여기서는 learnable parameter를 추가하는걸로 embedding을 진행하였습니다. (B, patch개수+1, latent_vector_dimension) 형식의 leanable parameter를 만든후에 이를 summation하였습니다.
*repeat, expand 차이
repeat의 경우 특정 tensor의 size 차원의 데이터를 반복합니다.
>>> x = torch.tensor([1, 2, 3])
>>> x.repeat(4, 2)
tensor([[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3]])
>>> x.repeat(4, 2, 1).size()
torch.Size([4, 2, 3])
예를 들면 [1,2,3]을 dim=0으로 4만큼 dim=1로 2만큼 반복하면, (4, 6)의 차원이 나오게 됩니다. 그리고 repeat은 원래의 data를 copy합니다.
expand의 경우는 마찬가지로 특정 tensor를 반복하여 생성하는데 개수가 1인 차원만 적용할 수 있습니다.
>>> x = torch.tensor([[1], [2], [3]])
>>> x.size()
torch.Size([3, 1])
>>> x.expand(3, 4)
tensor([[ 1, 1, 1, 1],
[ 2, 2, 2, 2],
[ 3, 3, 3, 3]])
>>> x.expand(-1, 4) # -1 means not changing the size of that dimension
tensor([[ 1, 1, 1, 1],
[ 2, 2, 2, 2],
[ 3, 3, 3, 3]])
expand는 원본을 참조합니다.
다음단계는 Transformer Encoding입니다.
class TransformerEncoder(nn.Module):
def __init__(self, latent_vector_dim, head_num, mlp_hidden_dim, drop_rate=0.1):
super().__init__()
self.ln1 = nn.LayerNorm(latent_vector_dim)
self.ln2 = nn.LayerNorm(latent_vector_dim)
self.msa = MultiheadedSelfAttention(
latent_vector_dim=latent_vector_dim,
head_num=head_num,
drop_rate=drop_rate,
)
self.dropout = nn.Dropout(drop_rate)
self.mlp = nn.Sequential(
nn.Linear(latent_vector_dim, mlp_hidden_dim),
nn.GELU(),
nn.Dropout(drop_rate),
nn.Linear(mlp_hidden_dim, latent_vector_dim),
nn.Dropout(drop_rate),
)
def forward(self, x):
# B, patch개수 + 1, latent_vector_dim
z = self.ln1(x) # B, patch개수 + 1, latent_vector_dim
z, attention_vector = self.msa(
z
) # z : B, patch개수 + 1 ,latent_vector_dim(head_number x head_dimension)
z = self.dropout(z)
x = x + z
z = self.ln2(x)
z = self.mlp(z)
x = x + z # B, patch개수 + 1 ,latent_vector_dim
return x, attention_vector
Embedding 모듈 이후에 (B, patch개수 + 1, latent_vector_dim)의 shape의 data를 normalization을 태우고나서 multi-head-attention과정을 거치게 됩니다. mha은 아래 자세히 설명하겠습니다. 이 후 다시 normalization과정과 mlp과정을 거치고 나서 residual block처럼 원래의 input feature를 summation하는 것이 transformer encoding과정의 전부입니다.
*layer norm vs batch norm
layer norm은 batch에 어떤 크기의 데이터들이 있든 관계없이 샘플데이터 단위로 normalization을 시켜줍니다. batch norm은 mini-batch내부 feature값들의 mean,std로 normalization을 진행합니다.
class MultiheadedSelfAttention(nn.Module):
def __init__(self, latent_vector_dim, head_num, drop_rate, device):
super().__init__()
self.device = device
self.head_num = head_num
self.latent_vector_dim = latent_vector_dim
self.head_dim = int(latent_vector_dim / head_num)
self.query = nn.Linear(latent_vector_dim, latent_vector_dim)
self.key = nn.Linear(latent_vector_dim, latent_vector_dim)
self.value = nn.Linear(latent_vector_dim, latent_vector_dim)
self.dropout = nn.Dropout(drop_rate)
def forward(self, x):
# B, patch개수 + 1, latent_vector_dim
batch_size = x.size(0)
q = self.query(x) # q : B, patch개수 + 1, latent_vector_dim
k = self.key(x) # k : B, patch개수 + 1, latent_vector_dim
v = self.value(x) # v : B, patch개수 + 1, latent_vector_dim
q = q.view(batch_size, -1, self.head_num, self.head_dim).permute(
0, 2, 1, 3
) # B, patch개수 + 1 ,head number, head dimension -> B, head_number, patch개수 + 1, head_dimension(int(latent_vector_dim / num_heads)
k = k.view(batch_size, -1, self.head_num, self.head_dim).permute(
0, 2, 3, 1
) # B, patch개수 + 1 ,head number, head dimension -> B, head_number, head_dimension(int(latent_vector_dim / num_heads), patch개수 + 1
v = v.view(batch_size, -1, self.head_num, self.head_dim).permute(
0, 2, 1, 3
) # B, patch개수 + 1 ,head number, head dimension -> B, head_number, patch개수 + 1, head_dimension(int(latent_vector_dim / num_heads)
attention = torch.softmax(
q @ k / torch.sqrt(self.head_dim * torch.ones(1)), dim=-1
) # B, head_number, patch개수 + 1, patch개수 + 1
x = (
self.dropout(attention) @ v
) # B, head_number, patch개수 + 1, patch개수 + 1 x B, head_number, patch개수 + 1, head_dimension => B, head_number, patch개수 + 1, head_dimension
x = x.permute(0, 2, 1, 3).reshape(
batch_size, -1, self.latent_vector_dim
) # B, patch개수 + 1, head_number, head_dimension -> B, patch개수 + 1 ,latent_vector_dim(head_number x head_dimension)
return x, attention
앞서 transformer encoder에서 msa로 (B, patch개수 + 1, latent_vector_dim) shape의 normalized된 data가 input으로 들어가는 것을 봤습니다.
이제 그림과 같은 self-attention의 과정을 거치게 됩니다. input data을 각각 learnable parameter와의 linear layer를 태워서 동일한 shape의 (B, patch개수 + 1, latent_vector_dim)를 query, key, value vector를 생성하고 최종적으로 아래와 같은 shape을 만듭니다.
q : B, head_number, patch개수 + 1, head_dimension
k : B, head_number, head_dimension, patch개수 + 1
v : B, head_number, patch개수 + 1, head_dimension
이 후 softmax(q,kT / root(D))*v 을 계산하여 최종적으로 B, patch개수 + 1, head_number, head_dimension 가 나오게 되고 이를 다시 B, patch개수 + 1 ,latent_vector_dim(head_number x head_dimension) 형태로 reshape합니다.
처음의 shape과 동일합니다. self-attention은 input shape과 동일한 모습입니다.
mlp 헤드 단계입니다. transformer encoder에서 나온 output을 단순히 norm과 linear layer를 태우는 것이 전부입니다. 그래서 최종적으로는 class만큼의 layer가 output shape으로 나오게 됩니다.
self.mlp_head = nn.Sequential(
nn.LayerNorm(latent_vector_dim), nn.Linear(latent_vector_dim, num_classes)
)
여기까지 ViT코드 설명이였습니다. 전체 코드는 아래 github에서 볼 수 있습니다.