Dr.Raum

Updated:

UNETR: Transformers for 3D Medical Image Segmentation

Vision Transformer

NLP에서의 Transformer가 1D sequence of input embeddings을 입력 받기 때문에 Vision 분야에서도 유사하게 2D, 3D입력을 1D로 변환해줄 필요가 있다. UNETR에서 제안하는 3D을 기준으로 설명하면 resolution $(H, W, D)$ Channel $C$의 3D input volume $\mathbf{x} \in \mathbb{R}^{H \times W \times D \times C}$ 을 Flatten 시켜서 $\mathbf{x}_v \in \mathbb{R}^{N \times (P^3 \cdot C)}$ 으로 만든다. 여기서 $(P, P, P)$는 image patch의 resolution이고, $N = (H \times W \times D) / P^3$ 은 Patch의 수, length of the sequence이다.

이렇게 Flatten한 Patch를 학습 가능한 Linear Projection($E$) ( $E \in \mathbb{R}^{(P^3 \cdot C) \times K}$ ) 을 통해 $K$ dimensional embedding space 으로 만든다. 이후 NLP 처럼 이미지의 위치 정보를 보존하기 위해 학습 가능한 1D positional embedding ($E_{pos} \in \mathbb{R}^{N \times K}$) 을 embedding결과에 더해준다.

\[z_0 = [\mathbf{x}_v^1E; \mathbf{x}_v^2E; ...; \mathbf{x}_v^NE] + E_{pos}\]

여기선 semantic segmentation이 목적이기 때문에 [CLS] Token은 사용되지 않는다.

embedding이 끝난 후 본격적으로 Transformer 네트워크를 통과한다. 여기서 부턴 NLP에서의 과정과 매우 유사하다.

\[z_i^{\prime} = MSA(Norm(z_{i-1})) + z_{i-1}, i=1...L\] \[z_i = MLP(Norm(z_i^{\prime})) + z_i^{\prime}, i=1...L\]

multilayer perceptron (MLP) 은 GELU activation를 사용하는 2개의 linear layers로 이루어져 있고, $L$은 transformer layers를 반복하는 수이다.

multi-head self-attention (MSA) 은 $n$ parallel self-attention(SA) heads로 이루여져 있으며 다음과 같이 계산한다. query, key, value는 입력 벡터를 3배의 dimension으로 늘린 뒤 각각을 query, key, value로 넣는다.

\[A = Softmax \left( \frac{qk^T}{\sqrt{K_h}} \right),   K_h=K/n\] \[SA(z) = Av\] \[MSA(z) = [SA_1(z); SA_2(z); ...; SA_n(z)]W_{msa}\]

$W_{msa} \in \mathbb{R}^{n \cdot K_h \times K}$ 는 학습가능한 파라미터 가중치이다.

Architecture

transformer를 활용한 다른 3D medical image segmentation 연구에서는 CNN을 feature extraction에 사용하고 Encoder와 Decoder를 잇는 bottleneck에 transformer를 두는 방식을 제안하였다. 하지만 본 논문에서는 transformer를 Encoder에 사용하고, 이를 바로 Decoder와 skip connections으로 연결하는 방법을 제안한다.
모델에서 Transformer를 Encoder할 때만 사용하고 Decoder할 때는 CNN-based를 사용하는 이유는 Transformer가 global information는 매우 잘 잡아내지만, localized information에 대해서는 부적합하기 때문이다.

U-net 구조와 유사하게 $\frac{H \times W \times D}{P^3} \times K$ 의 size를 갖는 transformer에서의 결과들 $z_i (i \in {3,6,9,12})$ 을 $\frac{H}{P} \times \frac{W}{P} \times \frac{D}{P} \times K$ 의 tensor로 reshape하여 decoder와 merge시킨다.

transformer의 마지막 layer 결과인 bottleneck에서는 deconvolutional layer를 통과시켜서 resolution을 2배 증가시킨다. 이 feature map을 이전 transformer output인 $z_9$의 feature map에 Deconv를 통과시킨 결과와 concat하고, $3 \times 3 \times 3$ conv에 통과시킨 뒤, 다시 upsample을 한다. 이러한 과정을 원래 input resolution까지 반복하고 최종적으로 $1 \times 1 \times 1$ conv와 softmax에 통과시켜 voxel-wise semantic predictions을 얻는다.

Loss Function

F1 Score

일반적인 Accuracy는 모델의 전체 예측 중에 맞은 예측의 비율을 말한다. 하지만 이는 분류 문제에서 클래스들의 분포가 균일하지 않을 때 부정확하다. 이러한 데이터 불균형을 문제를 해결하기 위해 F1 Score를 사용한다.

F1 Score는 Precision과 Recall을 사용하여 측정한다. Precision과 Recall은 각각 다음과 같이 구해진다.

\[Precision = \frac{TP}{TP + FP} ,  Recall = \frac{TP}{TP + FN}\]

