본문 바로가기
Broad AI without NLP/Education

[논문 리뷰] CRKT 논문 리뷰 - 코드편

by beeny-ds 2025. 1. 11.
지난 포스팅에서는 객관식 문제에 대해 다양한 Input 을 활용하여 Knowledge Tracing task 를 수행한 CRKT 논문에 대해 소개했다.

해당 논문은 github 코드가 굉장히 친절하게 짜여졌다.
본 포스팅에서는 CRKT 논문 저자가 개발한 코드에 대해 설명하고자 한다.

코드의 Flow 와 그 의미에 대해 상세히 다루도록 하겠다.

목차

1. CRKT github 소개

2. Preprocessed DBE_KT22 dataset info. 

3. Model Architecture 

4. Model Train

5. 필자 리뷰


 

1. CRKT github 소개

코드 파악을 위해 DBE_KT22 데이터를 활용하였다.

CRKT/data_preprocess.py 를 확인하면 데이터 별 전처리 코드가 잘 짜여져 있다.

필자는 DBE_KT22 데이터를 먼저 파악한 뒤, 어떻게 전처리하는지 코드를 분석하고 이를 활용하여 모델 Train 및 Inference 를 통해 Model Architecture 를 상세히 파악하였다.

 

다시 한번 귀한 코드를 배포해주신 박순욱님께 감사드린다.


 

2. Preprocessed DBE_KT22 dataset info. 

CRKT/data_preprocess.py 를 통해 전처리된 데이터의 정보 및 그 예시다.

Return Description example
question - 학생이 푼 문제 id 모음 [34, 35, 0, 3, 1, 55, … , 83, 8]
concept - 문제와 연관된 개념 id
- total length 는 학생이 푼 문제 id 들의 length 와 동일
- 각 index 의 값은 아래와 같이 채워짐
> 개념 id 외에는 max_concept_len(=200) 만큼 -1을 채워줌
> 각 index 의 length 는 max_concept_len 와 같음
[
[9, -1, -1, -1, … , -1],
# len=200
… ,
[5, 6, -1, -1, … , -1]
# len=200
]
score - 학생이 문제를 맞췄는지 틀렸는지 그 결과 [1, 0, 1, 1, 0, … , 0]
option - 학생이 문제의 몇 번을 선택했는지
> 객관식 문항 중 몇 번을 선택했는지와 같음
[2, 2, 2, 1, 2, … , 1]
unchosen - 학생이 선택하지 않은 문제 번호 중 랜덤한 하나
> 선택하지 않은 단 하나의 옵션을 기록
[1, 3, 0, 2, 0, … , 3]
pos_score - Contrastive Learning 학습을 위해 기존의 Score 를 변경
> 정답률이 낮은 기본 질문에 대해 정답으로 변경
(오답이 있으면 정답으로 변경)
[1. 0. 1. 1. 1. … . 0]
pos_option - Contrastive Learning 학습을 위해 기존의 Score 에 대한 option 을 변경
> 정답률이 낮은 기본 질문에 대해 정답 option 으로 변경
(오답이 있으면 정답 option 으로 변경)
[2, 2, 2, 1, 3, … , 1]
neg_score - Contrastive Learning 학습을 위해 기존의 Score 를 변경 [1. 0. 1. 0. 0. … . 0]
neg_option - Contrastive Learning 학습을 위해 기존의 Score 에 대한 option 을 변경 [2, 2, 2, 0, 2, … , 1]

 

  • concept 결과에 대한 해석
    • 문제와 연관된 개념 id
    • 개념 id 외에는 max_seq_length(=200) 만큼 -1을 채워줌
    • example: [5, 6, -1, -1, … , -1]
  • pos_option&score, neg_option&score 를 구하는 이유
    • 더보기
      • 문제 상황: 이러한 질문들은 KT 모델이 예측하기에 특히 어려운 문제로, 주로 KT 모델의 예측이 질문의 평균 정답률과 과적합되는 경향이 있음
      • 해결 방안: 정답을 맞춘 학생과 그렇지 않은 학생 사이의 Knowledge states 차이를 강화하는 것을 목표로 함
      • 학습 방법: Positive 지식 상태와 Negative 지식 상태를 도출하여 Contrastive Learning 을 통해 학생의 원래 지식 상태(=h_s)를 Positive 지식 상태에 가깝게, Negative 지식 상태에 멀게 정렬하도록 학습을 시킴
      • 학습 Detail: https://github.com/Soonwook34/CRKT/issues/1
        • 위 issue 에 가면 해당 사상에 대한 저자의 설명을 확인할 수 있다.

 

