Skip to content

Commit 7f60dc7

Browse files
authored
Merge pull request #293 from Paperspace/PS-13966-Fix_streaming_metrics_presentation_layer
Fix output of 'metrics stream'
2 parents 4d7e2ee + 00b6bcb commit 7f60dc7

File tree

8 files changed

+230
-136
lines changed

8 files changed

+230
-136
lines changed

gradient/api_sdk/repositories/common.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22
import collections
33
import datetime
44
import json
5-
import time
65

76
import dateutil
87
import six
98
import websocket
109

11-
from .. import serializers
10+
from .. import serializers, sdk_exceptions
1211
from ..clients import http_client
1312
from ..config import config
1413
from ..sdk_exceptions import ResourceFetchingError, ResourceCreatingDataError, ResourceCreatingError, GradientSdkError
@@ -408,6 +407,8 @@ def stream(self, **kwargs):
408407
yield data
409408
except websocket.WebSocketConnectionClosedException as e:
410409
self.logger.debug("WebSocketConnectionClosedException: {}".format(e))
410+
except sdk_exceptions.EndWebsocketStream:
411+
return
411412

412413
def _get_connection(self, kwargs):
413414
url = self._get_full_url(kwargs)

gradient/api_sdk/sdk_exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,7 @@ class PresignedUrlConnectionError(ArchiveUploadError):
6060

6161
class InvalidParametersError(GradientSdkError):
6262
pass
63+
64+
65+
class EndWebsocketStream(Exception):
66+
pass

gradient/cliutils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import curses
12
import json
23
import shutil
34

@@ -48,3 +49,36 @@ def validate_auth_options(kwargs):
4849

4950
def none_strings_to_none_objects(lst):
5051
return [elem if elem != "none" else None for elem in lst]
52+
53+
54+
class TerminalPrinter(object):
55+
def __init__(self):
56+
self.screen_writer = None
57+
58+
def init(self):
59+
self.screen_writer = curses.initscr()
60+
return self
61+
62+
def cleanup(self):
63+
"""Cleanup before program ends or terminal may stop working correlctly"""
64+
curses.nocbreak()
65+
self.screen_writer.keypad(False)
66+
curses.echo()
67+
curses.endwin()
68+
69+
def rewrite_screen(self, s):
70+
self.clear()
71+
self.add_line(s)
72+
self.refresh()
73+
74+
def add_line(self, s):
75+
self.add_str(s + "\n")
76+
77+
def add_str(self, s):
78+
self.screen_writer.addstr(s)
79+
80+
def refresh(self):
81+
self.screen_writer.refresh()
82+
83+
def clear(self):
84+
self.screen_writer.clear()

gradient/commands/common.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@
33
import json
44
import pydoc
55

6-
import click
76
import six
87
import terminaltables
98
from halo import halo
109

1110
from gradient.clilogger import CliLogger
12-
from gradient.cliutils import get_terminal_lines
11+
from gradient.cliutils import get_terminal_lines, TerminalPrinter
1312
from gradient.exceptions import ApplicationError
1413

1514

@@ -131,6 +130,18 @@ def __init__(self, *args, **kwargs):
131130
super(StreamMetricsCommand, self).__init__(*args, **kwargs)
132131
# {"metricName": {"pod_id": "value"}}
133132
self._recent_values = collections.OrderedDict()
133+
self.terminal_printer = None
134+
135+
def execute(self, *args, **kwargs):
136+
self.terminal_printer = TerminalPrinter()
137+
self.terminal_printer.init()
138+
try:
139+
rv = super(StreamMetricsCommand, self).execute(*args, **kwargs)
140+
finally:
141+
self.terminal_printer.clear()
142+
self.terminal_printer.cleanup()
143+
144+
return rv
134145

135146
def _get_instances(self, kwargs):
136147
metrics_stream = self.client.stream_metrics(**kwargs)
@@ -154,8 +165,7 @@ def _update_recent_values(self, metric_data):
154165
self._recent_values[metric_name][pod_name] = data["value"]
155166

156167
def _print_table_to_terminal(self, table_str):
157-
click.clear()
158-
super(StreamMetricsCommand, self)._print_table_to_terminal(table_str)
168+
self.terminal_printer.rewrite_screen(table_str)
159169

160170
def _get_table_data(self, objects):
161171
metrics = list(self._recent_values.keys())

