@@ -70,6 +70,40 @@ def __call__(self, outputs: torch.Tensor,
7070 pred_label_ids = max_idx .numpy ().tolist ()
7171 pred_label_scores = max_value .numpy ().tolist ()
7272
73+ # inference process do not have item in gt_label,
74+ # so select valid token with word_ids rather than
75+ # with gt_label_ids like official code.
76+ pred_words_biolabels = []
77+ word_biolabels = []
78+ pre_word_id = None
79+ for idx , cur_word_id in enumerate (word_ids ):
80+ if cur_word_id is not None :
81+ if cur_word_id != pre_word_id :
82+ if word_biolabels :
83+ pred_words_biolabels .append (word_biolabels )
84+ word_biolabels = []
85+ word_biolabels .append ((self .id2biolabel [pred_label_ids [idx ]],
86+ pred_label_scores [idx ]))
87+ else :
88+ pred_words_biolabels .append (word_biolabels )
89+ break
90+ pre_word_id = cur_word_id
91+ # record pred_label
92+ if self .only_label_first_subword :
93+ pred_label = LabelData ()
94+ pred_label .item = [
95+ pred_word_biolabels [0 ][0 ]
96+ for pred_word_biolabels in pred_words_biolabels
97+ ]
98+ pred_label .score = [
99+ pred_word_biolabels [0 ][1 ]
100+ for pred_word_biolabels in pred_words_biolabels
101+ ]
102+ merged_data_sample .pred_label = pred_label
103+ else :
104+ raise NotImplementedError (
105+ 'The `only_label_first_subword=False` is not support yet.' )
106+
73107 # inference process do not have item in gt_label,
74108 # so select valid token with word_ids rather than
75109 # with gt_label_ids like official code.
0 commit comments