diff --git a/hands_on/pyanno_voting/pyanno/tests/test_voting.py b/hands_on/pyanno_voting/pyanno/tests/test_voting.py index f21a10c..de7ae82 100644 --- a/hands_on/pyanno_voting/pyanno/tests/test_voting.py +++ b/hands_on/pyanno_voting/pyanno/tests/test_voting.py @@ -2,9 +2,11 @@ from pyanno import voting from pyanno.voting import MISSING_VALUE as MV +from numpy.testing import assert_array_almost_equal def test_labels_count(): + # Given annotations = [ [1, 2, MV, MV], [MV, MV, 3, 3], @@ -13,7 +15,11 @@ def test_labels_count(): ] nclasses = 5 expected = [0, 3, 1, 3, 0] + + # When result = voting.labels_count(annotations, nclasses) + + # Then assert result == expected @@ -41,3 +47,16 @@ def test_majority_vote_empty_item(): expected = [1, MV, 2] result = voting.majority_vote(annotations) assert result == expected + +def test_label_frequency(): + # Given + matrix = [[1, 1, 2], [-1, 1, 2]] + classes = 4 + expected_result = np.array([ 0. , 0.6, 0.4, 0. ]) + + # When + function_result = voting.labels_frequency([[1, 1, 2], [-1, 1, 2]], 4) + + # Then + assert_array_almost_equal(function_result, expected_result, decimal=6) + diff --git a/hands_on/pyanno_voting/pyanno/voting.py b/hands_on/pyanno_voting/pyanno/voting.py index d5b5747..571529b 100644 --- a/hands_on/pyanno_voting/pyanno/voting.py +++ b/hands_on/pyanno_voting/pyanno/voting.py @@ -100,3 +100,33 @@ def labels_frequency(annotations, nclasses): freq[k] is the frequency of elements of class k in `annotations`, i.e. their count over the number of total of observed (non-missing) elements """ + # Transform list of lists into a np array + annot = np.array(annotations) + + # Compute the frequency of each value in the matrix + unique, counts = np.unique(annot, return_counts=True) + + # Create a dictionary with the frequency values + freq = dict(zip(unique, counts)) + + # Good values + good_values = annot.size - freq[-1] + + + # Remove -1 freq from the dict + freq.pop(-1, None) + + + freq_list = [] + classes = list(range(nclasses)) + + for class_value in classes: + if class_value in freq.keys(): + freq_list.append(freq[class_value]/good_values) + + else: + freq_list.append(0) + + return np.array(freq_list) + +