|
5 | 5 | import re |
6 | 6 | import threading |
7 | 7 | import uuid |
| 8 | +import json |
8 | 9 | try: |
9 | 10 | import queue |
10 | 11 | except ImportError: |
11 | 12 | import Queue as queue |
12 | 13 | from xml.etree import ElementTree |
| 14 | +from urllib.parse import urlparse |
13 | 15 |
|
14 | 16 | import halo |
15 | 17 | import requests |
16 | 18 | import six |
17 | 19 |
|
18 | 20 | from gradient import api_sdk |
19 | 21 | from gradient.api_sdk.sdk_exceptions import ResourceFetchingError |
| 22 | +from gradient.api_sdk.utils import base64_encode |
20 | 23 | from gradient.cli_constants import CLI_PS_CLIENT_NAME |
| 24 | +from gradient.cli.jobs import get_workspace_handler |
| 25 | +from gradient.commands import jobs as jobs_commands |
21 | 26 | from gradient.commands.common import BaseCommand, DetailsCommandMixin, ListCommandPagerMixin |
| 27 | +from gradient.commands.jobs import BaseCreateJobCommandMixin, BaseJobCommand, CreateJobCommand |
22 | 28 | from gradient.exceptions import ApplicationError |
23 | 29 |
|
24 | 30 | S3_XMLNS = 'http://s3.amazonaws.com/doc/2006-03-01/' |
| 31 | +DATASET_IMPORTER_IMAGE = "paperspace/dataset-importer:latest" |
| 32 | +PROJECT_NAME = "Job Builder" |
| 33 | +SUPPORTED_URL = ['https', 'http'] |
| 34 | +IMPORTER_COMMAND = "go-getter" |
| 35 | +HTTP_SECRET = "HTTP_AUTH" |
| 36 | +S3_ACCESS_KEY = "AWS_ACCESS_KEY_ID" |
| 37 | +S3_SECRET_KEY = "AWS_SECRET_ACCESS_KEY" |
| 38 | +S3_REGION_KEY = "AWS_DEFAULT_REGION" |
25 | 39 |
|
26 | 40 |
|
27 | 41 | class WorkerPool(object): |
@@ -676,3 +690,92 @@ def update_status(): |
676 | 690 | for pre_signed in pre_signeds: |
677 | 691 | update_status() |
678 | 692 | pool.put(self._delete, url=pre_signed.url) |
| 693 | + |
| 694 | + |
| 695 | +class ImportDatasetCommand(BaseCreateJobCommandMixin, BaseJobCommand): |
| 696 | + def create_secret(self, key, value, expires_in=86400): |
| 697 | + client = api_sdk.clients.SecretsClient( |
| 698 | + api_key=self.api_key, |
| 699 | + logger=self.logger, |
| 700 | + ps_client_name=CLI_PS_CLIENT_NAME, |
| 701 | + ) |
| 702 | + |
| 703 | + response = client.ephemeral(key, value, expires_in) |
| 704 | + return response |
| 705 | + |
| 706 | + def get_command(self, s3_url, http_url, http_auth): |
| 707 | + command = "%s %s /data/output" % (IMPORTER_COMMAND, (s3_url or http_url)) |
| 708 | + if s3_url: |
| 709 | + command = "%s s3::%s /data/output" % (IMPORTER_COMMAND, s3_url) |
| 710 | + |
| 711 | + if http_url and http_auth is not None: |
| 712 | + url = urlparse(http_url) |
| 713 | + command_string = "%s https://${{HTTP_AUTH}}@%s /data/output" % (IMPORTER_COMMAND, url.path) |
| 714 | + command = base64_encode(command_string) |
| 715 | + |
| 716 | + return command |
| 717 | + |
| 718 | + def get_env_vars(self, s3_url, http_url, secrets): |
| 719 | + if s3_url is not None: |
| 720 | + if secrets[S3_ACCESS_KEY] is None or secrets[S3_SECRET_KEY] is None: |
| 721 | + self.logger.log('s3AccessKey and s3SecretKey required') |
| 722 | + return |
| 723 | + |
| 724 | + access_key_secret = self.create_secret(S3_ACCESS_KEY, secrets[S3_ACCESS_KEY]) |
| 725 | + secret_key_secret = self.create_secret(S3_SECRET_KEY, secrets[S3_SECRET_KEY]) |
| 726 | + |
| 727 | + access_key_value = "secret:ephemeral:%s" % access_key_secret[S3_ACCESS_KEY] |
| 728 | + secret_key_value = "secret:ephemeral:%s" % secret_key_secret[S3_SECRET_KEY] |
| 729 | + |
| 730 | + return { |
| 731 | + S3_ACCESS_KEY: access_key_value, |
| 732 | + S3_SECRET_KEY: secret_key_value, |
| 733 | + } |
| 734 | + |
| 735 | + if http_url and secrets[S3_ACCESS_KEY] is not None: |
| 736 | + http_auth_secret = self.create_secret(HTTP_SECRET, secrets[HTTP_SECRET]) |
| 737 | + return { |
| 738 | + HTTP_SECRET: http_auth_secret |
| 739 | + } |
| 740 | + |
| 741 | + return "" |
| 742 | + |
| 743 | + def _create(self, workflow): |
| 744 | + client = api_sdk.clients.JobsClient( |
| 745 | + api_key=self.api_key, |
| 746 | + ps_client_name=CLI_PS_CLIENT_NAME, |
| 747 | + ) |
| 748 | + return self.client.create(**workflow) |
| 749 | + |
| 750 | + |
| 751 | + def execute(self, cluster_id, machine_type, dataset_id, s3_url, http_url, http_auth, access_key, secret_key): |
| 752 | + if s3_url is None and http_url is None: |
| 753 | + self.logger.log('Error: --s3Url or --httpUrl required') |
| 754 | + return |
| 755 | + |
| 756 | + workflow = { |
| 757 | + "cluster_id": cluster_id, |
| 758 | + "container": DATASET_IMPORTER_IMAGE, |
| 759 | + "machine_type": machine_type, |
| 760 | + "project": PROJECT_NAME, |
| 761 | + "datasets": [{ "id": dataset_id, "name": "output", "output": True }], |
| 762 | + "project_id": None |
| 763 | + } |
| 764 | + |
| 765 | + dataset_url = s3_url or http_url |
| 766 | + |
| 767 | + url = urlparse(dataset_url) |
| 768 | + if url.scheme not in SUPPORTED_URL: |
| 769 | + self.logger.log('Invalid URL format supported [{}] Format:{} URL:{}'.format(','.join(SUPPORTED_URL), url.scheme, dataset_url)) |
| 770 | + return |
| 771 | + |
| 772 | + command = self.get_command(s3_url, http_url, http_auth) |
| 773 | + if command: |
| 774 | + workflow["command"] = command |
| 775 | + |
| 776 | + env_vars = self.get_env_vars(s3_url, http_url, { HTTP_SECRET: http_auth, S3_ACCESS_KEY: access_key, S3_SECRET_KEY: secret_key }) |
| 777 | + if env_vars: |
| 778 | + workflow["env_vars"] = env_vars |
| 779 | + |
| 780 | + command = CreateJobCommand(api_key=self.api_key, workspace_handler=get_workspace_handler()) |
| 781 | + command.execute(workflow) |
0 commit comments