This repository has been archived on 2024-05-31. You can view files and clone it, but cannot push or open issues or pull requests.
devicehub-teal/ereuse_devicehub/ereuse_utils/cli.py

302 lines
8.1 KiB
Python
Raw Normal View History

2023-03-21 16:31:43 +00:00
import enum as _enum
import getpass
import itertools
import os
import pathlib
import threading
from contextlib import contextmanager
from time import sleep
from typing import Any, Iterable, Type
from boltons import urlutils
from click import types as click_types
from colorama import Fore
from tqdm import tqdm
from ereuse_devicehub.ereuse_utils import if_none_return_none
COMMON_CONTEXT_S = {'help_option_names': ('-h', '--help')}
"""Common Context settings used for our implementations of the
Click cli.
"""
# Py2/3 compat. Empty conditional to avoid coverage
try:
_unicode = unicode
except NameError:
_unicode = str
class Enum(click_types.Choice):
"""
Enum support for click.
Use it as a collection: @click.option(..., type=cli.Enum(MyEnum)).
Then, this expects you to pass the *name* of a member of the enum.
From `this github issue <https://github.com/pallets/click/issues/
605#issuecomment-277539425>`_.
"""
def __init__(self, enum: Type[_enum.Enum]):
self.__enum = enum
super().__init__(enum.__members__)
def convert(self, value, param, ctx):
return self.__enum[super().convert(value, param, ctx)]
class Path(click_types.Path):
"""Like click.Path but returning ``pathlib.Path`` objects."""
def convert(self, value, param, ctx):
return pathlib.Path(super().convert(value, param, ctx))
class URL(click_types.StringParamType):
"""Returns a bolton's URL."""
name = 'url'
def __init__(
self,
scheme=None,
username=None,
password=None,
host=None,
port=None,
path=None,
query_params=None,
fragment=None,
) -> None:
super().__init__()
"""Creates the type URL. You can require or enforce parts
of the URL by setting parameters of this constructor.
If the param is...
- None, no check is performed (default).
- True, it is then required as part of the URL.
- False, it is then required NOT to be part of the URL.
- Any other value, then such value is required to be in
the URL.
"""
self.attrs = (
('scheme', scheme),
('username', username),
('password', password),
('host', host),
('port', port),
('path', path),
('query_params', query_params),
('fragment', fragment),
)
@if_none_return_none
def convert(self, value, param, ctx):
url = urlutils.URL(super().convert(value, param, ctx))
for name, attr in self.attrs:
if attr is True:
if not getattr(url, name):
self.fail(
'URL {} must contain {} but it does not.'.format(url, name)
)
elif attr is False:
if getattr(url, name):
self.fail('URL {} cannot contain {} but it does.'.format(url, name))
elif attr:
if getattr(url, name) != attr:
self.fail('{} form {} can only be {}'.format(name, url, attr))
return url
def password(service: str, username: str, prompt: str = 'Password:') -> str:
"""Gets a password from the keyring or the terminal."""
import keyring
return keyring.get_password(service, username) or getpass.getpass(prompt)
class Line(tqdm):
spinner_cycle = itertools.cycle(['-', '/', '|', '\\'])
def __init__(
self,
total=None,
desc=None,
leave=True,
file=None,
ncols=None,
mininterval=0.2,
maxinterval=10.0,
miniters=None,
ascii=None,
disable=False,
unit='it',
unit_scale=False,
dynamic_ncols=True,
smoothing=0.3,
bar_format=None,
initial=0,
position=None,
postfix=None,
unit_divisor=1000,
write_bytes=None,
gui=False,
close_message: Iterable = None,
error_message: Iterable = None,
**kwargs,
):
"""This cannot work with iterables. Iterable use is considered
backward-compatibility in tqdm and inconsistent in Line.
Manually call ``update``.
"""
self._close_message = close_message
self._error_message = error_message
if total:
bar_format = '{desc}{percentage:.1f}% |{bar}| {n:1g}/{total:1g} {elapsed}<{remaining}'
super().__init__(
None,
desc,
total,
leave,
file,
ncols,
mininterval,
maxinterval,
miniters,
ascii,
disable,
unit,
unit_scale,
dynamic_ncols,
smoothing,
bar_format,
initial,
position,
postfix,
unit_divisor,
write_bytes,
gui,
**kwargs,
)
def write_at_line(self, *args):
self.clear()
with self._lock:
self.display(''.join(str(arg) for arg in args))
def close_message(self, *args):
self._close_message = args
def error_message(self, *args):
self._error_message = args
2023-03-27 08:55:26 +00:00
def close(self): # noqa: C901
2023-03-21 16:31:43 +00:00
"""
Cleanup and (if leave=False) close the progressbar.
"""
if self.disable:
return
# Prevent multiple closures
self.disable = True
# decrement instance pos and remove from internal set
pos = abs(self.pos)
self._decr_instances(self)
# GUI mode
if not hasattr(self, "sp"):
return
# annoyingly, _supports_unicode isn't good enough
def fp_write(s):
self.fp.write(_unicode(s))
try:
fp_write('')
except ValueError as e:
if 'closed' in str(e):
return
raise # pragma: no cover
with self._lock:
if self.leave:
if self._close_message:
self.display(
''.join(str(arg) for arg in self._close_message), pos=pos
)
elif self.last_print_n < self.n:
# stats for overall rate (no weighted average)
self.avg_time = None
self.display(pos=pos)
if not max(
[abs(getattr(i, "pos", 0)) for i in self._instances] + [pos]
):
# only if not nested (#477)
fp_write('\n')
else:
if self._close_message:
self.display(
''.join(str(arg) for arg in self._close_message), pos=pos
)
else:
self.display(msg='', pos=pos)
if not pos:
fp_write('\r')
@contextmanager
def spin(self, prefix: str):
self._stop_running = threading.Event()
spin_thread = threading.Thread(target=self._spin, args=[prefix])
spin_thread.start()
try:
yield
finally:
self._stop_running.set()
spin_thread.join()
def _spin(self, prefix: str):
while not self._stop_running.is_set():
self.write_at_line(prefix, next(self.spinner_cycle))
sleep(0.50)
@classmethod
@contextmanager
def reserve_lines(self, n):
try:
yield
finally:
self.move_down(n - 1)
@classmethod
def move_down(cls, n: int):
print('\n' * n)
def __exit__(self, *exc):
if exc[0]:
self._close_message = self._error_message
return super().__exit__(*exc)
def clear():
os.system('clear')
def title(text: Any, ljust=32) -> str:
# Note that is 38 px + 1 extra space = 39 min
return str(text).ljust(ljust) + ' '
def danger(text: Any) -> str:
return '{}{}{}'.format(Fore.RED, text, Fore.RESET)
def warning(text: Any) -> str:
return '{}{}{}'.format(Fore.YELLOW, text, Fore.RESET)
def done(text: Any = 'done.') -> str:
return '{}{}{}'.format(Fore.GREEN, text, Fore.RESET)