Precision은 모델이 True라고 분류한 것 중에서 실제로 True인 것의 비율이다. 예를들면 암에 걸린 것을 target(positive)라 할 때 Precision은 모델이 암에 걸렸다고 예측한 사람들 중 실제 암에 걸린 사람이 몇 명인지를 나타내는 비율이다.

Recall은 positive인 것 중에서 모델이 positive라고 예측한 것의 비율이다. 예를들면 실제 암에 걸린 사람들 중에 모델이 암에 걸렸다고 판단한 사람들이 몇 명인지를 나타내는 비율이다.

이상적인 모델은 실제 positive에 대해 최대한 많은 positive를 찾아내고 모델이 찾아낸 positive 중 실제 positive인 값들이 많을수록 좋다. 즉 Precision과 Recall이 모두 높으면 좋지만 두 가지를 모두 높이는 건 힘들다.

F1 Score는 다음과 같이 Precision과 Recall로 구성되며 데이터가 불균형한 환경에서 잘 동작하는 평가지표가 된다.

\[f_1 score = \frac{2}{\frac{1}{recall} + \frac{1}{precision}} = 2 \cdot \frac{precision \cdot recall}{precision + recall}\]

이 처럼 조화평균으로 구해지는데, 그 이유는 Precision과 Recall중 더 작은 값에 영향을 많이 받게 하기 위함이다. 예를 들어 precision = 0.9 recall = 0.1 이라면 f1_score = 0.18이지만 precision = 0.6 recall = 0.4라면 f1_score = 0.48이다.

Dice Loss

dice loss는 데이터가 불균형적인 특징이 존재하는 semantic segmentation에서 많이 사용하는 loss 함수 이다. sm-segmentation 오픈라이브러리에서 dice loss(1 - Dice Coefficient) 는 1 - F1 Score로 계산되는데, 이는 Dice Coefficient가 F1 Score와 동일하기 때문이다.

$p$를 모델의 prediction $\hat{p}$를 ground truth라고 할 때, dice coefficient는 아래와 같다.

\[DSC = 2 \cdot \frac{\vert p \vert \cap \vert \hat{p} \vert}{\vert p \vert + \vert \hat{p} \vert}\]

앞서 precision과 recall을 $p$와 $\hat{p}$로 표현하면 다음과 같다.

\[precision = \frac{TP}{TP + FP} = \frac{\vert p \hat{p} \vert}{\vert p \vert} ,  recall = \frac{TP}{TP + FN} = \frac{\vert p \hat{p} \vert}{\vert \hat{p} \vert}\]

precison은 정의 상 전체 예측한 $\vert p \vert$ 중에 맞춘 것 $\vert p \hat{p} \vert$, recall은 정의 상 정답 $\vert \hat{p} \vert$ 중에 맞춘 것 $\vert p \hat{p} \vert$ 이므로 위와 같이 표현할 수 있다.

이 precision과 recall로 F1 Score를 계산하면 아래와 같이 DSC와 동일하다는 것을 확인할 수 있다.

\[f_1 score = 2 \cdot \frac{\vert p \hat{p} \vert}{\vert p \vert + \vert \hat{p} \vert} = 2 \cdot \frac{\vert p \vert \cap \vert \hat{p} \vert}{\vert p \vert + \vert \hat{p} \vert}\] \[s_v = 2 \cdot \frac{\vert p \cdot \hat{p} \vert}{\vert p \vert^2 + \vert \hat{p} \vert^2}\]

DSC를 벡터 변수에 대해 표현하면 $s_v$와 같고, 최종 Dice Loss는 아래와 같이 구해진다.

\[Dice \ Loss = 1 - 2 \cdot \frac{\sum p \cdot \hat{p}}{\sum p^2 + \sum \hat{p}^2}\]

UNETR Loss

본 논문에서는 Loss Function으로 dice loss와 cross-entropy loss를 사용한다.

\[\mathcal{L}(G, Y) = 1 - \frac{2}{J} \sum^J_{j=1} \frac{\sum^I_{i=1} G_{i,j}Y_{i,j}}{\sum^I_{i=1}G^2_{i,j} + \sum^I_{i=1}Y^2_{i,j}} - \frac{1}{I}\sum^I_{i=1}\sum^J_{j=1}G_{i,j} \log Y_{i,j}\]
  • $I$: number of voxels
  • $J$: number of classes
  • $Y_{i,j}$: probability output for class $j$ at voxel $i$
  • $G_{i,j}$: ground truth for class $j$ at voxel $i$


Project: Brain Tumor Segmentation using UNETR

