1212 import Queue as queue
1313from xml .etree import ElementTree
1414from urllib .parse import urlparse
15+ from ..api_sdk .clients import http_client
16+ from ..api_sdk .config import config
17+ from ..cli_constants import CLI_PS_CLIENT_NAME
1518
1619import halo
1720import requests
@@ -557,24 +560,114 @@ def update_status():
557560
558561class PutDatasetFilesCommand (BaseDatasetFilesCommand ):
559562
560- @classmethod
561- def _put (cls , path , url , content_type ):
563+ # @classmethod
564+ def _put (self , path , url , content_type , dataset_version_id = None , key = None ):
562565 size = os .path .getsize (path )
563566 with requests .Session () as session :
564567 headers = {'Content-Type' : content_type }
565568
566569 try :
567- if size > 0 :
570+ if size <= 0 :
571+ headers .update ({'Content-Size' : '0' })
572+ r = session .put (url , data = '' , headers = headers , timeout = 5 )
573+ # for files under half a GB
574+ elif size <= (10e8 ) / 2 :
568575 with open (path , 'rb' ) as f :
569576 r = session .put (
570577 url , data = f , headers = headers , timeout = 5 )
578+ # # for chonky files, use a multipart upload
571579 else :
572- headers .update ({'Content-Size' : '0' })
573- r = session .put (url , data = '' , headers = headers , timeout = 5 )
574-
575- cls .validate_s3_response (r )
580+ # Chunks need to be at least 5MB or AWS throws an
581+ # EntityTooSmall error; we'll arbitrarily choose a
582+ # 15MB chunksize
583+ #
584+ # Note also that AWS limits the max number of chunkc
585+ # in a multipart upload to 10000, so this setting
586+ # currently enforces a hard limit on 150GB per file.
587+ #
588+ # We can dynamically assign a larger part size if needed,
589+ # but for the majority of use cases we should be fine
590+ # as-is
591+ part_minsize = int (15e6 )
592+ dataset_id , _ , version = dataset_version_id .partition (":" )
593+ mpu_url = f'/datasets/{ dataset_id } /versions/{ version } /s3/preSignedUrls'
594+
595+ api_client = http_client .API (
596+ api_url = config .CONFIG_HOST ,
597+ api_key = self .api_key ,
598+ ps_client_name = CLI_PS_CLIENT_NAME
599+ )
600+
601+ mpu_create_res = api_client .post (
602+ url = mpu_url ,
603+ json = {
604+ 'datasetId' : dataset_id ,
605+ 'version' : version ,
606+ 'calls' : [{
607+ 'method' : 'createMultipartUpload' ,
608+ 'params' : {'Key' : key }
609+ }]
610+ }
611+ )
612+ mpu_data = json .loads (mpu_create_res .text )[0 ]['url' ]
613+
614+ parts = []
615+ with open (path , 'rb' ) as f :
616+ # we +2 the number of parts since we're doing floor
617+ # division, which will cut off any trailing part
618+ # less than the part_minsize, AND we want to 1-index
619+ # our range to match what AWS expects for part
620+ # numbers
621+ for part in range (1 , (size // part_minsize ) + 2 ):
622+ presigned_url_res = api_client .post (
623+ url = mpu_url ,
624+ json = {
625+ 'datasetId' : dataset_id ,
626+ 'version' : version ,
627+ 'calls' : [{
628+ 'method' : 'uploadPart' ,
629+ 'params' : {
630+ 'Key' : key ,
631+ 'UploadId' : mpu_data ['UploadId' ],
632+ 'PartNumber' : part
633+ }
634+ }]
635+ }
636+ )
637+
638+ presigned_url = json .loads (
639+ presigned_url_res .text
640+ )[0 ]['url' ]
641+
642+ chunk = f .read (part_minsize )
643+ part_res = session .put (
644+ presigned_url ,
645+ data = chunk ,
646+ timeout = 5 )
647+ etag = part_res .headers ['ETag' ].replace ('"' , '' )
648+ parts .append ({'ETag' : etag , 'PartNumber' : part })
649+
650+ r = api_client .post (
651+ url = mpu_url ,
652+ json = {
653+ 'datasetId' : dataset_id ,
654+ 'version' : version ,
655+ 'calls' : [{
656+ 'method' : 'completeMultipartUpload' ,
657+ 'params' : {
658+ 'Key' : key ,
659+ 'UploadId' : mpu_data ['UploadId' ],
660+ 'MultipartUpload' : {'Parts' : parts }
661+ }
662+ }]
663+ }
664+ )
665+
666+ self .validate_s3_response (r )
576667 except requests .exceptions .ConnectionError as e :
577- return cls .report_connection_error (e )
668+ return self .report_connection_error (e )
669+ except Exception as e :
670+ return e
578671
579672 @staticmethod
580673 def _list_files (source_path ):
@@ -599,8 +692,13 @@ def _sign_and_put(self, dataset_version_id, pool, results, update_status):
599692
600693 for pre_signed , result in zip (pre_signeds , results ):
601694 update_status ()
602- pool .put (self ._put , url = pre_signed .url ,
603- path = result ['path' ], content_type = result ['mimetype' ])
695+ pool .put (
696+ self ._put ,
697+ url = pre_signed .url ,
698+ path = result ['path' ],
699+ content_type = result ['mimetype' ],
700+ dataset_version_id = dataset_version_id ,
701+ key = result ['key' ])
604702
605703 def execute (self , dataset_version_id , source_paths , target_path ):
606704 self .assert_supported (dataset_version_id )
0 commit comments