diff --git a/hands_on/pyanno_voting/pyanno/tests/test_voting.py b/hands_on/pyanno_voting/pyanno/tests/test_voting.py index f21a10c..7e89da1 100644 --- a/hands_on/pyanno_voting/pyanno/tests/test_voting.py +++ b/hands_on/pyanno_voting/pyanno/tests/test_voting.py @@ -3,8 +3,11 @@ from pyanno import voting from pyanno.voting import MISSING_VALUE as MV +from math import isclose + def test_labels_count(): + #given annotations = [ [1, 2, MV, MV], [MV, MV, 3, 3], @@ -13,10 +16,49 @@ def test_labels_count(): ] nclasses = 5 expected = [0, 3, 1, 3, 0] + + #when result = voting.labels_count(annotations, nclasses) + + #then assert result == expected +def test_labels_frequency(): + #given + matrix = [ + [1, 2, 2, -1], + [2, 2, 2, 2], + [1, 1, 3, 3], + [1, 3, 3, 2], + [-1, 2, 3, 1], + [-1, -1, -1, 3], + ] + + matrix2 = [ + [-1, -1, -1, -1], + [-1, -1, -1, -1] + ] + + lowerlimit = 0 + upperlimit = 1 + nclasses = 4 + + expected2 = np.zeros(nclasses) + + + #when + result = voting.labels_frequency(matrix, nclasses) + result2 = voting.labels_frequency(matrix2, nclasses) + + #then + assert np.all([res != None for res in result]) + assert len(result) == nclasses + assert np.all(result2 == expected2) + assert np.all([i >= lowerlimit and i <= upperlimit for i in result]) + assert isclose(np.sum(result),upperlimit) or isclose(np.sum(result), lowerlimit,abs_tol=1e-12) + + def test_majority_vote(): annotations = [ [1, 2, 2, MV], diff --git a/hands_on/pyanno_voting/pyanno/voting.py b/hands_on/pyanno_voting/pyanno/voting.py index d5b5747..3694cc7 100644 --- a/hands_on/pyanno_voting/pyanno/voting.py +++ b/hands_on/pyanno_voting/pyanno/voting.py @@ -100,3 +100,20 @@ def labels_frequency(annotations, nclasses): freq[k] is the frequency of elements of class k in `annotations`, i.e. their count over the number of total of observed (non-missing) elements """ + annotations_array = np.ravel(annotations) + result = np.zeros(nclasses) + + dim = 0 + for number in annotations_array: + if number != -1: + dim = dim + 1 + if dim !=0: + for cl in np.arange(nclasses): + aux = 0 + for anot in annotations_array: + if cl == anot: + aux = aux + 1 + result[cl] = aux / dim + return(result) + else: + return(0)