1616class SERPostprocessor (nn .Module ):
1717 """PostProcessor for SER."""
1818
19- def __init__ (self , classes : Union [tuple , list ]) -> None :
19+ def __init__ (self ,
20+ classes : Union [tuple , list ],
21+ only_label_first_subword : bool = True ) -> None :
2022 super ().__init__ ()
2123 self .other_label_name = find_other_label_name_of_biolabel (classes )
2224 self .id2biolabel = self ._generate_id2biolabel_map (classes )
25+ assert only_label_first_subword is True , \
26+ 'Only support `only_label_first_subword=True` now.'
27+ self .only_label_first_subword = only_label_first_subword
2328 self .softmax = nn .Softmax (dim = - 1 )
2429
2530 def _generate_id2biolabel_map (self , classes : Union [tuple , list ]) -> Dict :
@@ -40,62 +45,95 @@ def _generate_id2biolabel_map(self, classes: Union[tuple, list]) -> Dict:
4045 def __call__ (self , outputs : torch .Tensor ,
4146 data_samples : Sequence [SERDataSample ]
4247 ) -> Sequence [SERDataSample ]:
43- # merge several truncation data_sample to one data_sample
4448 assert all ('truncation_word_ids' in d for d in data_samples ), \
4549 'The key `truncation_word_ids` should be specified' \
4650 'in PackSERInputs.'
47- truncation_word_ids = []
48- for data_sample in data_samples :
49- truncation_word_ids .append (data_sample .pop ('truncation_word_ids' ))
50- merged_data_sample = copy .deepcopy (data_samples [0 ])
51- merged_data_sample .set_metainfo (
52- dict (truncation_word_ids = truncation_word_ids ))
53- flattened_word_ids = [
54- word_id for word_ids in truncation_word_ids for word_id in word_ids
51+ truncation_word_ids = [
52+ data_sample .pop ('truncation_word_ids' )
53+ for data_sample in data_samples
54+ ]
55+ word_ids = [
56+ word_id for word_ids in truncation_word_ids
57+ for word_id in word_ids [1 :- 1 ]
5558 ]
5659
60+ # merge several truncation data_sample to one data_sample
61+ merged_data_sample = copy .deepcopy (data_samples [0 ])
62+
5763 # convert outputs dim from (truncation_num, max_length, label_num)
5864 # to (truncation_num * max_length, label_num)
5965 outputs = outputs .cpu ().detach ()
60- outputs = torch .reshape (outputs , (- 1 , outputs .size (- 1 )))
66+ outputs = torch .reshape (outputs [:, 1 : - 1 , :] , (- 1 , outputs .size (- 1 )))
6167 # get pred label ids/scores from outputs
6268 probs = self .softmax (outputs )
6369 max_value , max_idx = torch .max (probs , - 1 )
64- pred_label_ids = max_idx .numpy ()
65- pred_label_scores = max_value .numpy ()
70+ pred_label_ids = max_idx .numpy ().tolist ()
71+ pred_label_scores = max_value .numpy ().tolist ()
72+
73+ # inference process do not have item in gt_label,
74+ # so select valid token with word_ids rather than
75+ # with gt_label_ids like official code.
76+ pred_words_biolabels = []
77+ word_biolabels = []
78+ pre_word_id = None
79+ for idx , cur_word_id in enumerate (word_ids ):
80+ if cur_word_id is not None :
81+ if cur_word_id != pre_word_id :
82+ if word_biolabels :
83+ pred_words_biolabels .append (word_biolabels )
84+ word_biolabels = []
85+ word_biolabels .append ((self .id2biolabel [pred_label_ids [idx ]],
86+ pred_label_scores [idx ]))
87+ else :
88+ pred_words_biolabels .append (word_biolabels )
89+ break
90+ pre_word_id = cur_word_id
91+ # record pred_label
92+ if self .only_label_first_subword :
93+ pred_label = LabelData ()
94+ pred_label .item = [
95+ pred_word_biolabels [0 ][0 ]
96+ for pred_word_biolabels in pred_words_biolabels
97+ ]
98+ pred_label .score = [
99+ pred_word_biolabels [0 ][1 ]
100+ for pred_word_biolabels in pred_words_biolabels
101+ ]
102+ merged_data_sample .pred_label = pred_label
103+ else :
104+ raise NotImplementedError (
105+ 'The `only_label_first_subword=False` is not support yet.' )
66106
67107 # determine whether it is an inference process
68108 if 'item' in data_samples [0 ].gt_label :
69109 # merge gt label ids from data_samples
70110 gt_label_ids = [
71- data_sample .gt_label .item for data_sample in data_samples
111+ data_sample .gt_label .item [ 1 : - 1 ] for data_sample in data_samples
72112 ]
73113 gt_label_ids = torch .cat (
74- gt_label_ids , dim = 0 ).cpu ().detach ().numpy ()
75- gt_biolabels = [
76- self .id2biolabel [g ]
77- for (w , g ) in zip (flattened_word_ids , gt_label_ids )
78- if w is not None
79- ]
114+ gt_label_ids , dim = 0 ).cpu ().detach ().numpy ().tolist ()
115+ gt_words_biolabels = []
116+ word_biolabels = []
117+ pre_word_id = None
118+ for idx , cur_word_id in enumerate (word_ids ):
119+ if cur_word_id is not None :
120+ if cur_word_id != pre_word_id :
121+ if word_biolabels :
122+ gt_words_biolabels .append (word_biolabels )
123+ word_biolabels = []
124+ word_biolabels .append (self .id2biolabel [gt_label_ids [idx ]])
125+ else :
126+ gt_words_biolabels .append (word_biolabels )
127+ break
128+ pre_word_id = cur_word_id
80129 # update merged gt_label
81- merged_data_sample .gt_label .item = gt_biolabels
82-
83- # inference process do not have item in gt_label,
84- # so select valid token with flattened_word_ids
85- # rather than with gt_label_ids like official code.
86- pred_biolabels = [
87- self .id2biolabel [p ]
88- for (w , p ) in zip (flattened_word_ids , pred_label_ids )
89- if w is not None
90- ]
91- pred_biolabel_scores = [
92- s for (w , s ) in zip (flattened_word_ids , pred_label_scores )
93- if w is not None
94- ]
95- # record pred_label
96- pred_label = LabelData ()
97- pred_label .item = pred_biolabels
98- pred_label .score = pred_biolabel_scores
99- merged_data_sample .pred_label = pred_label
130+ if self .only_label_first_subword :
131+ merged_data_sample .gt_label .item = [
132+ gt_word_biolabels [0 ]
133+ for gt_word_biolabels in gt_words_biolabels
134+ ]
135+ else :
136+ raise NotImplementedError (
137+ 'The `only_label_first_subword=False` is not support yet.' )
100138
101139 return [merged_data_sample ]
0 commit comments