|
17 | 17 | import tensorflow.compat.v2 as tf |
18 | 18 |
|
19 | 19 | import itertools |
| 20 | +import math |
20 | 21 | import os |
21 | 22 | import random |
22 | 23 | import string |
@@ -955,33 +956,67 @@ def test_one_hot_output_shape(self): |
955 | 956 | outputs = layer(inputs) |
956 | 957 | self.assertAllEqual(outputs.shape.as_list(), [16, 2]) |
957 | 958 |
|
958 | | - def test_multi_hot_output_hard_maximum(self): |
959 | | - """Check binary output when pad_to_max_tokens=True.""" |
960 | | - vocab_data = ["earth", "wind", "and", "fire"] |
961 | | - input_array = np.array([["earth", "wind", "and", "fire", ""], |
962 | | - ["fire", "fire", "and", "earth", "michigan"]]) |
963 | | - expected_output = [ |
964 | | - [0, 1, 1, 1, 1, 0], |
965 | | - [1, 1, 0, 1, 1, 0], |
966 | | - ] |
| 959 | + @parameterized.product( |
| 960 | + sparse=[True, False], |
| 961 | + adapt=[True, False], |
| 962 | + pad_to_max=[True, False], |
| 963 | + mode=["multi_hot", "count", "tf_idf"], |
| 964 | + ) |
| 965 | + def test_binned_output(self, sparse, adapt, pad_to_max, mode): |
| 966 | + """Check "multi_hot", "count", and "tf_idf" output.""" |
| 967 | + # Adapt breaks ties with sort order. |
| 968 | + vocab_data = ["wind", "fire", "earth", "and"] |
| 969 | + # IDF weight for a term in 1 out of 1 document is log(1 + 1/2). |
| 970 | + idf_data = [math.log(1.5)] * 4 |
| 971 | + input_data = np.array([["and", "earth", "fire", "and", ""], |
| 972 | + ["michigan", "wind", "and", "ohio", ""]]) |
| 973 | + |
| 974 | + if mode == "count": |
| 975 | + expected_output = np.array([ |
| 976 | + [0, 0, 1, 1, 2], |
| 977 | + [2, 1, 0, 0, 1], |
| 978 | + ]) |
| 979 | + elif mode == "tf_idf": |
| 980 | + expected_output = np.array([ |
| 981 | + [0, 0, 1, 1, 2], |
| 982 | + [2, 1, 0, 0, 1], |
| 983 | + ]) * math.log(1.5) |
| 984 | + else: |
| 985 | + expected_output = np.array([ |
| 986 | + [0, 0, 1, 1, 1], |
| 987 | + [1, 1, 0, 0, 1], |
| 988 | + ]) |
| 989 | + expected_output_shape = [None, 5] |
| 990 | + if pad_to_max: |
| 991 | + expected_output = np.concatenate((expected_output, [[0], [0]]), axis=1) |
| 992 | + expected_output_shape = [None, 6] |
967 | 993 |
|
968 | | - input_data = keras.Input(shape=(None,), dtype=tf.string) |
| 994 | + inputs = keras.Input(shape=(None,), dtype=tf.string) |
969 | 995 | layer = index_lookup.IndexLookup( |
970 | 996 | max_tokens=6, |
971 | 997 | num_oov_indices=1, |
972 | 998 | mask_token="", |
973 | 999 | oov_token="[OOV]", |
974 | | - output_mode=index_lookup.MULTI_HOT, |
975 | | - pad_to_max_tokens=True, |
| 1000 | + output_mode=mode, |
| 1001 | + pad_to_max_tokens=pad_to_max, |
| 1002 | + sparse=sparse, |
| 1003 | + vocabulary=None if adapt else vocab_data, |
| 1004 | + idf_weights=None if adapt or mode != "tf_idf" else idf_data, |
976 | 1005 | dtype=tf.string) |
977 | | - layer.set_vocabulary(vocab_data) |
978 | | - binary_data = layer(input_data) |
979 | | - model = keras.Model(inputs=input_data, outputs=binary_data) |
980 | | - output_dataset = model.predict(input_array) |
981 | | - self.assertAllEqual(expected_output, output_dataset) |
| 1006 | + if adapt: |
| 1007 | + layer.adapt(vocab_data) |
| 1008 | + outputs = layer(inputs) |
| 1009 | + model = keras.Model(inputs, outputs) |
| 1010 | + output_data = model.predict(input_data) |
| 1011 | + if sparse: |
| 1012 | + output_data = tf.sparse.to_dense(output_data) |
| 1013 | + # Check output data. |
| 1014 | + self.assertAllClose(expected_output, output_data) |
| 1015 | + # Check symbolic output shape. |
| 1016 | + self.assertAllEqual(expected_output_shape, outputs.shape.as_list()) |
982 | 1017 |
|
983 | 1018 | def test_multi_hot_output_no_oov(self): |
984 | | - """Check binary output when pad_to_max_tokens=True.""" |
| 1019 | + """Check multi hot output when num_oov_indices=0.""" |
985 | 1020 | vocab_data = ["earth", "wind", "and", "fire"] |
986 | 1021 | valid_input = np.array([["earth", "wind", "and", "fire"], |
987 | 1022 | ["fire", "and", "earth", ""]]) |
@@ -1050,188 +1085,6 @@ def test_multi_hot_output_hard_maximum_multiple_adapts(self): |
1050 | 1085 | self.assertAllEqual(first_expected_output, first_output) |
1051 | 1086 | self.assertAllEqual(second_expected_output, second_output) |
1052 | 1087 |
|
1053 | | - def test_multi_hot_output_soft_maximum(self): |
1054 | | - """Check multi_hot output when pad_to_max_tokens=False.""" |
1055 | | - vocab_data = ["earth", "wind", "and", "fire"] |
1056 | | - input_array = np.array([["earth", "wind", "and", "fire", ""], |
1057 | | - ["fire", "and", "earth", "michigan", ""]]) |
1058 | | - expected_output = [ |
1059 | | - [0, 1, 1, 1, 1], |
1060 | | - [1, 1, 0, 1, 1], |
1061 | | - ] |
1062 | | - |
1063 | | - input_data = keras.Input(shape=(None,), dtype=tf.string) |
1064 | | - layer = index_lookup.IndexLookup( |
1065 | | - max_tokens=None, |
1066 | | - num_oov_indices=1, |
1067 | | - mask_token="", |
1068 | | - oov_token="[OOV]", |
1069 | | - output_mode=index_lookup.MULTI_HOT, |
1070 | | - dtype=tf.string) |
1071 | | - layer.set_vocabulary(vocab_data) |
1072 | | - binary_data = layer(input_data) |
1073 | | - model = keras.Model(inputs=input_data, outputs=binary_data) |
1074 | | - output_dataset = model.predict(input_array) |
1075 | | - self.assertAllEqual(expected_output, output_dataset) |
1076 | | - |
1077 | | - def test_multi_hot_output_shape(self): |
1078 | | - input_data = keras.Input(batch_size=16, shape=(4,), dtype=tf.string) |
1079 | | - layer = index_lookup.IndexLookup( |
1080 | | - max_tokens=2, |
1081 | | - num_oov_indices=1, |
1082 | | - mask_token="", |
1083 | | - oov_token="[OOV]", |
1084 | | - output_mode=index_lookup.MULTI_HOT, |
1085 | | - vocabulary=["foo"], |
1086 | | - dtype=tf.string) |
1087 | | - binary_data = layer(input_data) |
1088 | | - self.assertAllEqual(binary_data.shape.as_list(), [16, 2]) |
1089 | | - |
1090 | | - def test_count_output_hard_maxiumum(self): |
1091 | | - """Check count output when pad_to_max_tokens=True.""" |
1092 | | - vocab_data = ["earth", "wind", "and", "fire"] |
1093 | | - input_array = np.array([["earth", "wind", "and", "wind", ""], |
1094 | | - ["fire", "fire", "fire", "michigan", ""]]) |
1095 | | - expected_output = [ |
1096 | | - [0, 1, 2, 1, 0, 0], |
1097 | | - [1, 0, 0, 0, 3, 0], |
1098 | | - ] |
1099 | | - |
1100 | | - input_data = keras.Input(shape=(None,), dtype=tf.string) |
1101 | | - layer = index_lookup.IndexLookup( |
1102 | | - max_tokens=6, |
1103 | | - num_oov_indices=1, |
1104 | | - mask_token="", |
1105 | | - oov_token="[OOV]", |
1106 | | - output_mode=index_lookup.COUNT, |
1107 | | - pad_to_max_tokens=True, |
1108 | | - dtype=tf.string) |
1109 | | - layer.set_vocabulary(vocab_data) |
1110 | | - count_data = layer(input_data) |
1111 | | - model = keras.Model(inputs=input_data, outputs=count_data) |
1112 | | - output_dataset = model.predict(input_array) |
1113 | | - self.assertAllEqual(expected_output, output_dataset) |
1114 | | - |
1115 | | - def test_count_output_soft_maximum(self): |
1116 | | - """Check count output when pad_to_max_tokens=False.""" |
1117 | | - vocab_data = ["earth", "wind", "and", "fire"] |
1118 | | - input_array = np.array([["earth", "wind", "and", "wind", ""], |
1119 | | - ["fire", "fire", "fire", "michigan", ""]]) |
1120 | | - expected_output = [ |
1121 | | - [0, 1, 2, 1, 0], |
1122 | | - [1, 0, 0, 0, 3], |
1123 | | - ] |
1124 | | - |
1125 | | - input_data = keras.Input(shape=(None,), dtype=tf.string) |
1126 | | - layer = index_lookup.IndexLookup( |
1127 | | - max_tokens=None, |
1128 | | - num_oov_indices=1, |
1129 | | - mask_token="", |
1130 | | - oov_token="[OOV]", |
1131 | | - output_mode=index_lookup.COUNT, |
1132 | | - dtype=tf.string) |
1133 | | - layer.set_vocabulary(vocab_data) |
1134 | | - count_data = layer(input_data) |
1135 | | - model = keras.Model(inputs=input_data, outputs=count_data) |
1136 | | - output_dataset = model.predict(input_array) |
1137 | | - self.assertAllEqual(expected_output, output_dataset) |
1138 | | - |
1139 | | - def test_count_output_shape(self): |
1140 | | - input_data = keras.Input(batch_size=16, shape=(4,), dtype=tf.string) |
1141 | | - layer = index_lookup.IndexLookup( |
1142 | | - max_tokens=2, |
1143 | | - num_oov_indices=1, |
1144 | | - mask_token="", |
1145 | | - oov_token="[OOV]", |
1146 | | - output_mode=index_lookup.COUNT, |
1147 | | - vocabulary=["foo"], |
1148 | | - dtype=tf.string) |
1149 | | - count_data = layer(input_data) |
1150 | | - self.assertAllEqual(count_data.shape.as_list(), [16, 2]) |
1151 | | - |
1152 | | - @parameterized.named_parameters( |
1153 | | - ("sparse", True), |
1154 | | - ("dense", False), |
1155 | | - ) |
1156 | | - def test_ifidf_output_hard_maximum(self, sparse): |
1157 | | - """Check tf-idf output when pad_to_max_tokens=True.""" |
1158 | | - vocab_data = ["earth", "wind", "and", "fire"] |
1159 | | - # OOV idf weight (bucket 0) should 0.5, the average of passed weights. |
1160 | | - idf_weights = [.4, .25, .75, .6] |
1161 | | - input_array = np.array([["earth", "wind", "and", "earth", ""], |
1162 | | - ["ohio", "fire", "earth", "michigan", ""]]) |
1163 | | - expected_output = [ |
1164 | | - [0.00, 0.80, 0.25, 0.75, 0.00, 0.00], |
1165 | | - [1.00, 0.40, 0.00, 0.00, 0.60, 0.00], |
1166 | | - ] |
1167 | | - |
1168 | | - input_data = keras.Input(shape=(None,), dtype=tf.string) |
1169 | | - layer = index_lookup.IndexLookup( |
1170 | | - max_tokens=6, |
1171 | | - num_oov_indices=1, |
1172 | | - mask_token="", |
1173 | | - oov_token="[OOV]", |
1174 | | - output_mode=index_lookup.TF_IDF, |
1175 | | - pad_to_max_tokens=True, |
1176 | | - dtype=tf.string, |
1177 | | - sparse=sparse, |
1178 | | - vocabulary=vocab_data, |
1179 | | - idf_weights=idf_weights) |
1180 | | - layer_output = layer(input_data) |
1181 | | - model = keras.Model(inputs=input_data, outputs=layer_output) |
1182 | | - output_dataset = model.predict(input_array) |
1183 | | - if sparse: |
1184 | | - output_dataset = tf.sparse.to_dense(output_dataset) |
1185 | | - self.assertAllClose(expected_output, output_dataset) |
1186 | | - |
1187 | | - @parameterized.named_parameters( |
1188 | | - ("sparse", True), |
1189 | | - ("dense", False), |
1190 | | - ) |
1191 | | - def test_ifidf_output_soft_maximum(self, sparse): |
1192 | | - """Check tf-idf output when pad_to_max_tokens=False.""" |
1193 | | - vocab_data = ["earth", "wind", "and", "fire"] |
1194 | | - # OOV idf weight (bucket 0) should 0.5, the average of passed weights. |
1195 | | - idf_weights = [.4, .25, .75, .6] |
1196 | | - input_array = np.array([["earth", "wind", "and", "earth", ""], |
1197 | | - ["ohio", "fire", "earth", "michigan", ""]]) |
1198 | | - expected_output = [ |
1199 | | - [0.00, 0.80, 0.25, 0.75, 0.00], |
1200 | | - [1.00, 0.40, 0.00, 0.00, 0.60], |
1201 | | - ] |
1202 | | - |
1203 | | - input_data = keras.Input(shape=(None,), dtype=tf.string) |
1204 | | - layer = index_lookup.IndexLookup( |
1205 | | - max_tokens=None, |
1206 | | - num_oov_indices=1, |
1207 | | - mask_token="", |
1208 | | - oov_token="[OOV]", |
1209 | | - output_mode=index_lookup.TF_IDF, |
1210 | | - dtype=tf.string, |
1211 | | - sparse=sparse, |
1212 | | - vocabulary=vocab_data, |
1213 | | - idf_weights=idf_weights) |
1214 | | - layer_output = layer(input_data) |
1215 | | - model = keras.Model(inputs=input_data, outputs=layer_output) |
1216 | | - output_dataset = model.predict(input_array) |
1217 | | - if sparse: |
1218 | | - output_dataset = tf.sparse.to_dense(output_dataset) |
1219 | | - self.assertAllClose(expected_output, output_dataset) |
1220 | | - |
1221 | | - def test_ifidf_output_shape(self): |
1222 | | - input_data = keras.Input(batch_size=16, shape=(4,), dtype=tf.string) |
1223 | | - layer = index_lookup.IndexLookup( |
1224 | | - max_tokens=2, |
1225 | | - num_oov_indices=1, |
1226 | | - mask_token="", |
1227 | | - oov_token="[OOV]", |
1228 | | - output_mode=index_lookup.TF_IDF, |
1229 | | - dtype=tf.string, |
1230 | | - vocabulary=["foo"], |
1231 | | - idf_weights=[1.0]) |
1232 | | - layer_output = layer(input_data) |
1233 | | - self.assertAllEqual(layer_output.shape.as_list(), [16, 2]) |
1234 | | - |
1235 | 1088 | def test_int_output_file_vocab(self): |
1236 | 1089 | vocab_data = ["earth", "wind", "and", "fire"] |
1237 | 1090 | input_array = np.array([["earth", "wind", "and", "fire"], |
|
0 commit comments