Dataset

  1. Native (T1): T1 weighted image는 T2 weighted image보다 신호강도가 높아 해부학적 구조물을 좀 더 명확하게 구별 가능하다. T1영상에서 fluid는 어둡게 보이며, 낭종은 저신호강도로 검게, 피하지방조직은 고신호강도로 희게, 근육은 중신호강도를 보인다.
  2. Post-contrast T1-weighted (T1ce, also known as T1Gd): contrast agent (Gadolinium)를 T1 스캔 중에 주입하여 T1을 단축하고 신호 강도를 변경한다. Gad에 의해 강화된 이미지는 혈관 구조, 혈뇌 장벽 붕괴(종양, 농양, 염증)를 관찰하는데 유용하다.
  3. T2-weighted (T2): T2 weighted image는 fluid(e.g. CSF)가 매우 하얗게 보인다. 수분 함유량에 대체로 비례하여 낭종이 가장 하얗게 보이고, 이어서 부종이 동반된 조직, 정상 조직 순으로 보인다.
  4. T2-FLAIR (T2 - Fluid Attenuated Inversion Recovery): FLAIR는 T2에서 정상적인 CSF를 어둡게 나타낸다. 문제가 있는 CSF는 밝게 남아있으며 정상 CSF와 abnormality를 정말 쉽게 구별해낼 수 있다.

모델을 학습시킬 때는 T1 보다 T1ce가 더 유용하게 때문에 T1은 제외시킨다. T2는 밝게 보이는 fluid가 모델의 학습을 방해할 수 있으므로 제외하고 T2-FLAIR를 사용한다.

Images Format

사용할 data의 shape를 살펴보면 다음과 같다.

# Modality shape
print("Modality: ", test_image_t1.shape)

# Segmentation shape
print("Segmentation: ", test_image_seg.shape)

> Modality:  (240, 240, 155)

> Segmentation:  (240, 240, 155)

data는 3Dimage로 마지막 차원은 slices이다. (240, 240) 크기의 이미지를 155개 쌓아서 3D representation을 만들어낸다. 3개의 차원을 각각 width, height, depth라 하며 medical imaging에서는 slice 기준축에 따라 Y축 수평면(axial), Z축 관상면(coronal), X축 시상면(sagittal) 으로 나뉜다.

이러한 slice 평면은 segmentation task를 적절히 수행하는데 중요하다. 각각의 평면은 다른 관점을 보여주며, 해부학적 구조와 이상 부위를 특정하는데 필요하다.

slice = 95

print("Slice number: " + str(slice))

plt.figure(figsize=(12, 8))

# Apply a 90° rotation with an automatic resizing, otherwise the display is less obvious to analyze
# T1 - Transverse View
plt.subplot(1, 3, 1)
plt.imshow(test_image_t1ce[:,:,slice], cmap='gray')
plt.title('T1 - Axial View')

# T1 - Frontal View
plt.subplot(1, 3, 2)
plt.imshow(rotate(test_image_t1ce[:,slice,:], 90, resize=True), cmap='gray')
plt.title('T1 - Coronal View')

# T1 - Sagittal View
plt.subplot(1, 3, 3)
plt.imshow(rotate(test_image_t1ce[slice,:,:], 90, resize=True), cmap='gray')
plt.title('T1 - Sagittal View')
plt.show()

Annotations (labels)

# Isolation of class 0
seg_0 = test_image_seg.copy()
seg_0[seg_0 != 0] = np.nan

# Isolation of class 1
seg_1 = test_image_seg.copy()
seg_1[seg_1 != 1] = np.nan

# Isolation of class 2
seg_2 = test_image_seg.copy()
seg_2[seg_2 != 2] = np.nan

# Isolation of class 4
seg_4 = test_image_seg.copy()
seg_4[seg_4 != 4] = np.nan

  1. Label 0: Not Tumor (NT) volume
  2. Label 1: Necrotic and non-enhancing tumor core (NCR/NET)
  3. Label 2: Peritumoral edema (ED)
  4. Label 3: Missing (No pixels in all the volumes contain label 3)
  5. Label 4: GD-enhancing tumor (ET)

Label 3에는 pixel이 없기 때문에 Label 4로 처리한다.

Model Test

Dice Coefficient

# Train
Loss = 0.2376
Dice_Score = 0.7878

# Validation
Loss = 0.2280
Dice_Score = 0.7977

Git Repository

Dr.Raum

Pretrained Weights

UNETR 200 Epochs


Reference

Web

  1. Preprocess : https://www.kaggle.com/code/zeeshanlatif/brain-tumor-segmentation-using-u-net
  2. F1 Score: https://velog.io/@jadon/F1-score%EB%9E%80
  3. Dice Loss: https://attagungho.tistory.com/11#index
  4. UNETR : https://kimbg.tistory.com/33

Paper

  1. Transformer : https://arxiv.org/abs/1409.0473
  2. 3D U-net : https://arxiv.org/abs/1606.06650
  3. Vision Transformer : https://arxiv.org/abs/2010.11929
  4. UNETR : https://arxiv.org/abs/2103.10504

Leave a comment