|
8 | 8 |
|
9 | 9 | from detectron2 import model_zoo |
10 | 10 | from detectron2.config import get_cfg |
| 11 | +from detectron2.export.flatten import flatten_to_tuple |
11 | 12 | from detectron2.export.torchscript import dump_torchscript_IR, export_torchscript_with_instances |
12 | 13 | from detectron2.export.torchscript_patch import patch_builtin_len |
13 | 14 | from detectron2.layers import ShapeSpec |
@@ -83,77 +84,45 @@ def _test_retinanet_model(self, config_path): |
83 | 84 | class TestTracing(unittest.TestCase): |
84 | 85 | @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") |
85 | 86 | def testMaskRCNN(self): |
86 | | - class WrapModel(nn.ModuleList): |
87 | | - def forward(self, image): |
88 | | - inputs = [{"image": image}] |
89 | | - outputs = self[0].inference(inputs, do_postprocess=False)[0] |
90 | | - size = outputs.image_size |
91 | | - if torch.jit.is_tracing(): |
92 | | - assert isinstance(size, torch.Tensor) |
93 | | - else: |
94 | | - size = torch.as_tensor(size) |
95 | | - return ( |
96 | | - size, |
97 | | - outputs.pred_classes, |
98 | | - outputs.pred_boxes.tensor, |
99 | | - outputs.scores, |
100 | | - outputs.pred_masks, |
101 | | - ) |
102 | | - |
103 | | - @staticmethod |
104 | | - def convert_output(output): |
105 | | - r = Instances(tuple(output[0])) |
106 | | - r.pred_classes = output[1] |
107 | | - r.pred_boxes = Boxes(output[2]) |
108 | | - r.scores = output[3] |
109 | | - r.pred_masks = output[4] |
110 | | - return r |
| 87 | + def inference_func(model, image): |
| 88 | + inputs = [{"image": image}] |
| 89 | + outputs = model.inference(inputs, do_postprocess=False)[0] |
| 90 | + return outputs |
111 | 91 |
|
112 | | - self._test_model("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml", WrapModel) |
| 92 | + self._test_model("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml", inference_func) |
113 | 93 |
|
114 | 94 | @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") |
115 | 95 | def testRetinaNet(self): |
116 | | - class WrapModel(nn.ModuleList): |
117 | | - def forward(self, image): |
118 | | - inputs = [{"image": image}] |
119 | | - outputs = self[0].forward(inputs)[0]["instances"] |
120 | | - size = outputs.image_size |
121 | | - if torch.jit.is_tracing(): |
122 | | - assert isinstance(size, torch.Tensor) |
123 | | - else: |
124 | | - size = torch.as_tensor(size) |
125 | | - return ( |
126 | | - size, |
127 | | - outputs.pred_classes, |
128 | | - outputs.pred_boxes.tensor, |
129 | | - outputs.scores, |
130 | | - ) |
131 | | - |
132 | | - @staticmethod |
133 | | - def convert_output(output): |
134 | | - r = Instances(tuple(output[0])) |
135 | | - r.pred_classes = output[1] |
136 | | - r.pred_boxes = Boxes(output[2]) |
137 | | - r.scores = output[3] |
138 | | - return r |
| 96 | + def inference_func(model, image): |
| 97 | + return model.forward([{"image": image}])[0]["instances"] |
139 | 98 |
|
140 | | - self._test_model("COCO-Detection/retinanet_R_50_FPN_3x.yaml", WrapModel) |
| 99 | + self._test_model("COCO-Detection/retinanet_R_50_FPN_3x.yaml", inference_func) |
141 | 100 |
|
142 | | - def _test_model(self, config_path, WrapperCls): |
143 | | - # TODO wrapper should be handled by export API in the future |
| 101 | + def _test_model(self, config_path, inference_func): |
144 | 102 | model = model_zoo.get(config_path, trained=True) |
145 | 103 | image = get_sample_coco_image() |
146 | 104 |
|
147 | | - model = WrapperCls([model]) |
148 | | - model.eval() |
| 105 | + class Wrapper(nn.ModuleList): # a wrapper to make the model traceable |
| 106 | + def forward(self, image): |
| 107 | + outputs = inference_func(self[0], image) |
| 108 | + flattened_outputs, schema = flatten_to_tuple(outputs) |
| 109 | + if not hasattr(self, "schema"): |
| 110 | + self.schema = schema |
| 111 | + return flattened_outputs |
| 112 | + |
| 113 | + def rebuild(self, flattened_outputs): |
| 114 | + return self.schema(flattened_outputs) |
| 115 | + |
| 116 | + wrapper = Wrapper([model]) |
| 117 | + wrapper.eval() |
149 | 118 | with torch.no_grad(), patch_builtin_len(): |
150 | 119 | small_image = nn.functional.interpolate(image, scale_factor=0.5) |
151 | 120 | # trace with a different image, and the trace must still work |
152 | | - traced_model = torch.jit.trace(model, (small_image,)) |
| 121 | + traced_model = torch.jit.trace(wrapper, (small_image,)) |
153 | 122 |
|
154 | | - output = WrapperCls.convert_output(model(image)) |
155 | | - traced_output = WrapperCls.convert_output(traced_model(image)) |
156 | | - assert_instances_allclose(output, traced_output) |
| 123 | + output = inference_func(model, image) |
| 124 | + traced_output = wrapper.rebuild(traced_model(image)) |
| 125 | + assert_instances_allclose(output, traced_output, size_as_tensor=True) |
157 | 126 |
|
158 | 127 | def testKeypointHead(self): |
159 | 128 | class M(nn.Module): |
@@ -214,3 +183,22 @@ def forward(self, x): |
214 | 183 | for name in ["model_ts_code", "model_ts_IR", "model_ts_IR_inlined", "model"]: |
215 | 184 | fname = os.path.join(d, name + ".txt") |
216 | 185 | self.assertTrue(os.stat(fname).st_size > 0, fname) |
| 186 | + |
| 187 | + def test_flatten_basic(self): |
| 188 | + obj = [3, ([5, 6], {"name": [7, 9], "name2": 3})] |
| 189 | + res, schema = flatten_to_tuple(obj) |
| 190 | + self.assertEqual(res, (3, 5, 6, 7, 9, 3)) |
| 191 | + new_obj = schema(res) |
| 192 | + self.assertEqual(new_obj, obj) |
| 193 | + |
| 194 | + def test_flatten_instances_boxes(self): |
| 195 | + inst = Instances( |
| 196 | + torch.tensor([5, 8]), pred_masks=torch.tensor([3]), pred_boxes=Boxes(torch.ones((1, 4))) |
| 197 | + ) |
| 198 | + obj = [3, ([5, 6], inst)] |
| 199 | + res, schema = flatten_to_tuple(obj) |
| 200 | + self.assertEqual(res[:3], (3, 5, 6)) |
| 201 | + for r, expected in zip(res[3:], (inst.pred_boxes.tensor, inst.pred_masks, inst.image_size)): |
| 202 | + self.assertIs(r, expected) |
| 203 | + new_obj = schema(res) |
| 204 | + assert_instances_allclose(new_obj[1][1], inst, rtol=0.0, size_as_tensor=True) |
0 commit comments