3. Model Architecture

먼저 데이터의 index 를 변환해주는 처리를 거친다.

변환해주는 method get_index 이다. 해당 method 는 output 만 소개하겠다.

이후 연산 시 자주 등장하기 때문에 상세히 살펴보기 바란다.

  • qt_idx
    • 더보기
      • Description
        • question 으로부터 생성
          • question: 학생이 푼 문제 sequence 를 숫자로 나열
        • pad_id 를 -1 에서 문제의 수만큼 변경
        • index 는 0부터 시작하기 때문에 0~question 수 까지로 표현되고, question 수는 pad_id 가 됨
        • shape: [128, 200]
          • 128: batch_size 로 학생의 수를 의미
          • 200: max_seq_len 로 학생 당 몇 문제 푼 history 를 몇개까지 볼지를 의미

       

  • c_qt_idx
    • 더보기
      • Description
        • concept 으로부터 생성
        • pad_id 를 -1 에서 concept 수만큼 변경
        • index 는 0부터 시작하기 때문에 0~concept 수 까지로 표현되고, concept 수는 pad_id 가 됨
        • shape: [128, 200, 4]
          • 128: batch_size 로 학생의 수를 의미
          • 200: max_seq_len 로 학생 당 몇 문제 푼 history 를 몇개까지 볼지를 의미
          • 4: 한 문제에서 연결될 수 있는 최대 concept 수가 4라서
  • ot_idx
    • 더보기
      • Description
        • option 값을 unique index 로 변경
        • 문제의 수가 212개, 옵션의 max 수가 5개이기 때문에 모든 문제를 대상으로 option 값을 unique 하게 계산한다면 212 * 5 = 1060 이 됨
        • 학생이 답변한 문제에 대한 답을 unique 하게 표현
        • pad_id 는 1060 이 됨 (index 는 0부터 시작하기 때문에 0~1059 까지가 문제의 답변이고, 1060 은 pad_id 가 됨)
        • shape: [128, 200]
          • 128: batch_size 로 학생의 수를 의미
          • 200: max_seq_len 로 학생 당 몇 문제 푼 history 를 몇개까지 볼지를 의미
  • ut_mask
    • 더보기
      • Description
        • 학생이 선택하지 않은 option 이 무엇인지 표현하기 위한 정보
        • 학생이 선택하지 않은 option 에 대해서는 1, 학생이 선택했거나 option 의 수가 부족할 경우에는 0으로 표현
        • shape: [128, 200, 5]
          • 128: batch_size 로 학생의 수를 의미
          • 200: max_seq_len 로 학생 당 몇 문제 푼 history 를 몇개까지 볼지를 의미
          • 5: max_option 수
  • ut_idx
    • 더보기
      • Description
        • ut_mask 에서 0에 대한 값을 1061 로 변환하여 pad_id 로 활용
        • ut_mask 에서 1에 대한 값을 0부터 sequence 하게 하나씩 늘려서 활용
          • ot_idx 의 index 를 제외하고 학생이 선택하지 않은 index 번호를 표현하기 위함
          • ot_idx 의 index 는 ut_idx 에 1061 로 무의미하게 표현
        • shape: [128, 200, 5]
          • 128: batch_size 로 학생의 수를 의미
          • 200: max_seq_len 로 학생 당 몇 문제 푼 history 를 몇개까지 볼지를 의미
          • 5: max_option 수

 

출처: CRKT 논문 Model Architecture

a. Disentangled Response Encoder

a. Disentangled Response Encoder 구조를 코드 Level 로 분석한 결과

  • input: ot_idx, score, ut_idx, ut_mask
  • 연산 과정
    • encode_disentangled_response method 를 참고하기 바란다.
    • 더보기
      1. ot_idx & ut_idx 를 response embedding 에 넣는다.
        • ot_idx shape: [128, 200] / ut_idx shape: [128, 200, 5]
        • ot: ot_idx 를 input 으로 넣은 output
          • shape: [128, 200, 32]
        • ut: ut_idx 를 input 으로 넣은 output
          • shape: [128, 200, 5, 32]
      2. ot_prime 계산
        • 학생이 선택한 option 에 대해 각각 정답, 오답에 대한 embedding만 살림
        • 학생의 선택이 정답인 경우, 정답인 embedding 만 표현, 오답인 경우, 오답인 embedding 만 표현
      3. ut_prime 계산
        • ut 를 layer 에 통과시킨 뒤, unchosen response 에 대해서만 embedding 을 살리고 chosen response 에 대해서는 삭제
      4. dt 계산
        • ot_prime - lambda*ut_prime
      5. dt_hat 계산
        • dt 를 query, key, value 로 사용하여 Attention layer 를 attention output 을 얻음
  • output: dt_hat
    • 더보기
      • Description
        • 학생이 문제를 맞춘 정보와 못맞춘 정보, 그리고 학생이 선택하지 않은 option 들에 대한 정보를 표현한 결과
        • 즉, 학생의 정오답 정보 + 정오답 시 선택하지 않은 option 에 대한 정보를 추출한 결과

 

