Skip to content

Commit 31eed14

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
add flattening utilities to deal with tracing outputs
Reviewed By: theschnitz Differential Revision: D26282744 fbshipit-source-id: 1380fde412270dba9b166aab32092801e9e3906a
1 parent 222b64a commit 31eed14

File tree

3 files changed

+230
-62
lines changed

3 files changed

+230
-62
lines changed

detectron2/export/flatten.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import collections
2+
from dataclasses import dataclass
3+
from typing import List
4+
import torch
5+
6+
from detectron2.structures import Boxes, Instances
7+
8+
9+
@dataclass
10+
class Schema:
11+
"""
12+
A Schema defines how to flatten a possibly hierarchical object into tuple of
13+
primitive objects, so it can be used as inputs/outputs of PyTorch's tracing.
14+
15+
PyTorch does not support tracing a function that produces rich output
16+
structures (e.g. dict, Instances, Boxes). To trace such a function, we
17+
flatten the rich object into tuple of tensors, and return this tuple of tensors
18+
instead. Meanwhile, we also need to know how to "rebuild" the original object
19+
from the flattened results, so we can evaluate the flattened results.
20+
A Schema defines how to flatten an object, and while flattening it, it records
21+
necessary schemas so that the object can be rebuilt using the flattened outputs.
22+
23+
The flattened object and the schema object is returned by ``.flatten`` classmethod.
24+
Then the original object can be rebuilt with the ``__call__`` method of schema.
25+
26+
A Schema is a dataclass that can be serialized easily.
27+
"""
28+
29+
# inspired by FetchMapper in tensorflow/python/client/session.py
30+
31+
@classmethod
32+
def flatten(cls, obj):
33+
raise NotImplementedError
34+
35+
def __call__(self, values):
36+
raise NotImplementedError
37+
38+
@staticmethod
39+
def _concat(values):
40+
ret = ()
41+
idx_mapping = []
42+
for v in values:
43+
assert isinstance(v, tuple), "Flattened results must be a tuple"
44+
oldlen = len(ret)
45+
ret = ret + v
46+
idx_mapping.append([oldlen, len(ret)])
47+
return ret, idx_mapping
48+
49+
@staticmethod
50+
def _split(values, idx_mapping):
51+
if len(idx_mapping):
52+
expected_len = idx_mapping[-1][-1]
53+
assert (
54+
len(values) == expected_len
55+
), f"Values has length {len(values)} but expect length {expected_len}."
56+
ret = []
57+
for (start, end) in idx_mapping:
58+
ret.append(values[start:end])
59+
return ret
60+
61+
62+
@dataclass
63+
class ListSchema(Schema):
64+
schemas: List[Schema]
65+
idx_mapping: List[List[int]]
66+
is_tuple: bool
67+
68+
def __call__(self, values):
69+
values = self._split(values, self.idx_mapping)
70+
if len(values) != len(self.schemas):
71+
raise ValueError(
72+
f"Values has length {len(values)} but schemas " f"has length {len(self.schemas)}!"
73+
)
74+
values = [m(v) for m, v in zip(self.schemas, values)]
75+
return list(values) if not self.is_tuple else tuple(values)
76+
77+
@classmethod
78+
def flatten(cls, obj):
79+
is_tuple = isinstance(obj, tuple)
80+
res = [flatten_to_tuple(k) for k in obj]
81+
values, idx = cls._concat([k[0] for k in res])
82+
return values, cls([k[1] for k in res], idx, is_tuple)
83+
84+
85+
@dataclass
86+
class IdentitySchema(Schema):
87+
def __call__(self, values):
88+
return values[0]
89+
90+
@classmethod
91+
def flatten(cls, obj):
92+
return (obj,), cls()
93+
94+
95+
@dataclass
96+
class DictSchema(Schema):
97+
keys: List[str]
98+
value_schema: ListSchema
99+
100+
def __call__(self, values):
101+
values = self.value_schema(values)
102+
return dict(zip(self.keys, values))
103+
104+
@classmethod
105+
def flatten(cls, obj):
106+
for k in obj.keys():
107+
if not isinstance(k, str):
108+
raise KeyError("Only support flattening dictionaries if keys are str.")
109+
keys = sorted(obj.keys())
110+
values = [obj[k] for k in keys]
111+
ret, schema = ListSchema.flatten(values)
112+
return ret, cls(keys, schema)
113+
114+
115+
@dataclass
116+
class InstancesSchema(Schema):
117+
field_names: List[str]
118+
field_schema: ListSchema
119+
120+
def __call__(self, values):
121+
image_size, fields = values[-1], values[:-1]
122+
fields = self.field_schema(fields)
123+
fields = dict(zip(self.field_names, fields))
124+
return Instances(image_size, **fields)
125+
126+
@classmethod
127+
def flatten(cls, obj):
128+
field_names = sorted(obj.get_fields().keys())
129+
values = [obj.get(f) for f in field_names]
130+
ret, schema = ListSchema.flatten(values)
131+
size = obj.image_size
132+
if not isinstance(size, torch.Tensor):
133+
size = torch.tensor(size)
134+
return ret + (size,), cls(field_names, schema)
135+
136+
137+
@dataclass
138+
class BoxesSchema(Schema):
139+
def __call__(self, values):
140+
return Boxes(values[0])
141+
142+
@classmethod
143+
def flatten(cls, obj):
144+
return (obj.tensor,), cls()
145+
146+
147+
# if more custom structures needed in the future, can allow
148+
# passing in extra schemas for custom types
149+
def flatten_to_tuple(obj):
150+
"""
151+
Flatten an object so it can be used for PyTorch tracing.
152+
Also returns how to rebuild the original object from the flattened outputs.
153+
154+
Returns:
155+
res (tuple): the flattened results that can be used as tracing outputs
156+
schema: an object with a ``__call__`` method such that ``schema(res) == obj``.
157+
It is a pure dataclass that can be serialized.
158+
"""
159+
schemas = [
160+
((str, bytes), IdentitySchema),
161+
(collections.abc.Sequence, ListSchema),
162+
(collections.abc.Mapping, DictSchema),
163+
(Instances, InstancesSchema),
164+
(Boxes, BoxesSchema),
165+
]
166+
for klass, schema in schemas:
167+
if isinstance(obj, klass):
168+
F = schema
169+
break
170+
else:
171+
F = IdentitySchema
172+
173+
return F.flatten(obj)