tests/functional/test_deployments.py

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def generator(self):
2929
"pod_metrics": {"desgffa3mtgepvm-0": {"time_stamp": 1587673820, "value": "34914304"},
3030
"desgffa3mtgepvm-1": {"time_stamp": 1587673820, "value": "35942400"}}}"""
3131

32-
raise sdk_exceptions.GradientSdkError()
32+
raise sdk_exceptions.EndWebsocketStream()
3333

3434
return generator
3535

@@ -53,7 +53,7 @@ def generator(self):
5353
"pod_metrics": {"desgffa3mtgepvm-0": {"time_stamp": 1587640740, "value": "1234"},
5454
"desgffa3mtgepvm-1": {"time_stamp": 1587640740, "value": "234"}}}"""
5555

56-
raise sdk_exceptions.GradientSdkError()
56+
raise sdk_exceptions.EndWebsocketStream()
5757

5858
return generator
5959

@@ -1525,7 +1525,7 @@ def test_should_print_valid_error_message_when_error_code_was_returned_without_e
15251525
assert result.exit_code == 0, result.exc_info
15261526

15271527

1528-
class TestExperimentsMetricsStreamCommand(object):
1528+
class TestDeploymentssMetricsStreamCommand(object):
15291529
LIST_DEPLOYMENTS_URL = "https://api.paperspace.io/deployments/getDeploymentList/"
15301530
GET_METRICS_URL = "https://aws-testing.paperspace.io/metrics/api/v1/stream"
15311531
BASIC_OPTIONS_COMMAND = [
@@ -1559,52 +1559,47 @@ class TestExperimentsMetricsStreamCommand(object):
15591559
+-------------------+---------------+-------------+
15601560
| desgffa3mtgepvm-0 | | 34914304 |
15611561
| desgffa3mtgepvm-1 | | 35942400 |
1562-
+-------------------+---------------+-------------+
1563-
"""
1562+
+-------------------+---------------+-------------+"""
15641563
EXPECTED_TABLE_2 = """+-------------------+----------------------+-------------+
15651564
| Pod | cpuPercentage | memoryUsage |
15661565
+-------------------+----------------------+-------------+
15671566
| desgffa3mtgepvm-0 | 0.044894188888835944 | 34914304 |
15681567
| desgffa3mtgepvm-1 | 0.048185748888916656 | 35942400 |
1569-
+-------------------+----------------------+-------------+
1570-
"""
1568+
+-------------------+----------------------+-------------+"""
15711569
EXPECTED_TABLE_3 = """+-------------------+----------------------+-------------+
15721570
| Pod | cpuPercentage | memoryUsage |
15731571
+-------------------+----------------------+-------------+
15741572
| desgffa3mtgepvm-0 | 0.044894188888835944 | 34914304 |
15751573
| desgffa3mtgepvm-1 | 0.048185748888916656 | 35942400 |
1576-
+-------------------+----------------------+-------------+
1577-
"""
1574+
+-------------------+----------------------+-------------+"""
15781575

15791576
ALL_OPTIONS_EXPECTED_TABLE_1 = """+-------------------+---------------+---------------+
15801577
| Pod | gpuMemoryFree | gpuMemoryUsed |
15811578
+-------------------+---------------+---------------+
15821579
| desgffa3mtgepvm-0 | | 0 |
15831580
| desgffa3mtgepvm-1 | | 0 |
1584-
+-------------------+---------------+---------------+
1585-
"""
1581+
+-------------------+---------------+---------------+"""
15861582
ALL_OPTIONS_EXPECTED_TABLE_2 = """+-------------------+---------------+---------------+
15871583
| Pod | gpuMemoryFree | gpuMemoryUsed |
15881584
+-------------------+---------------+---------------+
15891585
| desgffa3mtgepvm-0 | | 321 |
15901586
| desgffa3mtgepvm-1 | | 432 |
1591-
+-------------------+---------------+---------------+
1592-
"""
1587+
+-------------------+---------------+---------------+"""
15931588
ALL_OPTIONS_EXPECTED_TABLE_3 = """+-------------------+---------------+---------------+
15941589
| Pod | gpuMemoryFree | gpuMemoryUsed |
15951590
+-------------------+---------------+---------------+
15961591
| desgffa3mtgepvm-0 | 1234 | 321 |
15971592
| desgffa3mtgepvm-1 | 234 | 432 |
1598-
+-------------------+---------------+---------------+
1599-
"""
1593+
+-------------------+---------------+---------------+"""
16001594

16011595
EXPECTED_STDOUT_WHEN_INVALID_API_KEY_WAS_USED = "Failed to fetch data: Incorrect API Key provided\nForbidden\n"
16021596
EXPECTED_STDOUT_WHEN_DEPLOYMENT_WAS_NOT_FOUND = "Deployment not found\n"
16031597

1598+
@mock.patch("gradient.commands.common.TerminalPrinter")
16041599
@mock.patch("gradient.api_sdk.repositories.common.websocket.create_connection")
16051600
@mock.patch("gradient.api_sdk.clients.http_client.requests.get")
16061601
def test_should_read_all_available_metrics_when_metrics_get_command_was_used_with_basic_options(
1607-
self, get_patched, create_ws_connection_patched,
1602+
self, get_patched, create_ws_connection_patched, terminal_printer_cls_patched,
16081603
basic_options_metrics_stream_websocket_connection_iterator):
16091604
get_patched.return_value = MockResponse(self.GET_LIST_OF_DEPLOYMENTS_RESPONSE_JSON)
16101605

@@ -1615,9 +1610,13 @@ def test_should_read_all_available_metrics_when_metrics_get_command_was_used_wit
16151610
runner = CliRunner()
16161611
result = runner.invoke(cli.cli, self.BASIC_OPTIONS_COMMAND)
16171612

1618-
assert self.EXPECTED_TABLE_1 in result.output, result.exc_info
1619-
assert self.EXPECTED_TABLE_2 in result.output, result.exc_info
1620-
assert self.EXPECTED_TABLE_3 in result.output, result.exc_info
1613+
terminal_printer_cls_patched().init.assert_called_once()
1614+
terminal_printer_cls_patched().rewrite_screen.assert_has_calls([
1615+
mock.call(self.EXPECTED_TABLE_1),
1616+
mock.call(self.EXPECTED_TABLE_2),
1617+
mock.call(self.EXPECTED_TABLE_3),
1618+
])
1619+
terminal_printer_cls_patched().cleanup.assert_called_once()
16211620

16221621
get_patched.assert_called_once_with(
16231622
self.LIST_DEPLOYMENTS_URL,
@@ -1628,10 +1627,11 @@ def test_should_read_all_available_metrics_when_metrics_get_command_was_used_wit
16281627
ws_connection_instance_mock.send.assert_called_once_with(self.BASIC_COMMAND_CHART_DESCRIPTOR)
16291628
assert result.exit_code == 0, result.exc_info
16301629

1630+
@mock.patch("gradient.commands.common.TerminalPrinter")
16311631
@mock.patch("gradient.api_sdk.repositories.common.websocket.create_connection")
16321632
@mock.patch("gradient.api_sdk.clients.http_client.requests.get")
16331633
def test_should_read_metrics_when_metrics_get_command_was_used_with_all_options(
1634-
self, get_patched, create_ws_connection_patched,
1634+
self, get_patched, create_ws_connection_patched, terminal_printer_cls_patched,
16351635
all_options_metrics_stream_websocket_connection_iterator):
16361636
get_patched.return_value = MockResponse(self.GET_LIST_OF_DEPLOYMENTS_RESPONSE_JSON)
16371637

@@ -1642,9 +1642,13 @@ def test_should_read_metrics_when_metrics_get_command_was_used_with_all_options(
16421642
runner = CliRunner()
16431643
result = runner.invoke(cli.cli, self.ALL_OPTIONS_COMMAND)
16441644

1645-
assert self.ALL_OPTIONS_EXPECTED_TABLE_1 in result.output, result.exc_info
1646-
assert self.ALL_OPTIONS_EXPECTED_TABLE_2 in result.output, result.exc_info
1647-
assert self.ALL_OPTIONS_EXPECTED_TABLE_3 in result.output, result.exc_info
1645+
terminal_printer_cls_patched().init.assert_called_once()
1646+
terminal_printer_cls_patched().rewrite_screen.assert_has_calls([
1647+
mock.call(self.ALL_OPTIONS_EXPECTED_TABLE_1),
1648+
mock.call(self.ALL_OPTIONS_EXPECTED_TABLE_2),
1649+
mock.call(self.ALL_OPTIONS_EXPECTED_TABLE_3),
1650+
])
1651+
terminal_printer_cls_patched().cleanup.assert_called_once()
16481652

16491653
get_patched.assert_called_once_with(
16501654
self.LIST_DEPLOYMENTS_URL,
@@ -1656,10 +1660,11 @@ def test_should_read_metrics_when_metrics_get_command_was_used_with_all_options(
16561660
ws_connection_instance_mock.send.assert_called_once_with(self.ALL_COMMANDS_CHART_DESCRIPTOR)
16571661
assert result.exit_code == 0, result.exc_info
16581662

1663+
@mock.patch("gradient.commands.common.TerminalPrinter")
16591664
@mock.patch("gradient.api_sdk.repositories.common.websocket.create_connection")
16601665
@mock.patch("gradient.api_sdk.clients.http_client.requests.get")
16611666
def test_should_read_metrics_when_metrics_get_was_executed_and_options_file_was_used(
1662-
self, get_patched, create_ws_connection_patched,
1667+
self, get_patched, create_ws_connection_patched, terminal_printer_cls_patched,
16631668
all_options_metrics_stream_websocket_connection_iterator,
16641669
deployments_metrics_stream_config_path):
16651670
get_patched.return_value = MockResponse(self.GET_LIST_OF_DEPLOYMENTS_RESPONSE_JSON)
@@ -1671,9 +1676,13 @@ def test_should_read_metrics_when_metrics_get_was_executed_and_options_file_was_
16711676
runner = CliRunner()
16721677
result = runner.invoke(cli.cli, command)
16731678

1674-
assert self.ALL_OPTIONS_EXPECTED_TABLE_1 in result.output, result.exc_info
1675-
assert self.ALL_OPTIONS_EXPECTED_TABLE_2 in result.output, result.exc_info
1676-
assert self.ALL_OPTIONS_EXPECTED_TABLE_3 in result.output, result.exc_info
1679+
terminal_printer_cls_patched().init.assert_called_once()
1680+
terminal_printer_cls_patched().rewrite_screen.assert_has_calls([
1681+
mock.call(self.ALL_OPTIONS_EXPECTED_TABLE_1),
1682+
mock.call(self.ALL_OPTIONS_EXPECTED_TABLE_2),
1683+
mock.call(self.ALL_OPTIONS_EXPECTED_TABLE_3),
1684+
])
1685+
terminal_printer_cls_patched().cleanup.assert_called_once()
16771686

16781687
get_patched.assert_called_once_with(
16791688
self.LIST_DEPLOYMENTS_URL,
@@ -1685,10 +1694,11 @@ def test_should_read_metrics_when_metrics_get_was_executed_and_options_file_was_
16851694
ws_connection_instance_mock.send.assert_called_once_with(self.ALL_COMMANDS_CHART_DESCRIPTOR)
16861695
assert result.exit_code == 0, result.exc_info
16871696

1697+
@mock.patch("gradient.commands.common.TerminalPrinter")
16881698
@mock.patch("gradient.api_sdk.repositories.common.websocket.create_connection")
16891699
@mock.patch("gradient.api_sdk.clients.http_client.requests.get")
16901700
def test_should_print_valid_error_message_when_invalid_api_key_was_used(
1691-
self, get_patched, create_ws_connection_patched):
1701+
self, get_patched, create_ws_connection_patched, terminal_printer_cls_patched):
16921702
get_patched.return_value = MockResponse({"status": 400, "message": "Invalid API token"}, 400)
16931703

16941704
runner = CliRunner()
@@ -1706,10 +1716,11 @@ def test_should_print_valid_error_message_when_invalid_api_key_was_used(
17061716
create_ws_connection_patched.assert_not_called()
17071717
assert result.exit_code == 0, result.exc_info
17081718

1719+
@mock.patch("gradient.commands.common.TerminalPrinter")
17091720
@mock.patch("gradient.api_sdk.repositories.common.websocket.create_connection")
17101721
@mock.patch("gradient.api_sdk.clients.http_client.requests.get")
17111722
def test_should_print_valid_error_message_when_deployment_was_not_found(
1712-
self, get_patched, create_ws_connection_patched):
1723+
self, get_patched, create_ws_connection_patched, terminal_printer_cls_patched):
17131724
get_patched.return_value = MockResponse({"deploymentList": []})
17141725

17151726
runner = CliRunner()

0 commit comments

Comments
 (0)