|
2 | 2 |
|
3 | 3 | import os |
4 | 4 | import threading |
| 5 | +import json |
5 | 6 |
|
6 | 7 | import tensorflow as tf |
7 | 8 | import tensorflow_hub as hub |
8 | 9 |
|
9 | 10 | from http.server import BaseHTTPRequestHandler, HTTPServer |
10 | 11 | from test.support.os_helper import EnvironmentVarGuard |
| 12 | +from contextlib import contextmanager |
| 13 | +from kagglehub.exceptions import BackendError |
11 | 14 |
|
| 15 | +MOUNT_PATH = "/kaggle/input" |
12 | 16 |
|
13 | | -class TestKaggleModuleResolver(unittest.TestCase): |
14 | | - class HubHTTPHandler(BaseHTTPRequestHandler): |
15 | | - def do_GET(self): |
16 | | - self.send_response(200) |
17 | | - self.send_header('Content-Type', 'application/gzip') |
18 | | - self.end_headers() |
| 17 | +@contextmanager |
| 18 | +def create_test_server(handler_class): |
| 19 | + hostname = 'localhost' |
| 20 | + port = 8080 |
| 21 | + addr = f"http://{hostname}:{port}" |
19 | 22 |
|
20 | | - with open('/input/tests/data/model.tar.gz', 'rb') as model_archive: |
21 | | - self.wfile.write(model_archive.read()) |
| 23 | + # Simulates we are inside a Kaggle environment. |
| 24 | + env = EnvironmentVarGuard() |
| 25 | + env.set('KAGGLE_KERNEL_RUN_TYPE', 'Interactive') |
| 26 | + env.set('KAGGLE_USER_SECRETS_TOKEN', 'foo jwt token') |
| 27 | + env.set('KAGGLE_DATA_PROXY_TOKEN', 'foo proxy token') |
| 28 | + env.set('KAGGLE_DATA_PROXY_URL', addr) |
22 | 29 |
|
23 | | - def _test_client(self, client_func, handler): |
24 | | - with HTTPServer(('localhost', 8080), handler) as test_server: |
| 30 | + with env: |
| 31 | + with HTTPServer((hostname, port), handler_class) as test_server: |
25 | 32 | threading.Thread(target=test_server.serve_forever).start() |
26 | 33 |
|
27 | 34 | try: |
28 | | - client_func() |
| 35 | + yield addr |
29 | 36 | finally: |
30 | 37 | test_server.shutdown() |
31 | 38 |
|
32 | | - def test_kaggle_resolver_succeeds(self): |
33 | | - # Simulates we are inside a Kaggle environment. |
34 | | - env = EnvironmentVarGuard() |
35 | | - env.set('KAGGLE_CONTAINER_NAME', 'foo') |
36 | | - # Attach model to right directory. |
37 | | - os.makedirs('/kaggle/input/foomodule/tensorflow2/barvar') |
38 | | - os.symlink('/input/tests/data/saved_model/', '/kaggle/input/foomodule/tensorflow2/barvar/2', target_is_directory=True) |
| 39 | +class HubHTTPHandler(BaseHTTPRequestHandler): |
| 40 | + def do_GET(self): |
| 41 | + self.send_response(200) |
| 42 | + self.send_header('Content-Type', 'application/gzip') |
| 43 | + self.end_headers() |
| 44 | + |
| 45 | + with open('/input/tests/data/model.tar.gz', 'rb') as model_archive: |
| 46 | + self.wfile.write(model_archive.read()) |
39 | 47 |
|
40 | | - with env: |
| 48 | +class KaggleJwtHandler(BaseHTTPRequestHandler): |
| 49 | + def do_POST(self): |
| 50 | + if self.path.endswith("AttachDatasourceUsingJwtRequest"): |
| 51 | + content_length = int(self.headers["Content-Length"]) |
| 52 | + request = json.loads(self.rfile.read(content_length)) |
| 53 | + model_ref = request["modelRef"] |
| 54 | + |
| 55 | + self.send_response(200) |
| 56 | + self.send_header("Content-type", "application/json") |
| 57 | + self.end_headers() |
| 58 | + |
| 59 | + if model_ref['ModelSlug'] == 'unknown': |
| 60 | + self.wfile.write(bytes(json.dumps({ |
| 61 | + "wasSuccessful": False, |
| 62 | + }), "utf-8")) |
| 63 | + return |
| 64 | + |
| 65 | + # Load the files |
| 66 | + mount_slug = f"{model_ref['ModelSlug']}/{model_ref['Framework']}/{model_ref['InstanceSlug']}/{model_ref['VersionNumber']}" |
| 67 | + os.makedirs(os.path.dirname(os.path.join(MOUNT_PATH, mount_slug))) |
| 68 | + os.symlink('/input/tests/data/saved_model/', os.path.join(MOUNT_PATH, mount_slug), target_is_directory=True) |
| 69 | + |
| 70 | + # Return the response |
| 71 | + self.wfile.write(bytes(json.dumps({ |
| 72 | + "wasSuccessful": True, |
| 73 | + "result": { |
| 74 | + "mountSlug": mount_slug, |
| 75 | + }, |
| 76 | + }), "utf-8")) |
| 77 | + else: |
| 78 | + self.send_response(404) |
| 79 | + self.wfile.write(bytes(f"Unhandled path: {self.path}", "utf-8")) |
| 80 | + |
| 81 | +class TestKaggleModuleResolver(unittest.TestCase): |
| 82 | + def test_kaggle_resolver_succeeds(self): |
| 83 | + with create_test_server(KaggleJwtHandler) as addr: |
41 | 84 | test_inputs = tf.ones([1,4]) |
42 | 85 | layer = hub.KerasLayer("https://kaggle.com/models/foo/foomodule/frameworks/TensorFlow2/variations/barvar/versions/2") |
43 | 86 | self.assertEqual([1, 1], layer(test_inputs).shape) |
44 | 87 |
|
45 | 88 | def test_kaggle_resolver_not_attached_throws(self): |
46 | | - # Simulates we are inside a Kaggle environment. |
47 | | - env = EnvironmentVarGuard() |
48 | | - env.set('KAGGLE_CONTAINER_NAME', 'foo') |
49 | | - with env: |
50 | | - with self.assertRaisesRegex(RuntimeError, '.*attach.*'): |
51 | | - hub.KerasLayer("https://kaggle.com/models/foo/foomodule/frameworks/TensorFlow2/variations/barvar/versions/2") |
| 89 | + with create_test_server(KaggleJwtHandler) as addr: |
| 90 | + with self.assertRaises(BackendError): |
| 91 | + hub.KerasLayer("https://kaggle.com/models/foo/unknown/frameworks/TensorFlow2/variations/barvar/versions/2") |
52 | 92 |
|
53 | 93 | def test_http_resolver_succeeds(self): |
54 | | - def call_hub(): |
| 94 | + with create_test_server(HubHTTPHandler) as addr: |
55 | 95 | test_inputs = tf.ones([1,4]) |
56 | | - layer = hub.KerasLayer('http://localhost:8080/model.tar.gz') |
| 96 | + layer = hub.KerasLayer(f'{addr}/model.tar.gz') |
57 | 97 | self.assertEqual([1, 1], layer(test_inputs).shape) |
58 | 98 |
|
59 | | - self._test_client(call_hub, TestKaggleModuleResolver.HubHTTPHandler) |
60 | | - |
61 | 99 | def test_local_path_resolver_succeeds(self): |
62 | 100 | test_inputs = tf.ones([1,4]) |
63 | 101 | layer = hub.KerasLayer('/input/tests/data/saved_model') |
|
0 commit comments