detectron2/utils/testing.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,18 +62,25 @@ def get_sample_coco_image(tensor=True):
6262
return ret
6363

6464

65-
def assert_instances_allclose(input, other, rtol=1e-5, msg=""):
65+
def assert_instances_allclose(input, other, *, rtol=1e-5, msg="", size_as_tensor=False):
6666
"""
6767
Args:
6868
input, other (Instances):
69+
size_as_tensor: compare image_size of the Instances as tensors (instead of tuples).
70+
Useful for comparing outputs of tracing.
6971
"""
7072
if not msg:
7173
msg = "Two Instances are different! "
7274
else:
7375
msg = msg.rstrip() + " "
74-
assert input.image_size == other.image_size, (
75-
msg + f"image_size is {input.image_size} vs. {other.image_size}!"
76-
)
76+
77+
size_error_msg = msg + f"image_size is {input.image_size} vs. {other.image_size}!"
78+
if size_as_tensor:
79+
assert torch.equal(
80+
torch.tensor(input.image_size), torch.tensor(other.image_size)
81+
), size_error_msg
82+
else:
83+
assert input.image_size == other.image_size, size_error_msg
7784
fields = sorted(input.get_fields().keys())
7885
fields_other = sorted(other.get_fields().keys())
7986
assert fields == fields_other, msg + f"Fields are {fields} vs {fields_other}!"

tests/test_export_torchscript.py

Lines changed: 46 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from detectron2 import model_zoo
1010
from detectron2.config import get_cfg
11+
from detectron2.export.flatten import flatten_to_tuple
1112
from detectron2.export.torchscript import dump_torchscript_IR, export_torchscript_with_instances
1213
from detectron2.export.torchscript_patch import patch_builtin_len
1314
from detectron2.layers import ShapeSpec
@@ -83,77 +84,45 @@ def _test_retinanet_model(self, config_path):
8384
class TestTracing(unittest.TestCase):
8485
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
8586
def testMaskRCNN(self):
86-
class WrapModel(nn.ModuleList):
87-
def forward(self, image):
88-
inputs = [{"image": image}]
89-
outputs = self[0].inference(inputs, do_postprocess=False)[0]
90-
size = outputs.image_size
91-
if torch.jit.is_tracing():
92-
assert isinstance(size, torch.Tensor)
93-
else:
94-
size = torch.as_tensor(size)
95-
return (
96-
size,
97-
outputs.pred_classes,
98-
outputs.pred_boxes.tensor,
99-
outputs.scores,
100-
outputs.pred_masks,
101-
)
102-
103-
@staticmethod
104-
def convert_output(output):
105-
r = Instances(tuple(output[0]))
106-
r.pred_classes = output[1]
107-
r.pred_boxes = Boxes(output[2])
108-
r.scores = output[3]
109-
r.pred_masks = output[4]
110-
return r
87+
def inference_func(model, image):
88+
inputs = [{"image": image}]
89+
outputs = model.inference(inputs, do_postprocess=False)[0]
90+
return outputs
11191