b. Knowledge Retriever

b. Knowledge Retriever 구조를 코드 Level 로 분석한 결과

  • input: qt, dt_hat
  • 연산 과정
    • get_knowledge_state method 를 참고하기 바란다  
    • 더보기
      1. question 정보를 embedding 시켜 qt_emb 를 생성한다.
      2. qt_emb 를 CRKTLayer 에 넣어 attention output 을 구한다.
        • q,k,v 모두 qt_emb 사용
        • output: qt_hat
      3. qt_hat 과 dt_hat 을 CRKTLayer 에 넣어 attention output 을 구한다.
        • q,k: qt_hat
        • v: dt_hat
      4. qt_hat 과 ht 를 [:, :-1, :] 만 추출하여 output 을 얻는다.
        • 그래서 output dim 이 [128, 199, 32]
        • [:, :-1, :] 의미: 학생들이 푼 마지막 문제 sequence 에 대한 index 를 제거
  • output: qt_hat, ht
    • 더보기
      • Description
        • 학생이 문제를 푼 sequence 정보(=qt_hat)와 학생의 정오답 및 선택하지 않은 답변에 대한 정보(=dt_hat) 를 병합
        • 즉, 학생이 문제를 푼 sequence 정보를 좀 더 상세하게 어떤 문제를 맞췄는지, 틀렸는지, 그리고 어떤 option 을 선택하지 않았는지에 대한 정보를 merge 하여 represenation

 

c. Concept Map Encoder

