diff --git a/test/test_writer.py b/test/test_writer.py index 597c24f9..7e8fbee7 100644 --- a/test/test_writer.py +++ b/test/test_writer.py @@ -687,6 +687,22 @@ def test_response_warc_1_1(self, is_gzip, builder_factory): assert '.' in recs[0].rec_headers['WARC-Date'] assert len(recs[0].rec_headers['WARC-Date']) == 27 + def test_response_digest_sha256(self, is_gzip, builder_factory): + writer = BufferWARCWriter(gzip=is_gzip, warc_version='WARC/1.1', digest_algorithm='sha256') + + builder = builder_factory(writer, warc_version='WARC/1.1', digest_algorithm='sha256') + resp = sample_response(builder) + writer.write_record(resp) + + stream = writer.get_stream() + + reader = ArchiveIterator(stream) + recs = list(reader) + + assert len(recs) == 1 + assert recs[0].rec_headers.get('WARC-Block-Digest') == 'sha256:HB3IP2QLBZJ45JMKAYFDDVME5MTC2WX2JJLQZSJYR575CFMRDDHA====' + assert recs[0].rec_headers.get('WARC-Payload-Digest') == 'sha256:5DS36RD4GUWABAHBIREZJMGMD67HUJPT5JRXYXEJ6WK3NKK4SJJQ====' + def _conv_to_streaming_record(self, record_buff, rec_type): # strip-off the two empty \r\n\r\n added at the end of uncompressed record record_buff = record_buff[:-4] diff --git a/warcio/recordbuilder.py b/warcio/recordbuilder.py index ca3c0890..490e72cb 100644 --- a/warcio/recordbuilder.py +++ b/warcio/recordbuilder.py @@ -28,12 +28,16 @@ class RecordBuilder(object): } NO_PAYLOAD_DIGEST_TYPES = ('warcinfo', 'revisit') + + # default digest algorithm + DIGEST_ALGORITHM = 'sha1' - def __init__(self, warc_version=None, header_filter=None): + def __init__(self, warc_version=None, header_filter=None, digest_algorithm=None): self.warc_version = self._parse_warc_version(warc_version) self.header_filter = header_filter + self.digest_algorithm = self._parse_digest_algorithm(digest_algorithm) def create_warcinfo_record(self, filename, info): warc_headers = StatusAndHeaders('', [], protocol=self.warc_version) @@ -146,6 +150,11 @@ def _parse_warc_version(self, version): return version return 'WARC/' + version + + def _parse_digest_algorithm(self, algorithm): + if not algorithm: + return self.DIGEST_ALGORITHM + return algorithm @classmethod def _make_warc_id(cls): @@ -221,9 +230,8 @@ def _iter_stream(stream): yield buf - @staticmethod - def _create_digester(): - return Digester('sha1') + def _create_digester(self): + return Digester(self.digest_algorithm) @staticmethod def _create_temp_file(): diff --git a/warcio/warcwriter.py b/warcio/warcwriter.py index 6fb71be9..a99a0d05 100644 --- a/warcio/warcwriter.py +++ b/warcio/warcwriter.py @@ -13,7 +13,8 @@ class BaseWARCWriter(RecordBuilder): def __init__(self, gzip=True, *args, **kwargs): super(BaseWARCWriter, self).__init__(warc_version=kwargs.get('warc_version'), - header_filter=kwargs.get('header_filter')) + header_filter=kwargs.get('header_filter'), + digest_algorithm=kwargs.get('digest_algorithm')) self.gzip = gzip self.hostname = gethostname()