112-
self._test_model("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml", WrapModel)
92+
self._test_model("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml", inference_func)
11393

11494
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
11595
def testRetinaNet(self):
116-
class WrapModel(nn.ModuleList):
117-
def forward(self, image):
118-
inputs = [{"image": image}]
119-
outputs = self[0].forward(inputs)[0]["instances"]
120-
size = outputs.image_size
121-
if torch.jit.is_tracing():
122-
assert isinstance(size, torch.Tensor)
123-
else:
124-
size = torch.as_tensor(size)
125-
return (
126-
size,
127-
outputs.pred_classes,
128-
outputs.pred_boxes.tensor,
129-
outputs.scores,
130-
)
131-
132-
@staticmethod
133-
def convert_output(output):
134-
r = Instances(tuple(output[0]))
135-
r.pred_classes = output[1]
136-
r.pred_boxes = Boxes(output[2])
137-
r.scores = output[3]
138-
return r
96+
def inference_func(model, image):
97+
return model.forward([{"image": image}])[0]["instances"]
13998

140-
self._test_model("COCO-Detection/retinanet_R_50_FPN_3x.yaml", WrapModel)
99+
self._test_model("COCO-Detection/retinanet_R_50_FPN_3x.yaml", inference_func)
141100

142-
def _test_model(self, config_path, WrapperCls):
143-
# TODO wrapper should be handled by export API in the future
101+
def _test_model(self, config_path, inference_func):
144102
model = model_zoo.get(config_path, trained=True)
145103
image = get_sample_coco_image()
146104

147-
model = WrapperCls([model])
148-
model.eval()
105+
class Wrapper(nn.ModuleList): # a wrapper to make the model traceable
106+
def forward(self, image):
107+
outputs = inference_func(self[0], image)
108+
flattened_outputs, schema = flatten_to_tuple(outputs)
109+
if not hasattr(self, "schema"):
110+
self.schema = schema
111+
return flattened_outputs
112+
113+
def rebuild(self, flattened_outputs):
114+
return self.schema(flattened_outputs)
115+
116+
wrapper = Wrapper([model])
117+
wrapper.eval()
149118
with torch.no_grad(), patch_builtin_len():
150119
small_image = nn.functional.interpolate(image, scale_factor=0.5)
151120
# trace with a different image, and the trace must still work
152-
traced_model = torch.jit.trace(model, (small_image,))
121+
traced_model = torch.jit.trace(wrapper, (small_image,))
153122

154-
output = WrapperCls.convert_output(model(image))
155-
traced_output = WrapperCls.convert_output(traced_model(image))
156-
assert_instances_allclose(output, traced_output)
123+
output = inference_func(model, image)
124+
traced_output = wrapper.rebuild(traced_model(image))
125+
assert_instances_allclose(output, traced_output, size_as_tensor=True)
157126

158127
def testKeypointHead(self):
159128
class M(nn.Module):
@@ -214,3 +183,22 @@ def forward(self, x):
214183
for name in ["model_ts_code", "model_ts_IR", "model_ts_IR_inlined", "model"]:
215184
fname = os.path.join(d, name + ".txt")
216185
self.assertTrue(os.stat(fname).st_size > 0, fname)
186+
187+
def test_flatten_basic(self):
188+
obj = [3, ([5, 6], {"name": [7, 9], "name2": 3})]
189+
res, schema = flatten_to_tuple(obj)
190+
self.assertEqual(res, (3, 5, 6, 7, 9, 3))
191+
new_obj = schema(res)
192+
self.assertEqual(new_obj, obj)
193+
194+
def test_flatten_instances_boxes(self):
195+
inst = Instances(
196+
torch.tensor([5, 8]), pred_masks=torch.tensor([3]), pred_boxes=Boxes(torch.ones((1, 4)))
197+
)
198+
obj = [3, ([5, 6], inst)]
199+
res, schema = flatten_to_tuple(obj)
200+
self.assertEqual(res[:3], (3, 5, 6))
201+
for r, expected in zip(res[3:], (inst.pred_boxes.tensor, inst.pred_masks, inst.image_size)):
202+
self.assertIs(r, expected)
203+
new_obj = schema(res)
204+
assert_instances_allclose(new_obj[1][1], inst, rtol=0.0, size_as_tensor=True)

0 commit comments

Comments
 (0)