c. Concept Map Encoder 구조 중 get_concept_mastery method 를 코드 Level 로 분석한 결과

  • input: ht
  • 연산 과정
    • get_concept_mastery method 를 참고하기 바란다.
    • 더보기
      1. ht 를 활용하여 ht_concept 을 구한다.
        • dim: [128, 199, #concept=93, 32]
        • ht 의 dim: [128, 199, 32]
        • #concept 만큼의 차원이 추가되는데 이는 동일한 값을 복붙한 것과 같음
      2. concept embedding weight 를 활용하여 빈 곽인 concept_batch 를 생성한다.
      3. ht 와 빈 곽인 ci_batch 를 concat 하여 concept encode 에 넣어 mt 를 만든다.
  • output: mt
    • 더보기
      • Description
        • concept map 구축을 위해 빈 곽인 concept map 을 1차 생성함
        • 1차 생성의 목표는 concept map 업데이트를 위해 빈 곽을 생성하는 것.
        • 빈 곽 생성 → 업데이트 를 반복하며 concept map 이 정교해짐
        • shape: [128, 199, 93, 32]

 

c. Concept Map Encoder 구조 중 get_concept_map method 를 코드 Level 로 분석한 결과

  • input: q_target
  • 연산 과정
    • get_concept_map method 를 참고하기 바란다.
    • 더보기
      1. qt 에서 첫 번째 문제에 대한 정보를 제외한 나머지 정보만 q_target 으로 저장
      2. [#concept x #concept] matrix 인 concept_map 을 {학생의 수 * (학생의 문제 풀이 이력 수-1)} 만큼 expand 하여 batch_adj 생성
        • shape: [25472, 93, 93]
          • 25472 개의 노드 간의 연결 matrix 생성
        • 모든 문제에 대한 정보에 KM 에 대한 정보를 추가하기 위함
      3. batch_adj 를 torch_geometric.utils 의 to_edge_index 함수에 input 으로 넣어 batch_edge_index & edge_weight 생성
        • 배치 처리된 그래프 데이터를 효율적으로 표현 (그래프 신경망에서 메모리 효율적인 계산을 가능하게 함)
        • batch_edge_index: 배치 당 시작 노드와 끝 노드의 연결 관계를 표현
          • shape: [3, 4381184]
          • 첫 번째 차원 의미: : 배치 인덱스, 시작 노드, 끝 노드를 나타냄
          • 두 번째 차원 의미: 총 edge 수
        • edge_weight: 각 edge 의 속성(가중치 등)을 나타냄
          • 4381184 ≒ 25472 * 93 * 93 * 0.02
          • 각 인접 행렬에서 약 2% 의 edge 가 존재함을 의미
          • 가중치는 초기에는 전부 1로 설정됨
      4. 학생마다 node 의 index 를 다르게 하기 위해 edge_index 생성
        • edge_index = batch_edge_index[1:] + (batch_index * self.num_c)
      5. q_target 을 edge 수 만큼 expand 하여 q_target_edge 생성
        • shape: [128, 199, 172, 32]
        • 172: edge 수
        • 각 학생의 문제 풀이 이력에 대한 각 embedding 을 복사하여 172개만큼 늘림
      6. concept_map 을 시작 노드와 끝 노드로 표현하는 cij_idx 생성
      7. cij_idx 를 concept_emb 에 넣어 concept_map 정보를 representation 하는 cij 생성
      8. cij 에서 시작 노드와 끝 노드를 concat 하는 cij_concat 생성
      9. cij_concat 을 batch_size, max_seq_len-1 만큼 expand 하는 cij_batch 생성
      10. q_target_edge 와 cij_batch 를 concat 한 뒤 enc_intensity layer 에 넣고, relu 를 태워 edge weight 를 생성하준 뒤 flatten 해줌
        • q_target_edge 인 학생의 문제 풀이 이력에 대한 embedding 과 cij_batch 인 시작 노드와 끝 노드 및 edge 관계에 대한 embedding 을 병합
        • flatten 시켜 병합한 embedding 을 쫙 펼쳐줘 1차원으로 바꿔줌
      11. num_nodes = self.batch_size * (self.seq_len - 1) * self.num_c 계산
        • self_loop 를 위해 node 의 개수를 구하여 add_remaining_self_loops 의 num_nodes 로 활용
      12. add_remaining_self_loops 를 통해 edge_index 와 edge_weight 를 구함
  • output: edge_index, edge_weight
    • 더보기
      • Description
        • 지식맵(node&edge) 정보와 학생의 문제 풀이 이력 정보를 merge 하여 node 정보와 edge 에 대한 weight 정보를 정의하여 표현
        • edge_index: 각 학생의 시작 node 와 끝 node 를 각각 다른 index 로 표현
        • edge_weight: node 간의 edge weight 를 표현

 

c. Concept Map Encoder 구조 중 update_concept_mastery method 를 코드 Level 로 분석한 결과

  • input: mt, edge_index, edge_weight, c_target_idx
  • 연산 과정
    • update_concept_mastery method 를 참고하기 바란다.
    • 더보기
      1. c_qt_idx 로부터 첫 번째 index 를 제거한 나머지 199개만 추출하여 c_target_idx 를 생성
      2. c_target_idx 를 dim 을 #concept 만큼의 one-hot 으로 표현한 g_target 생성한 뒤 flatten 한 g_target_flat 생성
        • g_target shape: [128, 199, 93]
        • g_target_flat shape: [2368896, 1]
      3. mt 를 쫙 펼쳐 새로운 mt 생성
        • 기존 mt shape: [128, 199, 93, 32]
        • 새로운 mt shape: [2368896, 1]
      4. gcn 에 쫙 펼친 mt 와 edge_index, edge_weight 를 input 으로 넣어 mt_hat 생성
        • mt_hat shape: [2368896, 1]
        • gcn 이란: Graph Convolutional Network 로서 "Semi-supervised Classification with Graph Convolutional Networks" 논문에서 소개된 그래프 신경망 아키텍처
          • 노드 분류, 그래프 분류 등 다양한 그래프 관련 작업에 사용
          • 그래프의 구조적 정보와 노드 특성을 동시에 활용하여 학습을 수행
          • 각 레이어에서 이웃 노드들의 정보를 집계하고, 이를 통해 노드의 표현을 갱신
          • 이 과정을 통해 그래프의 전역적인 구조와 지역적인 특성을 효과적으로 학습
      5. mt_hat 을 batch, seq, concept 으로 차원을 나눈 새로운 mt_hat 생성
        • mt_hat shape: [128, 199, 93]
  • output: mt_hat, g_target
    • 더보기
      • Description
        • mt_hat: 학생이 문제를 푼 정보와 KM 사이의 관계를 더욱 심층적으로 표현하기 위해 GCN 을 통해 연산을 진행
          • shape: [128, 199, 93]
        • g_target: 학생이 푼 문제들의 개념을 one-hot 으로 표현
          • shape: [128, 199, 93]

 

d. IRT-based Prediction

d. IRT-based Prediction 구조 중 get_concept_weight method 를 코드 Level 로 분석한 결과

  •  
  • input: q_target
  • 연산 과정
    • get_concept_weight method  를 참고하기 바란다.
    • 더보기
      1. concept weight 인 ck 를 [128,199,93,32] dim 으로 expand 해서 ck_batch 를 생성한다.
      2. q_target 을 linear layer 에 넣고, unsqueeze(2) 하여 num_c 만큼 차원을 추가하는 expand 해서 q_target_batch 를 생성한다.
      3. ck_batch 와 q_target_batch 를 곱하고 마지막 차원을 하나의 값으로 더해서 r_target 을 생성한다.
      4. r_target 의 값을 sort 하여 가장 큰 값을 기준으로 Top-k 를 추출하여 topk_values 로 생성한다.
      5. topk_values 가 True 이면 topk_r_target 의 값을 그대로 사용하고 False 이면 -1e32 값으로 변경하여 topk_r_target 으로 업데이트 한다.
      6. 5번에서 업데이트한 topk_r_target 값을 softmax 하여 topk_r_target 으로 업데이트 한다.
      7. 6번에서 업데이트한 topk_r_target 값을 topk_mask 가 False 이면 0으로, True 이면 값을 그대로 사용하는 topk_r_target 으로 업데이트 한다.
      8. 3번에서 생성한 r_target 값을 sigmoid 하여 r_target 으로 업데이트한다.
  • output: r_target, topk_r_target
    • 더보기
      • Description
        • r_target: 모든 학생의 첫 번째 문제를 제외한 나머지 문제들을 개념 정보와 연관성을 계산
          • 문제와 개념 간 연관성을 embedding 하여 표현
          • shape: [128, 199, 93]
        • topk_r_target: 모든 학생의 첫 번째 문제를 제외한 나머지 문제들이 개념들 중 가장 연관성이 큰 top-k개(ex. 10개)를 추출
          • 문제와 개념 간 연관성 정도를 sort 하여 top-k 만 사용 (나머지는 0으로 변환)
          • shape: [128, 199, 93]

 

d. IRT-based Prediction 구조 중 predict method 를 코드 Level 로 분석한 결과

  • input: mt_hat, topk_r_target, q_target
  • 연산 과정
    • predict method 를 참고하기 바란다.
    • 더보기
       
      1. r_target 과 mt_hat 을 곱한 뒤 마지막 차원을 합해줘서 학생의 능력인 ability 를 생성한다.
      2. 첫 번째 문제를 제외한 나머지 문제들의 정보를 mlp_diff layer 에 넣어 문제의 난이도 정보인 difficulty 를 생성한다.
      3. 학생의 능력인 ability 에서 문제의 난이도인 difficulty 를 빼준 뒤 sigmoid 에 넣어줘 최종 결과물인 output 을 생성한다.
  • output: output
    • 더보기
      • Description
        • 학생 별 첫 문제를 제외하고 2번째 문제부터 학생이 해당 문제를 맞출 확률이 output 이 됨
        • output 을 구하기 위해 문제 별 학생의 능력에서 해당 문제의 난이도를 빼준 값을 활용하였음
        • shape: [128, 199]

 

4. Model Train

Loss 는 아래와 같이 계산한다.

출처: CRKT 논문

a. Loss for KT 계산 목적

 

  • 학습 목적
    • 더보기
      • 학생이 한 문제를 풀었다고 가정했을 때부터 2번째 문제부터 학생의 어떤 문제 및 개념에 대한 성취 수준을 측정하기 위한 목적으로 계산
      • 학생이 어떤 문제 및 개념에 대한 성취 수준을 학습하기 위해 활용
      • BCELoss 활용
  • loss cal target: output&score
    • 더보기
      • Description
        • 학생 별 첫 문제를 제외하고 2번째 문제부터 학생이 해당 문제를 맞출 확률이 output 이 됨
        • output 을 구하기 위해 문제 별 학생의 능력에서 해당 문제의 난이도를 빼준 값을 활용하였음
        • shape: [128, 199] 를 쫙 펼쳐 [14983, 1] 로 변환
          • 변환 시 max_seq_len 에서 pad_id 에 해당되는 부분은 버림

 

b. Loss for topK 계산 목적

  • 학습 목적
    • 더보기
      • 4-D. IRT-based Prediction 에서 구한 r_target 과 4-C. Concept Map Encoder 에서 구한 g_target 의 값을 최대한 가깝게 만들어주는 방향으로 학습이 진행됨
      • 문제와 개념관의 관계를 바르게 예측하기 위한 방식의 학습
        • 이는 곧 문제와 top_k 개념을 추출할 때 올바르게 식별되도록 유도하기 위한 목적으로 활용됨
        • 양성 Label 과 음성 Label 간 imbalance 를 해결하기 위해 양성 Label 에 가중치를 적용
          • 가중치: (concept 수 - top_k) / top_k ex) (93-10)/10 = 8.3
      • BCEWithLogitsLoss 활용
  • loss cal target: r_target, g_target
    • 더보기
      r_target
      • 모든 학생의 첫 번째 문제를 제외한 나머지 문제들을 개념 정보와 연관성을 계산
        • 문제와 개념 간 연관성을 embedding 하여 표현
        • shape: [128, 199, 93] 를 쫙 펼쳐 [1393419, 1] 로 변환
          • 변환 시 max_seq_len 에서 pad_id 에 해당되는 부분은 버림

      g_target
      • 학생이 푼 문제들의 개념을 one-hot 으로 표현
        • shape: [128, 199, 93] 를 쫙 펼쳐 [1393419, 1] 로 변환
          • 변환 시 max_seq_len 에서 pad_id 에 해당되는 부분은 버림

 

c. Loss for CL 계산 목적

  • 학습 목적
    • 더보기
      • Description
        • 목적: problem-solving history (interaction sequence)로부터 knowledge state를 더 잘 추출하기 위함
          • KT 모델의 어떤 문제에 대한 예측 accuracy가 문제의 평균 정답률에 영향을 받는 일종의 overfit 문제를 해결하기 위함
          • 예를 들어, 어떤 문제의 평균 정답률이 지나치게 높을 경우 모델은 어떠한 problem-solving history (interaction sequence)가 들어와도 그 문제를 정답률에 따라 높은 확률로 맞춘다고 예측하는 경향
        • 모델의 예측 성능이 떨어지는 구간인 평균 정답률 40-60% 에서 성능 개선을 위해 도입
        • 평균 정답률 40-60%인 문제들을 맞춘/틀린 학생의 knowledge state를 더 잘 구분하여 예측 정확도를 높이고자 하였음
          • Item Response Theory (IRT)에 따라 숙련도가 높으면 어려운(난이도가 높은, 평균 정답률이 낮은) 문제를 더 잘 푼다.
      • 가정
        • 평균 정답률 40-60%의 base question을 맞춘 학생은 base question보다 쉬운 문제를 틀린 가상의 학생(negative sample)의 지식 상태보다, base question보다 어려운 문제를 맞춘 가상의 학생(positive sample)의 지식 상태에 더 가까울 것이다.
        • 평균 정답률 40-60%의 base question을 틀린 학생은 base question보다 어려운 문제를 맞춘 가상의 학생(negative sample)의 지식 상태보다, base question보다 쉬운 문제를 틀린 가상의 학생(negative sample)의 지식 상태에 더 가까울 것이다.
        • Contrative Learning의 학습 대상이 모델의 final output인 y_hat이 아닌 지식 상태 h이므로, 위와 같이 정오답 결과를 변경해 학습한 모델이 지식 상태의 미묘한 차이를 좀 더 잘 구분지을 수 있을 것이다.
      • 방법
        • interaction sequence로부터 knowledge state를 추출할 때, 그 문제를 맞춘 학생의 knowledge state를 더 좋게(mastery가 높게), 틀린 학생의 knowledge state는 더 안좋게(mastery가 낮게) 추출
          • interaction sequence에서 positive, negative sample을 생성해 학습
          • knowledge state는 개념들의 숙련도로 표현되며, 각 값이 클수록 숙련도가 높고 작을수록 숙련도가 낮다는 의미를 갖는다.
        • 이를 구현 시, 같은 concept set에 대한 base question이 여러개 존재할 수 있으므로 가장 최근에 푼 문제가 base question이 되어 target_info에 저장
        • 과정
          • 학생이 푼 문제 중 평균 정답률 40-60%의 문제들을 base questions으로 설정
          • 학생이 푼 문제의 정오답 결과를 Positive-sample, Negative-sample 로 구성하기 위한 변경
            • 학생이 푼 문제의 정오답 결과를 Positive-sample, Negative-sample 로 구성하기 위한 변경
              • base question을 맞췄을 경우, base quenstion과 관련있는(같은 concept set을 공유하는) 문제의 interaction은 아래와 같이 변경
                • 문제의 평균 정답률이 base question의 평균 정답률보다 낮은 경우, 그 문제의 score를 1로 바꿔 positive sample을 생성
                • 문제의 평균 정답률이 base question의 평균 정답률보다 높은 경우, 그 문제의 score를 0으로 바꿔 negative sample을 생성
              • base question을 틀렸을 경우, base quenstion과 관련있는 문제의 interaction은 아래와 같이 변경
                • 문제의 평균 정답률이 base question의 평균 정답률보다 낮은 경우, 그 문제의 score를 0로 바꿔 positive sample을 생성
                • 문제의 평균 정답률이 base question의 평균 정답률보다 높은 경우, 그 문제의 score를 1으로 바꿔 negative sample을 생성
            • 변경을 쉽게 풀어쓰기
              • Positive sample
                • 이렇게 해석할 수 있음
                  • 어려운 문제를 맞춘 학생은 그 문제를 틀린 학생보다 문제와 관련된 개념에 대한 숙련도가 높다.
                • base question 맞춘 경우: 더 어려운 문제를 정답으로 변경
                • base question 틀린 경우: 더 어려운 문제를 오답으로 변경
              • Negative sample
                • 이렇게 해석할 수 있음
                  • 어려운 문제를 틀린 학생은 그 문제를 맞춘 학생보다 문제와 관련된 개념에 대한 숙련도가 낮다.
                • base question 맞춘 경우: 더 쉬운 문제를 오답으로 변경
                • base question 틀린 경우: 더 쉬운 문제를 정답으로 변경
  • 연산 과정
    • 더보기
      1. 데이터 전처리 시 생성한 Positive option & Positive score 및 Negative option & Negative score 를 get_index 메소드에 넣어 모델 학습에 사용할 수 있도록 처리한다.
      2. 1번에서 생성한 데이터를 모델 추론과 동일한 방식으로 값을 추출한다.
        1. encode_disentangled_response 메소드(4-a)에 넣어 Positive sample 및 Negative sample 에 대한 representation 을 생성한다.
          • representation 의미: 학생이 문제를 맞춘 정보와 못맞춘 정보, 그리고 학생이 선택하지 않은 option 들에 대한 정보를 표현한 결과
        2. get_knowledge_state 메소드(4-b)에 넣어 Positive sample 및 Negative sample 에 대한 representation 을 생성한다.
          • representation 의미: 학생이 문제를 푼 sequence 정보를 좀 더 상세하게 어떤 문제를 맞췄는지, 틀렸는지, 그리고 어떤 option 을 선택하지 않았는지에 대한 정보를 merge 하여 표현한 결과
            • 학생이 문제를 푼 sequence 정보(=qt_hat)와 학생의 정오답 및 선택하지 않은 답변에 대한 정보(=dt_hat) 를 병합
      3. 학생의 지식 수준 표현인 2-b의 결과와 seq_mask 를 곱한 뒤, 학생이 문제를 푼 부분만을 대상으로 평균을 구한다.
        • 학생 별 평균 지식 수준을 표현하기 위함
        • 본래 학생의 평균 지식 수준과 Positive sample 및 Negative sample 에 대한 평균 지식 수준을 각각 구한다.
        • output 은 다음과 같음
          • pooled_score, pooled_pos_score, pooled_neg_score shape: [128, 28]
      4. 본래 학생의 평균 지식 수준과 각각의 Positive sample 및 Negative sample 에 대한 평균 지식 수준의 유사도를 구한다.
        • cosine 유사도 사용
        • output 은 다음과 같음 (shape: [128, 128])
          • pos_cossim: 본래 학생과 가상의 Positive sample 학생의 유사도
          • neg_cossim: 본래 학생과 가상의 negative sample 학생의 유사도
      5. 128x128 dim 의 단위 행렬(identity matrix)을 생성한 뒤, 단위행렬과 neg_cossim 을 더해준다.
        • 자기 자신에 대한 유사도에 1씩 더해주는 것과 동일
      6. d와 e에서 구해준 pos_cossim 과 neg_cossim 을 concat 한다.
        • output dim: [128, 256]
      7. 0부터 batch_size 까지인 학생 수까지를 0부터 1씩 sequence 하게 늘려줘 inter_label 을 생성한다.
      8. 구해진 [128, 256] 차원의 inter_cossim 과 inter_label 간의 CrossEntropyLoss 를 구한다.
        • CrossEntropyLoss 란
          • 다중 클래스 분류 문제에서 널리 사용되는 손실 함수
          • 소프트맥스 활성화와 음의 로그 우도 손실(Negative Log-Likelihood Loss)을 결합한 것
        • 특징
          • 내부적으로 소프트맥스 함수를 적용합니다.
          • 클래스 레이블은 정수로 인코딩되어야 합니다.
          • 모델의 원시 출력(로짓)을 직접 받아 처리합니다.
          • 모델이 정답 클래스에 높은 확률을 할당하도록 학습을 유도
        • 의미
          • pred [128, 256]: 128개의 샘플에 대해 각각 256개 클래스의 예측 점수(로짓)를 나타냅니다.
          • label : 각 샘플의 실제 클래스 레이블(0-255 사이의 정수)을 나타냅니다.
        • CrossEntropyLoss는 각 샘플에 대해:
          • pred에 소프트맥스를 적용하여 확률 분포로 변환합니다.
          • 실제 레이블에 해당하는 예측 확률의 음의 로그를 취합니다.
          • 모든 샘플에 대해 평균을 계산합니다.
  • loss cal target: inter_cossim&inter_label
    • 더보기
      inter_cossim
      • 본래 학생의 평균 지식 수준과 각각의 Positive sample 및 Negative sample 에 대한 평균 지식 수준의 유사도를 concat 한 결과
      • shape: [128, 256]

      inter_label
      • 0부터 batch_size 까지인 학생 수까지를 0부터 1씩 sequence 하게 늘려준 결과
      • shape: [128]

 

5. 필자 리뷰

가장 먼저 극찬하고 싶은 바는 CRKT 논문은 `잘` 쓰여졌다는 점이다.

여기서 `잘` 쓰여졌다는 의미는 논문의 내용을 보고 어떻게 구현했는지 상상되도록 구체화되었다는 뜻이다.

필자는 `잘` 쓰여진 논문이란 논문의 내용만 보고 어떻게 구현해야 할지 알 수 있는 논문이라 생각한다.

그런 의미에서 CRKT 논문은 `잘` 쓰여졌다.

 

또한 CRKT git 코드 또한 매우 `잘` 짜여졌다.

여기서 `잘` 짜여졌다는 의미는 논문의 내용과 매핑하여 파악하기 용이한 형태로 구조화되었다는 뜻이다.

각 연산 단계 별 차원의 변화와 각 단계의 의미 또한 주석으로 상세하게 달려있다.

 

하지만 단점 또한 존재한다.

  1. 주관식 문제에 대한 지식 측정이 제한된다.
  2. Contrastive Learning 을 통한 지식 상태 표현 성능 향상은 코드와 이론 Level 이 다른 것 같다.
  3.  RQ3 의 접근 방법에 대한 내용이 부족하다. ('왜' 는 이해 되었으나 '어떻게' 가 부족)
    • 이 부분도 상당히 큰 Task 중 하나로 판단되기 때문에 아예 빼거나 Future task 로 남겨뒀으면 어땠을까? 싶다.

 

이러한 단점에도 불구하고 EduTech 의 꽃(?)인 KT 모델의 발전에 큰 기여를 했다고 생각한다.


 

마무리,,

지금까지 CRKT 논문의 이론 뿐만 아니라 코드까지 상세하게 알아봤다.
이 글을 읽는 독자가 만약 EduTech 에 종사하고 있고, KT 모델을 개발하고자 한다면 CRKT 논문은 꼭 상세하게 읽고 코드로 재현해보는걸 추천한다.

다음 포스팅은 뤼이드에서 투고한 논문인 SAINT 에 대해 다룰 예정이다.

SAINT 는 모델 구조 특성 상 SAINT+ 와 동일한 사상을 갖기 때문에 SAINT & SAINT+ 를 파악하고 있는 독자라면 꼭 확인해보는걸 추천한다.

 

참고로 필자가 이전에 포스팅한 SAINT 이론 내용보다 더 상세한 내용을 소개할 예정이다.

배포된 코드들을 조합하여 SAINT 모델을 재현(≒구현)했기 때문에 상세히 다룰 수 있게 되었다.

반응형

댓글