From bbb2466b07e14b177ea093ad2628c9a7adacbefa Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
<66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Fri, 17 Nov 2023 23:02:22 +0000
Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20format?=
=?UTF-8?q?=20from=20pre-commit.com=20hooks?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
docs/conf.py | 12 +-
papermill/__init__.py | 3 +-
papermill/__main__.py | 2 +-
papermill/abs.py | 44 +-
papermill/adl.py | 11 +-
papermill/cli.py | 184 ++++----
papermill/clientwrap.py | 32 +-
papermill/engines.py | 118 +++---
papermill/exceptions.py | 20 +-
papermill/execute.py | 71 ++--
papermill/inspection.py | 34 +-
papermill/iorw.py | 167 ++++----
papermill/log.py | 2 +-
papermill/models.py | 10 +-
papermill/parameterize.py | 32 +-
papermill/s3.py | 146 +++----
papermill/tests/__init__.py | 6 +-
papermill/tests/test_abs.py | 107 ++---
papermill/tests/test_adl.py | 61 ++-
papermill/tests/test_autosave.py | 36 +-
papermill/tests/test_cli.py | 470 +++++++++------------
papermill/tests/test_clientwrap.py | 31 +-
papermill/tests/test_engines.py | 341 +++++++--------
papermill/tests/test_exceptions.py | 20 +-
papermill/tests/test_execute.py | 352 ++++++----------
papermill/tests/test_gcs.py | 110 ++---
papermill/tests/test_hdfs.py | 20 +-
papermill/tests/test_inspect.py | 106 +++--
papermill/tests/test_iorw.py | 232 +++++------
papermill/tests/test_parameterize.py | 160 +++----
papermill/tests/test_s3.py | 90 ++--
papermill/tests/test_translators.py | 603 +++++++++++++--------------
papermill/tests/test_utils.py | 38 +-
papermill/translators.py | 267 ++++++------
papermill/utils.py | 21 +-
papermill/version.py | 2 +-
36 files changed, 1763 insertions(+), 2198 deletions(-)
diff --git a/docs/conf.py b/docs/conf.py
index 00ddfde6..50adcd0f 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -80,7 +80,7 @@
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', 'UPDATE.md']
# The name of the Pygments (syntax highlighting) style to use.
-pygments_style = "sphinx"
+pygments_style = 'sphinx'
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = False
@@ -90,14 +90,14 @@
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
-html_theme = "furo"
+html_theme = 'furo'
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
html_theme_options = {
- "sidebar_hide_name": True,
+ 'sidebar_hide_name': True,
}
# Add any paths that contain custom static files (such as style sheets) here,
@@ -105,7 +105,7 @@
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
-html_logo = "_static/images/papermill.png"
+html_logo = '_static/images/papermill.png'
# -- Options for HTMLHelp output ------------------------------------------
@@ -132,9 +132,7 @@
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
-latex_documents = [
- (master_doc, 'papermill.tex', 'papermill Documentation', 'nteract team', 'manual')
-]
+latex_documents = [(master_doc, 'papermill.tex', 'papermill Documentation', 'nteract team', 'manual')]
# -- Options for manual page output ---------------------------------------
diff --git a/papermill/__init__.py b/papermill/__init__.py
index af32a9d3..e3b98fb6 100644
--- a/papermill/__init__.py
+++ b/papermill/__init__.py
@@ -1,5 +1,4 @@
-from .version import version as __version__
-
from .exceptions import PapermillException, PapermillExecutionError
from .execute import execute_notebook
from .inspection import inspect_notebook
+from .version import version as __version__
diff --git a/papermill/__main__.py b/papermill/__main__.py
index 1f08dacb..c386c2ff 100644
--- a/papermill/__main__.py
+++ b/papermill/__main__.py
@@ -1,4 +1,4 @@
from papermill.cli import papermill
-if __name__ == "__main__":
+if __name__ == '__main__':
papermill()
diff --git a/papermill/abs.py b/papermill/abs.py
index 2c5d4a45..0378d45f 100644
--- a/papermill/abs.py
+++ b/papermill/abs.py
@@ -1,9 +1,9 @@
"""Utilities for working with Azure blob storage"""
-import re
import io
+import re
-from azure.storage.blob import BlobServiceClient
from azure.identity import EnvironmentCredential
+from azure.storage.blob import BlobServiceClient
class AzureBlobStore:
@@ -20,7 +20,7 @@ class AzureBlobStore:
def _blob_service_client(self, account_name, sas_token=None):
blob_service_client = BlobServiceClient(
- account_url=f"{account_name}.blob.core.windows.net",
+ account_url=f'{account_name}.blob.core.windows.net',
credential=sas_token or EnvironmentCredential(),
)
@@ -32,17 +32,15 @@ def _split_url(self, url):
see: https://docs.microsoft.com/en-us/azure/storage/common/storage-dotnet-shared-access-signature-part-1 # noqa: E501
abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken
"""
- match = re.match(
- r"abs://(.*)\.blob\.core\.windows\.net\/(.*?)\/([^\?]*)\??(.*)$", url
- )
+ match = re.match(r'abs://(.*)\.blob\.core\.windows\.net\/(.*?)\/([^\?]*)\??(.*)$', url)
if not match:
raise Exception(f"Invalid azure blob url '{url}'")
else:
params = {
- "account": match.group(1),
- "container": match.group(2),
- "blob": match.group(3),
- "sas_token": match.group(4),
+ 'account': match.group(1),
+ 'container': match.group(2),
+ 'blob': match.group(3),
+ 'sas_token': match.group(4),
}
return params
@@ -50,32 +48,22 @@ def read(self, url):
"""Read storage at a given url"""
params = self._split_url(url)
output_stream = io.BytesIO()
- blob_service_client = self._blob_service_client(
- params["account"], params["sas_token"]
- )
- blob_client = blob_service_client.get_blob_client(
- params["container"], params["blob"]
- )
+ blob_service_client = self._blob_service_client(params['account'], params['sas_token'])
+ blob_client = blob_service_client.get_blob_client(params['container'], params['blob'])
blob_client.download_blob().readinto(output_stream)
output_stream.seek(0)
- return [line.decode("utf-8") for line in output_stream]
+ return [line.decode('utf-8') for line in output_stream]
def listdir(self, url):
"""Returns a list of the files under the specified path"""
params = self._split_url(url)
- blob_service_client = self._blob_service_client(
- params["account"], params["sas_token"]
- )
- container_client = blob_service_client.get_container_client(params["container"])
- return list(container_client.list_blobs(params["blob"]))
+ blob_service_client = self._blob_service_client(params['account'], params['sas_token'])
+ container_client = blob_service_client.get_container_client(params['container'])
+ return list(container_client.list_blobs(params['blob']))
def write(self, buf, url):
"""Write buffer to storage at a given url"""
params = self._split_url(url)
- blob_service_client = self._blob_service_client(
- params["account"], params["sas_token"]
- )
- blob_client = blob_service_client.get_blob_client(
- params["container"], params["blob"]
- )
+ blob_service_client = self._blob_service_client(params['account'], params['sas_token'])
+ blob_client = blob_service_client.get_blob_client(params['container'], params['blob'])
blob_client.upload_blob(data=buf, overwrite=True)
diff --git a/papermill/adl.py b/papermill/adl.py
index de7b64cb..4ad0f62a 100644
--- a/papermill/adl.py
+++ b/papermill/adl.py
@@ -21,7 +21,7 @@ def __init__(self):
@classmethod
def _split_url(cls, url):
- match = re.match(r"adl://(.*)\.azuredatalakestore\.net\/(.*)$", url)
+ match = re.match(r'adl://(.*)\.azuredatalakestore\.net\/(.*)$', url)
if not match:
raise Exception(f"Invalid ADL url '{url}'")
else:
@@ -39,12 +39,7 @@ def listdir(self, url):
"""Returns a list of the files under the specified path"""
(store_name, path) = self._split_url(url)
adapter = self._create_adapter(store_name)
- return [
- "adl://{store_name}.azuredatalakestore.net/{path_to_child}".format(
- store_name=store_name, path_to_child=path_to_child
- )
- for path_to_child in adapter.ls(path)
- ]
+ return [f'adl://{store_name}.azuredatalakestore.net/{path_to_child}' for path_to_child in adapter.ls(path)]
def read(self, url):
"""Read storage at a given url"""
@@ -60,5 +55,5 @@ def write(self, buf, url):
"""Write buffer to storage at a given url"""
(store_name, path) = self._split_url(url)
adapter = self._create_adapter(store_name)
- with adapter.open(path, "wb") as f:
+ with adapter.open(path, 'wb') as f:
f.write(buf.encode())
diff --git a/papermill/cli.py b/papermill/cli.py
index 3b76b00e..e80867df 100755
--- a/papermill/cli.py
+++ b/papermill/cli.py
@@ -1,23 +1,21 @@
"""Main `papermill` interface."""
+import base64
+import logging
import os
+import platform
import sys
-from stat import S_ISFIFO
-import nbclient
import traceback
-
-import base64
-import logging
+from stat import S_ISFIFO
import click
-
+import nbclient
import yaml
-import platform
+from . import __version__ as papermill_version
from .execute import execute_notebook
-from .iorw import read_yaml_file, NoDatesSafeLoader
from .inspection import display_notebook_help
-from . import __version__ as papermill_version
+from .iorw import NoDatesSafeLoader, read_yaml_file
click.disable_unicode_literals_warning = True
@@ -28,155 +26,147 @@
def print_papermill_version(ctx, param, value):
if not value:
return
- print(
- "{version} from {path} ({pyver})".format(
- version=papermill_version, path=__file__, pyver=platform.python_version()
- )
- )
+ print(f'{papermill_version} from {__file__} ({platform.python_version()})')
ctx.exit()
-@click.command(context_settings=dict(help_option_names=["-h", "--help"]))
+@click.command(context_settings=dict(help_option_names=['-h', '--help']))
@click.pass_context
-@click.argument("notebook_path", required=not INPUT_PIPED)
-@click.argument("output_path", default="")
+@click.argument('notebook_path', required=not INPUT_PIPED)
+@click.argument('output_path', default='')
@click.option(
- "--help-notebook",
+ '--help-notebook',
is_flag=True,
default=False,
- help="Display parameters information for the given notebook path.",
+ help='Display parameters information for the given notebook path.',
)
@click.option(
- "--parameters",
- "-p",
+ '--parameters',
+ '-p',
nargs=2,
multiple=True,
- help="Parameters to pass to the parameters cell.",
+ help='Parameters to pass to the parameters cell.',
)
@click.option(
- "--parameters_raw",
- "-r",
+ '--parameters_raw',
+ '-r',
nargs=2,
multiple=True,
- help="Parameters to be read as raw string.",
+ help='Parameters to be read as raw string.',
)
@click.option(
- "--parameters_file",
- "-f",
+ '--parameters_file',
+ '-f',
multiple=True,
- help="Path to YAML file containing parameters.",
+ help='Path to YAML file containing parameters.',
)
@click.option(
- "--parameters_yaml",
- "-y",
+ '--parameters_yaml',
+ '-y',
multiple=True,
- help="YAML string to be used as parameters.",
+ help='YAML string to be used as parameters.',
)
@click.option(
- "--parameters_base64",
- "-b",
+ '--parameters_base64',
+ '-b',
multiple=True,
- help="Base64 encoded YAML string as parameters.",
+ help='Base64 encoded YAML string as parameters.',
)
@click.option(
- "--inject-input-path",
+ '--inject-input-path',
is_flag=True,
default=False,
- help="Insert the path of the input notebook as PAPERMILL_INPUT_PATH as a notebook parameter.",
+ help='Insert the path of the input notebook as PAPERMILL_INPUT_PATH as a notebook parameter.',
)
@click.option(
- "--inject-output-path",
+ '--inject-output-path',
is_flag=True,
default=False,
- help="Insert the path of the output notebook as PAPERMILL_OUTPUT_PATH as a notebook parameter.",
+ help='Insert the path of the output notebook as PAPERMILL_OUTPUT_PATH as a notebook parameter.',
)
@click.option(
- "--inject-paths",
+ '--inject-paths',
is_flag=True,
default=False,
help=(
- "Insert the paths of input/output notebooks as PAPERMILL_INPUT_PATH/PAPERMILL_OUTPUT_PATH"
- " as notebook parameters."
+ 'Insert the paths of input/output notebooks as PAPERMILL_INPUT_PATH/PAPERMILL_OUTPUT_PATH'
+ ' as notebook parameters.'
),
)
+@click.option('--engine', help='The execution engine name to use in evaluating the notebook.')
@click.option(
- "--engine", help="The execution engine name to use in evaluating the notebook."
-)
-@click.option(
- "--request-save-on-cell-execute/--no-request-save-on-cell-execute",
+ '--request-save-on-cell-execute/--no-request-save-on-cell-execute',
default=True,
- help="Request save notebook after each cell execution",
+ help='Request save notebook after each cell execution',
)
@click.option(
- "--autosave-cell-every",
+ '--autosave-cell-every',
default=30,
type=int,
- help="How often in seconds to autosave the notebook during long cell executions (0 to disable)",
+ help='How often in seconds to autosave the notebook during long cell executions (0 to disable)',
)
@click.option(
- "--prepare-only/--prepare-execute",
+ '--prepare-only/--prepare-execute',
default=False,
- help="Flag for outputting the notebook without execution, but with parameters applied.",
+ help='Flag for outputting the notebook without execution, but with parameters applied.',
)
@click.option(
- "--kernel",
- "-k",
- help="Name of kernel to run. Ignores kernel name in the notebook document metadata.",
+ '--kernel',
+ '-k',
+ help='Name of kernel to run. Ignores kernel name in the notebook document metadata.',
)
@click.option(
- "--language",
- "-l",
- help="Language for notebook execution. Ignores language in the notebook document metadata.",
+ '--language',
+ '-l',
+ help='Language for notebook execution. Ignores language in the notebook document metadata.',
)
-@click.option("--cwd", default=None, help="Working directory to run notebook in.")
+@click.option('--cwd', default=None, help='Working directory to run notebook in.')
@click.option(
- "--progress-bar/--no-progress-bar",
+ '--progress-bar/--no-progress-bar',
default=None,
- help="Flag for turning on the progress bar.",
+ help='Flag for turning on the progress bar.',
)
@click.option(
- "--log-output/--no-log-output",
+ '--log-output/--no-log-output',
default=False,
- help="Flag for writing notebook output to the configured logger.",
+ help='Flag for writing notebook output to the configured logger.',
)
@click.option(
- "--stdout-file",
- type=click.File(mode="w", encoding="utf-8"),
- help="File to write notebook stdout output to.",
+ '--stdout-file',
+ type=click.File(mode='w', encoding='utf-8'),
+ help='File to write notebook stdout output to.',
)
@click.option(
- "--stderr-file",
- type=click.File(mode="w", encoding="utf-8"),
- help="File to write notebook stderr output to.",
+ '--stderr-file',
+ type=click.File(mode='w', encoding='utf-8'),
+ help='File to write notebook stderr output to.',
)
@click.option(
- "--log-level",
- type=click.Choice(["NOTSET", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]),
- default="INFO",
- help="Set log level",
+ '--log-level',
+ type=click.Choice(['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']),
+ default='INFO',
+ help='Set log level',
)
@click.option(
- "--start-timeout",
- "--start_timeout", # Backwards compatible naming
+ '--start-timeout',
+ '--start_timeout', # Backwards compatible naming
type=int,
default=60,
- help="Time in seconds to wait for kernel to start.",
+ help='Time in seconds to wait for kernel to start.',
)
@click.option(
- "--execution-timeout",
+ '--execution-timeout',
type=int,
- help="Time in seconds to wait for each cell before failing execution (default: forever)",
+ help='Time in seconds to wait for each cell before failing execution (default: forever)',
)
+@click.option('--report-mode/--no-report-mode', default=False, help='Flag for hiding input.')
@click.option(
- "--report-mode/--no-report-mode", default=False, help="Flag for hiding input."
-)
-@click.option(
- "--version",
+ '--version',
is_flag=True,
callback=print_papermill_version,
expose_value=False,
is_eager=True,
- help="Flag for displaying the version.",
+ help='Flag for displaying the version.',
)
def papermill(
click_ctx,
@@ -224,8 +214,8 @@ def papermill(
"""
# Jupyter deps use frozen modules, so we disable the python 3.11+ warning about debugger if running the CLI
- if "PYDEVD_DISABLE_FILE_VALIDATION" not in os.environ:
- os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1"
+ if 'PYDEVD_DISABLE_FILE_VALIDATION' not in os.environ:
+ os.environ['PYDEVD_DISABLE_FILE_VALIDATION'] = '1'
if not help_notebook:
required_output_path = not (INPUT_PIPED or OUTPUT_PIPED)
@@ -233,35 +223,33 @@ def papermill(
raise click.UsageError("Missing argument 'OUTPUT_PATH'")
if INPUT_PIPED and notebook_path and not output_path:
- input_path = "-"
+ input_path = '-'
output_path = notebook_path
else:
- input_path = notebook_path or "-"
- output_path = output_path or "-"
+ input_path = notebook_path or '-'
+ output_path = output_path or '-'
- if output_path == "-":
+ if output_path == '-':
# Save notebook to stdout just once
request_save_on_cell_execute = False
# Reduce default log level if we pipe to stdout
- if log_level == "INFO":
- log_level = "ERROR"
+ if log_level == 'INFO':
+ log_level = 'ERROR'
elif progress_bar is None:
progress_bar = not log_output
- logging.basicConfig(level=log_level, format="%(message)s")
+ logging.basicConfig(level=log_level, format='%(message)s')
# Read in Parameters
parameters_final = {}
if inject_input_path or inject_paths:
- parameters_final["PAPERMILL_INPUT_PATH"] = input_path
+ parameters_final['PAPERMILL_INPUT_PATH'] = input_path
if inject_output_path or inject_paths:
- parameters_final["PAPERMILL_OUTPUT_PATH"] = output_path
+ parameters_final['PAPERMILL_OUTPUT_PATH'] = output_path
for params in parameters_base64 or []:
- parameters_final.update(
- yaml.load(base64.b64decode(params), Loader=NoDatesSafeLoader) or {}
- )
+ parameters_final.update(yaml.load(base64.b64decode(params), Loader=NoDatesSafeLoader) or {})
for files in parameters_file or []:
parameters_final.update(read_yaml_file(files) or {})
for params in parameters_yaml or []:
@@ -301,11 +289,11 @@ def papermill(
def _resolve_type(value):
- if value == "True":
+ if value == 'True':
return True
- elif value == "False":
+ elif value == 'False':
return False
- elif value == "None":
+ elif value == 'None':
return None
elif _is_int(value):
return int(value)
diff --git a/papermill/clientwrap.py b/papermill/clientwrap.py
index b6718a2f..f4d4a8b2 100644
--- a/papermill/clientwrap.py
+++ b/papermill/clientwrap.py
@@ -1,5 +1,5 @@
-import sys
import asyncio
+import sys
from nbclient import NotebookClient
from nbclient.exceptions import CellExecutionError
@@ -27,9 +27,7 @@ def __init__(self, nb_man, km=None, raise_on_iopub_timeout=True, **kw):
Optional kernel manager. If none is provided, a kernel manager will
be created.
"""
- super().__init__(
- nb_man.nb, km=km, raise_on_iopub_timeout=raise_on_iopub_timeout, **kw
- )
+ super().__init__(nb_man.nb, km=km, raise_on_iopub_timeout=raise_on_iopub_timeout, **kw)
self.nb_man = nb_man
def execute(self, **kwargs):
@@ -39,18 +37,14 @@ def execute(self, **kwargs):
self.reset_execution_trackers()
# See https://bugs.python.org/issue37373 :(
- if (
- sys.version_info[0] == 3
- and sys.version_info[1] >= 8
- and sys.platform.startswith("win")
- ):
+ if sys.version_info[0] == 3 and sys.version_info[1] >= 8 and sys.platform.startswith('win'):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
with self.setup_kernel(**kwargs):
- self.log.info("Executing notebook with kernel: %s" % self.kernel_name)
+ self.log.info('Executing notebook with kernel: %s' % self.kernel_name)
self.papermill_execute_cells()
info_msg = self.wait_for_reply(self.kc.kernel_info())
- self.nb.metadata["language_info"] = info_msg["content"]["language_info"]
+ self.nb.metadata['language_info'] = info_msg['content']['language_info']
self.set_widgets_metadata()
return self.nb
@@ -77,9 +71,7 @@ def papermill_execute_cells(self):
self.nb_man.cell_start(cell, index)
self.execute_cell(cell, index)
except CellExecutionError as ex:
- self.nb_man.cell_exception(
- self.nb.cells[index], cell_index=index, exception=ex
- )
+ self.nb_man.cell_exception(self.nb.cells[index], cell_index=index, exception=ex)
break
finally:
self.nb_man.cell_complete(self.nb.cells[index], cell_index=index)
@@ -92,23 +84,23 @@ def log_output_message(self, output):
:param output: nbformat.notebooknode.NotebookNode
:return:
"""
- if output.output_type == "stream":
- content = "".join(output.text)
- if output.name == "stdout":
+ if output.output_type == 'stream':
+ content = ''.join(output.text)
+ if output.name == 'stdout':
if self.log_output:
self.log.info(content)
if self.stdout_file:
self.stdout_file.write(content)
self.stdout_file.flush()
- elif output.name == "stderr":
+ elif output.name == 'stderr':
if self.log_output:
# In case users want to redirect stderr differently, pipe to warning
self.log.warning(content)
if self.stderr_file:
self.stderr_file.write(content)
self.stderr_file.flush()
- elif self.log_output and ("data" in output and "text/plain" in output.data):
- self.log.info("".join(output.data["text/plain"]))
+ elif self.log_output and ('data' in output and 'text/plain' in output.data):
+ self.log.info(''.join(output.data['text/plain']))
def process_message(self, *arg, **kwargs):
output = super().process_message(*arg, **kwargs)
diff --git a/papermill/engines.py b/papermill/engines.py
index 5200ff7d..3e87f52b 100644
--- a/papermill/engines.py
+++ b/papermill/engines.py
@@ -1,16 +1,16 @@
"""Engines to perform different roles"""
-import sys
import datetime
-import dateutil
-
+import sys
from functools import wraps
+
+import dateutil
import entrypoints
-from .log import logger
-from .exceptions import PapermillException
from .clientwrap import PapermillNotebookClient
+from .exceptions import PapermillException
from .iorw import write_ipynb
-from .utils import merge_kwargs, remove_args, nb_kernel_name, nb_language
+from .log import logger
+from .utils import merge_kwargs, nb_kernel_name, nb_language, remove_args
class PapermillEngines:
@@ -33,7 +33,7 @@ def register_entry_points(self):
Load handlers provided by other packages
"""
- for entrypoint in entrypoints.get_group_all("papermill.engine"):
+ for entrypoint in entrypoints.get_group_all('papermill.engine'):
self.register(entrypoint.name, entrypoint.load())
def get_engine(self, name=None):
@@ -69,7 +69,7 @@ def catch_nb_assignment(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
- nb = kwargs.get("nb")
+ nb = kwargs.get('nb')
if nb:
# Reassign if executing notebook object was replaced
self.nb = nb
@@ -90,10 +90,10 @@ class NotebookExecutionManager:
shared manner.
"""
- PENDING = "pending"
- RUNNING = "running"
- COMPLETED = "completed"
- FAILED = "failed"
+ PENDING = 'pending'
+ RUNNING = 'running'
+ COMPLETED = 'completed'
+ FAILED = 'failed'
def __init__(
self,
@@ -110,15 +110,13 @@ def __init__(
self.end_time = None
self.autosave_cell_every = autosave_cell_every
self.max_autosave_pct = 25
- self.last_save_time = (
- self.now()
- ) # Not exactly true, but simplifies testing logic
+ self.last_save_time = self.now() # Not exactly true, but simplifies testing logic
self.pbar = None
if progress_bar:
# lazy import due to implict slow ipython import
from tqdm.auto import tqdm
- self.pbar = tqdm(total=len(self.nb.cells), unit="cell", desc="Executing")
+ self.pbar = tqdm(total=len(self.nb.cells), unit='cell', desc='Executing')
def now(self):
"""Helper to return current UTC time"""
@@ -169,7 +167,7 @@ def autosave_cell(self):
# Autosave is taking too long, so exponentially back off.
self.autosave_cell_every *= 2
logger.warning(
- "Autosave too slow: {:.2f} sec, over {}% limit. Backing off to {} sec".format(
+ 'Autosave too slow: {:.2f} sec, over {}% limit. Backing off to {} sec'.format(
save_elapsed, self.max_autosave_pct, self.autosave_cell_every
)
)
@@ -187,14 +185,14 @@ def notebook_start(self, **kwargs):
"""
self.set_timer()
- self.nb.metadata.papermill["start_time"] = self.start_time.isoformat()
- self.nb.metadata.papermill["end_time"] = None
- self.nb.metadata.papermill["duration"] = None
- self.nb.metadata.papermill["exception"] = None
+ self.nb.metadata.papermill['start_time'] = self.start_time.isoformat()
+ self.nb.metadata.papermill['end_time'] = None
+ self.nb.metadata.papermill['duration'] = None
+ self.nb.metadata.papermill['exception'] = None
for cell in self.nb.cells:
# Reset the cell execution counts.
- if cell.get("cell_type") == "code":
+ if cell.get('cell_type') == 'code':
cell.execution_count = None
# Clear out the papermill metadata for each cell.
@@ -205,7 +203,7 @@ def notebook_start(self, **kwargs):
duration=None,
status=self.PENDING, # pending, running, completed
)
- if cell.get("cell_type") == "code":
+ if cell.get('cell_type') == 'code':
cell.outputs = []
self.save()
@@ -219,17 +217,17 @@ def cell_start(self, cell, cell_index=None, **kwargs):
metadata for a cell and save the notebook to the output path.
"""
if self.log_output:
- ceel_num = cell_index + 1 if cell_index is not None else ""
- logger.info(f"Executing Cell {ceel_num:-<40}")
+ ceel_num = cell_index + 1 if cell_index is not None else ''
+ logger.info(f'Executing Cell {ceel_num:-<40}')
- cell.metadata.papermill["start_time"] = self.now().isoformat()
- cell.metadata.papermill["status"] = self.RUNNING
- cell.metadata.papermill["exception"] = False
+ cell.metadata.papermill['start_time'] = self.now().isoformat()
+ cell.metadata.papermill['status'] = self.RUNNING
+ cell.metadata.papermill['exception'] = False
# injects optional description of the current cell directly in the tqdm
cell_description = self.get_cell_description(cell)
- if cell_description is not None and hasattr(self, "pbar") and self.pbar:
- self.pbar.set_description(f"Executing {cell_description}")
+ if cell_description is not None and hasattr(self, 'pbar') and self.pbar:
+ self.pbar.set_description(f'Executing {cell_description}')
self.save()
@@ -242,9 +240,9 @@ def cell_exception(self, cell, cell_index=None, **kwargs):
set the metadata on the notebook indicating the location of the
failure.
"""
- cell.metadata.papermill["exception"] = True
- cell.metadata.papermill["status"] = self.FAILED
- self.nb.metadata.papermill["exception"] = True
+ cell.metadata.papermill['exception'] = True
+ cell.metadata.papermill['status'] = self.FAILED
+ self.nb.metadata.papermill['exception'] = True
@catch_nb_assignment
def cell_complete(self, cell, cell_index=None, **kwargs):
@@ -257,20 +255,18 @@ def cell_complete(self, cell, cell_index=None, **kwargs):
end_time = self.now()
if self.log_output:
- ceel_num = cell_index + 1 if cell_index is not None else ""
- logger.info(f"Ending Cell {ceel_num:-<43}")
+ ceel_num = cell_index + 1 if cell_index is not None else ''
+ logger.info(f'Ending Cell {ceel_num:-<43}')
# Ensure our last cell messages are not buffered by python
sys.stdout.flush()
sys.stderr.flush()
- cell.metadata.papermill["end_time"] = end_time.isoformat()
- if cell.metadata.papermill.get("start_time"):
- start_time = dateutil.parser.parse(cell.metadata.papermill["start_time"])
- cell.metadata.papermill["duration"] = (
- end_time - start_time
- ).total_seconds()
- if cell.metadata.papermill["status"] != self.FAILED:
- cell.metadata.papermill["status"] = self.COMPLETED
+ cell.metadata.papermill['end_time'] = end_time.isoformat()
+ if cell.metadata.papermill.get('start_time'):
+ start_time = dateutil.parser.parse(cell.metadata.papermill['start_time'])
+ cell.metadata.papermill['duration'] = (end_time - start_time).total_seconds()
+ if cell.metadata.papermill['status'] != self.FAILED:
+ cell.metadata.papermill['status'] = self.COMPLETED
self.save()
if self.pbar:
@@ -285,18 +281,16 @@ def notebook_complete(self, **kwargs):
Called by Engine when execution concludes, regardless of exceptions.
"""
self.end_time = self.now()
- self.nb.metadata.papermill["end_time"] = self.end_time.isoformat()
- if self.nb.metadata.papermill.get("start_time"):
- self.nb.metadata.papermill["duration"] = (
- self.end_time - self.start_time
- ).total_seconds()
+ self.nb.metadata.papermill['end_time'] = self.end_time.isoformat()
+ if self.nb.metadata.papermill.get('start_time'):
+ self.nb.metadata.papermill['duration'] = (self.end_time - self.start_time).total_seconds()
# Cleanup cell statuses in case callbacks were never called
for cell in self.nb.cells:
- if cell.metadata.papermill["status"] == self.FAILED:
+ if cell.metadata.papermill['status'] == self.FAILED:
break
- elif cell.metadata.papermill["status"] == self.PENDING:
- cell.metadata.papermill["status"] = self.COMPLETED
+ elif cell.metadata.papermill['status'] == self.PENDING:
+ cell.metadata.papermill['status'] = self.COMPLETED
self.complete_pbar()
self.cleanup_pbar()
@@ -304,12 +298,12 @@ def notebook_complete(self, **kwargs):
# Force a final sync
self.save()
- def get_cell_description(self, cell, escape_str="papermill_description="):
+ def get_cell_description(self, cell, escape_str='papermill_description='):
"""Fetches cell description if present"""
if cell is None:
return None
- cell_code = cell["source"]
+ cell_code = cell['source']
if cell_code is None or escape_str not in cell_code:
return None
@@ -317,13 +311,13 @@ def get_cell_description(self, cell, escape_str="papermill_description="):
def complete_pbar(self):
"""Refresh progress bar"""
- if hasattr(self, "pbar") and self.pbar:
+ if hasattr(self, 'pbar') and self.pbar:
self.pbar.n = len(self.nb.cells)
self.pbar.refresh()
def cleanup_pbar(self):
"""Clean up a progress bar"""
- if hasattr(self, "pbar") and self.pbar:
+ if hasattr(self, 'pbar') and self.pbar:
self.pbar.close()
self.pbar = None
@@ -371,9 +365,7 @@ def execute_notebook(
nb_man.notebook_start()
try:
- cls.execute_managed_notebook(
- nb_man, kernel_name, log_output=log_output, **kwargs
- )
+ cls.execute_managed_notebook(nb_man, kernel_name, log_output=log_output, **kwargs)
finally:
nb_man.cleanup_pbar()
nb_man.notebook_complete()
@@ -383,9 +375,7 @@ def execute_notebook(
@classmethod
def execute_managed_notebook(cls, nb_man, kernel_name, **kwargs):
"""An abstract method where implementation will be defined in a subclass."""
- raise NotImplementedError(
- "'execute_managed_notebook' is not implemented for this engine"
- )
+ raise NotImplementedError("'execute_managed_notebook' is not implemented for this engine")
@classmethod
def nb_kernel_name(cls, nb, name=None):
@@ -431,12 +421,12 @@ def execute_managed_notebook(
"""
# Exclude parameters that named differently downstream
- safe_kwargs = remove_args(["timeout", "startup_timeout"], **kwargs)
+ safe_kwargs = remove_args(['timeout', 'startup_timeout'], **kwargs)
# Nicely handle preprocessor arguments prioritizing values set by engine
final_kwargs = merge_kwargs(
safe_kwargs,
- timeout=execution_timeout if execution_timeout else kwargs.get("timeout"),
+ timeout=execution_timeout if execution_timeout else kwargs.get('timeout'),
startup_timeout=start_timeout,
kernel_name=kernel_name,
log=logger,
@@ -450,5 +440,5 @@ def execute_managed_notebook(
# Instantiate a PapermillEngines instance, register Handlers and entrypoints
papermill_engines = PapermillEngines()
papermill_engines.register(None, NBClientEngine)
-papermill_engines.register("nbclient", NBClientEngine)
+papermill_engines.register('nbclient', NBClientEngine)
papermill_engines.register_entry_points()
diff --git a/papermill/exceptions.py b/papermill/exceptions.py
index 38aab7e8..f78f95f7 100644
--- a/papermill/exceptions.py
+++ b/papermill/exceptions.py
@@ -33,10 +33,10 @@ def __str__(self):
# when called with str(). In order to maintain compatability with previous versions which
# passed only the message to the superclass constructor, __str__ method is implemented to
# provide the same result as was produced in the past.
- message = "\n" + 75 * "-" + "\n"
+ message = '\n' + 75 * '-' + '\n'
message += 'Exception encountered at "In [%s]":\n' % str(self.exec_count)
- message += "\n".join(self.traceback)
- message += "\n"
+ message += '\n'.join(self.traceback)
+ message += '\n'
return message
@@ -59,10 +59,8 @@ class PapermillParameterOverwriteWarning(PapermillWarning):
def missing_dependency_generator(package, dep):
def missing_dep():
raise PapermillOptionalDependencyException(
- "The {package} optional dependency is missing. "
- "Please run pip install papermill[{dep}] to install this dependency".format(
- package=package, dep=dep
- )
+ f'The {package} optional dependency is missing. '
+ f'Please run pip install papermill[{dep}] to install this dependency'
)
return missing_dep
@@ -71,11 +69,9 @@ def missing_dep():
def missing_environment_variable_generator(package, env_key):
def missing_dep():
raise PapermillOptionalDependencyException(
- "The {package} optional dependency is present, but the environment "
- "variable {env_key} is not set. Please set this variable as "
- "required by {package} on your platform.".format(
- package=package, env_key=env_key
- )
+ f'The {package} optional dependency is present, but the environment '
+ f'variable {env_key} is not set. Please set this variable as '
+ f'required by {package} on your platform.'
)
return missing_dep
diff --git a/papermill/execute.py b/papermill/execute.py
index 3d0d23ae..1b683918 100644
--- a/papermill/execute.py
+++ b/papermill/execute.py
@@ -1,17 +1,18 @@
-import nbformat
from pathlib import Path
-from .log import logger
-from .exceptions import PapermillExecutionError
-from .iorw import get_pretty_path, local_file_io_cwd, load_notebook_node, write_ipynb
+import nbformat
+
from .engines import papermill_engines
-from .utils import chdir
+from .exceptions import PapermillExecutionError
+from .inspection import _infer_parameters
+from .iorw import get_pretty_path, load_notebook_node, local_file_io_cwd, write_ipynb
+from .log import logger
from .parameterize import (
add_builtin_parameters,
parameterize_notebook,
parameterize_path,
)
-from .inspection import _infer_parameters
+from .utils import chdir
def execute_notebook(
@@ -83,23 +84,21 @@ def execute_notebook(
input_path = parameterize_path(input_path, path_parameters)
output_path = parameterize_path(output_path, path_parameters)
- logger.info("Input Notebook: %s" % get_pretty_path(input_path))
- logger.info("Output Notebook: %s" % get_pretty_path(output_path))
+ logger.info('Input Notebook: %s' % get_pretty_path(input_path))
+ logger.info('Output Notebook: %s' % get_pretty_path(output_path))
with local_file_io_cwd():
if cwd is not None:
- logger.info(f"Working directory: {get_pretty_path(cwd)}")
+ logger.info(f'Working directory: {get_pretty_path(cwd)}')
nb = load_notebook_node(input_path)
# Parameterize the Notebook.
if parameters:
- parameter_predefined = _infer_parameters(
- nb, name=kernel_name, language=language
- )
+ parameter_predefined = _infer_parameters(nb, name=kernel_name, language=language)
parameter_predefined = {p.name for p in parameter_predefined}
for p in parameters:
if p not in parameter_predefined:
- logger.warning(f"Passed unknown parameter: {p}")
+ logger.warning(f'Passed unknown parameter: {p}')
nb = parameterize_notebook(
nb,
parameters,
@@ -115,9 +114,7 @@ def execute_notebook(
if not prepare_only:
# Dropdown to the engine to fetch the kernel name from the notebook document
- kernel_name = papermill_engines.nb_kernel_name(
- engine_name=engine_name, nb=nb, name=kernel_name
- )
+ kernel_name = papermill_engines.nb_kernel_name(engine_name=engine_name, nb=nb, name=kernel_name)
# Execute the Notebook in `cwd` if it is set
with chdir(cwd):
nb = papermill_engines.execute_notebook_with_engine(
@@ -160,40 +157,36 @@ def prepare_notebook_metadata(nb, input_path, output_path, report_mode=False):
# Hide input if report-mode is set to True.
if report_mode:
for cell in nb.cells:
- if cell.cell_type == "code":
- cell.metadata["jupyter"] = cell.get("jupyter", {})
- cell.metadata["jupyter"]["source_hidden"] = True
+ if cell.cell_type == 'code':
+ cell.metadata['jupyter'] = cell.get('jupyter', {})
+ cell.metadata['jupyter']['source_hidden'] = True
# Record specified environment variable values.
- nb.metadata.papermill["input_path"] = input_path
- nb.metadata.papermill["output_path"] = output_path
+ nb.metadata.papermill['input_path'] = input_path
+ nb.metadata.papermill['output_path'] = output_path
return nb
-ERROR_MARKER_TAG = "papermill-error-cell-tag"
+ERROR_MARKER_TAG = 'papermill-error-cell-tag'
ERROR_STYLE = 'style="color:red; font-family:Helvetica Neue, Helvetica, Arial, sans-serif; font-size:2em;"'
ERROR_MESSAGE_TEMPLATE = (
- ""
- "An Exception was encountered at 'In [%s]'."
- ""
+ ''
+ 'An Exception was encountered at \'In [%s]\'.'
+ ''
)
ERROR_ANCHOR_MSG = (
- '"
- "Execution using papermill encountered an exception here and stopped:"
- ""
+ ''
+ 'Execution using papermill encountered an exception here and stopped:'
+ ''
)
def remove_error_markers(nb):
- nb.cells = [
- cell
- for cell in nb.cells
- if ERROR_MARKER_TAG not in cell.metadata.get("tags", [])
- ]
+ nb.cells = [cell for cell in nb.cells if ERROR_MARKER_TAG not in cell.metadata.get('tags', [])]
return nb
@@ -209,14 +202,12 @@ def raise_for_execution_errors(nb, output_path):
"""
error = None
for index, cell in enumerate(nb.cells):
- if cell.get("outputs") is None:
+ if cell.get('outputs') is None:
continue
for output in cell.outputs:
- if output.output_type == "error":
- if output.ename == "SystemExit" and (
- output.evalue == "" or output.evalue == "0"
- ):
+ if output.output_type == 'error':
+ if output.ename == 'SystemExit' and (output.evalue == '' or output.evalue == '0'):
continue
error = PapermillExecutionError(
cell_index=index,
@@ -233,9 +224,9 @@ def raise_for_execution_errors(nb, output_path):
# the relevant cell (by adding a note just before the failure with an HTML anchor)
error_msg = ERROR_MESSAGE_TEMPLATE % str(error.exec_count)
error_msg_cell = nbformat.v4.new_markdown_cell(error_msg)
- error_msg_cell.metadata["tags"] = [ERROR_MARKER_TAG]
+ error_msg_cell.metadata['tags'] = [ERROR_MARKER_TAG]
error_anchor_cell = nbformat.v4.new_markdown_cell(ERROR_ANCHOR_MSG)
- error_anchor_cell.metadata["tags"] = [ERROR_MARKER_TAG]
+ error_anchor_cell.metadata['tags'] = [ERROR_MARKER_TAG]
# Upgrade the Notebook to the latest v4 before writing into it
nb = nbformat.v4.upgrade(nb)
diff --git a/papermill/inspection.py b/papermill/inspection.py
index b1ec68f7..db5a6136 100644
--- a/papermill/inspection.py
+++ b/papermill/inspection.py
@@ -1,7 +1,8 @@
"""Deduce parameters of a notebook from the parameters cell."""
-import click
from pathlib import Path
+import click
+
from .iorw import get_pretty_path, load_notebook_node, local_file_io_cwd
from .log import logger
from .parameterize import add_builtin_parameters, parameterize_path
@@ -17,7 +18,7 @@
def _open_notebook(notebook_path, parameters):
path_parameters = add_builtin_parameters(parameters)
input_path = parameterize_path(notebook_path, path_parameters)
- logger.info("Input Notebook: %s" % get_pretty_path(input_path))
+ logger.info('Input Notebook: %s' % get_pretty_path(input_path))
with local_file_io_cwd():
return load_notebook_node(input_path)
@@ -38,7 +39,7 @@ def _infer_parameters(nb, name=None, language=None):
"""
params = []
- parameter_cell_idx = find_first_tagged_cell_index(nb, "parameters")
+ parameter_cell_idx = find_first_tagged_cell_index(nb, 'parameters')
if parameter_cell_idx < 0:
return params
parameter_cell = nb.cells[parameter_cell_idx]
@@ -50,11 +51,7 @@ def _infer_parameters(nb, name=None, language=None):
try:
params = translator.inspect(parameter_cell)
except NotImplementedError:
- logger.warning(
- "Translator for '{}' language does not support parameter introspection.".format(
- language
- )
- )
+ logger.warning(f"Translator for '{language}' language does not support parameter introspection.")
return params
@@ -74,7 +71,7 @@ def display_notebook_help(ctx, notebook_path, parameters):
pretty_path = get_pretty_path(notebook_path)
click.echo(f"\nParameters inferred for notebook '{pretty_path}':")
- if not any_tagged_cell(nb, "parameters"):
+ if not any_tagged_cell(nb, 'parameters'):
click.echo("\n No cell tagged 'parameters'")
return 1
@@ -82,25 +79,22 @@ def display_notebook_help(ctx, notebook_path, parameters):
if params:
for param in params:
p = param._asdict()
- type_repr = p["inferred_type_name"]
- if type_repr == "None":
- type_repr = "Unknown type"
+ type_repr = p['inferred_type_name']
+ if type_repr == 'None':
+ type_repr = 'Unknown type'
- definition = " {}: {} (default {})".format(
- p["name"], type_repr, p["default"]
- )
+ definition = ' {}: {} (default {})'.format(p['name'], type_repr, p['default'])
if len(definition) > 30:
- if len(p["help"]):
- param_help = "".join((definition, "\n", 34 * " ", p["help"]))
+ if len(p['help']):
+ param_help = ''.join((definition, '\n', 34 * ' ', p['help']))
else:
param_help = definition
else:
- param_help = "{:<34}{}".format(definition, p["help"])
+ param_help = '{:<34}{}'.format(definition, p['help'])
click.echo(param_help)
else:
click.echo(
- "\n Can't infer anything about this notebook's parameters. "
- "It may not have any parameter defined."
+ "\n Can't infer anything about this notebook's parameters. " 'It may not have any parameter defined.'
)
return 0
diff --git a/papermill/iorw.py b/papermill/iorw.py
index 961ee207..ecca680f 100644
--- a/papermill/iorw.py
+++ b/papermill/iorw.py
@@ -1,15 +1,14 @@
+import fnmatch
+import json
import os
import sys
-import json
-import yaml
-import fnmatch
-import nbformat
-import requests
import warnings
-import entrypoints
-
from contextlib import contextmanager
+import entrypoints
+import nbformat
+import requests
+import yaml
from tenacity import (
retry,
retry_if_exception_type,
@@ -30,37 +29,37 @@
try:
from .s3 import S3
except ImportError:
- S3 = missing_dependency_generator("boto3", "s3")
+ S3 = missing_dependency_generator('boto3', 's3')
try:
from .adl import ADL
except ImportError:
- ADL = missing_dependency_generator("azure.datalake.store", "azure")
+ ADL = missing_dependency_generator('azure.datalake.store', 'azure')
except KeyError as exc:
- if exc.args[0] == "APPDATA":
- ADL = missing_environment_variable_generator("azure.datalake.store", "APPDATA")
+ if exc.args[0] == 'APPDATA':
+ ADL = missing_environment_variable_generator('azure.datalake.store', 'APPDATA')
else:
raise
try:
from .abs import AzureBlobStore
except ImportError:
- AzureBlobStore = missing_dependency_generator("azure.storage.blob", "azure")
+ AzureBlobStore = missing_dependency_generator('azure.storage.blob', 'azure')
try:
from gcsfs import GCSFileSystem
except ImportError:
- GCSFileSystem = missing_dependency_generator("gcsfs", "gcs")
+ GCSFileSystem = missing_dependency_generator('gcsfs', 'gcs')
try:
- from pyarrow.fs import HadoopFileSystem, FileSelector
+ from pyarrow.fs import FileSelector, HadoopFileSystem
except ImportError:
- HadoopFileSystem = missing_dependency_generator("pyarrow", "hdfs")
+ HadoopFileSystem = missing_dependency_generator('pyarrow', 'hdfs')
try:
from github import Github
except ImportError:
- Github = missing_dependency_generator("pygithub", "github")
+ Github = missing_dependency_generator('pygithub', 'github')
def fallback_gs_is_retriable(e):
@@ -97,14 +96,14 @@ class PapermillIO:
def __init__(self):
self.reset()
- def read(self, path, extensions=[".ipynb", ".json"]):
+ def read(self, path, extensions=['.ipynb', '.json']):
# Handle https://github.com/nteract/papermill/issues/317
notebook_metadata = self.get_handler(path, extensions).read(path)
if isinstance(notebook_metadata, (bytes, bytearray)):
- return notebook_metadata.decode("utf-8")
+ return notebook_metadata.decode('utf-8')
return notebook_metadata
- def write(self, buf, path, extensions=[".ipynb", ".json"]):
+ def write(self, buf, path, extensions=['.ipynb', '.json']):
return self.get_handler(path, extensions).write(buf, path)
def listdir(self, path):
@@ -122,7 +121,7 @@ def register(self, scheme, handler):
def register_entry_points(self):
# Load handlers provided by other packages
- for entrypoint in entrypoints.get_group_all("papermill.io"):
+ for entrypoint in entrypoints.get_group_all('papermill.io'):
self.register(entrypoint.name, entrypoint.load())
def get_handler(self, path, extensions=None):
@@ -151,31 +150,21 @@ def get_handler(self, path, extensions=None):
return NotebookNodeHandler()
if extensions:
- if not fnmatch.fnmatch(os.path.basename(path).split("?")[0], "*.*"):
- warnings.warn(
- "the file is not specified with any extension : "
- + os.path.basename(path)
- )
- elif not any(
- fnmatch.fnmatch(os.path.basename(path).split("?")[0], "*" + ext)
- for ext in extensions
- ):
- warnings.warn(
- f"The specified file ({path}) does not end in one of {extensions}"
- )
+ if not fnmatch.fnmatch(os.path.basename(path).split('?')[0], '*.*'):
+ warnings.warn('the file is not specified with any extension : ' + os.path.basename(path))
+ elif not any(fnmatch.fnmatch(os.path.basename(path).split('?')[0], '*' + ext) for ext in extensions):
+ warnings.warn(f'The specified file ({path}) does not end in one of {extensions}')
local_handler = None
for scheme, handler in self._handlers:
- if scheme == "local":
+ if scheme == 'local':
local_handler = handler
if path.startswith(scheme):
return handler
if local_handler is None:
- raise PapermillException(
- f"Could not find a registered schema handler for: {path}"
- )
+ raise PapermillException(f'Could not find a registered schema handler for: {path}')
return local_handler
@@ -183,11 +172,11 @@ def get_handler(self, path, extensions=None):
class HttpHandler:
@classmethod
def read(cls, path):
- return requests.get(path, headers={"Accept": "application/json"}).text
+ return requests.get(path, headers={'Accept': 'application/json'}).text
@classmethod
def listdir(cls, path):
- raise PapermillException("listdir is not supported by HttpHandler")
+ raise PapermillException('listdir is not supported by HttpHandler')
@classmethod
def write(cls, buf, path):
@@ -206,7 +195,7 @@ def __init__(self):
def read(self, path):
try:
with chdir(self._cwd):
- with open(path, encoding="utf-8") as f:
+ with open(path, encoding='utf-8') as f:
return f.read()
except OSError as e:
try:
@@ -227,7 +216,7 @@ def write(self, buf, path):
dirname = os.path.dirname(path)
if dirname and not os.path.exists(dirname):
raise FileNotFoundError(f"output folder {dirname} doesn't exist.")
- with open(path, "w", encoding="utf-8") as f:
+ with open(path, 'w', encoding='utf-8') as f:
f.write(buf)
def pretty_path(self, path):
@@ -243,7 +232,7 @@ def cwd(self, new_path):
class S3Handler:
@classmethod
def read(cls, path):
- return "\n".join(S3().read(path))
+ return '\n'.join(S3().read(path))
@classmethod
def listdir(cls, path):
@@ -269,7 +258,7 @@ def _get_client(self):
def read(self, path):
lines = self._get_client().read(path)
- return "\n".join(lines)
+ return '\n'.join(lines)
def listdir(self, path):
return self._get_client().listdir(path)
@@ -292,7 +281,7 @@ def _get_client(self):
def read(self, path):
lines = self._get_client().read(path)
- return "\n".join(lines)
+ return '\n'.join(lines)
def listdir(self, path):
return self._get_client().listdir(path)
@@ -339,13 +328,13 @@ def write(self, buf, path):
)
def retry_write():
try:
- with self._get_client().open(path, "w") as f:
+ with self._get_client().open(path, 'w') as f:
return f.write(buf)
except Exception as e:
try:
message = e.message
except AttributeError:
- message = f"Generic exception {type(e)} raised"
+ message = f'Generic exception {type(e)} raised'
if gs_is_retriable(e):
raise PapermillRateLimitException(message)
# Reraise the original exception without retries
@@ -363,7 +352,7 @@ def __init__(self):
def _get_client(self):
if self._client is None:
- self._client = HadoopFileSystem(host="default")
+ self._client = HadoopFileSystem(host='default')
return self._client
def read(self, path):
@@ -387,7 +376,7 @@ def __init__(self):
def _get_client(self):
if self._client is None:
- token = os.environ.get("GITHUB_ACCESS_TOKEN", None)
+ token = os.environ.get('GITHUB_ACCESS_TOKEN', None)
if token:
self._client = Github(token)
else:
@@ -395,20 +384,20 @@ def _get_client(self):
return self._client
def read(self, path):
- splits = path.split("/")
+ splits = path.split('/')
org_id = splits[3]
repo_id = splits[4]
ref_id = splits[6]
- sub_path = "/".join(splits[7:])
- repo = self._get_client().get_repo(org_id + "/" + repo_id)
+ sub_path = '/'.join(splits[7:])
+ repo = self._get_client().get_repo(org_id + '/' + repo_id)
content = repo.get_contents(sub_path, ref=ref_id)
return content.decoded_content
def listdir(self, path):
- raise PapermillException("listdir is not supported by GithubHandler")
+ raise PapermillException('listdir is not supported by GithubHandler')
def write(self, buf, path):
- raise PapermillException("write is not supported by GithubHandler")
+ raise PapermillException('write is not supported by GithubHandler')
def pretty_path(self, path):
return path
@@ -421,15 +410,15 @@ def read(self, path):
return sys.stdin.read()
def listdir(self, path):
- raise PapermillException("listdir is not supported by Stream Handler")
+ raise PapermillException('listdir is not supported by Stream Handler')
def write(self, buf, path):
try:
- return sys.stdout.buffer.write(buf.encode("utf-8"))
+ return sys.stdout.buffer.write(buf.encode('utf-8'))
except AttributeError:
# Originally required by https://github.com/nteract/papermill/issues/420
# Support Buffer.io objects
- return sys.stdout.write(buf.encode("utf-8"))
+ return sys.stdout.write(buf.encode('utf-8'))
def pretty_path(self, path):
return path
@@ -442,61 +431,59 @@ def read(self, path):
return nbformat.writes(path)
def listdir(self, path):
- raise PapermillException("listdir is not supported by NotebookNode Handler")
+ raise PapermillException('listdir is not supported by NotebookNode Handler')
def write(self, buf, path):
- raise PapermillException("write is not supported by NotebookNode Handler")
+ raise PapermillException('write is not supported by NotebookNode Handler')
def pretty_path(self, path):
- return "NotebookNode object"
+ return 'NotebookNode object'
class NoIOHandler:
"""Handler for output_path of None - intended to not write anything"""
def read(self, path):
- raise PapermillException("read is not supported by NoIOHandler")
+ raise PapermillException('read is not supported by NoIOHandler')
def listdir(self, path):
- raise PapermillException("listdir is not supported by NoIOHandler")
+ raise PapermillException('listdir is not supported by NoIOHandler')
def write(self, buf, path):
return
def pretty_path(self, path):
- return "Notebook will not be saved"
+ return 'Notebook will not be saved'
# Hack to make YAML loader not auto-convert datetimes
# https://stackoverflow.com/a/52312810
class NoDatesSafeLoader(yaml.SafeLoader):
yaml_implicit_resolvers = {
- k: [r for r in v if r[0] != "tag:yaml.org,2002:timestamp"]
+ k: [r for r in v if r[0] != 'tag:yaml.org,2002:timestamp']
for k, v in yaml.SafeLoader.yaml_implicit_resolvers.items()
}
# Instantiate a PapermillIO instance and register Handlers.
papermill_io = PapermillIO()
-papermill_io.register("local", LocalHandler())
-papermill_io.register("s3://", S3Handler)
-papermill_io.register("adl://", ADLHandler())
-papermill_io.register("abs://", ABSHandler())
-papermill_io.register("http://", HttpHandler)
-papermill_io.register("https://", HttpHandler)
-papermill_io.register("gs://", GCSHandler())
-papermill_io.register("hdfs://", HDFSHandler())
-papermill_io.register("http://github.com/", GithubHandler())
-papermill_io.register("https://github.com/", GithubHandler())
-papermill_io.register("-", StreamHandler())
+papermill_io.register('local', LocalHandler())
+papermill_io.register('s3://', S3Handler)
+papermill_io.register('adl://', ADLHandler())
+papermill_io.register('abs://', ABSHandler())
+papermill_io.register('http://', HttpHandler)
+papermill_io.register('https://', HttpHandler)
+papermill_io.register('gs://', GCSHandler())
+papermill_io.register('hdfs://', HDFSHandler())
+papermill_io.register('http://github.com/', GithubHandler())
+papermill_io.register('https://github.com/', GithubHandler())
+papermill_io.register('-', StreamHandler())
papermill_io.register_entry_points()
def read_yaml_file(path):
"""Reads a YAML file from the location specified at 'path'."""
- return yaml.load(
- papermill_io.read(path, [".json", ".yaml", ".yml"]), Loader=NoDatesSafeLoader
- )
+ return yaml.load(papermill_io.read(path, ['.json', '.yaml', '.yml']), Loader=NoDatesSafeLoader)
def write_ipynb(nb, path):
@@ -523,27 +510,27 @@ def load_notebook_node(notebook_path):
if nb_upgraded is not None:
nb = nb_upgraded
- if not hasattr(nb.metadata, "papermill"):
- nb.metadata["papermill"] = {
- "default_parameters": dict(),
- "parameters": dict(),
- "environment_variables": dict(),
- "version": __version__,
+ if not hasattr(nb.metadata, 'papermill'):
+ nb.metadata['papermill'] = {
+ 'default_parameters': dict(),
+ 'parameters': dict(),
+ 'environment_variables': dict(),
+ 'version': __version__,
}
for cell in nb.cells:
- if not hasattr(cell.metadata, "tags"):
- cell.metadata["tags"] = [] # Create tags attr if one doesn't exist.
+ if not hasattr(cell.metadata, 'tags'):
+ cell.metadata['tags'] = [] # Create tags attr if one doesn't exist.
- if not hasattr(cell.metadata, "papermill"):
- cell.metadata["papermill"] = dict()
+ if not hasattr(cell.metadata, 'papermill'):
+ cell.metadata['papermill'] = dict()
return nb
def list_notebook_files(path):
"""Returns a list of all the notebook files in a directory."""
- return [p for p in papermill_io.listdir(path) if p.endswith(".ipynb")]
+ return [p for p in papermill_io.listdir(path) if p.endswith('.ipynb')]
def get_pretty_path(path):
@@ -553,14 +540,14 @@ def get_pretty_path(path):
@contextmanager
def local_file_io_cwd(path=None):
try:
- local_handler = papermill_io.get_handler("local")
+ local_handler = papermill_io.get_handler('local')
except PapermillException:
- logger.warning("No local file handler detected")
+ logger.warning('No local file handler detected')
else:
try:
old_cwd = local_handler.cwd(path or os.getcwd())
except AttributeError:
- logger.warning("Local file handler does not support cwd assignment")
+ logger.warning('Local file handler does not support cwd assignment')
else:
try:
yield
diff --git a/papermill/log.py b/papermill/log.py
index 273bc8f3..b90225d2 100644
--- a/papermill/log.py
+++ b/papermill/log.py
@@ -1,4 +1,4 @@
"""Sets up a logger"""
import logging
-logger = logging.getLogger("papermill")
+logger = logging.getLogger('papermill')
diff --git a/papermill/models.py b/papermill/models.py
index fcbb627f..35c077e5 100644
--- a/papermill/models.py
+++ b/papermill/models.py
@@ -2,11 +2,11 @@
from collections import namedtuple
Parameter = namedtuple(
- "Parameter",
+ 'Parameter',
[
- "name",
- "inferred_type_name", # string of type
- "default", # string representing the default value
- "help",
+ 'name',
+ 'inferred_type_name', # string of type
+ 'default', # string representing the default value
+ 'help',
],
)
diff --git a/papermill/parameterize.py b/papermill/parameterize.py
index db3ac837..a210f26e 100644
--- a/papermill/parameterize.py
+++ b/papermill/parameterize.py
@@ -1,15 +1,15 @@
+from datetime import datetime
+from uuid import uuid4
+
import nbformat
from .engines import papermill_engines
-from .log import logger
from .exceptions import PapermillMissingParameterException
from .iorw import read_yaml_file
+from .log import logger
from .translators import translate_parameters
from .utils import find_first_tagged_cell_index
-from uuid import uuid4
-from datetime import datetime
-
def add_builtin_parameters(parameters):
"""Add built-in parameters to a dictionary of parameters
@@ -20,10 +20,10 @@ def add_builtin_parameters(parameters):
Dictionary of parameters provided by the user
"""
with_builtin_parameters = {
- "pm": {
- "run_uuid": str(uuid4()),
- "current_datetime_local": datetime.now(),
- "current_datetime_utc": datetime.utcnow(),
+ 'pm': {
+ 'run_uuid': str(uuid4()),
+ 'current_datetime_local': datetime.now(),
+ 'current_datetime_utc': datetime.utcnow(),
}
}
@@ -53,14 +53,14 @@ def parameterize_path(path, parameters):
try:
return path.format(**parameters)
except KeyError as key_error:
- raise PapermillMissingParameterException(f"Missing parameter {key_error}")
+ raise PapermillMissingParameterException(f'Missing parameter {key_error}')
def parameterize_notebook(
nb,
parameters,
report_mode=False,
- comment="Parameters",
+ comment='Parameters',
kernel_name=None,
language=None,
engine_name=None,
@@ -93,14 +93,14 @@ def parameterize_notebook(
nb = nbformat.v4.upgrade(nb)
newcell = nbformat.v4.new_code_cell(source=param_content)
- newcell.metadata["tags"] = ["injected-parameters"]
+ newcell.metadata['tags'] = ['injected-parameters']
if report_mode:
- newcell.metadata["jupyter"] = newcell.get("jupyter", {})
- newcell.metadata["jupyter"]["source_hidden"] = True
+ newcell.metadata['jupyter'] = newcell.get('jupyter', {})
+ newcell.metadata['jupyter']['source_hidden'] = True
- param_cell_index = find_first_tagged_cell_index(nb, "parameters")
- injected_cell_index = find_first_tagged_cell_index(nb, "injected-parameters")
+ param_cell_index = find_first_tagged_cell_index(nb, 'parameters')
+ injected_cell_index = find_first_tagged_cell_index(nb, 'injected-parameters')
if injected_cell_index >= 0:
# Replace the injected cell with a new version
before = nb.cells[:injected_cell_index]
@@ -116,6 +116,6 @@ def parameterize_notebook(
after = nb.cells
nb.cells = before + [newcell] + after
- nb.metadata.papermill["parameters"] = parameters
+ nb.metadata.papermill['parameters'] = parameters
return nb
diff --git a/papermill/s3.py b/papermill/s3.py
index ccd2141a..06ac9aff 100644
--- a/papermill/s3.py
+++ b/papermill/s3.py
@@ -1,8 +1,7 @@
"""Utilities for working with S3."""
-import os
-
import logging
+import os
import threading
import zlib
@@ -11,8 +10,7 @@
from .exceptions import AwsError
from .utils import retry
-
-logger = logging.getLogger("papermill.s3")
+logger = logging.getLogger('papermill.s3')
class Bucket:
@@ -32,11 +30,9 @@ def __init__(self, name, service=None):
self.name = name
self.service = service
- def list(self, prefix="", delimiter=None):
+ def list(self, prefix='', delimiter=None):
"""Limits a list of Bucket's objects based on prefix and delimiter."""
- return self.service._list(
- bucket=self.name, prefix=prefix, delimiter=delimiter, objects=True
- )
+ return self.service._list(bucket=self.name, prefix=prefix, delimiter=delimiter, objects=True)
class Prefix:
@@ -61,7 +57,7 @@ def __init__(self, bucket, name, service=None):
self.service = service
def __str__(self):
- return f"s3://{self.bucket.name}/{self.name}"
+ return f's3://{self.bucket.name}/{self.name}'
def __repr__(self):
return self.__str__()
@@ -106,7 +102,7 @@ def __init__(
self.etag = etag
if last_modified:
try:
- self.last_modified = last_modified.isoformat().split("+")[0] + ".000Z"
+ self.last_modified = last_modified.isoformat().split('+')[0] + '.000Z'
except ValueError:
self.last_modified = last_modified
self.storage_class = storage_class
@@ -114,7 +110,7 @@ def __init__(
self.service = service
def __str__(self):
- return f"s3://{self.bucket.name}/{self.name}"
+ return f's3://{self.bucket.name}/{self.name}'
def __repr__(self):
return self.__str__()
@@ -146,47 +142,45 @@ def __init__(self, keyname=None, *args, **kwargs):
with self.lock:
if not all(S3.s3_session):
session = Session()
- client = session.client("s3")
+ client = session.client('s3')
session_params = {}
- endpoint_url = os.environ.get("BOTO3_ENDPOINT_URL", None)
+ endpoint_url = os.environ.get('BOTO3_ENDPOINT_URL', None)
if endpoint_url:
- session_params["endpoint_url"] = endpoint_url
+ session_params['endpoint_url'] = endpoint_url
- s3 = session.resource("s3", **session_params)
+ s3 = session.resource('s3', **session_params)
S3.s3_session = (session, client, s3)
(self.session, self.client, self.s3) = S3.s3_session
def _bucket_name(self, bucket):
- return self._clean(bucket).split("/", 1)[0]
+ return self._clean(bucket).split('/', 1)[0]
def _clean(self, name):
- if name.startswith("s3n:"):
- name = "s3:" + name[4:]
+ if name.startswith('s3n:'):
+ name = 's3:' + name[4:]
if self._is_s3(name):
return name[5:]
return name
def _clean_s3(self, name):
- return "s3:" + name[4:] if name.startswith("s3n:") else name
+ return 's3:' + name[4:] if name.startswith('s3n:') else name
def _get_key(self, name):
if isinstance(name, Key):
return name
- return Key(
- bucket=self._bucket_name(name), name=self._key_name(name), service=self
- )
+ return Key(bucket=self._bucket_name(name), name=self._key_name(name), service=self)
def _key_name(self, name):
- cleaned = self._clean(name).split("/", 1)
+ cleaned = self._clean(name).split('/', 1)
return cleaned[1] if len(cleaned) > 1 else None
@retry(3)
def _list(
self,
- prefix="",
+ prefix='',
bucket=None,
delimiter=None,
keys=False,
@@ -194,55 +188,55 @@ def _list(
page_size=1000,
**kwargs,
):
- assert bucket is not None, "You must specify a bucket to list"
+ assert bucket is not None, 'You must specify a bucket to list'
bucket = self._bucket_name(bucket)
- paginator = self.client.get_paginator("list_objects_v2")
+ paginator = self.client.get_paginator('list_objects_v2')
operation_parameters = {
- "Bucket": bucket,
- "Prefix": prefix,
- "PaginationConfig": {"PageSize": page_size},
+ 'Bucket': bucket,
+ 'Prefix': prefix,
+ 'PaginationConfig': {'PageSize': page_size},
}
if delimiter:
- operation_parameters["Delimiter"] = delimiter
+ operation_parameters['Delimiter'] = delimiter
page_iterator = paginator.paginate(**operation_parameters)
def sort(item):
- if "Key" in item:
- return item["Key"]
- return item["Prefix"]
+ if 'Key' in item:
+ return item['Key']
+ return item['Prefix']
for page in page_iterator:
locations = sorted(
- [i for i in page.get("Contents", []) + page.get("CommonPrefixes", [])],
+ [i for i in page.get('Contents', []) + page.get('CommonPrefixes', [])],
key=sort,
)
for item in locations:
if objects or keys:
- if "Key" in item:
+ if 'Key' in item:
yield Key(
bucket,
- item["Key"],
- size=item.get("Size"),
- etag=item.get("ETag"),
- last_modified=item.get("LastModified"),
- storage_class=item.get("StorageClass"),
+ item['Key'],
+ size=item.get('Size'),
+ etag=item.get('ETag'),
+ last_modified=item.get('LastModified'),
+ storage_class=item.get('StorageClass'),
service=self,
)
elif objects:
- yield Prefix(bucket, item["Prefix"], service=self)
+ yield Prefix(bucket, item['Prefix'], service=self)
else:
- prefix = item["Key"] if "Key" in item else item["Prefix"]
- yield f"s3://{bucket}/{prefix}"
+ prefix = item['Key'] if 'Key' in item else item['Prefix']
+ yield f's3://{bucket}/{prefix}'
def _put(
self,
source,
dest,
num_callbacks=10,
- policy="bucket-owner-full-control",
+ policy='bucket-owner-full-control',
**kwargs,
):
key = self._get_key(dest)
@@ -251,9 +245,9 @@ def _put(
# support passing in open file obj. Why did we do this in the past?
if not isinstance(source, str):
- obj.upload_fileobj(source, ExtraArgs={"ACL": policy})
+ obj.upload_fileobj(source, ExtraArgs={'ACL': policy})
else:
- obj.upload_file(source, ExtraArgs={"ACL": policy})
+ obj.upload_file(source, ExtraArgs={'ACL': policy})
return key
def _put_string(
@@ -261,14 +255,14 @@ def _put_string(
source,
dest,
num_callbacks=10,
- policy="bucket-owner-full-control",
+ policy='bucket-owner-full-control',
**kwargs,
):
key = self._get_key(dest)
obj = self.s3.Object(key.bucket.name, key.name)
if isinstance(source, str):
- source = source.encode("utf-8")
+ source = source.encode('utf-8')
obj.put(Body=source, ACL=policy)
return key
@@ -278,7 +272,7 @@ def _is_s3(self, name):
return False
name = self._clean_s3(name)
- return "s3://" in name
+ return 's3://' in name
def cat(
self,
@@ -286,7 +280,7 @@ def cat(
buffersize=None,
memsize=2**24,
compressed=False,
- encoding="UTF-8",
+ encoding='UTF-8',
raw=False,
):
"""
@@ -296,19 +290,17 @@ def cat(
skip encoding.
"""
- assert self._is_s3(source) or isinstance(
- source, Key
- ), "source must be a valid s3 path"
+ assert self._is_s3(source) or isinstance(source, Key), 'source must be a valid s3 path'
key = self._get_key(source) if not isinstance(source, Key) else source
- compressed = (compressed or key.name.endswith(".gz")) and not raw
+ compressed = (compressed or key.name.endswith('.gz')) and not raw
if compressed:
decompress = zlib.decompressobj(16 + zlib.MAX_WBITS)
size = 0
bytes_read = 0
err = None
- undecoded = ""
+ undecoded = ''
if key:
# try to read the file multiple times
for i in range(100):
@@ -318,7 +310,7 @@ def cat(
if not size:
size = obj.content_length
elif size != obj.content_length:
- raise AwsError("key size unexpectedly changed while reading")
+ raise AwsError('key size unexpectedly changed while reading')
# For an empty file, 0 (first-bytes-pos) is equal to the length of the object
# hence the range is "unsatisfiable", and botocore correctly handles it by
@@ -326,16 +318,16 @@ def cat(
if size == 0:
break
- r = obj.get(Range=f"bytes={bytes_read}-")
+ r = obj.get(Range=f'bytes={bytes_read}-')
try:
while bytes_read < size:
# this making this weird check because this call is
# about 100 times slower if the amt is too high
if size - bytes_read > buffersize:
- bytes = r["Body"].read(amt=buffersize)
+ bytes = r['Body'].read(amt=buffersize)
else:
- bytes = r["Body"].read()
+ bytes = r['Body'].read()
if compressed:
s = decompress.decompress(bytes)
else:
@@ -344,7 +336,7 @@ def cat(
if encoding and not raw:
try:
decoded = undecoded + s.decode(encoding)
- undecoded = ""
+ undecoded = ''
yield decoded
except UnicodeDecodeError:
undecoded += s
@@ -356,7 +348,7 @@ def cat(
bytes_read += len(bytes)
except zlib.error:
- logger.error("Error while decompressing [%s]", key.name)
+ logger.error('Error while decompressing [%s]', key.name)
raise
except UnicodeDecodeError:
raise
@@ -371,7 +363,7 @@ def cat(
if err:
raise Exception
else:
- raise AwsError("Failed to fully read [%s]" % source.name)
+ raise AwsError('Failed to fully read [%s]' % source.name)
if undecoded:
assert encoding is not None # only time undecoded is set
@@ -392,8 +384,8 @@ def cp_string(self, source, dest, **kwargs):
the s3 location
"""
- assert isinstance(source, str), "source must be a string"
- assert self._is_s3(dest), "Destination must be s3 location"
+ assert isinstance(source, str), 'source must be a string'
+ assert self._is_s3(dest), 'Destination must be s3 location'
return self._put_string(source, dest, **kwargs)
@@ -416,11 +408,9 @@ def list(self, name, iterator=False, **kwargs):
if True return iterator rather than converting to list object
"""
- assert self._is_s3(name), "name must be in form s3://bucket/key"
+ assert self._is_s3(name), 'name must be in form s3://bucket/key'
- it = self._list(
- bucket=self._bucket_name(name), prefix=self._key_name(name), **kwargs
- )
+ it = self._list(bucket=self._bucket_name(name), prefix=self._key_name(name), **kwargs)
return iter(it) if iterator else list(it)
def listdir(self, name, **kwargs):
@@ -442,27 +432,27 @@ def listdir(self, name, **kwargs):
files or prefixes that are encountered
"""
- assert self._is_s3(name), "name must be in form s3://bucket/prefix/"
+ assert self._is_s3(name), 'name must be in form s3://bucket/prefix/'
- if not name.endswith("/"):
- name += "/"
- return self.list(name, delimiter="/", **kwargs)
+ if not name.endswith('/'):
+ name += '/'
+ return self.list(name, delimiter='/', **kwargs)
- def read(self, source, compressed=False, encoding="UTF-8"):
+ def read(self, source, compressed=False, encoding='UTF-8'):
"""
Iterates over a file in s3 split on newline.
Yields a line in file.
"""
- buf = ""
+ buf = ''
for block in self.cat(source, compressed=compressed, encoding=encoding):
buf += block
- if "\n" in buf:
- ret, buf = buf.rsplit("\n", 1)
- yield from ret.split("\n")
+ if '\n' in buf:
+ ret, buf = buf.rsplit('\n', 1)
+ yield from ret.split('\n')
- lines = buf.split("\n")
+ lines = buf.split('\n')
yield from lines[:-1]
# only yield the last line if the line has content in it
diff --git a/papermill/tests/__init__.py b/papermill/tests/__init__.py
index 9843f37e..6ef2067e 100644
--- a/papermill/tests/__init__.py
+++ b/papermill/tests/__init__.py
@@ -1,13 +1,11 @@
import os
-
from io import StringIO
-
-kernel_name = "python3"
+kernel_name = 'python3'
def get_notebook_path(*args):
- return os.path.join(os.path.dirname(os.path.abspath(__file__)), "notebooks", *args)
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), 'notebooks', *args)
def get_notebook_dir(*args):
diff --git a/papermill/tests/test_abs.py b/papermill/tests/test_abs.py
index 7793f4bd..580828b9 100644
--- a/papermill/tests/test_abs.py
+++ b/papermill/tests/test_abs.py
@@ -1,14 +1,15 @@
import os
import unittest
-
from unittest.mock import Mock, patch
+
from azure.identity import EnvironmentCredential
+
from ..abs import AzureBlobStore
class MockBytesIO:
def __init__(self):
- self.list = [b"hello", b"world!"]
+ self.list = [b'hello', b'world!']
def __getitem__(self, index):
return self.list[index]
@@ -23,106 +24,86 @@ class ABSTest(unittest.TestCase):
"""
def setUp(self):
- self.list_blobs = Mock(return_value=["foo", "bar", "baz"])
+ self.list_blobs = Mock(return_value=['foo', 'bar', 'baz'])
self.upload_blob = Mock()
self.download_blob = Mock()
self._container_client = Mock(list_blobs=self.list_blobs)
- self._blob_client = Mock(
- upload_blob=self.upload_blob, download_blob=self.download_blob
- )
+ self._blob_client = Mock(upload_blob=self.upload_blob, download_blob=self.download_blob)
self._blob_service_client = Mock(
get_blob_client=Mock(return_value=self._blob_client),
get_container_client=Mock(return_value=self._container_client),
)
self.abs = AzureBlobStore()
self.abs._blob_service_client = Mock(return_value=self._blob_service_client)
- os.environ["AZURE_TENANT_ID"] = "mytenantid"
- os.environ["AZURE_CLIENT_ID"] = "myclientid"
- os.environ["AZURE_CLIENT_SECRET"] = "myclientsecret"
+ os.environ['AZURE_TENANT_ID'] = 'mytenantid'
+ os.environ['AZURE_CLIENT_ID'] = 'myclientid'
+ os.environ['AZURE_CLIENT_SECRET'] = 'myclientsecret'
def test_split_url_raises_exception_on_invalid_url(self):
with self.assertRaises(Exception) as context:
- AzureBlobStore._split_url("this_is_not_a_valid_url")
- self.assertTrue(
- "Invalid azure blob url 'this_is_not_a_valid_url'" in str(context.exception)
- )
+ AzureBlobStore._split_url('this_is_not_a_valid_url')
+ self.assertTrue("Invalid azure blob url 'this_is_not_a_valid_url'" in str(context.exception))
def test_split_url_splits_valid_url(self):
- params = AzureBlobStore._split_url(
- "abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken"
- )
- self.assertEqual(params["account"], "myaccount")
- self.assertEqual(params["container"], "sascontainer")
- self.assertEqual(params["blob"], "sasblob.txt")
- self.assertEqual(params["sas_token"], "sastoken")
+ params = AzureBlobStore._split_url('abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken')
+ self.assertEqual(params['account'], 'myaccount')
+ self.assertEqual(params['container'], 'sascontainer')
+ self.assertEqual(params['blob'], 'sasblob.txt')
+ self.assertEqual(params['sas_token'], 'sastoken')
def test_split_url_splits_valid_url_no_sas(self):
- params = AzureBlobStore._split_url(
- "abs://myaccount.blob.core.windows.net/container/blob.txt"
- )
- self.assertEqual(params["account"], "myaccount")
- self.assertEqual(params["container"], "container")
- self.assertEqual(params["blob"], "blob.txt")
- self.assertEqual(params["sas_token"], "")
+ params = AzureBlobStore._split_url('abs://myaccount.blob.core.windows.net/container/blob.txt')
+ self.assertEqual(params['account'], 'myaccount')
+ self.assertEqual(params['container'], 'container')
+ self.assertEqual(params['blob'], 'blob.txt')
+ self.assertEqual(params['sas_token'], '')
def test_split_url_splits_valid_url_with_prefix(self):
params = AzureBlobStore._split_url(
- "abs://myaccount.blob.core.windows.net/sascontainer/A/B/sasblob.txt?sastoken"
+ 'abs://myaccount.blob.core.windows.net/sascontainer/A/B/sasblob.txt?sastoken'
)
- self.assertEqual(params["account"], "myaccount")
- self.assertEqual(params["container"], "sascontainer")
- self.assertEqual(params["blob"], "A/B/sasblob.txt")
- self.assertEqual(params["sas_token"], "sastoken")
+ self.assertEqual(params['account'], 'myaccount')
+ self.assertEqual(params['container'], 'sascontainer')
+ self.assertEqual(params['blob'], 'A/B/sasblob.txt')
+ self.assertEqual(params['sas_token'], 'sastoken')
def test_listdir_calls(self):
self.assertEqual(
- self.abs.listdir(
- "abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken"
- ),
- ["foo", "bar", "baz"],
+ self.abs.listdir('abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken'),
+ ['foo', 'bar', 'baz'],
)
- self._blob_service_client.get_container_client.assert_called_once_with(
- "sascontainer"
- )
- self.list_blobs.assert_called_once_with("sasblob.txt")
+ self._blob_service_client.get_container_client.assert_called_once_with('sascontainer')
+ self.list_blobs.assert_called_once_with('sasblob.txt')
- @patch("papermill.abs.io.BytesIO", side_effect=MockBytesIO)
+ @patch('papermill.abs.io.BytesIO', side_effect=MockBytesIO)
def test_reads_file(self, mockBytesIO):
self.assertEqual(
- self.abs.read(
- "abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken"
- ),
- ["hello", "world!"],
- )
- self._blob_service_client.get_blob_client.assert_called_once_with(
- "sascontainer", "sasblob.txt"
+ self.abs.read('abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken'),
+ ['hello', 'world!'],
)
+ self._blob_service_client.get_blob_client.assert_called_once_with('sascontainer', 'sasblob.txt')
self.download_blob.assert_called_once_with()
def test_write_file(self):
self.abs.write(
- "hello world",
- "abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken",
+ 'hello world',
+ 'abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken',
)
- self._blob_service_client.get_blob_client.assert_called_once_with(
- "sascontainer", "sasblob.txt"
- )
- self.upload_blob.assert_called_once_with(data="hello world", overwrite=True)
+ self._blob_service_client.get_blob_client.assert_called_once_with('sascontainer', 'sasblob.txt')
+ self.upload_blob.assert_called_once_with(data='hello world', overwrite=True)
def test_blob_service_client(self):
abs = AzureBlobStore()
- blob = abs._blob_service_client(account_name="myaccount", sas_token="sastoken")
- self.assertEqual(blob.account_name, "myaccount")
+ blob = abs._blob_service_client(account_name='myaccount', sas_token='sastoken')
+ self.assertEqual(blob.account_name, 'myaccount')
# Credentials gets funky with v12.0.0, so I comment this out
# self.assertEqual(blob.credential, "sastoken")
def test_blob_service_client_environment_credentials(self):
abs = AzureBlobStore()
- blob = abs._blob_service_client(account_name="myaccount", sas_token="")
- self.assertEqual(blob.account_name, "myaccount")
+ blob = abs._blob_service_client(account_name='myaccount', sas_token='')
+ self.assertEqual(blob.account_name, 'myaccount')
self.assertIsInstance(blob.credential, EnvironmentCredential)
- self.assertEqual(blob.credential._credential._tenant_id, "mytenantid")
- self.assertEqual(blob.credential._credential._client_id, "myclientid")
- self.assertEqual(
- blob.credential._credential._client_credential, "myclientsecret"
- )
+ self.assertEqual(blob.credential._credential._tenant_id, 'mytenantid')
+ self.assertEqual(blob.credential._credential._client_id, 'myclientid')
+ self.assertEqual(blob.credential._credential._client_credential, 'myclientsecret')
diff --git a/papermill/tests/test_adl.py b/papermill/tests/test_adl.py
index 6db76be3..952c7a19 100644
--- a/papermill/tests/test_adl.py
+++ b/papermill/tests/test_adl.py
@@ -1,8 +1,9 @@
import unittest
+from unittest.mock import MagicMock, Mock, patch
-from unittest.mock import Mock, MagicMock, patch
-
-from ..adl import ADL, core as adl_core, lib as adl_lib
+from ..adl import ADL
+from ..adl import core as adl_core
+from ..adl import lib as adl_lib
class ADLTest(unittest.TestCase):
@@ -13,13 +14,13 @@ class ADLTest(unittest.TestCase):
def setUp(self):
self.ls = Mock(
return_value=[
- "path/to/directory/foo",
- "path/to/directory/bar",
- "path/to/directory/baz",
+ 'path/to/directory/foo',
+ 'path/to/directory/bar',
+ 'path/to/directory/baz',
]
)
self.fakeFile = MagicMock()
- self.fakeFile.__iter__.return_value = [b"a", b"b", b"c"]
+ self.fakeFile.__iter__.return_value = [b'a', b'b', b'c']
self.fakeFile.__enter__.return_value = self.fakeFile
self.open = Mock(return_value=self.fakeFile)
self.fakeAdapter = Mock(open=self.open, ls=self.ls)
@@ -28,49 +29,41 @@ def setUp(self):
def test_split_url_raises_exception_on_invalid_url(self):
with self.assertRaises(Exception) as context:
- ADL._split_url("this_is_not_a_valid_url")
- self.assertTrue(
- "Invalid ADL url 'this_is_not_a_valid_url'" in str(context.exception)
- )
+ ADL._split_url('this_is_not_a_valid_url')
+ self.assertTrue("Invalid ADL url 'this_is_not_a_valid_url'" in str(context.exception))
def test_split_url_splits_valid_url(self):
- (store_name, path) = ADL._split_url("adl://foo.azuredatalakestore.net/bar/baz")
- self.assertEqual(store_name, "foo")
- self.assertEqual(path, "bar/baz")
+ (store_name, path) = ADL._split_url('adl://foo.azuredatalakestore.net/bar/baz')
+ self.assertEqual(store_name, 'foo')
+ self.assertEqual(path, 'bar/baz')
def test_listdir_calls_ls_on_adl_adapter(self):
self.assertEqual(
- self.adl.listdir(
- "adl://foo_store.azuredatalakestore.net/path/to/directory"
- ),
+ self.adl.listdir('adl://foo_store.azuredatalakestore.net/path/to/directory'),
[
- "adl://foo_store.azuredatalakestore.net/path/to/directory/foo",
- "adl://foo_store.azuredatalakestore.net/path/to/directory/bar",
- "adl://foo_store.azuredatalakestore.net/path/to/directory/baz",
+ 'adl://foo_store.azuredatalakestore.net/path/to/directory/foo',
+ 'adl://foo_store.azuredatalakestore.net/path/to/directory/bar',
+ 'adl://foo_store.azuredatalakestore.net/path/to/directory/baz',
],
)
- self.ls.assert_called_once_with("path/to/directory")
+ self.ls.assert_called_once_with('path/to/directory')
def test_read_opens_and_reads_file(self):
self.assertEqual(
- self.adl.read("adl://foo_store.azuredatalakestore.net/path/to/file"),
- ["a", "b", "c"],
+ self.adl.read('adl://foo_store.azuredatalakestore.net/path/to/file'),
+ ['a', 'b', 'c'],
)
self.fakeFile.__iter__.assert_called_once_with()
def test_write_opens_file_and_writes_to_it(self):
- self.adl.write(
- "hello world", "adl://foo_store.azuredatalakestore.net/path/to/file"
- )
- self.fakeFile.write.assert_called_once_with(b"hello world")
+ self.adl.write('hello world', 'adl://foo_store.azuredatalakestore.net/path/to/file')
+ self.fakeFile.write.assert_called_once_with(b'hello world')
- @patch.object(adl_lib, "auth", return_value="my_token")
- @patch.object(adl_core, "AzureDLFileSystem", return_value="my_adapter")
+ @patch.object(adl_lib, 'auth', return_value='my_token')
+ @patch.object(adl_core, 'AzureDLFileSystem', return_value='my_adapter')
def test_create_adapter(self, azure_dl_filesystem_mock, auth_mock):
sut = ADL()
- actual = sut._create_adapter("my_store_name")
- assert actual == "my_adapter"
+ actual = sut._create_adapter('my_store_name')
+ assert actual == 'my_adapter'
auth_mock.assert_called_once_with()
- azure_dl_filesystem_mock.assert_called_once_with(
- "my_token", store_name="my_store_name"
- )
+ azure_dl_filesystem_mock.assert_called_once_with('my_token', store_name='my_store_name')
diff --git a/papermill/tests/test_autosave.py b/papermill/tests/test_autosave.py
index b234c29a..74ae06e8 100644
--- a/papermill/tests/test_autosave.py
+++ b/papermill/tests/test_autosave.py
@@ -1,28 +1,26 @@
-import nbformat
import os
import tempfile
import time
import unittest
from unittest.mock import patch
-from . import get_notebook_path
+import nbformat
from .. import engines
from ..engines import NotebookExecutionManager
from ..execute import execute_notebook
+from . import get_notebook_path
class TestMidCellAutosave(unittest.TestCase):
def setUp(self):
- self.notebook_name = "test_autosave.ipynb"
+ self.notebook_name = 'test_autosave.ipynb'
self.notebook_path = get_notebook_path(self.notebook_name)
self.nb = nbformat.read(self.notebook_path, as_version=4)
def test_autosave_not_too_fast(self):
- nb_man = NotebookExecutionManager(
- self.nb, output_path="test.ipynb", autosave_cell_every=0.5
- )
- with patch.object(engines, "write_ipynb") as write_mock:
+ nb_man = NotebookExecutionManager(self.nb, output_path='test.ipynb', autosave_cell_every=0.5)
+ with patch.object(engines, 'write_ipynb') as write_mock:
write_mock.reset_mock()
assert write_mock.call_count == 0 # check that the mock is sane
nb_man.autosave_cell() # First call to autosave shouldn't trigger save
@@ -34,38 +32,30 @@ def test_autosave_not_too_fast(self):
assert write_mock.call_count == 1
def test_autosave_disable(self):
- nb_man = NotebookExecutionManager(
- self.nb, output_path="test.ipynb", autosave_cell_every=0
- )
- with patch.object(engines, "write_ipynb") as write_mock:
+ nb_man = NotebookExecutionManager(self.nb, output_path='test.ipynb', autosave_cell_every=0)
+ with patch.object(engines, 'write_ipynb') as write_mock:
write_mock.reset_mock()
assert write_mock.call_count == 0 # check that the mock is sane
nb_man.autosave_cell() # First call to autosave shouldn't trigger save
assert write_mock.call_count == 0
nb_man.autosave_cell() # Call again right away. Still shouldn't save.
assert write_mock.call_count == 0
- time.sleep(
- 0.55
- ) # Sleep for long enough that autosave should work, if enabled
+ time.sleep(0.55) # Sleep for long enough that autosave should work, if enabled
nb_man.autosave_cell()
assert write_mock.call_count == 0 # but it's disabled.
def test_end2end_autosave_slow_notebook(self):
test_dir = tempfile.mkdtemp()
- nb_test_executed_fname = os.path.join(test_dir, f"output_{self.notebook_name}")
+ nb_test_executed_fname = os.path.join(test_dir, f'output_{self.notebook_name}')
# Count how many times it writes the file w/o autosave
- with patch.object(engines, "write_ipynb") as write_mock:
- execute_notebook(
- self.notebook_path, nb_test_executed_fname, autosave_cell_every=0
- )
+ with patch.object(engines, 'write_ipynb') as write_mock:
+ execute_notebook(self.notebook_path, nb_test_executed_fname, autosave_cell_every=0)
default_write_count = write_mock.call_count
# Turn on autosave and see how many more times it gets saved.
- with patch.object(engines, "write_ipynb") as write_mock:
- execute_notebook(
- self.notebook_path, nb_test_executed_fname, autosave_cell_every=1
- )
+ with patch.object(engines, 'write_ipynb') as write_mock:
+ execute_notebook(self.notebook_path, nb_test_executed_fname, autosave_cell_every=1)
# This notebook has a cell which takes 2.5 seconds to run.
# Autosave every 1 sec should add two more saves.
assert write_mock.call_count == default_write_count + 2
diff --git a/papermill/tests/test_cli.py b/papermill/tests/test_cli.py
index 7381fd24..ad6ddbed 100755
--- a/papermill/tests/test_cli.py
+++ b/papermill/tests/test_cli.py
@@ -2,35 +2,34 @@
""" Test the command line interface """
import os
-from pathlib import Path
-import sys
import subprocess
+import sys
import tempfile
-import uuid
-import nbclient
-
-import nbformat
import unittest
+import uuid
+from pathlib import Path
from unittest.mock import patch
+import nbclient
+import nbformat
import pytest
from click.testing import CliRunner
-from . import get_notebook_path, kernel_name
from .. import cli
-from ..cli import papermill, _is_int, _is_float, _resolve_type
+from ..cli import _is_float, _is_int, _resolve_type, papermill
+from . import get_notebook_path, kernel_name
@pytest.mark.parametrize(
- "test_input,expected",
+ 'test_input,expected',
[
- ("True", True),
- ("False", False),
- ("None", None),
- ("12.51", 12.51),
- ("10", 10),
- ("hello world", "hello world"),
- ("😍", "😍"),
+ ('True', True),
+ ('False', False),
+ ('None', None),
+ ('12.51', 12.51),
+ ('10', 10),
+ ('hello world', 'hello world'),
+ ('😍', '😍'),
],
)
def test_resolve_type(test_input, expected):
@@ -38,17 +37,17 @@ def test_resolve_type(test_input, expected):
@pytest.mark.parametrize(
- "value,expected",
+ 'value,expected',
[
(13.71, True),
- ("False", False),
- ("None", False),
+ ('False', False),
+ ('None', False),
(-8.2, True),
(10, True),
- ("10", True),
- ("12.31", True),
- ("hello world", False),
- ("😍", False),
+ ('10', True),
+ ('12.31', True),
+ ('hello world', False),
+ ('😍', False),
],
)
def test_is_float(value, expected):
@@ -56,17 +55,17 @@ def test_is_float(value, expected):
@pytest.mark.parametrize(
- "value,expected",
+ 'value,expected',
[
(13.71, True),
- ("False", False),
- ("None", False),
+ ('False', False),
+ ('None', False),
(-8.2, True),
- ("-23.2", False),
+ ('-23.2', False),
(10, True),
- ("13", True),
- ("hello world", False),
- ("😍", False),
+ ('13', True),
+ ('hello world', False),
+ ('😍', False),
],
)
def test_is_int(value, expected):
@@ -75,8 +74,8 @@ def test_is_int(value, expected):
class TestCLI(unittest.TestCase):
default_execute_kwargs = dict(
- input_path="input.ipynb",
- output_path="output.ipynb",
+ input_path='input.ipynb',
+ output_path='output.ipynb',
parameters={},
engine_name=None,
request_save_on_cell_execute=True,
@@ -97,47 +96,39 @@ class TestCLI(unittest.TestCase):
def setUp(self):
self.runner = CliRunner()
self.default_args = [
- self.default_execute_kwargs["input_path"],
- self.default_execute_kwargs["output_path"],
+ self.default_execute_kwargs['input_path'],
+ self.default_execute_kwargs['output_path'],
]
- self.sample_yaml_file = os.path.join(
- os.path.dirname(__file__), "parameters", "example.yaml"
- )
- self.sample_json_file = os.path.join(
- os.path.dirname(__file__), "parameters", "example.json"
- )
+ self.sample_yaml_file = os.path.join(os.path.dirname(__file__), 'parameters', 'example.yaml')
+ self.sample_json_file = os.path.join(os.path.dirname(__file__), 'parameters', 'example.json')
def augment_execute_kwargs(self, **new_kwargs):
kwargs = self.default_execute_kwargs.copy()
kwargs.update(new_kwargs)
return kwargs
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_parameters(self, execute_patch):
self.runner.invoke(
papermill,
- self.default_args + ["-p", "foo", "bar", "--parameters", "baz", "42"],
- )
- execute_patch.assert_called_with(
- **self.augment_execute_kwargs(parameters={"foo": "bar", "baz": 42})
+ self.default_args + ['-p', 'foo', 'bar', '--parameters', 'baz', '42'],
)
+ execute_patch.assert_called_with(**self.augment_execute_kwargs(parameters={'foo': 'bar', 'baz': 42}))
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_parameters_raw(self, execute_patch):
self.runner.invoke(
papermill,
- self.default_args + ["-r", "foo", "bar", "--parameters_raw", "baz", "42"],
- )
- execute_patch.assert_called_with(
- **self.augment_execute_kwargs(parameters={"foo": "bar", "baz": "42"})
+ self.default_args + ['-r', 'foo', 'bar', '--parameters_raw', 'baz', '42'],
)
+ execute_patch.assert_called_with(**self.augment_execute_kwargs(parameters={'foo': 'bar', 'baz': '42'}))
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_parameters_file(self, execute_patch):
extra_args = [
- "-f",
+ '-f',
self.sample_yaml_file,
- "--parameters_file",
+ '--parameters_file',
self.sample_json_file,
]
self.runner.invoke(papermill, self.default_args + extra_args)
@@ -145,45 +136,40 @@ def test_parameters_file(self, execute_patch):
**self.augment_execute_kwargs(
# Last input wins dict update
parameters={
- "foo": 54321,
- "bar": "value",
- "baz": {"k2": "v2", "k1": "v1"},
- "a_date": "2019-01-01",
+ 'foo': 54321,
+ 'bar': 'value',
+ 'baz': {'k2': 'v2', 'k1': 'v1'},
+ 'a_date': '2019-01-01',
}
)
)
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_parameters_yaml(self, execute_patch):
self.runner.invoke(
papermill,
- self.default_args
- + ["-y", '{"foo": "bar"}', "--parameters_yaml", '{"foo2": ["baz"]}'],
- )
- execute_patch.assert_called_with(
- **self.augment_execute_kwargs(parameters={"foo": "bar", "foo2": ["baz"]})
+ self.default_args + ['-y', '{"foo": "bar"}', '--parameters_yaml', '{"foo2": ["baz"]}'],
)
+ execute_patch.assert_called_with(**self.augment_execute_kwargs(parameters={'foo': 'bar', 'foo2': ['baz']}))
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_parameters_yaml_date(self, execute_patch):
- self.runner.invoke(papermill, self.default_args + ["-y", "a_date: 2019-01-01"])
- execute_patch.assert_called_with(
- **self.augment_execute_kwargs(parameters={"a_date": "2019-01-01"})
- )
+ self.runner.invoke(papermill, self.default_args + ['-y', 'a_date: 2019-01-01'])
+ execute_patch.assert_called_with(**self.augment_execute_kwargs(parameters={'a_date': '2019-01-01'}))
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_parameters_empty(self, execute_patch):
# "#empty" ---base64--> "I2VtcHR5"
with tempfile.TemporaryDirectory() as tmpdir:
- empty_yaml = Path(tmpdir) / "empty.yaml"
- empty_yaml.write_text("#empty")
+ empty_yaml = Path(tmpdir) / 'empty.yaml'
+ empty_yaml.write_text('#empty')
extra_args = [
- "--parameters_file",
+ '--parameters_file',
str(empty_yaml),
- "--parameters_yaml",
- "#empty",
- "--parameters_base64",
- "I2VtcHR5",
+ '--parameters_yaml',
+ '#empty',
+ '--parameters_base64',
+ 'I2VtcHR5',
]
self.runner.invoke(
papermill,
@@ -196,139 +182,113 @@ def test_parameters_empty(self, execute_patch):
)
)
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_parameters_yaml_override(self, execute_patch):
self.runner.invoke(
papermill,
- self.default_args
- + ["--parameters_yaml", '{"foo": "bar"}', "-y", '{"foo": ["baz"]}'],
+ self.default_args + ['--parameters_yaml', '{"foo": "bar"}', '-y', '{"foo": ["baz"]}'],
)
execute_patch.assert_called_with(
**self.augment_execute_kwargs(
# Last input wins dict update
- parameters={"foo": ["baz"]}
+ parameters={'foo': ['baz']}
)
)
@patch(
- cli.__name__ + ".execute_notebook",
- side_effect=nbclient.exceptions.DeadKernelError("Fake"),
+ cli.__name__ + '.execute_notebook',
+ side_effect=nbclient.exceptions.DeadKernelError('Fake'),
)
def test_parameters_dead_kernel(self, execute_patch):
result = self.runner.invoke(
papermill,
- self.default_args
- + ["--parameters_yaml", '{"foo": "bar"}', "-y", '{"foo": ["baz"]}'],
+ self.default_args + ['--parameters_yaml', '{"foo": "bar"}', '-y', '{"foo": ["baz"]}'],
)
assert result.exit_code == 138
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_parameters_base64(self, execute_patch):
extra_args = [
- "--parameters_base64",
- "eyJmb28iOiAicmVwbGFjZWQiLCAiYmFyIjogMn0=",
- "-b",
- "eydmb28nOiAxfQ==",
+ '--parameters_base64',
+ 'eyJmb28iOiAicmVwbGFjZWQiLCAiYmFyIjogMn0=',
+ '-b',
+ 'eydmb28nOiAxfQ==',
]
self.runner.invoke(papermill, self.default_args + extra_args)
- execute_patch.assert_called_with(
- **self.augment_execute_kwargs(parameters={"foo": 1, "bar": 2})
- )
+ execute_patch.assert_called_with(**self.augment_execute_kwargs(parameters={'foo': 1, 'bar': 2}))
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_parameters_base64_date(self, execute_patch):
self.runner.invoke(
papermill,
- self.default_args + ["--parameters_base64", "YV9kYXRlOiAyMDE5LTAxLTAx"],
- )
- execute_patch.assert_called_with(
- **self.augment_execute_kwargs(parameters={"a_date": "2019-01-01"})
+ self.default_args + ['--parameters_base64', 'YV9kYXRlOiAyMDE5LTAxLTAx'],
)
+ execute_patch.assert_called_with(**self.augment_execute_kwargs(parameters={'a_date': '2019-01-01'}))
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_inject_input_path(self, execute_patch):
- self.runner.invoke(papermill, self.default_args + ["--inject-input-path"])
+ self.runner.invoke(papermill, self.default_args + ['--inject-input-path'])
execute_patch.assert_called_with(
- **self.augment_execute_kwargs(
- parameters={"PAPERMILL_INPUT_PATH": "input.ipynb"}
- )
+ **self.augment_execute_kwargs(parameters={'PAPERMILL_INPUT_PATH': 'input.ipynb'})
)
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_inject_output_path(self, execute_patch):
- self.runner.invoke(papermill, self.default_args + ["--inject-output-path"])
+ self.runner.invoke(papermill, self.default_args + ['--inject-output-path'])
execute_patch.assert_called_with(
- **self.augment_execute_kwargs(
- parameters={"PAPERMILL_OUTPUT_PATH": "output.ipynb"}
- )
+ **self.augment_execute_kwargs(parameters={'PAPERMILL_OUTPUT_PATH': 'output.ipynb'})
)
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_inject_paths(self, execute_patch):
- self.runner.invoke(papermill, self.default_args + ["--inject-paths"])
+ self.runner.invoke(papermill, self.default_args + ['--inject-paths'])
execute_patch.assert_called_with(
**self.augment_execute_kwargs(
parameters={
- "PAPERMILL_INPUT_PATH": "input.ipynb",
- "PAPERMILL_OUTPUT_PATH": "output.ipynb",
+ 'PAPERMILL_INPUT_PATH': 'input.ipynb',
+ 'PAPERMILL_OUTPUT_PATH': 'output.ipynb',
}
)
)
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_engine(self, execute_patch):
- self.runner.invoke(
- papermill, self.default_args + ["--engine", "engine-that-could"]
- )
- execute_patch.assert_called_with(
- **self.augment_execute_kwargs(engine_name="engine-that-could")
- )
+ self.runner.invoke(papermill, self.default_args + ['--engine', 'engine-that-could'])
+ execute_patch.assert_called_with(**self.augment_execute_kwargs(engine_name='engine-that-could'))
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_prepare_only(self, execute_patch):
- self.runner.invoke(papermill, self.default_args + ["--prepare-only"])
- execute_patch.assert_called_with(
- **self.augment_execute_kwargs(prepare_only=True)
- )
+ self.runner.invoke(papermill, self.default_args + ['--prepare-only'])
+ execute_patch.assert_called_with(**self.augment_execute_kwargs(prepare_only=True))
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_kernel(self, execute_patch):
- self.runner.invoke(papermill, self.default_args + ["-k", "python3"])
- execute_patch.assert_called_with(
- **self.augment_execute_kwargs(kernel_name="python3")
- )
+ self.runner.invoke(papermill, self.default_args + ['-k', 'python3'])
+ execute_patch.assert_called_with(**self.augment_execute_kwargs(kernel_name='python3'))
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_language(self, execute_patch):
- self.runner.invoke(papermill, self.default_args + ["-l", "python"])
- execute_patch.assert_called_with(
- **self.augment_execute_kwargs(language="python")
- )
+ self.runner.invoke(papermill, self.default_args + ['-l', 'python'])
+ execute_patch.assert_called_with(**self.augment_execute_kwargs(language='python'))
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_set_cwd(self, execute_patch):
- self.runner.invoke(papermill, self.default_args + ["--cwd", "a/path/here"])
- execute_patch.assert_called_with(
- **self.augment_execute_kwargs(cwd="a/path/here")
- )
+ self.runner.invoke(papermill, self.default_args + ['--cwd', 'a/path/here'])
+ execute_patch.assert_called_with(**self.augment_execute_kwargs(cwd='a/path/here'))
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_progress_bar(self, execute_patch):
- self.runner.invoke(papermill, self.default_args + ["--progress-bar"])
- execute_patch.assert_called_with(
- **self.augment_execute_kwargs(progress_bar=True)
- )
+ self.runner.invoke(papermill, self.default_args + ['--progress-bar'])
+ execute_patch.assert_called_with(**self.augment_execute_kwargs(progress_bar=True))
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_no_progress_bar(self, execute_patch):
- self.runner.invoke(papermill, self.default_args + ["--no-progress-bar"])
- execute_patch.assert_called_with(
- **self.augment_execute_kwargs(progress_bar=False)
- )
+ self.runner.invoke(papermill, self.default_args + ['--no-progress-bar'])
+ execute_patch.assert_called_with(**self.augment_execute_kwargs(progress_bar=False))
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_log_output(self, execute_patch):
- self.runner.invoke(papermill, self.default_args + ["--log-output"])
+ self.runner.invoke(papermill, self.default_args + ['--log-output'])
execute_patch.assert_called_with(
**self.augment_execute_kwargs(
log_output=True,
@@ -336,107 +296,89 @@ def test_log_output(self, execute_patch):
)
)
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_log_output_plus_progress(self, execute_patch):
- self.runner.invoke(
- papermill, self.default_args + ["--log-output", "--progress-bar"]
- )
- execute_patch.assert_called_with(
- **self.augment_execute_kwargs(log_output=True, progress_bar=True)
- )
+ self.runner.invoke(papermill, self.default_args + ['--log-output', '--progress-bar'])
+ execute_patch.assert_called_with(**self.augment_execute_kwargs(log_output=True, progress_bar=True))
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_no_log_output(self, execute_patch):
- self.runner.invoke(papermill, self.default_args + ["--no-log-output"])
- execute_patch.assert_called_with(
- **self.augment_execute_kwargs(log_output=False)
- )
+ self.runner.invoke(papermill, self.default_args + ['--no-log-output'])
+ execute_patch.assert_called_with(**self.augment_execute_kwargs(log_output=False))
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_log_level(self, execute_patch):
- self.runner.invoke(papermill, self.default_args + ["--log-level", "WARNING"])
+ self.runner.invoke(papermill, self.default_args + ['--log-level', 'WARNING'])
# TODO: this does not actually test log-level being set
execute_patch.assert_called_with(**self.augment_execute_kwargs())
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_start_timeout(self, execute_patch):
- self.runner.invoke(papermill, self.default_args + ["--start-timeout", "123"])
- execute_patch.assert_called_with(
- **self.augment_execute_kwargs(start_timeout=123)
- )
+ self.runner.invoke(papermill, self.default_args + ['--start-timeout', '123'])
+ execute_patch.assert_called_with(**self.augment_execute_kwargs(start_timeout=123))
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_start_timeout_backwards_compatibility(self, execute_patch):
- self.runner.invoke(papermill, self.default_args + ["--start_timeout", "123"])
- execute_patch.assert_called_with(
- **self.augment_execute_kwargs(start_timeout=123)
- )
+ self.runner.invoke(papermill, self.default_args + ['--start_timeout', '123'])
+ execute_patch.assert_called_with(**self.augment_execute_kwargs(start_timeout=123))
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_execution_timeout(self, execute_patch):
- self.runner.invoke(
- papermill, self.default_args + ["--execution-timeout", "123"]
- )
- execute_patch.assert_called_with(
- **self.augment_execute_kwargs(execution_timeout=123)
- )
+ self.runner.invoke(papermill, self.default_args + ['--execution-timeout', '123'])
+ execute_patch.assert_called_with(**self.augment_execute_kwargs(execution_timeout=123))
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_report_mode(self, execute_patch):
- self.runner.invoke(papermill, self.default_args + ["--report-mode"])
- execute_patch.assert_called_with(
- **self.augment_execute_kwargs(report_mode=True)
- )
+ self.runner.invoke(papermill, self.default_args + ['--report-mode'])
+ execute_patch.assert_called_with(**self.augment_execute_kwargs(report_mode=True))
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_no_report_mode(self, execute_patch):
- self.runner.invoke(papermill, self.default_args + ["--no-report-mode"])
- execute_patch.assert_called_with(
- **self.augment_execute_kwargs(report_mode=False)
- )
+ self.runner.invoke(papermill, self.default_args + ['--no-report-mode'])
+ execute_patch.assert_called_with(**self.augment_execute_kwargs(report_mode=False))
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_version(self, execute_patch):
- self.runner.invoke(papermill, ["--version"])
+ self.runner.invoke(papermill, ['--version'])
execute_patch.assert_not_called()
- @patch(cli.__name__ + ".execute_notebook")
- @patch(cli.__name__ + ".display_notebook_help")
+ @patch(cli.__name__ + '.execute_notebook')
+ @patch(cli.__name__ + '.display_notebook_help')
def test_help_notebook(self, display_notebook_help, execute_path):
- self.runner.invoke(papermill, ["--help-notebook", "input_path.ipynb"])
+ self.runner.invoke(papermill, ['--help-notebook', 'input_path.ipynb'])
execute_path.assert_not_called()
assert display_notebook_help.call_count == 1
- assert display_notebook_help.call_args[0][1] == "input_path.ipynb"
+ assert display_notebook_help.call_args[0][1] == 'input_path.ipynb'
- @patch(cli.__name__ + ".execute_notebook")
+ @patch(cli.__name__ + '.execute_notebook')
def test_many_args(self, execute_patch):
extra_args = [
- "-f",
+ '-f',
self.sample_yaml_file,
- "-y",
+ '-y',
'{"yaml_foo": {"yaml_bar": "yaml_baz"}}',
- "-b",
- "eyJiYXNlNjRfZm9vIjogImJhc2U2NF9iYXIifQ==",
- "-p",
- "baz",
- "replace",
- "-r",
- "foo",
- "54321",
- "--kernel",
- "R",
- "--engine",
- "engine-that-could",
- "--prepare-only",
- "--log-output",
- "--autosave-cell-every",
- "17",
- "--no-progress-bar",
- "--start-timeout",
- "321",
- "--execution-timeout",
- "654",
- "--report-mode",
+ '-b',
+ 'eyJiYXNlNjRfZm9vIjogImJhc2U2NF9iYXIifQ==',
+ '-p',
+ 'baz',
+ 'replace',
+ '-r',
+ 'foo',
+ '54321',
+ '--kernel',
+ 'R',
+ '--engine',
+ 'engine-that-could',
+ '--prepare-only',
+ '--log-output',
+ '--autosave-cell-every',
+ '17',
+ '--no-progress-bar',
+ '--start-timeout',
+ '321',
+ '--execution-timeout',
+ '654',
+ '--report-mode',
]
self.runner.invoke(
papermill,
@@ -445,18 +387,18 @@ def test_many_args(self, execute_patch):
execute_patch.assert_called_with(
**self.augment_execute_kwargs(
parameters={
- "foo": "54321",
- "bar": "value",
- "baz": "replace",
- "yaml_foo": {"yaml_bar": "yaml_baz"},
- "base64_foo": "base64_bar",
- "a_date": "2019-01-01",
+ 'foo': '54321',
+ 'bar': 'value',
+ 'baz': 'replace',
+ 'yaml_foo': {'yaml_bar': 'yaml_baz'},
+ 'base64_foo': 'base64_bar',
+ 'a_date': '2019-01-01',
},
- engine_name="engine-that-could",
+ engine_name='engine-that-could',
request_save_on_cell_execute=True,
autosave_cell_every=17,
prepare_only=True,
- kernel_name="R",
+ kernel_name='R',
log_output=True,
progress_bar=False,
start_timeout=321,
@@ -468,7 +410,7 @@ def test_many_args(self, execute_patch):
def papermill_cli(papermill_args=None, **kwargs):
- cmd = [sys.executable, "-m", "papermill"]
+ cmd = [sys.executable, '-m', 'papermill']
if papermill_args:
cmd.extend(papermill_args)
return subprocess.Popen(cmd, **kwargs)
@@ -476,11 +418,11 @@ def papermill_cli(papermill_args=None, **kwargs):
def papermill_version():
try:
- proc = papermill_cli(["--version"], stdout=subprocess.PIPE)
+ proc = papermill_cli(['--version'], stdout=subprocess.PIPE)
out, _ = proc.communicate()
if proc.returncode:
return None
- return out.decode("utf-8")
+ return out.decode('utf-8')
except (OSError, SystemExit): # pragma: no cover
return None
@@ -488,54 +430,50 @@ def papermill_version():
@pytest.fixture()
def notebook():
metadata = {
- "kernelspec": {
- "name": "python3",
- "language": "python",
- "display_name": "python3",
+ 'kernelspec': {
+ 'name': 'python3',
+ 'language': 'python',
+ 'display_name': 'python3',
}
}
return nbformat.v4.new_notebook(
metadata=metadata,
- cells=[
- nbformat.v4.new_markdown_cell("This is a notebook with kernel: python3")
- ],
+ cells=[nbformat.v4.new_markdown_cell('This is a notebook with kernel: python3')],
)
-require_papermill_installed = pytest.mark.skipif(
- not papermill_version(), reason="papermill is not installed"
-)
+require_papermill_installed = pytest.mark.skipif(not papermill_version(), reason='papermill is not installed')
@require_papermill_installed
def test_pipe_in_out_auto(notebook):
process = papermill_cli(stdout=subprocess.PIPE, stdin=subprocess.PIPE)
text = nbformat.writes(notebook)
- out, err = process.communicate(input=text.encode("utf-8"))
+ out, err = process.communicate(input=text.encode('utf-8'))
# Test no message on std error
assert not err
# Test that output is a valid notebook
- nbformat.reads(out.decode("utf-8"), as_version=4)
+ nbformat.reads(out.decode('utf-8'), as_version=4)
@require_papermill_installed
def test_pipe_in_out_explicit(notebook):
- process = papermill_cli(["-", "-"], stdout=subprocess.PIPE, stdin=subprocess.PIPE)
+ process = papermill_cli(['-', '-'], stdout=subprocess.PIPE, stdin=subprocess.PIPE)
text = nbformat.writes(notebook)
- out, err = process.communicate(input=text.encode("utf-8"))
+ out, err = process.communicate(input=text.encode('utf-8'))
# Test no message on std error
assert not err
# Test that output is a valid notebook
- nbformat.reads(out.decode("utf-8"), as_version=4)
+ nbformat.reads(out.decode('utf-8'), as_version=4)
@require_papermill_installed
def test_pipe_out_auto(tmpdir, notebook):
- nb_file = tmpdir.join("notebook.ipynb")
+ nb_file = tmpdir.join('notebook.ipynb')
nb_file.write(nbformat.writes(notebook))
process = papermill_cli([str(nb_file)], stdout=subprocess.PIPE)
@@ -545,31 +483,31 @@ def test_pipe_out_auto(tmpdir, notebook):
assert not err
# Test that output is a valid notebook
- nbformat.reads(out.decode("utf-8"), as_version=4)
+ nbformat.reads(out.decode('utf-8'), as_version=4)
@require_papermill_installed
def test_pipe_out_explicit(tmpdir, notebook):
- nb_file = tmpdir.join("notebook.ipynb")
+ nb_file = tmpdir.join('notebook.ipynb')
nb_file.write(nbformat.writes(notebook))
- process = papermill_cli([str(nb_file), "-"], stdout=subprocess.PIPE)
+ process = papermill_cli([str(nb_file), '-'], stdout=subprocess.PIPE)
out, err = process.communicate()
# Test no message on std error
assert not err
# Test that output is a valid notebook
- nbformat.reads(out.decode("utf-8"), as_version=4)
+ nbformat.reads(out.decode('utf-8'), as_version=4)
@require_papermill_installed
def test_pipe_in_auto(tmpdir, notebook):
- nb_file = tmpdir.join("notebook.ipynb")
+ nb_file = tmpdir.join('notebook.ipynb')
process = papermill_cli([str(nb_file)], stdin=subprocess.PIPE)
text = nbformat.writes(notebook)
- out, _ = process.communicate(input=text.encode("utf-8"))
+ out, _ = process.communicate(input=text.encode('utf-8'))
# Nothing on stdout
assert not out
@@ -581,11 +519,11 @@ def test_pipe_in_auto(tmpdir, notebook):
@require_papermill_installed
def test_pipe_in_explicit(tmpdir, notebook):
- nb_file = tmpdir.join("notebook.ipynb")
+ nb_file = tmpdir.join('notebook.ipynb')
- process = papermill_cli(["-", str(nb_file)], stdin=subprocess.PIPE)
+ process = papermill_cli(['-', str(nb_file)], stdin=subprocess.PIPE)
text = nbformat.writes(notebook)
- out, _ = process.communicate(input=text.encode("utf-8"))
+ out, _ = process.communicate(input=text.encode('utf-8'))
# Nothing on stdout
assert not out
@@ -597,20 +535,20 @@ def test_pipe_in_explicit(tmpdir, notebook):
@require_papermill_installed
def test_stdout_file(tmpdir):
- nb_file = tmpdir.join("notebook.ipynb")
- stdout_file = tmpdir.join("notebook.stdout")
+ nb_file = tmpdir.join('notebook.ipynb')
+ stdout_file = tmpdir.join('notebook.stdout')
secret = str(uuid.uuid4())
process = papermill_cli(
[
- get_notebook_path("simple_execute.ipynb"),
+ get_notebook_path('simple_execute.ipynb'),
str(nb_file),
- "-k",
+ '-k',
kernel_name,
- "-p",
- "msg",
+ '-p',
+ 'msg',
secret,
- "--stdout-file",
+ '--stdout-file',
str(stdout_file),
]
)
@@ -620,4 +558,4 @@ def test_stdout_file(tmpdir):
assert not err
with open(str(stdout_file)) as fp:
- assert fp.read() == secret + "\n"
+ assert fp.read() == secret + '\n'
diff --git a/papermill/tests/test_clientwrap.py b/papermill/tests/test_clientwrap.py
index deeb29a1..32309cf6 100644
--- a/papermill/tests/test_clientwrap.py
+++ b/papermill/tests/test_clientwrap.py
@@ -1,40 +1,39 @@
-import nbformat
import unittest
-
from unittest.mock import call, patch
-from . import get_notebook_path
+import nbformat
-from ..log import logger
-from ..engines import NotebookExecutionManager
from ..clientwrap import PapermillNotebookClient
+from ..engines import NotebookExecutionManager
+from ..log import logger
+from . import get_notebook_path
class TestPapermillClientWrapper(unittest.TestCase):
def setUp(self):
- self.nb = nbformat.read(get_notebook_path("test_logging.ipynb"), as_version=4)
+ self.nb = nbformat.read(get_notebook_path('test_logging.ipynb'), as_version=4)
self.nb_man = NotebookExecutionManager(self.nb)
self.client = PapermillNotebookClient(self.nb_man, log=logger, log_output=True)
def test_logging_stderr_msg(self):
- with patch.object(logger, "warning") as warning_mock:
- for output in self.nb.cells[0].get("outputs", []):
+ with patch.object(logger, 'warning') as warning_mock:
+ for output in self.nb.cells[0].get('outputs', []):
self.client.log_output_message(output)
- warning_mock.assert_called_once_with("INFO:test:test text\n")
+ warning_mock.assert_called_once_with('INFO:test:test text\n')
def test_logging_stdout_msg(self):
- with patch.object(logger, "info") as info_mock:
- for output in self.nb.cells[1].get("outputs", []):
+ with patch.object(logger, 'info') as info_mock:
+ for output in self.nb.cells[1].get('outputs', []):
self.client.log_output_message(output)
- info_mock.assert_called_once_with("hello world\n")
+ info_mock.assert_called_once_with('hello world\n')
def test_logging_data_msg(self):
- with patch.object(logger, "info") as info_mock:
- for output in self.nb.cells[2].get("outputs", []):
+ with patch.object(logger, 'info') as info_mock:
+ for output in self.nb.cells[2].get('outputs', []):
self.client.log_output_message(output)
info_mock.assert_has_calls(
[
- call(""),
- call(""),
+ call(''),
+ call(''),
]
)
diff --git a/papermill/tests/test_engines.py b/papermill/tests/test_engines.py
index e635a6f9..b750a01e 100644
--- a/papermill/tests/test_engines.py
+++ b/papermill/tests/test_engines.py
@@ -1,17 +1,16 @@
import copy
-import dateutil
import unittest
-
from abc import ABCMeta
-from unittest.mock import Mock, patch, call
-from nbformat.notebooknode import NotebookNode
+from unittest.mock import Mock, call, patch
-from . import get_notebook_path
+import dateutil
+from nbformat.notebooknode import NotebookNode
from .. import engines, exceptions
-from ..log import logger
+from ..engines import Engine, NBClientEngine, NotebookExecutionManager
from ..iorw import load_notebook_node
-from ..engines import NotebookExecutionManager, Engine, NBClientEngine
+from ..log import logger
+from . import get_notebook_path
def AnyMock(cls):
@@ -30,11 +29,11 @@ def __eq__(self, other):
class TestNotebookExecutionManager(unittest.TestCase):
def setUp(self):
- self.notebook_name = "simple_execute.ipynb"
+ self.notebook_name = 'simple_execute.ipynb'
self.notebook_path = get_notebook_path(self.notebook_name)
self.nb = load_notebook_node(self.notebook_path)
self.foo_nb = copy.deepcopy(self.nb)
- self.foo_nb.metadata["foo"] = "bar"
+ self.foo_nb.metadata['foo'] = 'bar'
def test_basic_pbar(self):
nb_man = NotebookExecutionManager(self.nb)
@@ -51,73 +50,69 @@ def test_set_timer(self):
nb_man = NotebookExecutionManager(self.nb)
now = nb_man.now()
- with patch.object(nb_man, "now", return_value=now):
+ with patch.object(nb_man, 'now', return_value=now):
nb_man.set_timer()
self.assertEqual(nb_man.start_time, now)
self.assertIsNone(nb_man.end_time)
def test_save(self):
- nb_man = NotebookExecutionManager(self.nb, output_path="test.ipynb")
- with patch.object(engines, "write_ipynb") as write_mock:
+ nb_man = NotebookExecutionManager(self.nb, output_path='test.ipynb')
+ with patch.object(engines, 'write_ipynb') as write_mock:
nb_man.save()
- write_mock.assert_called_with(self.nb, "test.ipynb")
+ write_mock.assert_called_with(self.nb, 'test.ipynb')
def test_save_no_output(self):
nb_man = NotebookExecutionManager(self.nb)
- with patch.object(engines, "write_ipynb") as write_mock:
+ with patch.object(engines, 'write_ipynb') as write_mock:
nb_man.save()
write_mock.assert_not_called()
def test_save_new_nb(self):
nb_man = NotebookExecutionManager(self.nb)
nb_man.save(nb=self.foo_nb)
- self.assertEqual(nb_man.nb.metadata["foo"], "bar")
+ self.assertEqual(nb_man.nb.metadata['foo'], 'bar')
def test_get_cell_description(self):
nb_man = NotebookExecutionManager(self.nb)
self.assertIsNone(nb_man.get_cell_description(nb_man.nb.cells[0]))
- self.assertEqual(nb_man.get_cell_description(nb_man.nb.cells[1]), "DESC")
+ self.assertEqual(nb_man.get_cell_description(nb_man.nb.cells[1]), 'DESC')
def test_notebook_start(self):
nb_man = NotebookExecutionManager(self.nb)
- nb_man.nb.metadata["foo"] = "bar"
+ nb_man.nb.metadata['foo'] = 'bar'
nb_man.save = Mock()
nb_man.notebook_start()
- self.assertEqual(
- nb_man.nb.metadata.papermill["start_time"], nb_man.start_time.isoformat()
- )
- self.assertIsNone(nb_man.nb.metadata.papermill["end_time"])
- self.assertIsNone(nb_man.nb.metadata.papermill["duration"])
- self.assertIsNone(nb_man.nb.metadata.papermill["exception"])
+ self.assertEqual(nb_man.nb.metadata.papermill['start_time'], nb_man.start_time.isoformat())
+ self.assertIsNone(nb_man.nb.metadata.papermill['end_time'])
+ self.assertIsNone(nb_man.nb.metadata.papermill['duration'])
+ self.assertIsNone(nb_man.nb.metadata.papermill['exception'])
for cell in nb_man.nb.cells:
- self.assertIsNone(cell.metadata.papermill["start_time"])
- self.assertIsNone(cell.metadata.papermill["end_time"])
- self.assertIsNone(cell.metadata.papermill["duration"])
- self.assertIsNone(cell.metadata.papermill["exception"])
- self.assertEqual(
- cell.metadata.papermill["status"], NotebookExecutionManager.PENDING
- )
- self.assertIsNone(cell.get("execution_count"))
- if cell.cell_type == "code":
- self.assertEqual(cell.get("outputs"), [])
+ self.assertIsNone(cell.metadata.papermill['start_time'])
+ self.assertIsNone(cell.metadata.papermill['end_time'])
+ self.assertIsNone(cell.metadata.papermill['duration'])
+ self.assertIsNone(cell.metadata.papermill['exception'])
+ self.assertEqual(cell.metadata.papermill['status'], NotebookExecutionManager.PENDING)
+ self.assertIsNone(cell.get('execution_count'))
+ if cell.cell_type == 'code':
+ self.assertEqual(cell.get('outputs'), [])
else:
- self.assertIsNone(cell.get("outputs"))
+ self.assertIsNone(cell.get('outputs'))
nb_man.save.assert_called_once()
def test_notebook_start_new_nb(self):
nb_man = NotebookExecutionManager(self.nb)
nb_man.notebook_start(nb=self.foo_nb)
- self.assertEqual(nb_man.nb.metadata["foo"], "bar")
+ self.assertEqual(nb_man.nb.metadata['foo'], 'bar')
def test_notebook_start_markdown_code(self):
nb_man = NotebookExecutionManager(self.nb)
nb_man.notebook_start(nb=self.foo_nb)
- self.assertNotIn("execution_count", nb_man.nb.cells[-1])
- self.assertNotIn("outputs", nb_man.nb.cells[-1])
+ self.assertNotIn('execution_count', nb_man.nb.cells[-1])
+ self.assertNotIn('outputs', nb_man.nb.cells[-1])
def test_cell_start(self):
nb_man = NotebookExecutionManager(self.nb)
@@ -129,18 +124,16 @@ def test_cell_start(self):
nb_man.save = Mock()
nb_man.cell_start(cell)
- self.assertEqual(cell.metadata.papermill["start_time"], fixed_now.isoformat())
- self.assertFalse(cell.metadata.papermill["exception"])
- self.assertEqual(
- cell.metadata.papermill["status"], NotebookExecutionManager.RUNNING
- )
+ self.assertEqual(cell.metadata.papermill['start_time'], fixed_now.isoformat())
+ self.assertFalse(cell.metadata.papermill['exception'])
+ self.assertEqual(cell.metadata.papermill['status'], NotebookExecutionManager.RUNNING)
nb_man.save.assert_called_once()
def test_cell_start_new_nb(self):
nb_man = NotebookExecutionManager(self.nb)
nb_man.cell_start(self.foo_nb.cells[0], nb=self.foo_nb)
- self.assertEqual(nb_man.nb.metadata["foo"], "bar")
+ self.assertEqual(nb_man.nb.metadata['foo'], 'bar')
def test_cell_exception(self):
nb_man = NotebookExecutionManager(self.nb)
@@ -148,16 +141,14 @@ def test_cell_exception(self):
cell = nb_man.nb.cells[0]
nb_man.cell_exception(cell)
- self.assertEqual(nb_man.nb.metadata.papermill["exception"], True)
- self.assertEqual(cell.metadata.papermill["exception"], True)
- self.assertEqual(
- cell.metadata.papermill["status"], NotebookExecutionManager.FAILED
- )
+ self.assertEqual(nb_man.nb.metadata.papermill['exception'], True)
+ self.assertEqual(cell.metadata.papermill['exception'], True)
+ self.assertEqual(cell.metadata.papermill['status'], NotebookExecutionManager.FAILED)
def test_cell_exception_new_nb(self):
nb_man = NotebookExecutionManager(self.nb)
nb_man.cell_exception(self.foo_nb.cells[0], nb=self.foo_nb)
- self.assertEqual(nb_man.nb.metadata["foo"], "bar")
+ self.assertEqual(nb_man.nb.metadata['foo'], 'bar')
def test_cell_complete_after_cell_start(self):
nb_man = NotebookExecutionManager(self.nb)
@@ -173,18 +164,16 @@ def test_cell_complete_after_cell_start(self):
nb_man.pbar = Mock()
nb_man.cell_complete(cell)
- self.assertIsNotNone(cell.metadata.papermill["start_time"])
- start_time = dateutil.parser.parse(cell.metadata.papermill["start_time"])
+ self.assertIsNotNone(cell.metadata.papermill['start_time'])
+ start_time = dateutil.parser.parse(cell.metadata.papermill['start_time'])
- self.assertEqual(cell.metadata.papermill["end_time"], fixed_now.isoformat())
+ self.assertEqual(cell.metadata.papermill['end_time'], fixed_now.isoformat())
self.assertEqual(
- cell.metadata.papermill["duration"],
+ cell.metadata.papermill['duration'],
(fixed_now - start_time).total_seconds(),
)
- self.assertFalse(cell.metadata.papermill["exception"])
- self.assertEqual(
- cell.metadata.papermill["status"], NotebookExecutionManager.COMPLETED
- )
+ self.assertFalse(cell.metadata.papermill['exception'])
+ self.assertEqual(cell.metadata.papermill['status'], NotebookExecutionManager.COMPLETED)
nb_man.save.assert_called_once()
nb_man.pbar.update.assert_called_once()
@@ -202,12 +191,10 @@ def test_cell_complete_without_cell_start(self):
nb_man.pbar = Mock()
nb_man.cell_complete(cell)
- self.assertEqual(cell.metadata.papermill["end_time"], fixed_now.isoformat())
- self.assertIsNone(cell.metadata.papermill["duration"])
- self.assertFalse(cell.metadata.papermill["exception"])
- self.assertEqual(
- cell.metadata.papermill["status"], NotebookExecutionManager.COMPLETED
- )
+ self.assertEqual(cell.metadata.papermill['end_time'], fixed_now.isoformat())
+ self.assertIsNone(cell.metadata.papermill['duration'])
+ self.assertFalse(cell.metadata.papermill['exception'])
+ self.assertEqual(cell.metadata.papermill['status'], NotebookExecutionManager.COMPLETED)
nb_man.save.assert_called_once()
nb_man.pbar.update.assert_called_once()
@@ -227,18 +214,16 @@ def test_cell_complete_after_cell_exception(self):
nb_man.pbar = Mock()
nb_man.cell_complete(cell)
- self.assertIsNotNone(cell.metadata.papermill["start_time"])
- start_time = dateutil.parser.parse(cell.metadata.papermill["start_time"])
+ self.assertIsNotNone(cell.metadata.papermill['start_time'])
+ start_time = dateutil.parser.parse(cell.metadata.papermill['start_time'])
- self.assertEqual(cell.metadata.papermill["end_time"], fixed_now.isoformat())
+ self.assertEqual(cell.metadata.papermill['end_time'], fixed_now.isoformat())
self.assertEqual(
- cell.metadata.papermill["duration"],
+ cell.metadata.papermill['duration'],
(fixed_now - start_time).total_seconds(),
)
- self.assertTrue(cell.metadata.papermill["exception"])
- self.assertEqual(
- cell.metadata.papermill["status"], NotebookExecutionManager.FAILED
- )
+ self.assertTrue(cell.metadata.papermill['exception'])
+ self.assertEqual(cell.metadata.papermill['status'], NotebookExecutionManager.FAILED)
nb_man.save.assert_called_once()
nb_man.pbar.update.assert_called_once()
@@ -247,9 +232,9 @@ def test_cell_complete_new_nb(self):
nb_man = NotebookExecutionManager(self.nb)
nb_man.notebook_start()
baz_nb = copy.deepcopy(nb_man.nb)
- baz_nb.metadata["baz"] = "buz"
+ baz_nb.metadata['baz'] = 'buz'
nb_man.cell_complete(baz_nb.cells[0], nb=baz_nb)
- self.assertEqual(nb_man.nb.metadata["baz"], "buz")
+ self.assertEqual(nb_man.nb.metadata['baz'], 'buz')
def test_notebook_complete(self):
nb_man = NotebookExecutionManager(self.nb)
@@ -264,17 +249,15 @@ def test_notebook_complete(self):
nb_man.notebook_complete()
- self.assertIsNotNone(nb_man.nb.metadata.papermill["start_time"])
- start_time = dateutil.parser.parse(nb_man.nb.metadata.papermill["start_time"])
+ self.assertIsNotNone(nb_man.nb.metadata.papermill['start_time'])
+ start_time = dateutil.parser.parse(nb_man.nb.metadata.papermill['start_time'])
+ self.assertEqual(nb_man.nb.metadata.papermill['end_time'], fixed_now.isoformat())
self.assertEqual(
- nb_man.nb.metadata.papermill["end_time"], fixed_now.isoformat()
- )
- self.assertEqual(
- nb_man.nb.metadata.papermill["duration"],
+ nb_man.nb.metadata.papermill['duration'],
(fixed_now - start_time).total_seconds(),
)
- self.assertFalse(nb_man.nb.metadata.papermill["exception"])
+ self.assertFalse(nb_man.nb.metadata.papermill['exception'])
nb_man.save.assert_called_once()
nb_man.cleanup_pbar.assert_called_once()
@@ -283,18 +266,16 @@ def test_notebook_complete_new_nb(self):
nb_man = NotebookExecutionManager(self.nb)
nb_man.notebook_start()
baz_nb = copy.deepcopy(nb_man.nb)
- baz_nb.metadata["baz"] = "buz"
+ baz_nb.metadata['baz'] = 'buz'
nb_man.notebook_complete(nb=baz_nb)
- self.assertEqual(nb_man.nb.metadata["baz"], "buz")
+ self.assertEqual(nb_man.nb.metadata['baz'], 'buz')
def test_notebook_complete_cell_status_completed(self):
nb_man = NotebookExecutionManager(self.nb)
nb_man.notebook_start()
nb_man.notebook_complete()
for cell in nb_man.nb.cells:
- self.assertEqual(
- cell.metadata.papermill["status"], NotebookExecutionManager.COMPLETED
- )
+ self.assertEqual(cell.metadata.papermill['status'], NotebookExecutionManager.COMPLETED)
def test_notebook_complete_cell_status_with_failed(self):
nb_man = NotebookExecutionManager(self.nb)
@@ -302,22 +283,20 @@ def test_notebook_complete_cell_status_with_failed(self):
nb_man.cell_exception(nb_man.nb.cells[1])
nb_man.notebook_complete()
self.assertEqual(
- nb_man.nb.cells[0].metadata.papermill["status"],
+ nb_man.nb.cells[0].metadata.papermill['status'],
NotebookExecutionManager.COMPLETED,
)
self.assertEqual(
- nb_man.nb.cells[1].metadata.papermill["status"],
+ nb_man.nb.cells[1].metadata.papermill['status'],
NotebookExecutionManager.FAILED,
)
for cell in nb_man.nb.cells[2:]:
- self.assertEqual(
- cell.metadata.papermill["status"], NotebookExecutionManager.PENDING
- )
+ self.assertEqual(cell.metadata.papermill['status'], NotebookExecutionManager.PENDING)
class TestEngineBase(unittest.TestCase):
def setUp(self):
- self.notebook_name = "simple_execute.ipynb"
+ self.notebook_name = 'simple_execute.ipynb'
self.notebook_path = get_notebook_path(self.notebook_name)
self.nb = load_notebook_node(self.notebook_path)
@@ -326,28 +305,26 @@ def test_wrap_and_execute_notebook(self):
Mocks each wrapped call and proves the correct inputs get applied to
the correct underlying calls for execute_notebook.
"""
- with patch.object(Engine, "execute_managed_notebook") as exec_mock:
- with patch.object(engines, "NotebookExecutionManager") as wrap_mock:
+ with patch.object(Engine, 'execute_managed_notebook') as exec_mock:
+ with patch.object(engines, 'NotebookExecutionManager') as wrap_mock:
Engine.execute_notebook(
self.nb,
- "python",
- output_path="foo.ipynb",
+ 'python',
+ output_path='foo.ipynb',
progress_bar=False,
log_output=True,
- bar="baz",
+ bar='baz',
)
wrap_mock.assert_called_once_with(
self.nb,
- output_path="foo.ipynb",
+ output_path='foo.ipynb',
progress_bar=False,
log_output=True,
autosave_cell_every=30,
)
wrap_mock.return_value.notebook_start.assert_called_once()
- exec_mock.assert_called_once_with(
- wrap_mock.return_value, "python", log_output=True, bar="baz"
- )
+ exec_mock.assert_called_once_with(wrap_mock.return_value, 'python', log_output=True, bar='baz')
wrap_mock.return_value.notebook_complete.assert_called_once()
wrap_mock.return_value.cleanup_pbar.assert_called_once()
@@ -359,28 +336,26 @@ def execute_managed_notebook(cls, nb_man, kernel_name, **kwargs):
nb_man.cell_start(cell)
nb_man.cell_complete(cell)
- with patch.object(NotebookExecutionManager, "save") as save_mock:
- nb = CellCallbackEngine.execute_notebook(
- copy.deepcopy(self.nb), "python", output_path="foo.ipynb"
- )
+ with patch.object(NotebookExecutionManager, 'save') as save_mock:
+ nb = CellCallbackEngine.execute_notebook(copy.deepcopy(self.nb), 'python', output_path='foo.ipynb')
self.assertEqual(nb, AnyMock(NotebookNode))
self.assertNotEqual(self.nb, nb)
self.assertEqual(save_mock.call_count, 8)
- self.assertIsNotNone(nb.metadata.papermill["start_time"])
- self.assertIsNotNone(nb.metadata.papermill["end_time"])
- self.assertEqual(nb.metadata.papermill["duration"], AnyMock(float))
- self.assertFalse(nb.metadata.papermill["exception"])
+ self.assertIsNotNone(nb.metadata.papermill['start_time'])
+ self.assertIsNotNone(nb.metadata.papermill['end_time'])
+ self.assertEqual(nb.metadata.papermill['duration'], AnyMock(float))
+ self.assertFalse(nb.metadata.papermill['exception'])
for cell in nb.cells:
- self.assertIsNotNone(cell.metadata.papermill["start_time"])
- self.assertIsNotNone(cell.metadata.papermill["end_time"])
- self.assertEqual(cell.metadata.papermill["duration"], AnyMock(float))
- self.assertFalse(cell.metadata.papermill["exception"])
+ self.assertIsNotNone(cell.metadata.papermill['start_time'])
+ self.assertIsNotNone(cell.metadata.papermill['end_time'])
+ self.assertEqual(cell.metadata.papermill['duration'], AnyMock(float))
+ self.assertFalse(cell.metadata.papermill['exception'])
self.assertEqual(
- cell.metadata.papermill["status"],
+ cell.metadata.papermill['status'],
NotebookExecutionManager.COMPLETED,
)
@@ -390,13 +365,9 @@ class NoCellCallbackEngine(Engine):
def execute_managed_notebook(cls, nb_man, kernel_name, **kwargs):
pass
- with patch.object(NotebookExecutionManager, "save") as save_mock:
- with patch.object(
- NotebookExecutionManager, "complete_pbar"
- ) as pbar_comp_mock:
- nb = NoCellCallbackEngine.execute_notebook(
- copy.deepcopy(self.nb), "python", output_path="foo.ipynb"
- )
+ with patch.object(NotebookExecutionManager, 'save') as save_mock:
+ with patch.object(NotebookExecutionManager, 'complete_pbar') as pbar_comp_mock:
+ nb = NoCellCallbackEngine.execute_notebook(copy.deepcopy(self.nb), 'python', output_path='foo.ipynb')
self.assertEqual(nb, AnyMock(NotebookNode))
self.assertNotEqual(self.nb, nb)
@@ -404,38 +375,38 @@ def execute_managed_notebook(cls, nb_man, kernel_name, **kwargs):
self.assertEqual(save_mock.call_count, 2)
pbar_comp_mock.assert_called_once()
- self.assertIsNotNone(nb.metadata.papermill["start_time"])
- self.assertIsNotNone(nb.metadata.papermill["end_time"])
- self.assertEqual(nb.metadata.papermill["duration"], AnyMock(float))
- self.assertFalse(nb.metadata.papermill["exception"])
+ self.assertIsNotNone(nb.metadata.papermill['start_time'])
+ self.assertIsNotNone(nb.metadata.papermill['end_time'])
+ self.assertEqual(nb.metadata.papermill['duration'], AnyMock(float))
+ self.assertFalse(nb.metadata.papermill['exception'])
for cell in nb.cells:
- self.assertIsNone(cell.metadata.papermill["start_time"])
- self.assertIsNone(cell.metadata.papermill["end_time"])
- self.assertIsNone(cell.metadata.papermill["duration"])
- self.assertIsNone(cell.metadata.papermill["exception"])
+ self.assertIsNone(cell.metadata.papermill['start_time'])
+ self.assertIsNone(cell.metadata.papermill['end_time'])
+ self.assertIsNone(cell.metadata.papermill['duration'])
+ self.assertIsNone(cell.metadata.papermill['exception'])
self.assertEqual(
- cell.metadata.papermill["status"],
+ cell.metadata.papermill['status'],
NotebookExecutionManager.COMPLETED,
)
class TestNBClientEngine(unittest.TestCase):
def setUp(self):
- self.notebook_name = "simple_execute.ipynb"
+ self.notebook_name = 'simple_execute.ipynb'
self.notebook_path = get_notebook_path(self.notebook_name)
self.nb = load_notebook_node(self.notebook_path)
def test_nb_convert_engine(self):
- with patch.object(engines, "PapermillNotebookClient") as client_mock:
- with patch.object(NotebookExecutionManager, "save") as save_mock:
+ with patch.object(engines, 'PapermillNotebookClient') as client_mock:
+ with patch.object(NotebookExecutionManager, 'save') as save_mock:
nb = NBClientEngine.execute_notebook(
copy.deepcopy(self.nb),
- "python",
- output_path="foo.ipynb",
+ 'python',
+ output_path='foo.ipynb',
progress_bar=False,
log_output=True,
- bar="baz",
+ bar='baz',
start_timeout=30,
execution_timeout=1000,
)
@@ -447,16 +418,14 @@ def test_nb_convert_engine(self):
args, kwargs = client_mock.call_args
expected = [
- ("timeout", 1000),
- ("startup_timeout", 30),
- ("kernel_name", "python"),
- ("log", logger),
- ("log_output", True),
+ ('timeout', 1000),
+ ('startup_timeout', 30),
+ ('kernel_name', 'python'),
+ ('log', logger),
+ ('log_output', True),
]
actual = {(key, kwargs[key]) for key in kwargs}
- msg = (
- f"Expected arguments {expected} are not a subset of actual {actual}"
- )
+ msg = f'Expected arguments {expected} are not a subset of actual {actual}'
self.assertTrue(set(expected).issubset(actual), msg=msg)
client_mock.return_value.execute.assert_called_once_with()
@@ -464,71 +433,63 @@ def test_nb_convert_engine(self):
self.assertEqual(save_mock.call_count, 2)
def test_nb_convert_engine_execute(self):
- with patch.object(NotebookExecutionManager, "save") as save_mock:
+ with patch.object(NotebookExecutionManager, 'save') as save_mock:
nb = NBClientEngine.execute_notebook(
self.nb,
- "python",
- output_path="foo.ipynb",
+ 'python',
+ output_path='foo.ipynb',
progress_bar=False,
log_output=True,
)
self.assertEqual(save_mock.call_count, 8)
self.assertEqual(nb, AnyMock(NotebookNode))
- self.assertIsNotNone(nb.metadata.papermill["start_time"])
- self.assertIsNotNone(nb.metadata.papermill["end_time"])
- self.assertEqual(nb.metadata.papermill["duration"], AnyMock(float))
- self.assertFalse(nb.metadata.papermill["exception"])
+ self.assertIsNotNone(nb.metadata.papermill['start_time'])
+ self.assertIsNotNone(nb.metadata.papermill['end_time'])
+ self.assertEqual(nb.metadata.papermill['duration'], AnyMock(float))
+ self.assertFalse(nb.metadata.papermill['exception'])
for cell in nb.cells:
- self.assertIsNotNone(cell.metadata.papermill["start_time"])
- self.assertIsNotNone(cell.metadata.papermill["end_time"])
- self.assertEqual(cell.metadata.papermill["duration"], AnyMock(float))
- self.assertFalse(cell.metadata.papermill["exception"])
+ self.assertIsNotNone(cell.metadata.papermill['start_time'])
+ self.assertIsNotNone(cell.metadata.papermill['end_time'])
+ self.assertEqual(cell.metadata.papermill['duration'], AnyMock(float))
+ self.assertFalse(cell.metadata.papermill['exception'])
self.assertEqual(
- cell.metadata.papermill["status"],
+ cell.metadata.papermill['status'],
NotebookExecutionManager.COMPLETED,
)
def test_nb_convert_log_outputs(self):
- with patch.object(logger, "info") as info_mock:
- with patch.object(logger, "warning") as warning_mock:
- with patch.object(NotebookExecutionManager, "save"):
+ with patch.object(logger, 'info') as info_mock:
+ with patch.object(logger, 'warning') as warning_mock:
+ with patch.object(NotebookExecutionManager, 'save'):
NBClientEngine.execute_notebook(
self.nb,
- "python",
- output_path="foo.ipynb",
+ 'python',
+ output_path='foo.ipynb',
progress_bar=False,
log_output=True,
)
info_mock.assert_has_calls(
[
- call("Executing notebook with kernel: python"),
- call(
- "Executing Cell 1---------------------------------------"
- ),
- call(
- "Ending Cell 1------------------------------------------"
- ),
- call(
- "Executing Cell 2---------------------------------------"
- ),
- call("None\n"),
- call(
- "Ending Cell 2------------------------------------------"
- ),
+ call('Executing notebook with kernel: python'),
+ call('Executing Cell 1---------------------------------------'),
+ call('Ending Cell 1------------------------------------------'),
+ call('Executing Cell 2---------------------------------------'),
+ call('None\n'),
+ call('Ending Cell 2------------------------------------------'),
]
)
warning_mock.is_not_called()
def test_nb_convert_no_log_outputs(self):
- with patch.object(logger, "info") as info_mock:
- with patch.object(logger, "warning") as warning_mock:
- with patch.object(NotebookExecutionManager, "save"):
+ with patch.object(logger, 'info') as info_mock:
+ with patch.object(logger, 'warning') as warning_mock:
+ with patch.object(NotebookExecutionManager, 'save'):
NBClientEngine.execute_notebook(
self.nb,
- "python",
- output_path="foo.ipynb",
+ 'python',
+ output_path='foo.ipynb',
progress_bar=False,
log_output=False,
)
@@ -542,33 +503,31 @@ def setUp(self):
def test_registration(self):
mock_engine = Mock()
- self.papermill_engines.register("mock_engine", mock_engine)
- self.assertIn("mock_engine", self.papermill_engines._engines)
- self.assertIs(mock_engine, self.papermill_engines._engines["mock_engine"])
+ self.papermill_engines.register('mock_engine', mock_engine)
+ self.assertIn('mock_engine', self.papermill_engines._engines)
+ self.assertIs(mock_engine, self.papermill_engines._engines['mock_engine'])
def test_getting(self):
mock_engine = Mock()
- self.papermill_engines.register("mock_engine", mock_engine)
+ self.papermill_engines.register('mock_engine', mock_engine)
# test retrieving an engine works
- retrieved_engine = self.papermill_engines.get_engine("mock_engine")
+ retrieved_engine = self.papermill_engines.get_engine('mock_engine')
self.assertIs(mock_engine, retrieved_engine)
# test you can't retrieve a non-registered engine
self.assertRaises(
exceptions.PapermillException,
self.papermill_engines.get_engine,
- "non-existent",
+ 'non-existent',
)
def test_registering_entry_points(self):
fake_entrypoint = Mock(load=Mock())
- fake_entrypoint.name = "fake-engine"
+ fake_entrypoint.name = 'fake-engine'
- with patch(
- "entrypoints.get_group_all", return_value=[fake_entrypoint]
- ) as mock_get_group_all:
+ with patch('entrypoints.get_group_all', return_value=[fake_entrypoint]) as mock_get_group_all:
self.papermill_engines.register_entry_points()
- mock_get_group_all.assert_called_once_with("papermill.engine")
+ mock_get_group_all.assert_called_once_with('papermill.engine')
self.assertEqual(
- self.papermill_engines.get_engine("fake-engine"),
+ self.papermill_engines.get_engine('fake-engine'),
fake_entrypoint.load.return_value,
)
diff --git a/papermill/tests/test_exceptions.py b/papermill/tests/test_exceptions.py
index 9c555942..191767fb 100644
--- a/papermill/tests/test_exceptions.py
+++ b/papermill/tests/test_exceptions.py
@@ -12,29 +12,29 @@ def temp_file():
"""NamedTemporaryFile must be set in wb mode, closed without delete, opened with open(file, "rb"),
then manually deleted. Otherwise, file fails to be read due to permission error on Windows.
"""
- with tempfile.NamedTemporaryFile(mode="wb", delete=False) as f:
+ with tempfile.NamedTemporaryFile(mode='wb', delete=False) as f:
yield f
os.unlink(f.name)
@pytest.mark.parametrize(
- "exc,args",
+ 'exc,args',
[
(
exceptions.PapermillExecutionError,
- (1, 2, "TestSource", "Exception", Exception(), ["Traceback", "Message"]),
+ (1, 2, 'TestSource', 'Exception', Exception(), ['Traceback', 'Message']),
),
(
exceptions.PapermillMissingParameterException,
- ("PapermillMissingParameterException",),
+ ('PapermillMissingParameterException',),
),
- (exceptions.AwsError, ("AwsError",)),
- (exceptions.FileExistsError, ("FileExistsError",)),
- (exceptions.PapermillException, ("PapermillException",)),
- (exceptions.PapermillRateLimitException, ("PapermillRateLimitException",)),
+ (exceptions.AwsError, ('AwsError',)),
+ (exceptions.FileExistsError, ('FileExistsError',)),
+ (exceptions.PapermillException, ('PapermillException',)),
+ (exceptions.PapermillRateLimitException, ('PapermillRateLimitException',)),
(
exceptions.PapermillOptionalDependencyException,
- ("PapermillOptionalDependencyException",),
+ ('PapermillOptionalDependencyException',),
),
],
)
@@ -45,7 +45,7 @@ def test_exceptions_are_unpickleable(temp_file, exc, args):
temp_file.close() # close to re-open for reading
# Read the Pickled File
- with open(temp_file.name, "rb") as read_file:
+ with open(temp_file.name, 'rb') as read_file:
read_file.seek(0)
data = read_file.read()
pickled_err = pickle.loads(data)
diff --git a/papermill/tests/test_execute.py b/papermill/tests/test_execute.py
index 350d9b0f..6396de35 100644
--- a/papermill/tests/test_execute.py
+++ b/papermill/tests/test_execute.py
@@ -3,20 +3,19 @@
import tempfile
import unittest
from copy import deepcopy
-from unittest.mock import patch, ANY
-
from functools import partial
from pathlib import Path
+from unittest.mock import ANY, patch
import nbformat
from nbformat import validate
from .. import engines, translators
-from ..log import logger
+from ..exceptions import PapermillExecutionError
+from ..execute import execute_notebook
from ..iorw import load_notebook_node
+from ..log import logger
from ..utils import chdir
-from ..execute import execute_notebook
-from ..exceptions import PapermillExecutionError
from . import get_notebook_path, kernel_name
execute_notebook = partial(execute_notebook, kernel_name=kernel_name)
@@ -25,132 +24,112 @@
class TestNotebookHelpers(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp()
- self.notebook_name = "simple_execute.ipynb"
+ self.notebook_name = 'simple_execute.ipynb'
self.notebook_path = get_notebook_path(self.notebook_name)
- self.nb_test_executed_fname = os.path.join(
- self.test_dir, f"output_{self.notebook_name}"
- )
+ self.nb_test_executed_fname = os.path.join(self.test_dir, f'output_{self.notebook_name}')
def tearDown(self):
shutil.rmtree(self.test_dir)
- @patch(engines.__name__ + ".PapermillNotebookClient")
+ @patch(engines.__name__ + '.PapermillNotebookClient')
def test_start_timeout(self, preproc_mock):
- execute_notebook(
- self.notebook_path, self.nb_test_executed_fname, start_timeout=123
- )
+ execute_notebook(self.notebook_path, self.nb_test_executed_fname, start_timeout=123)
args, kwargs = preproc_mock.call_args
expected = [
- ("timeout", None),
- ("startup_timeout", 123),
- ("kernel_name", kernel_name),
- ("log", logger),
+ ('timeout', None),
+ ('startup_timeout', 123),
+ ('kernel_name', kernel_name),
+ ('log', logger),
]
actual = {(key, kwargs[key]) for key in kwargs}
self.assertTrue(
set(expected).issubset(actual),
- msg=f"Expected arguments {expected} are not a subset of actual {actual}",
+ msg=f'Expected arguments {expected} are not a subset of actual {actual}',
)
- @patch(engines.__name__ + ".PapermillNotebookClient")
+ @patch(engines.__name__ + '.PapermillNotebookClient')
def test_default_start_timeout(self, preproc_mock):
execute_notebook(self.notebook_path, self.nb_test_executed_fname)
args, kwargs = preproc_mock.call_args
expected = [
- ("timeout", None),
- ("startup_timeout", 60),
- ("kernel_name", kernel_name),
- ("log", logger),
+ ('timeout', None),
+ ('startup_timeout', 60),
+ ('kernel_name', kernel_name),
+ ('log', logger),
]
actual = {(key, kwargs[key]) for key in kwargs}
self.assertTrue(
set(expected).issubset(actual),
- msg=f"Expected arguments {expected} are not a subset of actual {actual}",
+ msg=f'Expected arguments {expected} are not a subset of actual {actual}',
)
def test_cell_insertion(self):
- execute_notebook(
- self.notebook_path, self.nb_test_executed_fname, {"msg": "Hello"}
- )
+ execute_notebook(self.notebook_path, self.nb_test_executed_fname, {'msg': 'Hello'})
test_nb = load_notebook_node(self.nb_test_executed_fname)
self.assertListEqual(
- test_nb.cells[1].get("source").split("\n"),
- ["# Parameters", 'msg = "Hello"', ""],
+ test_nb.cells[1].get('source').split('\n'),
+ ['# Parameters', 'msg = "Hello"', ''],
)
- self.assertEqual(test_nb.metadata.papermill.parameters, {"msg": "Hello"})
+ self.assertEqual(test_nb.metadata.papermill.parameters, {'msg': 'Hello'})
def test_no_tags(self):
- notebook_name = "no_parameters.ipynb"
- nb_test_executed_fname = os.path.join(self.test_dir, f"output_{notebook_name}")
- execute_notebook(
- get_notebook_path(notebook_name), nb_test_executed_fname, {"msg": "Hello"}
- )
+ notebook_name = 'no_parameters.ipynb'
+ nb_test_executed_fname = os.path.join(self.test_dir, f'output_{notebook_name}')
+ execute_notebook(get_notebook_path(notebook_name), nb_test_executed_fname, {'msg': 'Hello'})
test_nb = load_notebook_node(nb_test_executed_fname)
self.assertListEqual(
- test_nb.cells[0].get("source").split("\n"),
- ["# Parameters", 'msg = "Hello"', ""],
+ test_nb.cells[0].get('source').split('\n'),
+ ['# Parameters', 'msg = "Hello"', ''],
)
- self.assertEqual(test_nb.metadata.papermill.parameters, {"msg": "Hello"})
+ self.assertEqual(test_nb.metadata.papermill.parameters, {'msg': 'Hello'})
def test_quoted_params(self):
- execute_notebook(
- self.notebook_path, self.nb_test_executed_fname, {"msg": '"Hello"'}
- )
+ execute_notebook(self.notebook_path, self.nb_test_executed_fname, {'msg': '"Hello"'})
test_nb = load_notebook_node(self.nb_test_executed_fname)
self.assertListEqual(
- test_nb.cells[1].get("source").split("\n"),
- ["# Parameters", r'msg = "\"Hello\""', ""],
+ test_nb.cells[1].get('source').split('\n'),
+ ['# Parameters', r'msg = "\"Hello\""', ''],
)
- self.assertEqual(test_nb.metadata.papermill.parameters, {"msg": '"Hello"'})
+ self.assertEqual(test_nb.metadata.papermill.parameters, {'msg': '"Hello"'})
def test_backslash_params(self):
- execute_notebook(
- self.notebook_path, self.nb_test_executed_fname, {"foo": r"do\ not\ crash"}
- )
+ execute_notebook(self.notebook_path, self.nb_test_executed_fname, {'foo': r'do\ not\ crash'})
test_nb = load_notebook_node(self.nb_test_executed_fname)
self.assertListEqual(
- test_nb.cells[1].get("source").split("\n"),
- ["# Parameters", r'foo = "do\\ not\\ crash"', ""],
- )
- self.assertEqual(
- test_nb.metadata.papermill.parameters, {"foo": r"do\ not\ crash"}
+ test_nb.cells[1].get('source').split('\n'),
+ ['# Parameters', r'foo = "do\\ not\\ crash"', ''],
)
+ self.assertEqual(test_nb.metadata.papermill.parameters, {'foo': r'do\ not\ crash'})
def test_backslash_quote_params(self):
- execute_notebook(
- self.notebook_path, self.nb_test_executed_fname, {"foo": r"bar=\"baz\""}
- )
+ execute_notebook(self.notebook_path, self.nb_test_executed_fname, {'foo': r'bar=\"baz\"'})
test_nb = load_notebook_node(self.nb_test_executed_fname)
self.assertListEqual(
- test_nb.cells[1].get("source").split("\n"),
- ["# Parameters", r'foo = "bar=\\\"baz\\\""', ""],
+ test_nb.cells[1].get('source').split('\n'),
+ ['# Parameters', r'foo = "bar=\\\"baz\\\""', ''],
)
- self.assertEqual(test_nb.metadata.papermill.parameters, {"foo": r"bar=\"baz\""})
+ self.assertEqual(test_nb.metadata.papermill.parameters, {'foo': r'bar=\"baz\"'})
def test_double_backslash_quote_params(self):
- execute_notebook(
- self.notebook_path, self.nb_test_executed_fname, {"foo": r'\\"bar\\"'}
- )
+ execute_notebook(self.notebook_path, self.nb_test_executed_fname, {'foo': r'\\"bar\\"'})
test_nb = load_notebook_node(self.nb_test_executed_fname)
self.assertListEqual(
- test_nb.cells[1].get("source").split("\n"),
- ["# Parameters", r'foo = "\\\\\"bar\\\\\""', ""],
+ test_nb.cells[1].get('source').split('\n'),
+ ['# Parameters', r'foo = "\\\\\"bar\\\\\""', ''],
)
- self.assertEqual(test_nb.metadata.papermill.parameters, {"foo": r'\\"bar\\"'})
+ self.assertEqual(test_nb.metadata.papermill.parameters, {'foo': r'\\"bar\\"'})
def test_prepare_only(self):
- for example in ["broken1.ipynb", "keyboard_interrupt.ipynb"]:
+ for example in ['broken1.ipynb', 'keyboard_interrupt.ipynb']:
path = get_notebook_path(example)
result_path = os.path.join(self.test_dir, example)
# Should not raise as we don't execute the notebook at all
- execute_notebook(
- path, result_path, {"foo": r"do\ not\ crash"}, prepare_only=True
- )
+ execute_notebook(path, result_path, {'foo': r'do\ not\ crash'}, prepare_only=True)
nb = load_notebook_node(result_path)
- self.assertEqual(nb.cells[0].cell_type, "code")
+ self.assertEqual(nb.cells[0].cell_type, 'code')
self.assertEqual(
- nb.cells[0].get("source").split("\n"),
- ["# Parameters", r'foo = "do\\ not\\ crash"', ""],
+ nb.cells[0].get('source').split('\n'),
+ ['# Parameters', r'foo = "do\\ not\\ crash"', ''],
)
@@ -162,52 +141,43 @@ def tearDown(self):
shutil.rmtree(self.test_dir)
def test(self):
- path = get_notebook_path("broken1.ipynb")
+ path = get_notebook_path('broken1.ipynb')
# check that the notebook has two existing marker cells, so that this test is sure to be
# validating the removal logic (the markers are simulatin an error in the first code cell
# that has since been fixed)
original_nb = load_notebook_node(path)
- self.assertEqual(
- original_nb.cells[0].metadata["tags"], ["papermill-error-cell-tag"]
- )
- self.assertIn("In [1]", original_nb.cells[0].source)
- self.assertEqual(
- original_nb.cells[2].metadata["tags"], ["papermill-error-cell-tag"]
- )
+ self.assertEqual(original_nb.cells[0].metadata['tags'], ['papermill-error-cell-tag'])
+ self.assertIn('In [1]', original_nb.cells[0].source)
+ self.assertEqual(original_nb.cells[2].metadata['tags'], ['papermill-error-cell-tag'])
- result_path = os.path.join(self.test_dir, "broken1.ipynb")
+ result_path = os.path.join(self.test_dir, 'broken1.ipynb')
with self.assertRaises(PapermillExecutionError):
execute_notebook(path, result_path)
nb = load_notebook_node(result_path)
- self.assertEqual(nb.cells[0].cell_type, "markdown")
+ self.assertEqual(nb.cells[0].cell_type, 'markdown')
self.assertRegex(
nb.cells[0].source,
r'^$',
)
- self.assertEqual(nb.cells[0].metadata["tags"], ["papermill-error-cell-tag"])
+ self.assertEqual(nb.cells[0].metadata['tags'], ['papermill-error-cell-tag'])
- self.assertEqual(nb.cells[1].cell_type, "markdown")
+ self.assertEqual(nb.cells[1].cell_type, 'markdown')
self.assertEqual(nb.cells[2].execution_count, 1)
- self.assertEqual(nb.cells[3].cell_type, "markdown")
- self.assertEqual(nb.cells[4].cell_type, "markdown")
+ self.assertEqual(nb.cells[3].cell_type, 'markdown')
+ self.assertEqual(nb.cells[4].cell_type, 'markdown')
- self.assertEqual(nb.cells[5].cell_type, "markdown")
- self.assertRegex(
- nb.cells[5].source, ''
- )
- self.assertEqual(nb.cells[5].metadata["tags"], ["papermill-error-cell-tag"])
+ self.assertEqual(nb.cells[5].cell_type, 'markdown')
+ self.assertRegex(nb.cells[5].source, '')
+ self.assertEqual(nb.cells[5].metadata['tags'], ['papermill-error-cell-tag'])
self.assertEqual(nb.cells[6].execution_count, 2)
- self.assertEqual(nb.cells[6].outputs[0].output_type, "error")
+ self.assertEqual(nb.cells[6].outputs[0].output_type, 'error')
self.assertEqual(nb.cells[7].execution_count, None)
# double check the removal (the new cells above should be the only two tagged ones)
self.assertEqual(
- sum(
- "papermill-error-cell-tag" in cell.metadata.get("tags", [])
- for cell in nb.cells
- ),
+ sum('papermill-error-cell-tag' in cell.metadata.get('tags', []) for cell in nb.cells),
2,
)
@@ -220,25 +190,23 @@ def tearDown(self):
shutil.rmtree(self.test_dir)
def test(self):
- path = get_notebook_path("broken2.ipynb")
- result_path = os.path.join(self.test_dir, "broken2.ipynb")
+ path = get_notebook_path('broken2.ipynb')
+ result_path = os.path.join(self.test_dir, 'broken2.ipynb')
with self.assertRaises(PapermillExecutionError):
execute_notebook(path, result_path)
nb = load_notebook_node(result_path)
- self.assertEqual(nb.cells[0].cell_type, "markdown")
+ self.assertEqual(nb.cells[0].cell_type, 'markdown')
self.assertRegex(
nb.cells[0].source,
r'^.*In \[2\].*$',
)
self.assertEqual(nb.cells[1].execution_count, 1)
- self.assertEqual(nb.cells[2].cell_type, "markdown")
- self.assertRegex(
- nb.cells[2].source, ''
- )
+ self.assertEqual(nb.cells[2].cell_type, 'markdown')
+ self.assertRegex(nb.cells[2].source, '')
self.assertEqual(nb.cells[3].execution_count, 2)
- self.assertEqual(nb.cells[3].outputs[0].output_type, "display_data")
- self.assertEqual(nb.cells[3].outputs[1].output_type, "error")
+ self.assertEqual(nb.cells[3].outputs[0].output_type, 'display_data')
+ self.assertEqual(nb.cells[3].outputs[1].output_type, 'error')
self.assertEqual(nb.cells[4].execution_count, None)
@@ -246,33 +214,25 @@ def test(self):
class TestReportMode(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp()
- self.notebook_name = "report_mode_test.ipynb"
+ self.notebook_name = 'report_mode_test.ipynb'
self.notebook_path = get_notebook_path(self.notebook_name)
- self.nb_test_executed_fname = os.path.join(
- self.test_dir, f"output_{self.notebook_name}"
- )
+ self.nb_test_executed_fname = os.path.join(self.test_dir, f'output_{self.notebook_name}')
def tearDown(self):
shutil.rmtree(self.test_dir)
def test_report_mode(self):
- nb = execute_notebook(
- self.notebook_path, self.nb_test_executed_fname, {"a": 0}, report_mode=True
- )
+ nb = execute_notebook(self.notebook_path, self.nb_test_executed_fname, {'a': 0}, report_mode=True)
for cell in nb.cells:
- if cell.cell_type == "code":
- self.assertEqual(
- cell.metadata.get("jupyter", {}).get("source_hidden"), True
- )
+ if cell.cell_type == 'code':
+ self.assertEqual(cell.metadata.get('jupyter', {}).get('source_hidden'), True)
class TestOutputPathNone(unittest.TestCase):
def test_output_path_of_none(self):
"""Output path of None should return notebook node obj but not write an ipynb"""
- nb = execute_notebook(
- get_notebook_path("simple_execute.ipynb"), None, {"msg": "Hello"}
- )
- self.assertEqual(nb.metadata.papermill.parameters, {"msg": "Hello"})
+ nb = execute_notebook(get_notebook_path('simple_execute.ipynb'), None, {'msg': 'Hello'})
+ self.assertEqual(nb.metadata.papermill.parameters, {'msg': 'Hello'})
class TestCWD(unittest.TestCase):
@@ -280,26 +240,20 @@ def setUp(self):
self.test_dir = tempfile.mkdtemp()
self.base_test_dir = tempfile.mkdtemp()
- self.check_notebook_name = "read_check.ipynb"
- self.check_notebook_path = os.path.join(self.base_test_dir, "read_check.ipynb")
+ self.check_notebook_name = 'read_check.ipynb'
+ self.check_notebook_path = os.path.join(self.base_test_dir, 'read_check.ipynb')
# Setup read paths so base_test_dir has check_notebook_name
- shutil.copyfile(
- get_notebook_path(self.check_notebook_name), self.check_notebook_path
- )
- with open(os.path.join(self.test_dir, "check.txt"), "w", encoding="utf-8") as f:
+ shutil.copyfile(get_notebook_path(self.check_notebook_name), self.check_notebook_path)
+ with open(os.path.join(self.test_dir, 'check.txt'), 'w', encoding='utf-8') as f:
# Needed for read_check to pass
- f.write("exists")
+ f.write('exists')
- self.simple_notebook_name = "simple_execute.ipynb"
- self.simple_notebook_path = os.path.join(
- self.base_test_dir, "simple_execute.ipynb"
- )
+ self.simple_notebook_name = 'simple_execute.ipynb'
+ self.simple_notebook_path = os.path.join(self.base_test_dir, 'simple_execute.ipynb')
# Setup read paths so base_test_dir has simple_notebook_name
- shutil.copyfile(
- get_notebook_path(self.simple_notebook_name), self.simple_notebook_path
- )
+ shutil.copyfile(get_notebook_path(self.simple_notebook_name), self.simple_notebook_path)
- self.nb_test_executed_fname = "test_output.ipynb"
+ self.nb_test_executed_fname = 'test_output.ipynb'
def tearDown(self):
shutil.rmtree(self.test_dir)
@@ -313,23 +267,13 @@ def test_local_save_ignores_cwd_assignment(self):
self.nb_test_executed_fname,
cwd=self.test_dir,
)
- self.assertTrue(
- os.path.isfile(
- os.path.join(self.base_test_dir, self.nb_test_executed_fname)
- )
- )
+ self.assertTrue(os.path.isfile(os.path.join(self.base_test_dir, self.nb_test_executed_fname)))
def test_execution_respects_cwd_assignment(self):
with chdir(self.base_test_dir):
# Both paths are relative
- execute_notebook(
- self.check_notebook_name, self.nb_test_executed_fname, cwd=self.test_dir
- )
- self.assertTrue(
- os.path.isfile(
- os.path.join(self.base_test_dir, self.nb_test_executed_fname)
- )
- )
+ execute_notebook(self.check_notebook_name, self.nb_test_executed_fname, cwd=self.test_dir)
+ self.assertTrue(os.path.isfile(os.path.join(self.base_test_dir, self.nb_test_executed_fname)))
def test_pathlib_paths(self):
# Copy of test_execution_respects_cwd_assignment but with `Path`s
@@ -339,9 +283,7 @@ def test_pathlib_paths(self):
Path(self.nb_test_executed_fname),
cwd=Path(self.test_dir),
)
- self.assertTrue(
- Path(self.base_test_dir).joinpath(self.nb_test_executed_fname).exists()
- )
+ self.assertTrue(Path(self.base_test_dir).joinpath(self.nb_test_executed_fname).exists())
class TestSysExit(unittest.TestCase):
@@ -352,64 +294,62 @@ def tearDown(self):
shutil.rmtree(self.test_dir)
def test_sys_exit(self):
- notebook_name = "sysexit.ipynb"
- result_path = os.path.join(self.test_dir, f"output_{notebook_name}")
+ notebook_name = 'sysexit.ipynb'
+ result_path = os.path.join(self.test_dir, f'output_{notebook_name}')
execute_notebook(get_notebook_path(notebook_name), result_path)
nb = load_notebook_node(result_path)
- self.assertEqual(nb.cells[0].cell_type, "code")
+ self.assertEqual(nb.cells[0].cell_type, 'code')
self.assertEqual(nb.cells[0].execution_count, 1)
self.assertEqual(nb.cells[1].execution_count, 2)
- self.assertEqual(nb.cells[1].outputs[0].output_type, "error")
- self.assertEqual(nb.cells[1].outputs[0].ename, "SystemExit")
- self.assertEqual(nb.cells[1].outputs[0].evalue, "")
+ self.assertEqual(nb.cells[1].outputs[0].output_type, 'error')
+ self.assertEqual(nb.cells[1].outputs[0].ename, 'SystemExit')
+ self.assertEqual(nb.cells[1].outputs[0].evalue, '')
self.assertEqual(nb.cells[2].execution_count, None)
def test_sys_exit0(self):
- notebook_name = "sysexit0.ipynb"
- result_path = os.path.join(self.test_dir, f"output_{notebook_name}")
+ notebook_name = 'sysexit0.ipynb'
+ result_path = os.path.join(self.test_dir, f'output_{notebook_name}')
execute_notebook(get_notebook_path(notebook_name), result_path)
nb = load_notebook_node(result_path)
- self.assertEqual(nb.cells[0].cell_type, "code")
+ self.assertEqual(nb.cells[0].cell_type, 'code')
self.assertEqual(nb.cells[0].execution_count, 1)
self.assertEqual(nb.cells[1].execution_count, 2)
- self.assertEqual(nb.cells[1].outputs[0].output_type, "error")
- self.assertEqual(nb.cells[1].outputs[0].ename, "SystemExit")
- self.assertEqual(nb.cells[1].outputs[0].evalue, "0")
+ self.assertEqual(nb.cells[1].outputs[0].output_type, 'error')
+ self.assertEqual(nb.cells[1].outputs[0].ename, 'SystemExit')
+ self.assertEqual(nb.cells[1].outputs[0].evalue, '0')
self.assertEqual(nb.cells[2].execution_count, None)
def test_sys_exit1(self):
- notebook_name = "sysexit1.ipynb"
- result_path = os.path.join(self.test_dir, f"output_{notebook_name}")
+ notebook_name = 'sysexit1.ipynb'
+ result_path = os.path.join(self.test_dir, f'output_{notebook_name}')
with self.assertRaises(PapermillExecutionError):
execute_notebook(get_notebook_path(notebook_name), result_path)
nb = load_notebook_node(result_path)
- self.assertEqual(nb.cells[0].cell_type, "markdown")
+ self.assertEqual(nb.cells[0].cell_type, 'markdown')
self.assertRegex(
nb.cells[0].source,
r'^$',
)
self.assertEqual(nb.cells[1].execution_count, 1)
- self.assertEqual(nb.cells[2].cell_type, "markdown")
- self.assertRegex(
- nb.cells[2].source, ''
- )
+ self.assertEqual(nb.cells[2].cell_type, 'markdown')
+ self.assertRegex(nb.cells[2].source, '')
self.assertEqual(nb.cells[3].execution_count, 2)
- self.assertEqual(nb.cells[3].outputs[0].output_type, "error")
+ self.assertEqual(nb.cells[3].outputs[0].output_type, 'error')
self.assertEqual(nb.cells[4].execution_count, None)
def test_system_exit(self):
- notebook_name = "systemexit.ipynb"
- result_path = os.path.join(self.test_dir, f"output_{notebook_name}")
+ notebook_name = 'systemexit.ipynb'
+ result_path = os.path.join(self.test_dir, f'output_{notebook_name}')
execute_notebook(get_notebook_path(notebook_name), result_path)
nb = load_notebook_node(result_path)
- self.assertEqual(nb.cells[0].cell_type, "code")
+ self.assertEqual(nb.cells[0].cell_type, 'code')
self.assertEqual(nb.cells[0].execution_count, 1)
self.assertEqual(nb.cells[1].execution_count, 2)
- self.assertEqual(nb.cells[1].outputs[0].output_type, "error")
- self.assertEqual(nb.cells[1].outputs[0].ename, "SystemExit")
- self.assertEqual(nb.cells[1].outputs[0].evalue, "")
+ self.assertEqual(nb.cells[1].outputs[0].output_type, 'error')
+ self.assertEqual(nb.cells[1].outputs[0].ename, 'SystemExit')
+ self.assertEqual(nb.cells[1].outputs[0].evalue, '')
self.assertEqual(nb.cells[2].execution_count, None)
@@ -421,11 +361,9 @@ def tearDown(self):
shutil.rmtree(self.test_dir)
def test_from_version_4_4_upgrades(self):
- notebook_name = "nb_version_4.4.ipynb"
- result_path = os.path.join(self.test_dir, f"output_{notebook_name}")
- execute_notebook(
- get_notebook_path(notebook_name), result_path, {"var": "It works"}
- )
+ notebook_name = 'nb_version_4.4.ipynb'
+ result_path = os.path.join(self.test_dir, f'output_{notebook_name}')
+ execute_notebook(get_notebook_path(notebook_name), result_path, {'var': 'It works'})
nb = load_notebook_node(result_path)
validate(nb)
@@ -438,11 +376,9 @@ def tearDown(self):
shutil.rmtree(self.test_dir)
def test_no_v3_language_backport(self):
- notebook_name = "blank-vscode.ipynb"
- result_path = os.path.join(self.test_dir, f"output_{notebook_name}")
- execute_notebook(
- get_notebook_path(notebook_name), result_path, {"var": "It works"}
- )
+ notebook_name = 'blank-vscode.ipynb'
+ result_path = os.path.join(self.test_dir, f'output_{notebook_name}')
+ execute_notebook(get_notebook_path(notebook_name), result_path, {'var': 'It works'})
nb = load_notebook_node(result_path)
validate(nb)
@@ -455,25 +391,21 @@ def execute_managed_notebook(cls, nb_man, kernel_name, **kwargs):
@classmethod
def nb_kernel_name(cls, nb, name=None):
- return "my_custom_kernel"
+ return 'my_custom_kernel'
@classmethod
def nb_language(cls, nb, language=None):
- return "my_custom_language"
+ return 'my_custom_language'
def setUp(self):
self.test_dir = tempfile.mkdtemp()
- self.notebook_path = get_notebook_path("simple_execute.ipynb")
- self.nb_test_executed_fname = os.path.join(
- self.test_dir, "output_{}".format("simple_execute.ipynb")
- )
+ self.notebook_path = get_notebook_path('simple_execute.ipynb')
+ self.nb_test_executed_fname = os.path.join(self.test_dir, 'output_{}'.format('simple_execute.ipynb'))
self._orig_papermill_engines = deepcopy(engines.papermill_engines)
self._orig_translators = deepcopy(translators.papermill_translators)
- engines.papermill_engines.register("custom_engine", self.CustomEngine)
- translators.papermill_translators.register(
- "my_custom_language", translators.PythonTranslator()
- )
+ engines.papermill_engines.register('custom_engine', self.CustomEngine)
+ translators.papermill_translators.register('my_custom_language', translators.PythonTranslator())
def tearDown(self):
shutil.rmtree(self.test_dir)
@@ -482,46 +414,40 @@ def tearDown(self):
@patch.object(
CustomEngine,
- "execute_managed_notebook",
+ 'execute_managed_notebook',
wraps=CustomEngine.execute_managed_notebook,
)
@patch(
- "papermill.parameterize.translate_parameters",
+ 'papermill.parameterize.translate_parameters',
wraps=translators.translate_parameters,
)
- def test_custom_kernel_name_and_language(
- self, translate_parameters, execute_managed_notebook
- ):
+ def test_custom_kernel_name_and_language(self, translate_parameters, execute_managed_notebook):
"""Tests execute against engine with custom implementations to fetch
kernel name and language from the notebook object
"""
execute_notebook(
self.notebook_path,
self.nb_test_executed_fname,
- engine_name="custom_engine",
- parameters={"msg": "fake msg"},
- )
- self.assertEqual(
- execute_managed_notebook.call_args[0], (ANY, "my_custom_kernel")
+ engine_name='custom_engine',
+ parameters={'msg': 'fake msg'},
)
+ self.assertEqual(execute_managed_notebook.call_args[0], (ANY, 'my_custom_kernel'))
self.assertEqual(
translate_parameters.call_args[0],
- (ANY, "my_custom_language", {"msg": "fake msg"}, ANY),
+ (ANY, 'my_custom_language', {'msg': 'fake msg'}, ANY),
)
class TestNotebookNodeInput(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.TemporaryDirectory()
- self.result_path = os.path.join(self.test_dir.name, "output.ipynb")
+ self.result_path = os.path.join(self.test_dir.name, 'output.ipynb')
def tearDown(self):
self.test_dir.cleanup()
def test_notebook_node_input(self):
- input_nb = nbformat.read(
- get_notebook_path("simple_execute.ipynb"), as_version=4
- )
- execute_notebook(input_nb, self.result_path, {"msg": "Hello"})
+ input_nb = nbformat.read(get_notebook_path('simple_execute.ipynb'), as_version=4)
+ execute_notebook(input_nb, self.result_path, {'msg': 'Hello'})
test_nb = nbformat.read(self.result_path, as_version=4)
- self.assertEqual(test_nb.metadata.papermill.parameters, {"msg": "Hello"})
+ self.assertEqual(test_nb.metadata.papermill.parameters, {'msg': 'Hello'})
diff --git a/papermill/tests/test_gcs.py b/papermill/tests/test_gcs.py
index 280deb8f..61de47b5 100644
--- a/papermill/tests/test_gcs.py
+++ b/papermill/tests/test_gcs.py
@@ -69,124 +69,100 @@ class GCSTest(unittest.TestCase):
def setUp(self):
self.gcs_handler = GCSHandler()
- @patch("papermill.iorw.GCSFileSystem", side_effect=mock_gcs_fs_wrapper())
+ @patch('papermill.iorw.GCSFileSystem', side_effect=mock_gcs_fs_wrapper())
def test_gcs_read(self, mock_gcs_filesystem):
client = self.gcs_handler._get_client()
- self.assertEqual(self.gcs_handler.read("gs://bucket/test.ipynb"), 1)
+ self.assertEqual(self.gcs_handler.read('gs://bucket/test.ipynb'), 1)
# Check that client is only generated once
self.assertIs(client, self.gcs_handler._get_client())
- @patch("papermill.iorw.GCSFileSystem", side_effect=mock_gcs_fs_wrapper())
+ @patch('papermill.iorw.GCSFileSystem', side_effect=mock_gcs_fs_wrapper())
def test_gcs_write(self, mock_gcs_filesystem):
client = self.gcs_handler._get_client()
- self.assertEqual(
- self.gcs_handler.write("new value", "gs://bucket/test.ipynb"), 1
- )
+ self.assertEqual(self.gcs_handler.write('new value', 'gs://bucket/test.ipynb'), 1)
# Check that client is only generated once
self.assertIs(client, self.gcs_handler._get_client())
- @patch("papermill.iorw.GCSFileSystem", side_effect=mock_gcs_fs_wrapper())
+ @patch('papermill.iorw.GCSFileSystem', side_effect=mock_gcs_fs_wrapper())
def test_gcs_listdir(self, mock_gcs_filesystem):
client = self.gcs_handler._get_client()
- self.gcs_handler.listdir("testdir")
+ self.gcs_handler.listdir('testdir')
# Check that client is only generated once
self.assertIs(client, self.gcs_handler._get_client())
@patch(
- "papermill.iorw.GCSFileSystem",
- side_effect=mock_gcs_fs_wrapper(
- GCSRateLimitException({"message": "test", "code": 429}), 10
- ),
+ 'papermill.iorw.GCSFileSystem',
+ side_effect=mock_gcs_fs_wrapper(GCSRateLimitException({'message': 'test', 'code': 429}), 10),
)
def test_gcs_handle_exception(self, mock_gcs_filesystem):
- with patch.object(GCSHandler, "RETRY_DELAY", 0):
- with patch.object(GCSHandler, "RETRY_MULTIPLIER", 0):
- with patch.object(GCSHandler, "RETRY_MAX_DELAY", 0):
+ with patch.object(GCSHandler, 'RETRY_DELAY', 0):
+ with patch.object(GCSHandler, 'RETRY_MULTIPLIER', 0):
+ with patch.object(GCSHandler, 'RETRY_MAX_DELAY', 0):
with self.assertRaises(PapermillRateLimitException):
- self.gcs_handler.write(
- "raise_limit_exception", "gs://bucket/test.ipynb"
- )
+ self.gcs_handler.write('raise_limit_exception', 'gs://bucket/test.ipynb')
@patch(
- "papermill.iorw.GCSFileSystem",
- side_effect=mock_gcs_fs_wrapper(
- GCSRateLimitException({"message": "test", "code": 429}), 1
- ),
+ 'papermill.iorw.GCSFileSystem',
+ side_effect=mock_gcs_fs_wrapper(GCSRateLimitException({'message': 'test', 'code': 429}), 1),
)
def test_gcs_retry(self, mock_gcs_filesystem):
- with patch.object(GCSHandler, "RETRY_DELAY", 0):
- with patch.object(GCSHandler, "RETRY_MULTIPLIER", 0):
- with patch.object(GCSHandler, "RETRY_MAX_DELAY", 0):
+ with patch.object(GCSHandler, 'RETRY_DELAY', 0):
+ with patch.object(GCSHandler, 'RETRY_MULTIPLIER', 0):
+ with patch.object(GCSHandler, 'RETRY_MAX_DELAY', 0):
self.assertEqual(
- self.gcs_handler.write(
- "raise_limit_exception", "gs://bucket/test.ipynb"
- ),
+ self.gcs_handler.write('raise_limit_exception', 'gs://bucket/test.ipynb'),
2,
)
@patch(
- "papermill.iorw.GCSFileSystem",
- side_effect=mock_gcs_fs_wrapper(
- GCSHttpError({"message": "test", "code": 429}), 1
- ),
+ 'papermill.iorw.GCSFileSystem',
+ side_effect=mock_gcs_fs_wrapper(GCSHttpError({'message': 'test', 'code': 429}), 1),
)
def test_gcs_retry_older_exception(self, mock_gcs_filesystem):
- with patch.object(GCSHandler, "RETRY_DELAY", 0):
- with patch.object(GCSHandler, "RETRY_MULTIPLIER", 0):
- with patch.object(GCSHandler, "RETRY_MAX_DELAY", 0):
+ with patch.object(GCSHandler, 'RETRY_DELAY', 0):
+ with patch.object(GCSHandler, 'RETRY_MULTIPLIER', 0):
+ with patch.object(GCSHandler, 'RETRY_MAX_DELAY', 0):
self.assertEqual(
- self.gcs_handler.write(
- "raise_limit_exception", "gs://bucket/test.ipynb"
- ),
+ self.gcs_handler.write('raise_limit_exception', 'gs://bucket/test.ipynb'),
2,
)
- @patch("papermill.iorw.gs_is_retriable", side_effect=fallback_gs_is_retriable)
+ @patch('papermill.iorw.gs_is_retriable', side_effect=fallback_gs_is_retriable)
@patch(
- "papermill.iorw.GCSFileSystem",
- side_effect=mock_gcs_fs_wrapper(
- GCSRateLimitException({"message": "test", "code": None}), 1
- ),
+ 'papermill.iorw.GCSFileSystem',
+ side_effect=mock_gcs_fs_wrapper(GCSRateLimitException({'message': 'test', 'code': None}), 1),
)
- def test_gcs_fallback_retry_unknown_failure_code(
- self, mock_gcs_filesystem, mock_gcs_retriable
- ):
- with patch.object(GCSHandler, "RETRY_DELAY", 0):
- with patch.object(GCSHandler, "RETRY_MULTIPLIER", 0):
- with patch.object(GCSHandler, "RETRY_MAX_DELAY", 0):
+ def test_gcs_fallback_retry_unknown_failure_code(self, mock_gcs_filesystem, mock_gcs_retriable):
+ with patch.object(GCSHandler, 'RETRY_DELAY', 0):
+ with patch.object(GCSHandler, 'RETRY_MULTIPLIER', 0):
+ with patch.object(GCSHandler, 'RETRY_MAX_DELAY', 0):
self.assertEqual(
- self.gcs_handler.write(
- "raise_limit_exception", "gs://bucket/test.ipynb"
- ),
+ self.gcs_handler.write('raise_limit_exception', 'gs://bucket/test.ipynb'),
2,
)
- @patch("papermill.iorw.gs_is_retriable", return_value=False)
+ @patch('papermill.iorw.gs_is_retriable', return_value=False)
@patch(
- "papermill.iorw.GCSFileSystem",
- side_effect=mock_gcs_fs_wrapper(
- GCSRateLimitException({"message": "test", "code": 500}), 1
- ),
+ 'papermill.iorw.GCSFileSystem',
+ side_effect=mock_gcs_fs_wrapper(GCSRateLimitException({'message': 'test', 'code': 500}), 1),
)
def test_gcs_invalid_code(self, mock_gcs_filesystem, mock_gcs_retriable):
with self.assertRaises(GCSRateLimitException):
- self.gcs_handler.write("fatal_exception", "gs://bucket/test.ipynb")
+ self.gcs_handler.write('fatal_exception', 'gs://bucket/test.ipynb')
- @patch("papermill.iorw.gs_is_retriable", side_effect=fallback_gs_is_retriable)
+ @patch('papermill.iorw.gs_is_retriable', side_effect=fallback_gs_is_retriable)
@patch(
- "papermill.iorw.GCSFileSystem",
- side_effect=mock_gcs_fs_wrapper(
- GCSRateLimitException({"message": "test", "code": 500}), 1
- ),
+ 'papermill.iorw.GCSFileSystem',
+ side_effect=mock_gcs_fs_wrapper(GCSRateLimitException({'message': 'test', 'code': 500}), 1),
)
def test_fallback_gcs_invalid_code(self, mock_gcs_filesystem, mock_gcs_retriable):
with self.assertRaises(GCSRateLimitException):
- self.gcs_handler.write("fatal_exception", "gs://bucket/test.ipynb")
+ self.gcs_handler.write('fatal_exception', 'gs://bucket/test.ipynb')
@patch(
- "papermill.iorw.GCSFileSystem",
- side_effect=mock_gcs_fs_wrapper(ValueError("not-a-retry"), 1),
+ 'papermill.iorw.GCSFileSystem',
+ side_effect=mock_gcs_fs_wrapper(ValueError('not-a-retry'), 1),
)
def test_gcs_unretryable(self, mock_gcs_filesystem):
with self.assertRaises(ValueError):
- self.gcs_handler.write("no_a_rate_limit", "gs://bucket/test.ipynb")
+ self.gcs_handler.write('no_a_rate_limit', 'gs://bucket/test.ipynb')
diff --git a/papermill/tests/test_hdfs.py b/papermill/tests/test_hdfs.py
index 0577e1f5..e8c49dd2 100644
--- a/papermill/tests/test_hdfs.py
+++ b/papermill/tests/test_hdfs.py
@@ -8,7 +8,7 @@
class MockHadoopFileSystem(MagicMock):
def get_file_info(self, path):
- return [MockFileInfo("test1.ipynb"), MockFileInfo("test2.ipynb")]
+ return [MockFileInfo('test1.ipynb'), MockFileInfo('test2.ipynb')]
def open_input_stream(self, path):
return MockHadoopFile()
@@ -19,7 +19,7 @@ def open_output_stream(self, path):
class MockHadoopFile:
def __init__(self):
- self._content = b"Content of notebook"
+ self._content = b'Content of notebook'
def __enter__(self, *args):
return self
@@ -40,8 +40,8 @@ def __init__(self, path):
self.path = path
-@pytest.mark.skip(reason="No valid dep package for python 3.12 yet")
-@patch("papermill.iorw.HadoopFileSystem", side_effect=MockHadoopFileSystem())
+@pytest.mark.skip(reason='No valid dep package for python 3.12 yet')
+@patch('papermill.iorw.HadoopFileSystem', side_effect=MockHadoopFileSystem())
class HDFSTest(unittest.TestCase):
def setUp(self):
self.hdfs_handler = HDFSHandler()
@@ -49,8 +49,8 @@ def setUp(self):
def test_hdfs_listdir(self, mock_hdfs_filesystem):
client = self.hdfs_handler._get_client()
self.assertEqual(
- self.hdfs_handler.listdir("hdfs:///Projects/"),
- ["test1.ipynb", "test2.ipynb"],
+ self.hdfs_handler.listdir('hdfs:///Projects/'),
+ ['test1.ipynb', 'test2.ipynb'],
)
# Check if client is the same after calling
self.assertIs(client, self.hdfs_handler._get_client())
@@ -58,14 +58,12 @@ def test_hdfs_listdir(self, mock_hdfs_filesystem):
def test_hdfs_read(self, mock_hdfs_filesystem):
client = self.hdfs_handler._get_client()
self.assertEqual(
- self.hdfs_handler.read("hdfs:///Projects/test1.ipynb"),
- b"Content of notebook",
+ self.hdfs_handler.read('hdfs:///Projects/test1.ipynb'),
+ b'Content of notebook',
)
self.assertIs(client, self.hdfs_handler._get_client())
def test_hdfs_write(self, mock_hdfs_filesystem):
client = self.hdfs_handler._get_client()
- self.assertEqual(
- self.hdfs_handler.write("hdfs:///Projects/test1.ipynb", b"New content"), 1
- )
+ self.assertEqual(self.hdfs_handler.write('hdfs:///Projects/test1.ipynb', b'New content'), 1)
self.assertIs(client, self.hdfs_handler._get_client())
diff --git a/papermill/tests/test_inspect.py b/papermill/tests/test_inspect.py
index bab1df65..6d787e2d 100644
--- a/papermill/tests/test_inspect.py
+++ b/papermill/tests/test_inspect.py
@@ -3,11 +3,9 @@
import pytest
from click import Context
-
from papermill.inspection import display_notebook_help, inspect_notebook
-
-NOTEBOOKS_PATH = Path(__file__).parent / "notebooks"
+NOTEBOOKS_PATH = Path(__file__).parent / 'notebooks'
def _get_fullpath(name):
@@ -17,55 +15,55 @@ def _get_fullpath(name):
@pytest.fixture
def click_context():
mock = MagicMock(spec=Context, command=MagicMock())
- mock.command.get_usage.return_value = "Dummy usage"
+ mock.command.get_usage.return_value = 'Dummy usage'
return mock
@pytest.mark.parametrize(
- "name, expected",
+ 'name, expected',
[
- (_get_fullpath("no_parameters.ipynb"), {}),
+ (_get_fullpath('no_parameters.ipynb'), {}),
(
- _get_fullpath("simple_execute.ipynb"),
+ _get_fullpath('simple_execute.ipynb'),
{
- "msg": {
- "name": "msg",
- "inferred_type_name": "None",
- "default": "None",
- "help": "",
+ 'msg': {
+ 'name': 'msg',
+ 'inferred_type_name': 'None',
+ 'default': 'None',
+ 'help': '',
}
},
),
(
- _get_fullpath("complex_parameters.ipynb"),
+ _get_fullpath('complex_parameters.ipynb'),
{
- "msg": {
- "name": "msg",
- "inferred_type_name": "None",
- "default": "None",
- "help": "",
+ 'msg': {
+ 'name': 'msg',
+ 'inferred_type_name': 'None',
+ 'default': 'None',
+ 'help': '',
},
- "a": {
- "name": "a",
- "inferred_type_name": "float",
- "default": "2.25",
- "help": "Variable a",
+ 'a': {
+ 'name': 'a',
+ 'inferred_type_name': 'float',
+ 'default': '2.25',
+ 'help': 'Variable a',
},
- "b": {
- "name": "b",
- "inferred_type_name": "List[str]",
- "default": "['Hello','World']",
- "help": "Nice list",
+ 'b': {
+ 'name': 'b',
+ 'inferred_type_name': 'List[str]',
+ 'default': "['Hello','World']",
+ 'help': 'Nice list',
},
- "c": {
- "name": "c",
- "inferred_type_name": "NoneType",
- "default": "None",
- "help": "",
+ 'c': {
+ 'name': 'c',
+ 'inferred_type_name': 'NoneType',
+ 'default': 'None',
+ 'help': '',
},
},
),
- (_get_fullpath("notimplemented_translator.ipynb"), {}),
+ (_get_fullpath('notimplemented_translator.ipynb'), {}),
],
)
def test_inspect_notebook(name, expected):
@@ -74,50 +72,50 @@ def test_inspect_notebook(name, expected):
def test_str_path():
expected = {
- "msg": {
- "name": "msg",
- "inferred_type_name": "None",
- "default": "None",
- "help": "",
+ 'msg': {
+ 'name': 'msg',
+ 'inferred_type_name': 'None',
+ 'default': 'None',
+ 'help': '',
}
}
- assert inspect_notebook(str(_get_fullpath("simple_execute.ipynb"))) == expected
+ assert inspect_notebook(str(_get_fullpath('simple_execute.ipynb'))) == expected
@pytest.mark.parametrize(
- "name, expected",
+ 'name, expected',
[
(
- _get_fullpath("no_parameters.ipynb"),
+ _get_fullpath('no_parameters.ipynb'),
[
- "Dummy usage",
+ 'Dummy usage',
"\nParameters inferred for notebook '{name}':",
"\n No cell tagged 'parameters'",
],
),
(
- _get_fullpath("simple_execute.ipynb"),
+ _get_fullpath('simple_execute.ipynb'),
[
- "Dummy usage",
+ 'Dummy usage',
"\nParameters inferred for notebook '{name}':",
- " msg: Unknown type (default None)",
+ ' msg: Unknown type (default None)',
],
),
(
- _get_fullpath("complex_parameters.ipynb"),
+ _get_fullpath('complex_parameters.ipynb'),
[
- "Dummy usage",
+ 'Dummy usage',
"\nParameters inferred for notebook '{name}':",
- " msg: Unknown type (default None)",
- " a: float (default 2.25) Variable a",
+ ' msg: Unknown type (default None)',
+ ' a: float (default 2.25) Variable a',
" b: List[str] (default ['Hello','World'])\n Nice list",
- " c: NoneType (default None) ",
+ ' c: NoneType (default None) ',
],
),
(
- _get_fullpath("notimplemented_translator.ipynb"),
+ _get_fullpath('notimplemented_translator.ipynb'),
[
- "Dummy usage",
+ 'Dummy usage',
"\nParameters inferred for notebook '{name}':",
"\n Can't infer anything about this notebook's parameters. It may not have any parameter defined.", # noqa
],
@@ -125,7 +123,7 @@ def test_str_path():
],
)
def test_display_notebook_help(click_context, name, expected):
- with patch("papermill.inspection.click.echo") as echo:
+ with patch('papermill.inspection.click.echo') as echo:
display_notebook_help(click_context, str(name), None)
assert echo.call_count == len(expected)
diff --git a/papermill/tests/test_iorw.py b/papermill/tests/test_iorw.py
index 39ad12b0..cb1eab75 100644
--- a/papermill/tests/test_iorw.py
+++ b/papermill/tests/test_iorw.py
@@ -1,31 +1,31 @@
+import io
import json
-import unittest
import os
-import io
+import unittest
+from tempfile import TemporaryDirectory
+from unittest.mock import Mock, patch
+
import nbformat
import pytest
-
from requests.exceptions import ConnectionError
-from tempfile import TemporaryDirectory
-from unittest.mock import Mock, patch
from .. import iorw
+from ..exceptions import PapermillException
from ..iorw import (
+ ADLHandler,
HttpHandler,
LocalHandler,
NoIOHandler,
- ADLHandler,
NotebookNodeHandler,
- StreamHandler,
PapermillIO,
- read_yaml_file,
- papermill_io,
+ StreamHandler,
local_file_io_cwd,
+ papermill_io,
+ read_yaml_file,
)
-from ..exceptions import PapermillException
from . import get_notebook_path
-FIXTURE_PATH = os.path.join(os.path.dirname(__file__), "fixtures")
+FIXTURE_PATH = os.path.join(os.path.dirname(__file__), 'fixtures')
class TestPapermillIO(unittest.TestCase):
@@ -38,16 +38,16 @@ def __init__(self, ver):
self.ver = ver
def read(self, path):
- return f"contents from {path} for version {self.ver}"
+ return f'contents from {path} for version {self.ver}'
def listdir(self, path):
- return ["fake", "contents"]
+ return ['fake', 'contents']
def write(self, buf, path):
- return f"wrote {buf}"
+ return f'wrote {buf}'
def pretty_path(self, path):
- return f"{path}/pretty/{self.ver}"
+ return f'{path}/pretty/{self.ver}'
class FakeByteHandler:
def __init__(self, ver):
@@ -59,13 +59,13 @@ def read(self, path):
return f.read()
def listdir(self, path):
- return ["fake", "contents"]
+ return ['fake', 'contents']
def write(self, buf, path):
- return f"wrote {buf}"
+ return f'wrote {buf}'
def pretty_path(self, path):
- return f"{path}/pretty/{self.ver}"
+ return f'{path}/pretty/{self.ver}'
def setUp(self):
self.papermill_io = PapermillIO()
@@ -73,8 +73,8 @@ def setUp(self):
self.fake1 = self.FakeHandler(1)
self.fake2 = self.FakeHandler(2)
self.fake_byte1 = self.FakeByteHandler(1)
- self.papermill_io.register("fake", self.fake1)
- self.papermill_io_bytes.register("notebooks", self.fake_byte1)
+ self.papermill_io.register('fake', self.fake1)
+ self.papermill_io_bytes.register('notebooks', self.fake_byte1)
self.old_papermill_io = iorw.papermill_io
iorw.papermill_io = self.papermill_io
@@ -83,117 +83,103 @@ def tearDown(self):
iorw.papermill_io = self.old_papermill_io
def test_get_handler(self):
- self.assertEqual(self.papermill_io.get_handler("fake"), self.fake1)
+ self.assertEqual(self.papermill_io.get_handler('fake'), self.fake1)
def test_get_local_handler(self):
with self.assertRaises(PapermillException):
- self.papermill_io.get_handler("dne")
+ self.papermill_io.get_handler('dne')
- self.papermill_io.register("local", self.fake2)
- self.assertEqual(self.papermill_io.get_handler("dne"), self.fake2)
+ self.papermill_io.register('local', self.fake2)
+ self.assertEqual(self.papermill_io.get_handler('dne'), self.fake2)
def test_get_no_io_handler(self):
self.assertIsInstance(self.papermill_io.get_handler(None), NoIOHandler)
def test_get_notebook_node_handler(self):
- test_nb = nbformat.read(
- get_notebook_path("test_notebooknode_io.ipynb"), as_version=4
- )
- self.assertIsInstance(
- self.papermill_io.get_handler(test_nb), NotebookNodeHandler
- )
+ test_nb = nbformat.read(get_notebook_path('test_notebooknode_io.ipynb'), as_version=4)
+ self.assertIsInstance(self.papermill_io.get_handler(test_nb), NotebookNodeHandler)
def test_entrypoint_register(self):
fake_entrypoint = Mock(load=Mock())
- fake_entrypoint.name = "fake-from-entry-point://"
+ fake_entrypoint.name = 'fake-from-entry-point://'
- with patch(
- "entrypoints.get_group_all", return_value=[fake_entrypoint]
- ) as mock_get_group_all:
+ with patch('entrypoints.get_group_all', return_value=[fake_entrypoint]) as mock_get_group_all:
self.papermill_io.register_entry_points()
- mock_get_group_all.assert_called_once_with("papermill.io")
- fake_ = self.papermill_io.get_handler("fake-from-entry-point://")
+ mock_get_group_all.assert_called_once_with('papermill.io')
+ fake_ = self.papermill_io.get_handler('fake-from-entry-point://')
assert fake_ == fake_entrypoint.load.return_value
def test_register_ordering(self):
# Should match fake1 with fake2 path
- self.assertEqual(self.papermill_io.get_handler("fake2/path"), self.fake1)
+ self.assertEqual(self.papermill_io.get_handler('fake2/path'), self.fake1)
self.papermill_io.reset()
- self.papermill_io.register("fake", self.fake1)
- self.papermill_io.register("fake2", self.fake2)
+ self.papermill_io.register('fake', self.fake1)
+ self.papermill_io.register('fake2', self.fake2)
# Should match fake1 with fake1 path, and NOT fake2 path/match
- self.assertEqual(self.papermill_io.get_handler("fake/path"), self.fake1)
+ self.assertEqual(self.papermill_io.get_handler('fake/path'), self.fake1)
# Should match fake2 with fake2 path
- self.assertEqual(self.papermill_io.get_handler("fake2/path"), self.fake2)
+ self.assertEqual(self.papermill_io.get_handler('fake2/path'), self.fake2)
def test_read(self):
- self.assertEqual(
- self.papermill_io.read("fake/path"), "contents from fake/path for version 1"
- )
+ self.assertEqual(self.papermill_io.read('fake/path'), 'contents from fake/path for version 1')
def test_read_bytes(self):
- self.assertIsNotNone(
- self.papermill_io_bytes.read(
- "notebooks/gcs/gcs_in/gcs-simple_notebook.ipynb"
- )
- )
+ self.assertIsNotNone(self.papermill_io_bytes.read('notebooks/gcs/gcs_in/gcs-simple_notebook.ipynb'))
def test_read_with_no_file_extension(self):
with pytest.warns(UserWarning):
- self.papermill_io.read("fake/path")
+ self.papermill_io.read('fake/path')
def test_read_with_invalid_file_extension(self):
with pytest.warns(UserWarning):
- self.papermill_io.read("fake/path/fakeinputpath.ipynb1")
+ self.papermill_io.read('fake/path/fakeinputpath.ipynb1')
def test_read_with_valid_file_extension(self):
with pytest.warns(None) as warns:
- self.papermill_io.read("fake/path/fakeinputpath.ipynb")
+ self.papermill_io.read('fake/path/fakeinputpath.ipynb')
self.assertEqual(len(warns), 0)
def test_read_yaml_with_no_file_extension(self):
with pytest.warns(UserWarning):
- read_yaml_file("fake/path")
+ read_yaml_file('fake/path')
def test_read_yaml_with_invalid_file_extension(self):
with pytest.warns(UserWarning):
- read_yaml_file("fake/path/fakeinputpath.ipynb")
+ read_yaml_file('fake/path/fakeinputpath.ipynb')
def test_read_stdin(self):
- file_content = "Τὴ γλῶσσα μοῦ ἔδωσαν ἑλληνικὴ"
- with patch("sys.stdin", io.StringIO(file_content)):
- self.assertEqual(self.old_papermill_io.read("-"), file_content)
+ file_content = 'Τὴ γλῶσσα μοῦ ἔδωσαν ἑλληνικὴ'
+ with patch('sys.stdin', io.StringIO(file_content)):
+ self.assertEqual(self.old_papermill_io.read('-'), file_content)
def test_listdir(self):
- self.assertEqual(self.papermill_io.listdir("fake/path"), ["fake", "contents"])
+ self.assertEqual(self.papermill_io.listdir('fake/path'), ['fake', 'contents'])
def test_write(self):
- self.assertEqual(self.papermill_io.write("buffer", "fake/path"), "wrote buffer")
+ self.assertEqual(self.papermill_io.write('buffer', 'fake/path'), 'wrote buffer')
def test_write_with_no_file_extension(self):
with pytest.warns(UserWarning):
- self.papermill_io.write("buffer", "fake/path")
+ self.papermill_io.write('buffer', 'fake/path')
def test_write_with_path_of_none(self):
- self.assertIsNone(self.papermill_io.write("buffer", None))
+ self.assertIsNone(self.papermill_io.write('buffer', None))
def test_write_with_invalid_file_extension(self):
with pytest.warns(UserWarning):
- self.papermill_io.write("buffer", "fake/path/fakeoutputpath.ipynb1")
+ self.papermill_io.write('buffer', 'fake/path/fakeoutputpath.ipynb1')
def test_write_stdout(self):
- file_content = "Τὴ γλῶσσα μοῦ ἔδωσαν ἑλληνικὴ"
+ file_content = 'Τὴ γλῶσσα μοῦ ἔδωσαν ἑλληνικὴ'
out = io.BytesIO()
- with patch("sys.stdout", out):
- self.old_papermill_io.write(file_content, "-")
- self.assertEqual(out.getvalue(), file_content.encode("utf-8"))
+ with patch('sys.stdout', out):
+ self.old_papermill_io.write(file_content, '-')
+ self.assertEqual(out.getvalue(), file_content.encode('utf-8'))
def test_pretty_path(self):
- self.assertEqual(
- self.papermill_io.pretty_path("fake/path"), "fake/path/pretty/1"
- )
+ self.assertEqual(self.papermill_io.pretty_path('fake/path'), 'fake/path/pretty/1')
class TestLocalHandler(unittest.TestCase):
@@ -202,36 +188,34 @@ class TestLocalHandler(unittest.TestCase):
"""
def test_read_utf8(self):
- self.assertEqual(
- LocalHandler().read(os.path.join(FIXTURE_PATH, "rock.txt")).strip(), "✄"
- )
+ self.assertEqual(LocalHandler().read(os.path.join(FIXTURE_PATH, 'rock.txt')).strip(), '✄')
def test_write_utf8(self):
with TemporaryDirectory() as temp_dir:
- path = os.path.join(temp_dir, "paper.txt")
- LocalHandler().write("✄", path)
- with open(path, encoding="utf-8") as f:
- self.assertEqual(f.read().strip(), "✄")
+ path = os.path.join(temp_dir, 'paper.txt')
+ LocalHandler().write('✄', path)
+ with open(path, encoding='utf-8') as f:
+ self.assertEqual(f.read().strip(), '✄')
def test_write_no_directory_exists(self):
with self.assertRaises(FileNotFoundError):
- LocalHandler().write("buffer", "fake/path/fakenb.ipynb")
+ LocalHandler().write('buffer', 'fake/path/fakenb.ipynb')
def test_write_local_directory(self):
- with patch.object(io, "open"):
+ with patch.object(io, 'open'):
# Shouldn't raise with missing directory
- LocalHandler().write("buffer", "local.ipynb")
+ LocalHandler().write('buffer', 'local.ipynb')
def test_write_passed_cwd(self):
with TemporaryDirectory() as temp_dir:
handler = LocalHandler()
handler.cwd(temp_dir)
- handler.write("✄", "paper.txt")
+ handler.write('✄', 'paper.txt')
- path = os.path.join(temp_dir, "paper.txt")
- with open(path, encoding="utf-8") as f:
- self.assertEqual(f.read().strip(), "✄")
+ path = os.path.join(temp_dir, 'paper.txt')
+ with open(path, encoding='utf-8') as f:
+ self.assertEqual(f.read().strip(), '✄')
def test_local_file_io_cwd(self):
with TemporaryDirectory() as temp_dir:
@@ -241,16 +225,16 @@ def test_local_file_io_cwd(self):
try:
local_handler = LocalHandler()
papermill_io.reset()
- papermill_io.register("local", local_handler)
+ papermill_io.register('local', local_handler)
with local_file_io_cwd(temp_dir):
- local_handler.write("✄", "paper.txt")
- self.assertEqual(local_handler.read("paper.txt"), "✄")
+ local_handler.write('✄', 'paper.txt')
+ self.assertEqual(local_handler.read('paper.txt'), '✄')
# Double check it used the tmpdir
- path = os.path.join(temp_dir, "paper.txt")
- with open(path, encoding="utf-8") as f:
- self.assertEqual(f.read().strip(), "✄")
+ path = os.path.join(temp_dir, 'paper.txt')
+ with open(path, encoding='utf-8') as f:
+ self.assertEqual(f.read().strip(), '✄')
finally:
papermill_io.handlers = handlers
@@ -263,7 +247,7 @@ def test_invalid_string(self):
# a string from which we can't extract a notebook is assumed to
# be a file and an IOError will be raised
with self.assertRaises(IOError):
- LocalHandler().read("a random string")
+ LocalHandler().read('a random string')
class TestNoIOHandler(unittest.TestCase):
@@ -276,10 +260,10 @@ def test_raises_on_listdir(self):
NoIOHandler().listdir(None)
def test_write_returns_none(self):
- self.assertIsNone(NoIOHandler().write("buf", None))
+ self.assertIsNone(NoIOHandler().write('buf', None))
def test_pretty_path(self):
- expect = "Notebook will not be saved"
+ expect = 'Notebook will not be saved'
self.assertEqual(NoIOHandler().pretty_path(None), expect)
@@ -291,20 +275,20 @@ class TestADLHandler(unittest.TestCase):
def setUp(self):
self.handler = ADLHandler()
self.handler._client = Mock(
- read=Mock(return_value=["foo", "bar", "baz"]),
- listdir=Mock(return_value=["foo", "bar", "baz"]),
+ read=Mock(return_value=['foo', 'bar', 'baz']),
+ listdir=Mock(return_value=['foo', 'bar', 'baz']),
write=Mock(),
)
def test_read(self):
- self.assertEqual(self.handler.read("some_path"), "foo\nbar\nbaz")
+ self.assertEqual(self.handler.read('some_path'), 'foo\nbar\nbaz')
def test_listdir(self):
- self.assertEqual(self.handler.listdir("some_path"), ["foo", "bar", "baz"])
+ self.assertEqual(self.handler.listdir('some_path'), ['foo', 'bar', 'baz'])
def test_write(self):
- self.handler.write("foo", "bar")
- self.handler._client.write.assert_called_once_with("foo", "bar")
+ self.handler.write('foo', 'bar')
+ self.handler._client.write.assert_called_once_with('foo', 'bar')
class TestHttpHandler(unittest.TestCase):
@@ -318,34 +302,32 @@ def test_listdir(self):
`listdir` function is not supported.
"""
with self.assertRaises(PapermillException) as e:
- HttpHandler.listdir("http://example.com")
+ HttpHandler.listdir('http://example.com')
- self.assertEqual(f"{e.exception}", "listdir is not supported by HttpHandler")
+ self.assertEqual(f'{e.exception}', 'listdir is not supported by HttpHandler')
def test_read(self):
"""
Tests that the `read` function performs a request to the giving path
and returns the response.
"""
- path = "http://example.com"
- text = "request test response"
+ path = 'http://example.com'
+ text = 'request test response'
- with patch("papermill.iorw.requests.get") as mock_get:
+ with patch('papermill.iorw.requests.get') as mock_get:
mock_get.return_value = Mock(text=text)
self.assertEqual(HttpHandler.read(path), text)
- mock_get.assert_called_once_with(
- path, headers={"Accept": "application/json"}
- )
+ mock_get.assert_called_once_with(path, headers={'Accept': 'application/json'})
def test_write(self):
"""
Tests that the `write` function performs a put request to the given
path.
"""
- path = "http://example.com"
+ path = 'http://example.com'
buf = '{"papermill": true}'
- with patch("papermill.iorw.requests.put") as mock_put:
+ with patch('papermill.iorw.requests.put') as mock_put:
HttpHandler.write(buf, path)
mock_put.assert_called_once_with(path, json=json.loads(buf))
@@ -353,7 +335,7 @@ def test_write_failure(self):
"""
Tests that the `write` function raises on failure to put the buffer.
"""
- path = "http://localhost:9999"
+ path = 'http://localhost:9999'
buf = '{"papermill": true}'
with self.assertRaises(ConnectionError):
@@ -361,36 +343,34 @@ def test_write_failure(self):
class TestStreamHandler(unittest.TestCase):
- @patch("sys.stdin", io.StringIO("mock stream"))
+ @patch('sys.stdin', io.StringIO('mock stream'))
def test_read_from_stdin(self):
- result = StreamHandler().read("foo")
- self.assertEqual(result, "mock stream")
+ result = StreamHandler().read('foo')
+ self.assertEqual(result, 'mock stream')
def test_raises_on_listdir(self):
with self.assertRaises(PapermillException):
StreamHandler().listdir(None)
- @patch("sys.stdout")
+ @patch('sys.stdout')
def test_write_to_stdout_buffer(self, mock_stdout):
mock_stdout.buffer = io.BytesIO()
- StreamHandler().write("mock stream", "foo")
- self.assertEqual(mock_stdout.buffer.getbuffer(), b"mock stream")
+ StreamHandler().write('mock stream', 'foo')
+ self.assertEqual(mock_stdout.buffer.getbuffer(), b'mock stream')
- @patch("sys.stdout", new_callable=io.BytesIO)
+ @patch('sys.stdout', new_callable=io.BytesIO)
def test_write_to_stdout(self, mock_stdout):
- StreamHandler().write("mock stream", "foo")
- self.assertEqual(mock_stdout.getbuffer(), b"mock stream")
+ StreamHandler().write('mock stream', 'foo')
+ self.assertEqual(mock_stdout.getbuffer(), b'mock stream')
def test_pretty_path_returns_input_path(self):
'''Should return the input str, which often is the default registered schema "-"'''
- self.assertEqual(StreamHandler().pretty_path("foo"), "foo")
+ self.assertEqual(StreamHandler().pretty_path('foo'), 'foo')
class TestNotebookNodeHandler(unittest.TestCase):
def test_read_notebook_node(self):
- input_nb = nbformat.read(
- get_notebook_path("test_notebooknode_io.ipynb"), as_version=4
- )
+ input_nb = nbformat.read(get_notebook_path('test_notebooknode_io.ipynb'), as_version=4)
result = NotebookNodeHandler().read(input_nb)
expect = (
'{\n "cells": [\n {\n "cell_type": "code",\n "execution_count": null,'
@@ -403,12 +383,12 @@ def test_read_notebook_node(self):
def test_raises_on_listdir(self):
with self.assertRaises(PapermillException):
- NotebookNodeHandler().listdir("foo")
+ NotebookNodeHandler().listdir('foo')
def test_raises_on_write(self):
with self.assertRaises(PapermillException):
- NotebookNodeHandler().write("foo", "bar")
+ NotebookNodeHandler().write('foo', 'bar')
def test_pretty_path(self):
- expect = "NotebookNode object"
- self.assertEqual(NotebookNodeHandler().pretty_path("foo"), expect)
+ expect = 'NotebookNode object'
+ self.assertEqual(NotebookNodeHandler().pretty_path('foo'), expect)
diff --git a/papermill/tests/test_parameterize.py b/papermill/tests/test_parameterize.py
index 4e2df4f4..fbd12ff0 100644
--- a/papermill/tests/test_parameterize.py
+++ b/papermill/tests/test_parameterize.py
@@ -1,205 +1,173 @@
import unittest
+from datetime import datetime
-from ..iorw import load_notebook_node
from ..exceptions import PapermillMissingParameterException
+from ..iorw import load_notebook_node
from ..parameterize import (
+ add_builtin_parameters,
parameterize_notebook,
parameterize_path,
- add_builtin_parameters,
)
from . import get_notebook_path
-from datetime import datetime
class TestNotebookParametrizing(unittest.TestCase):
def count_nb_injected_parameter_cells(self, nb):
- return len(
- [
- c
- for c in nb.cells
- if "injected-parameters" in c.get("metadata", {}).get("tags", [])
- ]
- )
+ return len([c for c in nb.cells if 'injected-parameters' in c.get('metadata', {}).get('tags', [])])
def test_no_tag_copying(self):
# Test that injected cell does not copy other tags
- test_nb = load_notebook_node(get_notebook_path("simple_execute.ipynb"))
- test_nb.cells[0]["metadata"]["tags"].append("some tag")
+ test_nb = load_notebook_node(get_notebook_path('simple_execute.ipynb'))
+ test_nb.cells[0]['metadata']['tags'].append('some tag')
- test_nb = parameterize_notebook(test_nb, {"msg": "Hello"})
+ test_nb = parameterize_notebook(test_nb, {'msg': 'Hello'})
cell_zero = test_nb.cells[0]
- self.assertTrue("some tag" in cell_zero.get("metadata").get("tags"))
- self.assertTrue("parameters" in cell_zero.get("metadata").get("tags"))
+ self.assertTrue('some tag' in cell_zero.get('metadata').get('tags'))
+ self.assertTrue('parameters' in cell_zero.get('metadata').get('tags'))
cell_one = test_nb.cells[1]
- self.assertTrue("some tag" not in cell_one.get("metadata").get("tags"))
- self.assertTrue("injected-parameters" in cell_one.get("metadata").get("tags"))
+ self.assertTrue('some tag' not in cell_one.get('metadata').get('tags'))
+ self.assertTrue('injected-parameters' in cell_one.get('metadata').get('tags'))
self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 1)
def test_injected_parameters_tag(self):
- test_nb = load_notebook_node(get_notebook_path("simple_execute.ipynb"))
+ test_nb = load_notebook_node(get_notebook_path('simple_execute.ipynb'))
- test_nb = parameterize_notebook(test_nb, {"msg": "Hello"})
+ test_nb = parameterize_notebook(test_nb, {'msg': 'Hello'})
cell_zero = test_nb.cells[0]
- self.assertTrue("parameters" in cell_zero.get("metadata").get("tags"))
- self.assertTrue(
- "injected-parameters" not in cell_zero.get("metadata").get("tags")
- )
+ self.assertTrue('parameters' in cell_zero.get('metadata').get('tags'))
+ self.assertTrue('injected-parameters' not in cell_zero.get('metadata').get('tags'))
cell_one = test_nb.cells[1]
- self.assertTrue("injected-parameters" in cell_one.get("metadata").get("tags"))
+ self.assertTrue('injected-parameters' in cell_one.get('metadata').get('tags'))
self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 1)
def test_repeated_run_injected_parameters_tag(self):
- test_nb = load_notebook_node(get_notebook_path("simple_execute.ipynb"))
+ test_nb = load_notebook_node(get_notebook_path('simple_execute.ipynb'))
self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 0)
- test_nb = parameterize_notebook(test_nb, {"msg": "Hello"})
+ test_nb = parameterize_notebook(test_nb, {'msg': 'Hello'})
self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 1)
- parameterize_notebook(test_nb, {"msg": "Hello"})
+ parameterize_notebook(test_nb, {'msg': 'Hello'})
self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 1)
def test_no_parameter_tag(self):
- test_nb = load_notebook_node(get_notebook_path("simple_execute.ipynb"))
- test_nb.cells[0]["metadata"]["tags"] = []
+ test_nb = load_notebook_node(get_notebook_path('simple_execute.ipynb'))
+ test_nb.cells[0]['metadata']['tags'] = []
- test_nb = parameterize_notebook(test_nb, {"msg": "Hello"})
+ test_nb = parameterize_notebook(test_nb, {'msg': 'Hello'})
cell_zero = test_nb.cells[0]
- self.assertTrue("injected-parameters" in cell_zero.get("metadata").get("tags"))
- self.assertTrue("parameters" not in cell_zero.get("metadata").get("tags"))
+ self.assertTrue('injected-parameters' in cell_zero.get('metadata').get('tags'))
+ self.assertTrue('parameters' not in cell_zero.get('metadata').get('tags'))
self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 1)
def test_repeated_run_no_parameters_tag(self):
- test_nb = load_notebook_node(get_notebook_path("simple_execute.ipynb"))
- test_nb.cells[0]["metadata"]["tags"] = []
+ test_nb = load_notebook_node(get_notebook_path('simple_execute.ipynb'))
+ test_nb.cells[0]['metadata']['tags'] = []
self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 0)
- test_nb = parameterize_notebook(test_nb, {"msg": "Hello"})
+ test_nb = parameterize_notebook(test_nb, {'msg': 'Hello'})
self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 1)
- test_nb = parameterize_notebook(test_nb, {"msg": "Hello"})
+ test_nb = parameterize_notebook(test_nb, {'msg': 'Hello'})
self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 1)
def test_custom_comment(self):
- test_nb = load_notebook_node(get_notebook_path("simple_execute.ipynb"))
- test_nb = parameterize_notebook(
- test_nb, {"msg": "Hello"}, comment="This is a custom comment"
- )
+ test_nb = load_notebook_node(get_notebook_path('simple_execute.ipynb'))
+ test_nb = parameterize_notebook(test_nb, {'msg': 'Hello'}, comment='This is a custom comment')
cell_one = test_nb.cells[1]
- first_line = cell_one["source"].split("\n")[0]
- self.assertEqual(first_line, "# This is a custom comment")
+ first_line = cell_one['source'].split('\n')[0]
+ self.assertEqual(first_line, '# This is a custom comment')
class TestBuiltinParameters(unittest.TestCase):
def test_add_builtin_parameters_keeps_provided_parameters(self):
- with_builtin_parameters = add_builtin_parameters({"foo": "bar"})
- self.assertEqual(with_builtin_parameters["foo"], "bar")
+ with_builtin_parameters = add_builtin_parameters({'foo': 'bar'})
+ self.assertEqual(with_builtin_parameters['foo'], 'bar')
def test_add_builtin_parameters_adds_dict_of_builtins(self):
- with_builtin_parameters = add_builtin_parameters({"foo": "bar"})
- self.assertIn("pm", with_builtin_parameters)
- self.assertIsInstance(with_builtin_parameters["pm"], type({}))
+ with_builtin_parameters = add_builtin_parameters({'foo': 'bar'})
+ self.assertIn('pm', with_builtin_parameters)
+ self.assertIsInstance(with_builtin_parameters['pm'], type({}))
def test_add_builtin_parameters_allows_to_override_builtin(self):
- with_builtin_parameters = add_builtin_parameters({"pm": "foo"})
- self.assertEqual(with_builtin_parameters["pm"], "foo")
+ with_builtin_parameters = add_builtin_parameters({'pm': 'foo'})
+ self.assertEqual(with_builtin_parameters['pm'], 'foo')
def test_builtin_parameters_include_run_uuid(self):
- with_builtin_parameters = add_builtin_parameters({"foo": "bar"})
- self.assertIn("run_uuid", with_builtin_parameters["pm"])
+ with_builtin_parameters = add_builtin_parameters({'foo': 'bar'})
+ self.assertIn('run_uuid', with_builtin_parameters['pm'])
def test_builtin_parameters_include_current_datetime_local(self):
- with_builtin_parameters = add_builtin_parameters({"foo": "bar"})
- self.assertIn("current_datetime_local", with_builtin_parameters["pm"])
- self.assertIsInstance(
- with_builtin_parameters["pm"]["current_datetime_local"], datetime
- )
+ with_builtin_parameters = add_builtin_parameters({'foo': 'bar'})
+ self.assertIn('current_datetime_local', with_builtin_parameters['pm'])
+ self.assertIsInstance(with_builtin_parameters['pm']['current_datetime_local'], datetime)
def test_builtin_parameters_include_current_datetime_utc(self):
- with_builtin_parameters = add_builtin_parameters({"foo": "bar"})
- self.assertIn("current_datetime_utc", with_builtin_parameters["pm"])
- self.assertIsInstance(
- with_builtin_parameters["pm"]["current_datetime_utc"], datetime
- )
+ with_builtin_parameters = add_builtin_parameters({'foo': 'bar'})
+ self.assertIn('current_datetime_utc', with_builtin_parameters['pm'])
+ self.assertIsInstance(with_builtin_parameters['pm']['current_datetime_utc'], datetime)
class TestPathParameterizing(unittest.TestCase):
def test_plain_text_path_with_empty_parameters_object(self):
- self.assertEqual(parameterize_path("foo/bar", {}), "foo/bar")
+ self.assertEqual(parameterize_path('foo/bar', {}), 'foo/bar')
def test_plain_text_path_with_none_parameters(self):
- self.assertEqual(parameterize_path("foo/bar", None), "foo/bar")
+ self.assertEqual(parameterize_path('foo/bar', None), 'foo/bar')
def test_plain_text_path_with_unused_parameters(self):
- self.assertEqual(parameterize_path("foo/bar", {"baz": "quux"}), "foo/bar")
+ self.assertEqual(parameterize_path('foo/bar', {'baz': 'quux'}), 'foo/bar')
def test_path_with_single_parameter(self):
- self.assertEqual(
- parameterize_path("foo/bar/{baz}", {"baz": "quux"}), "foo/bar/quux"
- )
+ self.assertEqual(parameterize_path('foo/bar/{baz}', {'baz': 'quux'}), 'foo/bar/quux')
def test_path_with_boolean_parameter(self):
- self.assertEqual(
- parameterize_path("foo/bar/{baz}", {"baz": False}), "foo/bar/False"
- )
+ self.assertEqual(parameterize_path('foo/bar/{baz}', {'baz': False}), 'foo/bar/False')
def test_path_with_dict_parameter(self):
- self.assertEqual(
- parameterize_path("foo/{bar[baz]}/", {"bar": {"baz": "quux"}}), "foo/quux/"
- )
+ self.assertEqual(parameterize_path('foo/{bar[baz]}/', {'bar': {'baz': 'quux'}}), 'foo/quux/')
def test_path_with_list_parameter(self):
- self.assertEqual(
- parameterize_path("foo/{bar[0]}/", {"bar": [1, 2, 3]}), "foo/1/"
- )
- self.assertEqual(
- parameterize_path("foo/{bar[2]}/", {"bar": [1, 2, 3]}), "foo/3/"
- )
+ self.assertEqual(parameterize_path('foo/{bar[0]}/', {'bar': [1, 2, 3]}), 'foo/1/')
+ self.assertEqual(parameterize_path('foo/{bar[2]}/', {'bar': [1, 2, 3]}), 'foo/3/')
def test_path_with_none_parameter(self):
- self.assertEqual(
- parameterize_path("foo/bar/{baz}", {"baz": None}), "foo/bar/None"
- )
+ self.assertEqual(parameterize_path('foo/bar/{baz}', {'baz': None}), 'foo/bar/None')
def test_path_with_numeric_parameter(self):
- self.assertEqual(parameterize_path("foo/bar/{baz}", {"baz": 42}), "foo/bar/42")
+ self.assertEqual(parameterize_path('foo/bar/{baz}', {'baz': 42}), 'foo/bar/42')
def test_path_with_numeric_format_string(self):
- self.assertEqual(
- parameterize_path("foo/bar/{baz:03d}", {"baz": 42}), "foo/bar/042"
- )
+ self.assertEqual(parameterize_path('foo/bar/{baz:03d}', {'baz': 42}), 'foo/bar/042')
def test_path_with_float_format_string(self):
- self.assertEqual(
- parameterize_path("foo/bar/{baz:.03f}", {"baz": 0.3}), "foo/bar/0.300"
- )
+ self.assertEqual(parameterize_path('foo/bar/{baz:.03f}', {'baz': 0.3}), 'foo/bar/0.300')
def test_path_with_multiple_parameter(self):
- self.assertEqual(
- parameterize_path("{foo}/{baz}", {"foo": "bar", "baz": "quux"}), "bar/quux"
- )
+ self.assertEqual(parameterize_path('{foo}/{baz}', {'foo': 'bar', 'baz': 'quux'}), 'bar/quux')
def test_parameterized_path_with_undefined_parameter(self):
with self.assertRaises(PapermillMissingParameterException) as context:
- parameterize_path("{foo}", {})
+ parameterize_path('{foo}', {})
self.assertEqual(str(context.exception), "Missing parameter 'foo'")
def test_parameterized_path_with_none_parameters(self):
with self.assertRaises(PapermillMissingParameterException) as context:
- parameterize_path("{foo}", None)
+ parameterize_path('{foo}', None)
self.assertEqual(str(context.exception), "Missing parameter 'foo'")
def test_path_of_none_returns_none(self):
- self.assertIsNone(parameterize_path(path=None, parameters={"foo": "bar"}))
+ self.assertIsNone(parameterize_path(path=None, parameters={'foo': 'bar'}))
self.assertIsNone(parameterize_path(path=None, parameters=None))
def test_path_of_notebook_node_returns_input(self):
- test_nb = load_notebook_node(get_notebook_path("simple_execute.ipynb"))
+ test_nb = load_notebook_node(get_notebook_path('simple_execute.ipynb'))
result_nb = parameterize_path(test_nb, parameters=None)
self.assertIs(result_nb, test_nb)
diff --git a/papermill/tests/test_s3.py b/papermill/tests/test_s3.py
index 156b4a7a..de86f5b6 100644
--- a/papermill/tests/test_s3.py
+++ b/papermill/tests/test_s3.py
@@ -1,52 +1,52 @@
# The following tests are purposely limited to the exposed interface by iorw.py
import os.path
-import pytest
+
import boto3
import moto
-
+import pytest
from moto import mock_s3
-from ..s3 import Bucket, Prefix, Key, S3
+from ..s3 import S3, Bucket, Key, Prefix
@pytest.fixture
def bucket_no_service():
"""Returns a bucket instance with no services"""
- return Bucket("my_test_bucket")
+ return Bucket('my_test_bucket')
@pytest.fixture
def bucket_with_service():
"""Returns a bucket instance with a service"""
- return Bucket("my_sqs_bucket", ["sqs"])
+ return Bucket('my_sqs_bucket', ['sqs'])
@pytest.fixture
def bucket_sqs():
"""Returns a bucket instance with a sqs service"""
- return Bucket("my_sqs_bucket", ["sqs"])
+ return Bucket('my_sqs_bucket', ['sqs'])
@pytest.fixture
def bucket_ec2():
"""Returns a bucket instance with a ec2 service"""
- return Bucket("my_sqs_bucket", ["ec2"])
+ return Bucket('my_sqs_bucket', ['ec2'])
@pytest.fixture
def bucket_multiservice():
"""Returns a bucket instance with a ec2 service"""
- return Bucket("my_sqs_bucket", ["ec2", "sqs"])
+ return Bucket('my_sqs_bucket', ['ec2', 'sqs'])
def test_bucket_init():
- assert Bucket("my_test_bucket")
- assert Bucket("my_sqs_bucket", "sqs")
+ assert Bucket('my_test_bucket')
+ assert Bucket('my_sqs_bucket', 'sqs')
def test_bucket_defaults():
- name = "a bucket"
+ name = 'a bucket'
b1 = Bucket(name)
b2 = Bucket(name, None)
@@ -86,19 +86,19 @@ def test_prefix_init():
Prefix(service=None)
with pytest.raises(TypeError):
- Prefix("my_test_prefix")
+ Prefix('my_test_prefix')
- b1 = Bucket("my_test_bucket")
- p1 = Prefix(b1, "sqs_test", service="sqs")
- assert Prefix(b1, "test_bucket")
- assert Prefix(b1, "test_bucket", service=None)
- assert Prefix(b1, "test_bucket", None)
+ b1 = Bucket('my_test_bucket')
+ p1 = Prefix(b1, 'sqs_test', service='sqs')
+ assert Prefix(b1, 'test_bucket')
+ assert Prefix(b1, 'test_bucket', service=None)
+ assert Prefix(b1, 'test_bucket', None)
assert p1.bucket.service == p1.service
def test_prefix_defaults():
- bucket = Bucket("my data pool")
- name = "bigdata bucket"
+ bucket = Bucket('my data pool')
+ name = 'bigdata bucket'
p1 = Prefix(bucket, name)
p2 = Prefix(bucket, name, None)
@@ -107,13 +107,13 @@ def test_prefix_defaults():
def test_prefix_str(bucket_sqs):
- p1 = Prefix(bucket_sqs, "sqs_prefix_test", "sqs")
- assert str(p1) == "s3://" + str(bucket_sqs) + "/sqs_prefix_test"
+ p1 = Prefix(bucket_sqs, 'sqs_prefix_test', 'sqs')
+ assert str(p1) == 's3://' + str(bucket_sqs) + '/sqs_prefix_test'
def test_prefix_repr(bucket_sqs):
- p1 = Prefix(bucket_sqs, "sqs_prefix_test", "sqs")
- assert repr(p1) == "s3://" + str(bucket_sqs) + "/sqs_prefix_test"
+ p1 = Prefix(bucket_sqs, 'sqs_prefix_test', 'sqs')
+ assert repr(p1) == 's3://' + str(bucket_sqs) + '/sqs_prefix_test'
def test_key_init():
@@ -121,13 +121,13 @@ def test_key_init():
def test_key_repr():
- k = Key("foo", "bar")
- assert repr(k) == "s3://foo/bar"
+ k = Key('foo', 'bar')
+ assert repr(k) == 's3://foo/bar'
def test_key_defaults():
- bucket = Bucket("my data pool")
- name = "bigdata bucket"
+ bucket = Bucket('my data pool')
+ name = 'bigdata bucket'
k1 = Key(bucket, name)
k2 = Key(bucket, name, None, None, None, None, None)
@@ -148,36 +148,36 @@ def test_s3_defaults():
local_dir = os.path.dirname(os.path.abspath(__file__))
-test_bucket_name = "test-pm-bucket"
-test_string = "Hello"
-test_file_path = "notebooks/s3/s3_in/s3-simple_notebook.ipynb"
-test_empty_file_path = "notebooks/s3/s3_in/s3-empty.ipynb"
+test_bucket_name = 'test-pm-bucket'
+test_string = 'Hello'
+test_file_path = 'notebooks/s3/s3_in/s3-simple_notebook.ipynb'
+test_empty_file_path = 'notebooks/s3/s3_in/s3-empty.ipynb'
with open(os.path.join(local_dir, test_file_path)) as f:
test_nb_content = f.read()
-no_empty_lines = lambda s: "\n".join([l for l in s.split("\n") if len(l) > 0])
+no_empty_lines = lambda s: '\n'.join([l for l in s.split('\n') if len(l) > 0])
test_clean_nb_content = no_empty_lines(test_nb_content)
-read_from_gen = lambda g: "\n".join(g)
+read_from_gen = lambda g: '\n'.join(g)
-@pytest.fixture(scope="function")
+@pytest.fixture(scope='function')
def s3_client():
mock_s3 = moto.mock_s3()
mock_s3.start()
- client = boto3.client("s3")
+ client = boto3.client('s3')
client.create_bucket(
Bucket=test_bucket_name,
- CreateBucketConfiguration={"LocationConstraint": "us-west-2"},
+ CreateBucketConfiguration={'LocationConstraint': 'us-west-2'},
)
client.put_object(Bucket=test_bucket_name, Key=test_file_path, Body=test_nb_content)
- client.put_object(Bucket=test_bucket_name, Key=test_empty_file_path, Body="")
+ client.put_object(Bucket=test_bucket_name, Key=test_empty_file_path, Body='')
yield S3()
try:
client.delete_object(Bucket=test_bucket_name, Key=test_file_path)
- client.delete_object(Bucket=test_bucket_name, Key=test_file_path + ".txt")
+ client.delete_object(Bucket=test_bucket_name, Key=test_file_path + '.txt')
client.delete_object(Bucket=test_bucket_name, Key=test_empty_file_path)
except Exception:
pass
@@ -185,19 +185,19 @@ def s3_client():
def test_s3_read(s3_client):
- s3_path = f"s3://{test_bucket_name}/{test_file_path}"
+ s3_path = f's3://{test_bucket_name}/{test_file_path}'
data = read_from_gen(s3_client.read(s3_path))
assert data == test_clean_nb_content
def test_s3_read_empty(s3_client):
- s3_path = f"s3://{test_bucket_name}/{test_empty_file_path}"
+ s3_path = f's3://{test_bucket_name}/{test_empty_file_path}'
data = read_from_gen(s3_client.read(s3_path))
- assert data == ""
+ assert data == ''
def test_s3_write(s3_client):
- s3_path = f"s3://{test_bucket_name}/{test_file_path}.txt"
+ s3_path = f's3://{test_bucket_name}/{test_file_path}.txt'
s3_client.cp_string(test_string, s3_path)
data = read_from_gen(s3_client.read(s3_path))
@@ -205,7 +205,7 @@ def test_s3_write(s3_client):
def test_s3_overwrite(s3_client):
- s3_path = f"s3://{test_bucket_name}/{test_file_path}"
+ s3_path = f's3://{test_bucket_name}/{test_file_path}'
s3_client.cp_string(test_string, s3_path)
data = read_from_gen(s3_client.read(s3_path))
@@ -214,8 +214,8 @@ def test_s3_overwrite(s3_client):
def test_s3_listdir(s3_client):
dir_name = os.path.dirname(test_file_path)
- s3_dir = f"s3://{test_bucket_name}/{dir_name}"
- s3_path = f"s3://{test_bucket_name}/{test_file_path}"
+ s3_dir = f's3://{test_bucket_name}/{dir_name}'
+ s3_path = f's3://{test_bucket_name}/{test_file_path}'
dir_listings = s3_client.listdir(s3_dir)
assert len(dir_listings) == 2
assert s3_path in dir_listings
diff --git a/papermill/tests/test_translators.py b/papermill/tests/test_translators.py
index 906784f6..ab49475d 100644
--- a/papermill/tests/test_translators.py
+++ b/papermill/tests/test_translators.py
@@ -1,8 +1,7 @@
-import pytest
-
-from unittest.mock import Mock
from collections import OrderedDict
+from unittest.mock import Mock
+import pytest
from nbformat.v4 import new_code_cell
from .. import translators
@@ -11,29 +10,29 @@
@pytest.mark.parametrize(
- "test_input,expected",
+ 'test_input,expected',
[
- ("foo", '"foo"'),
+ ('foo', '"foo"'),
('{"foo": "bar"}', '"{\\"foo\\": \\"bar\\"}"'),
- ({"foo": "bar"}, '{"foo": "bar"}'),
- ({"foo": '"bar"'}, '{"foo": "\\"bar\\""}'),
- ({"foo": ["bar"]}, '{"foo": ["bar"]}'),
- ({"foo": {"bar": "baz"}}, '{"foo": {"bar": "baz"}}'),
- ({"foo": {"bar": '"baz"'}}, '{"foo": {"bar": "\\"baz\\""}}'),
- (["foo"], '["foo"]'),
- (["foo", '"bar"'], '["foo", "\\"bar\\""]'),
- ([{"foo": "bar"}], '[{"foo": "bar"}]'),
- ([{"foo": '"bar"'}], '[{"foo": "\\"bar\\""}]'),
- (12345, "12345"),
- (-54321, "-54321"),
- (1.2345, "1.2345"),
- (-5432.1, "-5432.1"),
- (float("nan"), "float('nan')"),
- (float("-inf"), "float('-inf')"),
- (float("inf"), "float('inf')"),
- (True, "True"),
- (False, "False"),
- (None, "None"),
+ ({'foo': 'bar'}, '{"foo": "bar"}'),
+ ({'foo': '"bar"'}, '{"foo": "\\"bar\\""}'),
+ ({'foo': ['bar']}, '{"foo": ["bar"]}'),
+ ({'foo': {'bar': 'baz'}}, '{"foo": {"bar": "baz"}}'),
+ ({'foo': {'bar': '"baz"'}}, '{"foo": {"bar": "\\"baz\\""}}'),
+ (['foo'], '["foo"]'),
+ (['foo', '"bar"'], '["foo", "\\"bar\\""]'),
+ ([{'foo': 'bar'}], '[{"foo": "bar"}]'),
+ ([{'foo': '"bar"'}], '[{"foo": "\\"bar\\""}]'),
+ (12345, '12345'),
+ (-54321, '-54321'),
+ (1.2345, '1.2345'),
+ (-5432.1, '-5432.1'),
+ (float('nan'), "float('nan')"),
+ (float('-inf'), "float('-inf')"),
+ (float('inf'), "float('inf')"),
+ (True, 'True'),
+ (False, 'False'),
+ (None, 'None'),
],
)
def test_translate_type_python(test_input, expected):
@@ -41,16 +40,16 @@ def test_translate_type_python(test_input, expected):
@pytest.mark.parametrize(
- "parameters,expected",
+ 'parameters,expected',
[
- ({"foo": "bar"}, '# Parameters\nfoo = "bar"\n'),
- ({"foo": True}, "# Parameters\nfoo = True\n"),
- ({"foo": 5}, "# Parameters\nfoo = 5\n"),
- ({"foo": 1.1}, "# Parameters\nfoo = 1.1\n"),
- ({"foo": ["bar", "baz"]}, '# Parameters\nfoo = ["bar", "baz"]\n'),
- ({"foo": {"bar": "baz"}}, '# Parameters\nfoo = {"bar": "baz"}\n'),
+ ({'foo': 'bar'}, '# Parameters\nfoo = "bar"\n'),
+ ({'foo': True}, '# Parameters\nfoo = True\n'),
+ ({'foo': 5}, '# Parameters\nfoo = 5\n'),
+ ({'foo': 1.1}, '# Parameters\nfoo = 1.1\n'),
+ ({'foo': ['bar', 'baz']}, '# Parameters\nfoo = ["bar", "baz"]\n'),
+ ({'foo': {'bar': 'baz'}}, '# Parameters\nfoo = {"bar": "baz"}\n'),
(
- OrderedDict([["foo", "bar"], ["baz", ["buz"]]]),
+ OrderedDict([['foo', 'bar'], ['baz', ['buz']]]),
'# Parameters\nfoo = "bar"\nbaz = ["buz"]\n',
),
],
@@ -60,39 +59,39 @@ def test_translate_codify_python(parameters, expected):
@pytest.mark.parametrize(
- "test_input,expected",
- [("", "#"), ("foo", "# foo"), ("['best effort']", "# ['best effort']")],
+ 'test_input,expected',
+ [('', '#'), ('foo', '# foo'), ("['best effort']", "# ['best effort']")],
)
def test_translate_comment_python(test_input, expected):
assert translators.PythonTranslator.comment(test_input) == expected
@pytest.mark.parametrize(
- "test_input,expected",
+ 'test_input,expected',
[
- ("a = 2", [Parameter("a", "None", "2", "")]),
- ("a: int = 2", [Parameter("a", "int", "2", "")]),
- ("a = 2 # type:int", [Parameter("a", "int", "2", "")]),
+ ('a = 2', [Parameter('a', 'None', '2', '')]),
+ ('a: int = 2', [Parameter('a', 'int', '2', '')]),
+ ('a = 2 # type:int', [Parameter('a', 'int', '2', '')]),
(
- "a = False # Nice variable a",
- [Parameter("a", "None", "False", "Nice variable a")],
+ 'a = False # Nice variable a',
+ [Parameter('a', 'None', 'False', 'Nice variable a')],
),
(
- "a: float = 2.258 # type: int Nice variable a",
- [Parameter("a", "float", "2.258", "Nice variable a")],
+ 'a: float = 2.258 # type: int Nice variable a',
+ [Parameter('a', 'float', '2.258', 'Nice variable a')],
),
(
"a = 'this is a string' # type: int Nice variable a",
- [Parameter("a", "int", "'this is a string'", "Nice variable a")],
+ [Parameter('a', 'int', "'this is a string'", 'Nice variable a')],
),
(
"a: List[str] = ['this', 'is', 'a', 'string', 'list'] # Nice variable a",
[
Parameter(
- "a",
- "List[str]",
+ 'a',
+ 'List[str]',
"['this', 'is', 'a', 'string', 'list']",
- "Nice variable a",
+ 'Nice variable a',
)
],
),
@@ -100,10 +99,10 @@ def test_translate_comment_python(test_input, expected):
"a: List[str] = [\n 'this', # First\n 'is',\n 'a',\n 'string',\n 'list' # Last\n] # Nice variable a", # noqa
[
Parameter(
- "a",
- "List[str]",
+ 'a',
+ 'List[str]',
"['this','is','a','string','list']",
- "Nice variable a",
+ 'Nice variable a',
)
],
),
@@ -111,10 +110,10 @@ def test_translate_comment_python(test_input, expected):
"a: List[str] = [\n 'this',\n 'is',\n 'a',\n 'string',\n 'list'\n] # Nice variable a", # noqa
[
Parameter(
- "a",
- "List[str]",
+ 'a',
+ 'List[str]',
"['this','is','a','string','list']",
- "Nice variable a",
+ 'Nice variable a',
)
],
),
@@ -132,12 +131,12 @@ def test_translate_comment_python(test_input, expected):
""",
[
Parameter(
- "a",
- "List[str]",
+ 'a',
+ 'List[str]',
"['this','is','a','string','list']",
- "Nice variable a",
+ 'Nice variable a',
),
- Parameter("b", "float", "-2.3432", "My b variable"),
+ Parameter('b', 'float', '-2.3432', 'My b variable'),
],
),
],
@@ -148,26 +147,26 @@ def test_inspect_python(test_input, expected):
@pytest.mark.parametrize(
- "test_input,expected",
+ 'test_input,expected',
[
- ("foo", '"foo"'),
+ ('foo', '"foo"'),
('{"foo": "bar"}', '"{\\"foo\\": \\"bar\\"}"'),
- ({"foo": "bar"}, 'list("foo" = "bar")'),
- ({"foo": '"bar"'}, 'list("foo" = "\\"bar\\"")'),
- ({"foo": ["bar"]}, 'list("foo" = list("bar"))'),
- ({"foo": {"bar": "baz"}}, 'list("foo" = list("bar" = "baz"))'),
- ({"foo": {"bar": '"baz"'}}, 'list("foo" = list("bar" = "\\"baz\\""))'),
- (["foo"], 'list("foo")'),
- (["foo", '"bar"'], 'list("foo", "\\"bar\\"")'),
- ([{"foo": "bar"}], 'list(list("foo" = "bar"))'),
- ([{"foo": '"bar"'}], 'list(list("foo" = "\\"bar\\""))'),
- (12345, "12345"),
- (-54321, "-54321"),
- (1.2345, "1.2345"),
- (-5432.1, "-5432.1"),
- (True, "TRUE"),
- (False, "FALSE"),
- (None, "NULL"),
+ ({'foo': 'bar'}, 'list("foo" = "bar")'),
+ ({'foo': '"bar"'}, 'list("foo" = "\\"bar\\"")'),
+ ({'foo': ['bar']}, 'list("foo" = list("bar"))'),
+ ({'foo': {'bar': 'baz'}}, 'list("foo" = list("bar" = "baz"))'),
+ ({'foo': {'bar': '"baz"'}}, 'list("foo" = list("bar" = "\\"baz\\""))'),
+ (['foo'], 'list("foo")'),
+ (['foo', '"bar"'], 'list("foo", "\\"bar\\"")'),
+ ([{'foo': 'bar'}], 'list(list("foo" = "bar"))'),
+ ([{'foo': '"bar"'}], 'list(list("foo" = "\\"bar\\""))'),
+ (12345, '12345'),
+ (-54321, '-54321'),
+ (1.2345, '1.2345'),
+ (-5432.1, '-5432.1'),
+ (True, 'TRUE'),
+ (False, 'FALSE'),
+ (None, 'NULL'),
],
)
def test_translate_type_r(test_input, expected):
@@ -175,28 +174,28 @@ def test_translate_type_r(test_input, expected):
@pytest.mark.parametrize(
- "test_input,expected",
- [("", "#"), ("foo", "# foo"), ("['best effort']", "# ['best effort']")],
+ 'test_input,expected',
+ [('', '#'), ('foo', '# foo'), ("['best effort']", "# ['best effort']")],
)
def test_translate_comment_r(test_input, expected):
assert translators.RTranslator.comment(test_input) == expected
@pytest.mark.parametrize(
- "parameters,expected",
+ 'parameters,expected',
[
- ({"foo": "bar"}, '# Parameters\nfoo = "bar"\n'),
- ({"foo": True}, "# Parameters\nfoo = TRUE\n"),
- ({"foo": 5}, "# Parameters\nfoo = 5\n"),
- ({"foo": 1.1}, "# Parameters\nfoo = 1.1\n"),
- ({"foo": ["bar", "baz"]}, '# Parameters\nfoo = list("bar", "baz")\n'),
- ({"foo": {"bar": "baz"}}, '# Parameters\nfoo = list("bar" = "baz")\n'),
+ ({'foo': 'bar'}, '# Parameters\nfoo = "bar"\n'),
+ ({'foo': True}, '# Parameters\nfoo = TRUE\n'),
+ ({'foo': 5}, '# Parameters\nfoo = 5\n'),
+ ({'foo': 1.1}, '# Parameters\nfoo = 1.1\n'),
+ ({'foo': ['bar', 'baz']}, '# Parameters\nfoo = list("bar", "baz")\n'),
+ ({'foo': {'bar': 'baz'}}, '# Parameters\nfoo = list("bar" = "baz")\n'),
(
- OrderedDict([["foo", "bar"], ["baz", ["buz"]]]),
+ OrderedDict([['foo', 'bar'], ['baz', ['buz']]]),
'# Parameters\nfoo = "bar"\nbaz = list("buz")\n',
),
# Underscores remove
- ({"___foo": 5}, "# Parameters\nfoo = 5\n"),
+ ({'___foo': 5}, '# Parameters\nfoo = 5\n'),
],
)
def test_translate_codify_r(parameters, expected):
@@ -204,28 +203,28 @@ def test_translate_codify_r(parameters, expected):
@pytest.mark.parametrize(
- "test_input,expected",
+ 'test_input,expected',
[
- ("foo", '"foo"'),
+ ('foo', '"foo"'),
('{"foo": "bar"}', '"{\\"foo\\": \\"bar\\"}"'),
- ({"foo": "bar"}, 'Map("foo" -> "bar")'),
- ({"foo": '"bar"'}, 'Map("foo" -> "\\"bar\\"")'),
- ({"foo": ["bar"]}, 'Map("foo" -> Seq("bar"))'),
- ({"foo": {"bar": "baz"}}, 'Map("foo" -> Map("bar" -> "baz"))'),
- ({"foo": {"bar": '"baz"'}}, 'Map("foo" -> Map("bar" -> "\\"baz\\""))'),
- (["foo"], 'Seq("foo")'),
- (["foo", '"bar"'], 'Seq("foo", "\\"bar\\"")'),
- ([{"foo": "bar"}], 'Seq(Map("foo" -> "bar"))'),
- ([{"foo": '"bar"'}], 'Seq(Map("foo" -> "\\"bar\\""))'),
- (12345, "12345"),
- (-54321, "-54321"),
- (1.2345, "1.2345"),
- (-5432.1, "-5432.1"),
- (2147483648, "2147483648L"),
- (-2147483649, "-2147483649L"),
- (True, "true"),
- (False, "false"),
- (None, "None"),
+ ({'foo': 'bar'}, 'Map("foo" -> "bar")'),
+ ({'foo': '"bar"'}, 'Map("foo" -> "\\"bar\\"")'),
+ ({'foo': ['bar']}, 'Map("foo" -> Seq("bar"))'),
+ ({'foo': {'bar': 'baz'}}, 'Map("foo" -> Map("bar" -> "baz"))'),
+ ({'foo': {'bar': '"baz"'}}, 'Map("foo" -> Map("bar" -> "\\"baz\\""))'),
+ (['foo'], 'Seq("foo")'),
+ (['foo', '"bar"'], 'Seq("foo", "\\"bar\\"")'),
+ ([{'foo': 'bar'}], 'Seq(Map("foo" -> "bar"))'),
+ ([{'foo': '"bar"'}], 'Seq(Map("foo" -> "\\"bar\\""))'),
+ (12345, '12345'),
+ (-54321, '-54321'),
+ (1.2345, '1.2345'),
+ (-5432.1, '-5432.1'),
+ (2147483648, '2147483648L'),
+ (-2147483649, '-2147483649L'),
+ (True, 'true'),
+ (False, 'false'),
+ (None, 'None'),
],
)
def test_translate_type_scala(test_input, expected):
@@ -233,19 +232,19 @@ def test_translate_type_scala(test_input, expected):
@pytest.mark.parametrize(
- "test_input,expected",
- [("", "//"), ("foo", "// foo"), ("['best effort']", "// ['best effort']")],
+ 'test_input,expected',
+ [('', '//'), ('foo', '// foo'), ("['best effort']", "// ['best effort']")],
)
def test_translate_comment_scala(test_input, expected):
assert translators.ScalaTranslator.comment(test_input) == expected
@pytest.mark.parametrize(
- "input_name,input_value,expected",
+ 'input_name,input_value,expected',
[
- ("foo", '""', 'val foo = ""'),
- ("foo", '"bar"', 'val foo = "bar"'),
- ("foo", 'Map("foo" -> "bar")', 'val foo = Map("foo" -> "bar")'),
+ ('foo', '""', 'val foo = ""'),
+ ('foo', '"bar"', 'val foo = "bar"'),
+ ('foo', 'Map("foo" -> "bar")', 'val foo = Map("foo" -> "bar")'),
],
)
def test_translate_assign_scala(input_name, input_value, expected):
@@ -253,16 +252,16 @@ def test_translate_assign_scala(input_name, input_value, expected):
@pytest.mark.parametrize(
- "parameters,expected",
+ 'parameters,expected',
[
- ({"foo": "bar"}, '// Parameters\nval foo = "bar"\n'),
- ({"foo": True}, "// Parameters\nval foo = true\n"),
- ({"foo": 5}, "// Parameters\nval foo = 5\n"),
- ({"foo": 1.1}, "// Parameters\nval foo = 1.1\n"),
- ({"foo": ["bar", "baz"]}, '// Parameters\nval foo = Seq("bar", "baz")\n'),
- ({"foo": {"bar": "baz"}}, '// Parameters\nval foo = Map("bar" -> "baz")\n'),
+ ({'foo': 'bar'}, '// Parameters\nval foo = "bar"\n'),
+ ({'foo': True}, '// Parameters\nval foo = true\n'),
+ ({'foo': 5}, '// Parameters\nval foo = 5\n'),
+ ({'foo': 1.1}, '// Parameters\nval foo = 1.1\n'),
+ ({'foo': ['bar', 'baz']}, '// Parameters\nval foo = Seq("bar", "baz")\n'),
+ ({'foo': {'bar': 'baz'}}, '// Parameters\nval foo = Map("bar" -> "baz")\n'),
(
- OrderedDict([["foo", "bar"], ["baz", ["buz"]]]),
+ OrderedDict([['foo', 'bar'], ['baz', ['buz']]]),
'// Parameters\nval foo = "bar"\nval baz = Seq("buz")\n',
),
],
@@ -273,26 +272,26 @@ def test_translate_codify_scala(parameters, expected):
# C# section
@pytest.mark.parametrize(
- "test_input,expected",
+ 'test_input,expected',
[
- ("foo", '"foo"'),
+ ('foo', '"foo"'),
('{"foo": "bar"}', '"{\\"foo\\": \\"bar\\"}"'),
- ({"foo": "bar"}, 'new Dictionary{ { "foo" , "bar" } }'),
- ({"foo": '"bar"'}, 'new Dictionary{ { "foo" , "\\"bar\\"" } }'),
- (["foo"], 'new [] { "foo" }'),
- (["foo", '"bar"'], 'new [] { "foo", "\\"bar\\"" }'),
+ ({'foo': 'bar'}, 'new Dictionary{ { "foo" , "bar" } }'),
+ ({'foo': '"bar"'}, 'new Dictionary{ { "foo" , "\\"bar\\"" } }'),
+ (['foo'], 'new [] { "foo" }'),
+ (['foo', '"bar"'], 'new [] { "foo", "\\"bar\\"" }'),
(
- [{"foo": "bar"}],
+ [{'foo': 'bar'}],
'new [] { new Dictionary{ { "foo" , "bar" } } }',
),
- (12345, "12345"),
- (-54321, "-54321"),
- (1.2345, "1.2345"),
- (-5432.1, "-5432.1"),
- (2147483648, "2147483648L"),
- (-2147483649, "-2147483649L"),
- (True, "true"),
- (False, "false"),
+ (12345, '12345'),
+ (-54321, '-54321'),
+ (1.2345, '1.2345'),
+ (-5432.1, '-5432.1'),
+ (2147483648, '2147483648L'),
+ (-2147483649, '-2147483649L'),
+ (True, 'true'),
+ (False, 'false'),
],
)
def test_translate_type_csharp(test_input, expected):
@@ -300,34 +299,34 @@ def test_translate_type_csharp(test_input, expected):
@pytest.mark.parametrize(
- "test_input,expected",
- [("", "//"), ("foo", "// foo"), ("['best effort']", "// ['best effort']")],
+ 'test_input,expected',
+ [('', '//'), ('foo', '// foo'), ("['best effort']", "// ['best effort']")],
)
def test_translate_comment_csharp(test_input, expected):
assert translators.CSharpTranslator.comment(test_input) == expected
@pytest.mark.parametrize(
- "input_name,input_value,expected",
- [("foo", '""', 'var foo = "";'), ("foo", '"bar"', 'var foo = "bar";')],
+ 'input_name,input_value,expected',
+ [('foo', '""', 'var foo = "";'), ('foo', '"bar"', 'var foo = "bar";')],
)
def test_translate_assign_csharp(input_name, input_value, expected):
assert translators.CSharpTranslator.assign(input_name, input_value) == expected
@pytest.mark.parametrize(
- "parameters,expected",
+ 'parameters,expected',
[
- ({"foo": "bar"}, '// Parameters\nvar foo = "bar";\n'),
- ({"foo": True}, "// Parameters\nvar foo = true;\n"),
- ({"foo": 5}, "// Parameters\nvar foo = 5;\n"),
- ({"foo": 1.1}, "// Parameters\nvar foo = 1.1;\n"),
+ ({'foo': 'bar'}, '// Parameters\nvar foo = "bar";\n'),
+ ({'foo': True}, '// Parameters\nvar foo = true;\n'),
+ ({'foo': 5}, '// Parameters\nvar foo = 5;\n'),
+ ({'foo': 1.1}, '// Parameters\nvar foo = 1.1;\n'),
(
- {"foo": ["bar", "baz"]},
+ {'foo': ['bar', 'baz']},
'// Parameters\nvar foo = new [] { "bar", "baz" };\n',
),
(
- {"foo": {"bar": "baz"}},
+ {'foo': {'bar': 'baz'}},
'// Parameters\nvar foo = new Dictionary{ { "bar" , "baz" } };\n',
),
],
@@ -338,29 +337,29 @@ def test_translate_codify_csharp(parameters, expected):
# Powershell section
@pytest.mark.parametrize(
- "test_input,expected",
+ 'test_input,expected',
[
- ("foo", '"foo"'),
+ ('foo', '"foo"'),
('{"foo": "bar"}', '"{`"foo`": `"bar`"}"'),
- ({"foo": "bar"}, '@{"foo" = "bar"}'),
- ({"foo": '"bar"'}, '@{"foo" = "`"bar`""}'),
- ({"foo": ["bar"]}, '@{"foo" = @("bar")}'),
- ({"foo": {"bar": "baz"}}, '@{"foo" = @{"bar" = "baz"}}'),
- ({"foo": {"bar": '"baz"'}}, '@{"foo" = @{"bar" = "`"baz`""}}'),
- (["foo"], '@("foo")'),
- (["foo", '"bar"'], '@("foo", "`"bar`"")'),
- ([{"foo": "bar"}], '@(@{"foo" = "bar"})'),
- ([{"foo": '"bar"'}], '@(@{"foo" = "`"bar`""})'),
- (12345, "12345"),
- (-54321, "-54321"),
- (1.2345, "1.2345"),
- (-5432.1, "-5432.1"),
- (float("nan"), "[double]::NaN"),
- (float("-inf"), "[double]::NegativeInfinity"),
- (float("inf"), "[double]::PositiveInfinity"),
- (True, "$True"),
- (False, "$False"),
- (None, "$Null"),
+ ({'foo': 'bar'}, '@{"foo" = "bar"}'),
+ ({'foo': '"bar"'}, '@{"foo" = "`"bar`""}'),
+ ({'foo': ['bar']}, '@{"foo" = @("bar")}'),
+ ({'foo': {'bar': 'baz'}}, '@{"foo" = @{"bar" = "baz"}}'),
+ ({'foo': {'bar': '"baz"'}}, '@{"foo" = @{"bar" = "`"baz`""}}'),
+ (['foo'], '@("foo")'),
+ (['foo', '"bar"'], '@("foo", "`"bar`"")'),
+ ([{'foo': 'bar'}], '@(@{"foo" = "bar"})'),
+ ([{'foo': '"bar"'}], '@(@{"foo" = "`"bar`""})'),
+ (12345, '12345'),
+ (-54321, '-54321'),
+ (1.2345, '1.2345'),
+ (-5432.1, '-5432.1'),
+ (float('nan'), '[double]::NaN'),
+ (float('-inf'), '[double]::NegativeInfinity'),
+ (float('inf'), '[double]::PositiveInfinity'),
+ (True, '$True'),
+ (False, '$False'),
+ (None, '$Null'),
],
)
def test_translate_type_powershell(test_input, expected):
@@ -368,16 +367,16 @@ def test_translate_type_powershell(test_input, expected):
@pytest.mark.parametrize(
- "parameters,expected",
+ 'parameters,expected',
[
- ({"foo": "bar"}, '# Parameters\n$foo = "bar"\n'),
- ({"foo": True}, "# Parameters\n$foo = $True\n"),
- ({"foo": 5}, "# Parameters\n$foo = 5\n"),
- ({"foo": 1.1}, "# Parameters\n$foo = 1.1\n"),
- ({"foo": ["bar", "baz"]}, '# Parameters\n$foo = @("bar", "baz")\n'),
- ({"foo": {"bar": "baz"}}, '# Parameters\n$foo = @{"bar" = "baz"}\n'),
+ ({'foo': 'bar'}, '# Parameters\n$foo = "bar"\n'),
+ ({'foo': True}, '# Parameters\n$foo = $True\n'),
+ ({'foo': 5}, '# Parameters\n$foo = 5\n'),
+ ({'foo': 1.1}, '# Parameters\n$foo = 1.1\n'),
+ ({'foo': ['bar', 'baz']}, '# Parameters\n$foo = @("bar", "baz")\n'),
+ ({'foo': {'bar': 'baz'}}, '# Parameters\n$foo = @{"bar" = "baz"}\n'),
(
- OrderedDict([["foo", "bar"], ["baz", ["buz"]]]),
+ OrderedDict([['foo', 'bar'], ['baz', ['buz']]]),
'# Parameters\n$foo = "bar"\n$baz = @("buz")\n',
),
],
@@ -387,16 +386,16 @@ def test_translate_codify_powershell(parameters, expected):
@pytest.mark.parametrize(
- "input_name,input_value,expected",
- [("foo", '""', '$foo = ""'), ("foo", '"bar"', '$foo = "bar"')],
+ 'input_name,input_value,expected',
+ [('foo', '""', '$foo = ""'), ('foo', '"bar"', '$foo = "bar"')],
)
def test_translate_assign_powershell(input_name, input_value, expected):
assert translators.PowershellTranslator.assign(input_name, input_value) == expected
@pytest.mark.parametrize(
- "test_input,expected",
- [("", "#"), ("foo", "# foo"), ("['best effort']", "# ['best effort']")],
+ 'test_input,expected',
+ [('', '#'), ('foo', '# foo'), ("['best effort']", "# ['best effort']")],
)
def test_translate_comment_powershell(test_input, expected):
assert translators.PowershellTranslator.comment(test_input) == expected
@@ -404,23 +403,23 @@ def test_translate_comment_powershell(test_input, expected):
# F# section
@pytest.mark.parametrize(
- "test_input,expected",
+ 'test_input,expected',
[
- ("foo", '"foo"'),
+ ('foo', '"foo"'),
('{"foo": "bar"}', '"{\\"foo\\": \\"bar\\"}"'),
- ({"foo": "bar"}, '[ ("foo", "bar" :> IComparable) ] |> Map.ofList'),
- ({"foo": '"bar"'}, '[ ("foo", "\\"bar\\"" :> IComparable) ] |> Map.ofList'),
- (["foo"], '[ "foo" ]'),
- (["foo", '"bar"'], '[ "foo"; "\\"bar\\"" ]'),
- ([{"foo": "bar"}], '[ [ ("foo", "bar" :> IComparable) ] |> Map.ofList ]'),
- (12345, "12345"),
- (-54321, "-54321"),
- (1.2345, "1.2345"),
- (-5432.1, "-5432.1"),
- (2147483648, "2147483648L"),
- (-2147483649, "-2147483649L"),
- (True, "true"),
- (False, "false"),
+ ({'foo': 'bar'}, '[ ("foo", "bar" :> IComparable) ] |> Map.ofList'),
+ ({'foo': '"bar"'}, '[ ("foo", "\\"bar\\"" :> IComparable) ] |> Map.ofList'),
+ (['foo'], '[ "foo" ]'),
+ (['foo', '"bar"'], '[ "foo"; "\\"bar\\"" ]'),
+ ([{'foo': 'bar'}], '[ [ ("foo", "bar" :> IComparable) ] |> Map.ofList ]'),
+ (12345, '12345'),
+ (-54321, '-54321'),
+ (1.2345, '1.2345'),
+ (-5432.1, '-5432.1'),
+ (2147483648, '2147483648L'),
+ (-2147483649, '-2147483649L'),
+ (True, 'true'),
+ (False, 'false'),
],
)
def test_translate_type_fsharp(test_input, expected):
@@ -428,10 +427,10 @@ def test_translate_type_fsharp(test_input, expected):
@pytest.mark.parametrize(
- "test_input,expected",
+ 'test_input,expected',
[
- ("", "(* *)"),
- ("foo", "(* foo *)"),
+ ('', '(* *)'),
+ ('foo', '(* foo *)'),
("['best effort']", "(* ['best effort'] *)"),
],
)
@@ -440,23 +439,23 @@ def test_translate_comment_fsharp(test_input, expected):
@pytest.mark.parametrize(
- "input_name,input_value,expected",
- [("foo", '""', 'let foo = ""'), ("foo", '"bar"', 'let foo = "bar"')],
+ 'input_name,input_value,expected',
+ [('foo', '""', 'let foo = ""'), ('foo', '"bar"', 'let foo = "bar"')],
)
def test_translate_assign_fsharp(input_name, input_value, expected):
assert translators.FSharpTranslator.assign(input_name, input_value) == expected
@pytest.mark.parametrize(
- "parameters,expected",
+ 'parameters,expected',
[
- ({"foo": "bar"}, '(* Parameters *)\nlet foo = "bar"\n'),
- ({"foo": True}, "(* Parameters *)\nlet foo = true\n"),
- ({"foo": 5}, "(* Parameters *)\nlet foo = 5\n"),
- ({"foo": 1.1}, "(* Parameters *)\nlet foo = 1.1\n"),
- ({"foo": ["bar", "baz"]}, '(* Parameters *)\nlet foo = [ "bar"; "baz" ]\n'),
+ ({'foo': 'bar'}, '(* Parameters *)\nlet foo = "bar"\n'),
+ ({'foo': True}, '(* Parameters *)\nlet foo = true\n'),
+ ({'foo': 5}, '(* Parameters *)\nlet foo = 5\n'),
+ ({'foo': 1.1}, '(* Parameters *)\nlet foo = 1.1\n'),
+ ({'foo': ['bar', 'baz']}, '(* Parameters *)\nlet foo = [ "bar"; "baz" ]\n'),
(
- {"foo": {"bar": "baz"}},
+ {'foo': {'bar': 'baz'}},
'(* Parameters *)\nlet foo = [ ("bar", "baz" :> IComparable) ] |> Map.ofList\n',
),
],
@@ -466,26 +465,26 @@ def test_translate_codify_fsharp(parameters, expected):
@pytest.mark.parametrize(
- "test_input,expected",
+ 'test_input,expected',
[
- ("foo", '"foo"'),
+ ('foo', '"foo"'),
('{"foo": "bar"}', '"{\\"foo\\": \\"bar\\"}"'),
- ({"foo": "bar"}, 'Dict("foo" => "bar")'),
- ({"foo": '"bar"'}, 'Dict("foo" => "\\"bar\\"")'),
- ({"foo": ["bar"]}, 'Dict("foo" => ["bar"])'),
- ({"foo": {"bar": "baz"}}, 'Dict("foo" => Dict("bar" => "baz"))'),
- ({"foo": {"bar": '"baz"'}}, 'Dict("foo" => Dict("bar" => "\\"baz\\""))'),
- (["foo"], '["foo"]'),
- (["foo", '"bar"'], '["foo", "\\"bar\\""]'),
- ([{"foo": "bar"}], '[Dict("foo" => "bar")]'),
- ([{"foo": '"bar"'}], '[Dict("foo" => "\\"bar\\"")]'),
- (12345, "12345"),
- (-54321, "-54321"),
- (1.2345, "1.2345"),
- (-5432.1, "-5432.1"),
- (True, "true"),
- (False, "false"),
- (None, "nothing"),
+ ({'foo': 'bar'}, 'Dict("foo" => "bar")'),
+ ({'foo': '"bar"'}, 'Dict("foo" => "\\"bar\\"")'),
+ ({'foo': ['bar']}, 'Dict("foo" => ["bar"])'),
+ ({'foo': {'bar': 'baz'}}, 'Dict("foo" => Dict("bar" => "baz"))'),
+ ({'foo': {'bar': '"baz"'}}, 'Dict("foo" => Dict("bar" => "\\"baz\\""))'),
+ (['foo'], '["foo"]'),
+ (['foo', '"bar"'], '["foo", "\\"bar\\""]'),
+ ([{'foo': 'bar'}], '[Dict("foo" => "bar")]'),
+ ([{'foo': '"bar"'}], '[Dict("foo" => "\\"bar\\"")]'),
+ (12345, '12345'),
+ (-54321, '-54321'),
+ (1.2345, '1.2345'),
+ (-5432.1, '-5432.1'),
+ (True, 'true'),
+ (False, 'false'),
+ (None, 'nothing'),
],
)
def test_translate_type_julia(test_input, expected):
@@ -493,16 +492,16 @@ def test_translate_type_julia(test_input, expected):
@pytest.mark.parametrize(
- "parameters,expected",
+ 'parameters,expected',
[
- ({"foo": "bar"}, '# Parameters\nfoo = "bar"\n'),
- ({"foo": True}, "# Parameters\nfoo = true\n"),
- ({"foo": 5}, "# Parameters\nfoo = 5\n"),
- ({"foo": 1.1}, "# Parameters\nfoo = 1.1\n"),
- ({"foo": ["bar", "baz"]}, '# Parameters\nfoo = ["bar", "baz"]\n'),
- ({"foo": {"bar": "baz"}}, '# Parameters\nfoo = Dict("bar" => "baz")\n'),
+ ({'foo': 'bar'}, '# Parameters\nfoo = "bar"\n'),
+ ({'foo': True}, '# Parameters\nfoo = true\n'),
+ ({'foo': 5}, '# Parameters\nfoo = 5\n'),
+ ({'foo': 1.1}, '# Parameters\nfoo = 1.1\n'),
+ ({'foo': ['bar', 'baz']}, '# Parameters\nfoo = ["bar", "baz"]\n'),
+ ({'foo': {'bar': 'baz'}}, '# Parameters\nfoo = Dict("bar" => "baz")\n'),
(
- OrderedDict([["foo", "bar"], ["baz", ["buz"]]]),
+ OrderedDict([['foo', 'bar'], ['baz', ['buz']]]),
'# Parameters\nfoo = "bar"\nbaz = ["buz"]\n',
),
],
@@ -512,44 +511,44 @@ def test_translate_codify_julia(parameters, expected):
@pytest.mark.parametrize(
- "test_input,expected",
- [("", "#"), ("foo", "# foo"), ('["best effort"]', '# ["best effort"]')],
+ 'test_input,expected',
+ [('', '#'), ('foo', '# foo'), ('["best effort"]', '# ["best effort"]')],
)
def test_translate_comment_julia(test_input, expected):
assert translators.JuliaTranslator.comment(test_input) == expected
@pytest.mark.parametrize(
- "test_input,expected",
+ 'test_input,expected',
[
- ("foo", '"foo"'),
+ ('foo', '"foo"'),
('{"foo": "bar"}', '"{""foo"": ""bar""}"'),
- ({1: "foo"}, "containers.Map({'1'}, {\"foo\"})"),
- ({1.0: "foo"}, "containers.Map({'1.0'}, {\"foo\"})"),
- ({None: "foo"}, "containers.Map({'None'}, {\"foo\"})"),
- ({True: "foo"}, "containers.Map({'True'}, {\"foo\"})"),
- ({"foo": "bar"}, "containers.Map({'foo'}, {\"bar\"})"),
- ({"foo": '"bar"'}, 'containers.Map({\'foo\'}, {"""bar"""})'),
- ({"foo": ["bar"]}, "containers.Map({'foo'}, {{\"bar\"}})"),
+ ({1: 'foo'}, 'containers.Map({\'1\'}, {"foo"})'),
+ ({1.0: 'foo'}, 'containers.Map({\'1.0\'}, {"foo"})'),
+ ({None: 'foo'}, 'containers.Map({\'None\'}, {"foo"})'),
+ ({True: 'foo'}, 'containers.Map({\'True\'}, {"foo"})'),
+ ({'foo': 'bar'}, 'containers.Map({\'foo\'}, {"bar"})'),
+ ({'foo': '"bar"'}, 'containers.Map({\'foo\'}, {"""bar"""})'),
+ ({'foo': ['bar']}, 'containers.Map({\'foo\'}, {{"bar"}})'),
(
- {"foo": {"bar": "baz"}},
+ {'foo': {'bar': 'baz'}},
"containers.Map({'foo'}, {containers.Map({'bar'}, {\"baz\"})})",
),
(
- {"foo": {"bar": '"baz"'}},
+ {'foo': {'bar': '"baz"'}},
'containers.Map({\'foo\'}, {containers.Map({\'bar\'}, {"""baz"""})})',
),
- (["foo"], '{"foo"}'),
- (["foo", '"bar"'], '{"foo", """bar"""}'),
- ([{"foo": "bar"}], "{containers.Map({'foo'}, {\"bar\"})}"),
- ([{"foo": '"bar"'}], '{containers.Map({\'foo\'}, {"""bar"""})}'),
- (12345, "12345"),
- (-54321, "-54321"),
- (1.2345, "1.2345"),
- (-5432.1, "-5432.1"),
- (True, "true"),
- (False, "false"),
- (None, "NaN"),
+ (['foo'], '{"foo"}'),
+ (['foo', '"bar"'], '{"foo", """bar"""}'),
+ ([{'foo': 'bar'}], '{containers.Map({\'foo\'}, {"bar"})}'),
+ ([{'foo': '"bar"'}], '{containers.Map({\'foo\'}, {"""bar"""})}'),
+ (12345, '12345'),
+ (-54321, '-54321'),
+ (1.2345, '1.2345'),
+ (-5432.1, '-5432.1'),
+ (True, 'true'),
+ (False, 'false'),
+ (None, 'NaN'),
],
)
def test_translate_type_matlab(test_input, expected):
@@ -557,19 +556,19 @@ def test_translate_type_matlab(test_input, expected):
@pytest.mark.parametrize(
- "parameters,expected",
+ 'parameters,expected',
[
- ({"foo": "bar"}, '% Parameters\nfoo = "bar";\n'),
- ({"foo": True}, "% Parameters\nfoo = true;\n"),
- ({"foo": 5}, "% Parameters\nfoo = 5;\n"),
- ({"foo": 1.1}, "% Parameters\nfoo = 1.1;\n"),
- ({"foo": ["bar", "baz"]}, '% Parameters\nfoo = {"bar", "baz"};\n'),
+ ({'foo': 'bar'}, '% Parameters\nfoo = "bar";\n'),
+ ({'foo': True}, '% Parameters\nfoo = true;\n'),
+ ({'foo': 5}, '% Parameters\nfoo = 5;\n'),
+ ({'foo': 1.1}, '% Parameters\nfoo = 1.1;\n'),
+ ({'foo': ['bar', 'baz']}, '% Parameters\nfoo = {"bar", "baz"};\n'),
(
- {"foo": {"bar": "baz"}},
- "% Parameters\nfoo = containers.Map({'bar'}, {\"baz\"});\n",
+ {'foo': {'bar': 'baz'}},
+ '% Parameters\nfoo = containers.Map({\'bar\'}, {"baz"});\n',
),
(
- OrderedDict([["foo", "bar"], ["baz", ["buz"]]]),
+ OrderedDict([['foo', 'bar'], ['baz', ['buz']]]),
'% Parameters\nfoo = "bar";\nbaz = {"buz"};\n',
),
],
@@ -579,8 +578,8 @@ def test_translate_codify_matlab(parameters, expected):
@pytest.mark.parametrize(
- "test_input,expected",
- [("", "%"), ("foo", "% foo"), ("['best effort']", "% ['best effort']")],
+ 'test_input,expected',
+ [('', '%'), ('foo', '% foo'), ("['best effort']", "% ['best effort']")],
)
def test_translate_comment_matlab(test_input, expected):
assert translators.MatlabTranslator.comment(test_input) == expected
@@ -589,44 +588,32 @@ def test_translate_comment_matlab(test_input, expected):
def test_find_translator_with_exact_kernel_name():
my_new_kernel_translator = Mock()
my_new_language_translator = Mock()
- translators.papermill_translators.register(
- "my_new_kernel", my_new_kernel_translator
- )
- translators.papermill_translators.register(
- "my_new_language", my_new_language_translator
- )
+ translators.papermill_translators.register('my_new_kernel', my_new_kernel_translator)
+ translators.papermill_translators.register('my_new_language', my_new_language_translator)
assert (
- translators.papermill_translators.find_translator(
- "my_new_kernel", "my_new_language"
- )
+ translators.papermill_translators.find_translator('my_new_kernel', 'my_new_language')
is my_new_kernel_translator
)
def test_find_translator_with_exact_language():
my_new_language_translator = Mock()
- translators.papermill_translators.register(
- "my_new_language", my_new_language_translator
- )
+ translators.papermill_translators.register('my_new_language', my_new_language_translator)
assert (
- translators.papermill_translators.find_translator(
- "unregistered_kernel", "my_new_language"
- )
+ translators.papermill_translators.find_translator('unregistered_kernel', 'my_new_language')
is my_new_language_translator
)
def test_find_translator_with_no_such_kernel_or_language():
with pytest.raises(PapermillException):
- translators.papermill_translators.find_translator(
- "unregistered_kernel", "unregistered_language"
- )
+ translators.papermill_translators.find_translator('unregistered_kernel', 'unregistered_language')
def test_translate_uses_str_representation_of_unknown_types():
class FooClass:
def __str__(self):
- return "foo"
+ return 'foo'
obj = FooClass()
assert translators.Translator.translate(obj) == '"foo"'
@@ -637,7 +624,7 @@ class MyNewTranslator(translators.Translator):
pass
with pytest.raises(NotImplementedError):
- MyNewTranslator.translate_dict({"foo": "bar"})
+ MyNewTranslator.translate_dict({'foo': 'bar'})
def test_translator_must_implement_translate_list():
@@ -645,7 +632,7 @@ class MyNewTranslator(translators.Translator):
pass
with pytest.raises(NotImplementedError):
- MyNewTranslator.translate_list(["foo", "bar"])
+ MyNewTranslator.translate_list(['foo', 'bar'])
def test_translator_must_implement_comment():
@@ -653,24 +640,24 @@ class MyNewTranslator(translators.Translator):
pass
with pytest.raises(NotImplementedError):
- MyNewTranslator.comment("foo")
+ MyNewTranslator.comment('foo')
# Bash/sh section
@pytest.mark.parametrize(
- "test_input,expected",
+ 'test_input,expected',
[
- ("foo", "foo"),
- ("foo space", "'foo space'"),
+ ('foo', 'foo'),
+ ('foo space', "'foo space'"),
("foo's apostrophe", "'foo'\"'\"'s apostrophe'"),
- ("shell ( is ) ", "'shell ( is ) '"),
- (12345, "12345"),
- (-54321, "-54321"),
- (1.2345, "1.2345"),
- (-5432.1, "-5432.1"),
- (True, "true"),
- (False, "false"),
- (None, ""),
+ ('shell ( is ) ', "'shell ( is ) '"),
+ (12345, '12345'),
+ (-54321, '-54321'),
+ (1.2345, '1.2345'),
+ (-5432.1, '-5432.1'),
+ (True, 'true'),
+ (False, 'false'),
+ (None, ''),
],
)
def test_translate_type_sh(test_input, expected):
@@ -678,23 +665,23 @@ def test_translate_type_sh(test_input, expected):
@pytest.mark.parametrize(
- "test_input,expected",
- [("", "#"), ("foo", "# foo"), ("['best effort']", "# ['best effort']")],
+ 'test_input,expected',
+ [('', '#'), ('foo', '# foo'), ("['best effort']", "# ['best effort']")],
)
def test_translate_comment_sh(test_input, expected):
assert translators.BashTranslator.comment(test_input) == expected
@pytest.mark.parametrize(
- "parameters,expected",
+ 'parameters,expected',
[
- ({"foo": "bar"}, "# Parameters\nfoo=bar\n"),
- ({"foo": "shell ( is ) "}, "# Parameters\nfoo='shell ( is ) '\n"),
- ({"foo": True}, "# Parameters\nfoo=true\n"),
- ({"foo": 5}, "# Parameters\nfoo=5\n"),
- ({"foo": 1.1}, "# Parameters\nfoo=1.1\n"),
+ ({'foo': 'bar'}, '# Parameters\nfoo=bar\n'),
+ ({'foo': 'shell ( is ) '}, "# Parameters\nfoo='shell ( is ) '\n"),
+ ({'foo': True}, '# Parameters\nfoo=true\n'),
+ ({'foo': 5}, '# Parameters\nfoo=5\n'),
+ ({'foo': 1.1}, '# Parameters\nfoo=1.1\n'),
(
- OrderedDict([["foo", "bar"], ["baz", "$dumb(shell)"]]),
+ OrderedDict([['foo', 'bar'], ['baz', '$dumb(shell)']]),
"# Parameters\nfoo=bar\nbaz='$dumb(shell)'\n",
),
],
diff --git a/papermill/tests/test_utils.py b/papermill/tests/test_utils.py
index 519fa383..4d058fb2 100644
--- a/papermill/tests/test_utils.py
+++ b/papermill/tests/test_utils.py
@@ -1,59 +1,53 @@
-import pytest
import warnings
-
-from unittest.mock import Mock, call
-from tempfile import TemporaryDirectory
from pathlib import Path
+from tempfile import TemporaryDirectory
+from unittest.mock import Mock, call
-from nbformat.v4 import new_notebook, new_code_cell
+import pytest
+from nbformat.v4 import new_code_cell, new_notebook
+from ..exceptions import PapermillParameterOverwriteWarning
from ..utils import (
any_tagged_cell,
- retry,
chdir,
merge_kwargs,
remove_args,
+ retry,
)
-from ..exceptions import PapermillParameterOverwriteWarning
def test_no_tagged_cell():
nb = new_notebook(
- cells=[new_code_cell("a = 2", metadata={"tags": []})],
+ cells=[new_code_cell('a = 2', metadata={'tags': []})],
)
- assert not any_tagged_cell(nb, "parameters")
+ assert not any_tagged_cell(nb, 'parameters')
def test_tagged_cell():
nb = new_notebook(
- cells=[new_code_cell("a = 2", metadata={"tags": ["parameters"]})],
+ cells=[new_code_cell('a = 2', metadata={'tags': ['parameters']})],
)
- assert any_tagged_cell(nb, "parameters")
+ assert any_tagged_cell(nb, 'parameters')
def test_merge_kwargs():
with warnings.catch_warnings(record=True) as wrn:
- assert merge_kwargs({"a": 1, "b": 2}, a=3) == {"a": 3, "b": 2}
+ assert merge_kwargs({'a': 1, 'b': 2}, a=3) == {'a': 3, 'b': 2}
assert len(wrn) == 1
assert issubclass(wrn[0].category, PapermillParameterOverwriteWarning)
- assert (
- wrn[0].message.__str__()
- == "Callee will overwrite caller's argument(s): a=3"
- )
+ assert wrn[0].message.__str__() == "Callee will overwrite caller's argument(s): a=3"
def test_remove_args():
- assert remove_args(["a"], a=1, b=2, c=3) == {"c": 3, "b": 2}
+ assert remove_args(['a'], a=1, b=2, c=3) == {'c': 3, 'b': 2}
def test_retry():
- m = Mock(
- side_effect=RuntimeError(), __name__="m", __module__="test_s3", __doc__="m"
- )
+ m = Mock(side_effect=RuntimeError(), __name__='m', __module__='test_s3', __doc__='m')
wrapped_m = retry(3)(m)
with pytest.raises(RuntimeError):
- wrapped_m("foo")
- m.assert_has_calls([call("foo"), call("foo"), call("foo")])
+ wrapped_m('foo')
+ m.assert_has_calls([call('foo'), call('foo'), call('foo')])
def test_chdir():
diff --git a/papermill/translators.py b/papermill/translators.py
index ace316bf..0086f84f 100644
--- a/papermill/translators.py
+++ b/papermill/translators.py
@@ -6,7 +6,6 @@
from .exceptions import PapermillException
from .models import Parameter
-
logger = logging.getLogger(__name__)
@@ -29,9 +28,7 @@ def find_translator(self, kernel_name, language):
elif language in self._translators:
return self._translators[language]
raise PapermillException(
- "No parameter translator functions specified for kernel '{}' or language '{}'".format(
- kernel_name, language
- )
+ f"No parameter translator functions specified for kernel '{kernel_name}' or language '{language}'"
)
@@ -39,15 +36,15 @@ class Translator:
@classmethod
def translate_raw_str(cls, val):
"""Reusable by most interpreters"""
- return f"{val}"
+ return f'{val}'
@classmethod
def translate_escaped_str(cls, str_val):
"""Reusable by most interpreters"""
if isinstance(str_val, str):
- str_val = str_val.encode("unicode_escape")
- str_val = str_val.decode("utf-8")
- str_val = str_val.replace('"', r"\"")
+ str_val = str_val.encode('unicode_escape')
+ str_val = str_val.decode('utf-8')
+ str_val = str_val.replace('"', r'\"')
return f'"{str_val}"'
@classmethod
@@ -73,15 +70,15 @@ def translate_float(cls, val):
@classmethod
def translate_bool(cls, val):
"""Default behavior for translation"""
- return "true" if val else "false"
+ return 'true' if val else 'false'
@classmethod
def translate_dict(cls, val):
- raise NotImplementedError(f"dict type translation not implemented for {cls}")
+ raise NotImplementedError(f'dict type translation not implemented for {cls}')
@classmethod
def translate_list(cls, val):
- raise NotImplementedError(f"list type translation not implemented for {cls}")
+ raise NotImplementedError(f'list type translation not implemented for {cls}')
@classmethod
def translate(cls, val):
@@ -106,17 +103,17 @@ def translate(cls, val):
@classmethod
def comment(cls, cmt_str):
- raise NotImplementedError(f"comment translation not implemented for {cls}")
+ raise NotImplementedError(f'comment translation not implemented for {cls}')
@classmethod
def assign(cls, name, str_val):
- return f"{name} = {str_val}"
+ return f'{name} = {str_val}'
@classmethod
- def codify(cls, parameters, comment="Parameters"):
- content = f"{cls.comment(comment)}\n"
+ def codify(cls, parameters, comment='Parameters'):
+ content = f'{cls.comment(comment)}\n'
for name, val in parameters.items():
- content += f"{cls.assign(name, cls.translate(val))}\n"
+ content += f'{cls.assign(name, cls.translate(val))}\n'
return content
@classmethod
@@ -140,7 +137,7 @@ def inspect(cls, parameters_cell):
List[Parameter]
A list of all parameters
"""
- raise NotImplementedError(f"parameters introspection not implemented for {cls}")
+ raise NotImplementedError(f'parameters introspection not implemented for {cls}')
class PythonTranslator(Translator):
@@ -166,22 +163,20 @@ def translate_bool(cls, val):
@classmethod
def translate_dict(cls, val):
- escaped = ", ".join(
- [f"{cls.translate_str(k)}: {cls.translate(v)}" for k, v in val.items()]
- )
- return f"{{{escaped}}}"
+ escaped = ', '.join([f'{cls.translate_str(k)}: {cls.translate(v)}' for k, v in val.items()])
+ return f'{{{escaped}}}'
@classmethod
def translate_list(cls, val):
- escaped = ", ".join([cls.translate(v) for v in val])
- return f"[{escaped}]"
+ escaped = ', '.join([cls.translate(v) for v in val])
+ return f'[{escaped}]'
@classmethod
def comment(cls, cmt_str):
- return f"# {cmt_str}".strip()
+ return f'# {cmt_str}'.strip()
@classmethod
- def codify(cls, parameters, comment="Parameters"):
+ def codify(cls, parameters, comment='Parameters'):
content = super().codify(parameters, comment)
try:
# Put content through the Black Python code formatter
@@ -192,7 +187,7 @@ def codify(cls, parameters, comment="Parameters"):
except ImportError:
logger.debug("Black is not installed, parameters won't be formatted")
except AttributeError as aerr:
- logger.warning(f"Black encountered an error, skipping formatting ({aerr})")
+ logger.warning(f'Black encountered an error, skipping formatting ({aerr})')
return content
@classmethod
@@ -213,7 +208,7 @@ def inspect(cls, parameters_cell):
A list of all parameters
"""
params = []
- src = parameters_cell["source"]
+ src = parameters_cell['source']
def flatten_accumulator(accumulator):
"""Flatten a multilines variable definition.
@@ -225,10 +220,10 @@ def flatten_accumulator(accumulator):
Returns:
Flatten definition
"""
- flat_string = ""
+ flat_string = ''
for line in accumulator[:-1]:
- if "#" in line:
- comment_pos = line.index("#")
+ if '#' in line:
+ comment_pos = line.index('#')
flat_string += line[:comment_pos].strip()
else:
flat_string += line.strip()
@@ -244,10 +239,10 @@ def flatten_accumulator(accumulator):
grouped_variable = []
accumulator = []
for iline, line in enumerate(src.splitlines()):
- if len(line.strip()) == 0 or line.strip().startswith("#"):
+ if len(line.strip()) == 0 or line.strip().startswith('#'):
continue # Skip blank and comment
- nequal = line.count("=")
+ nequal = line.count('=')
if nequal > 0:
grouped_variable.append(flatten_accumulator(accumulator))
accumulator = []
@@ -265,16 +260,16 @@ def flatten_accumulator(accumulator):
match = re.match(cls.PARAMETER_PATTERN, definition)
if match is not None:
attr = match.groupdict()
- if attr["target"] is None: # Fail to get variable name
+ if attr['target'] is None: # Fail to get variable name
continue
- type_name = str(attr["annotation"] or attr["type_comment"] or None)
+ type_name = str(attr['annotation'] or attr['type_comment'] or None)
params.append(
Parameter(
- name=attr["target"].strip(),
+ name=attr['target'].strip(),
inferred_type_name=type_name.strip(),
- default=str(attr["value"]).strip(),
- help=str(attr["help"] or "").strip(),
+ default=str(attr['value']).strip(),
+ help=str(attr['help'] or '').strip(),
)
)
@@ -284,85 +279,79 @@ def flatten_accumulator(accumulator):
class RTranslator(Translator):
@classmethod
def translate_none(cls, val):
- return "NULL"
+ return 'NULL'
@classmethod
def translate_bool(cls, val):
- return "TRUE" if val else "FALSE"
+ return 'TRUE' if val else 'FALSE'
@classmethod
def translate_dict(cls, val):
- escaped = ", ".join(
- [f"{cls.translate_str(k)} = {cls.translate(v)}" for k, v in val.items()]
- )
- return f"list({escaped})"
+ escaped = ', '.join([f'{cls.translate_str(k)} = {cls.translate(v)}' for k, v in val.items()])
+ return f'list({escaped})'
@classmethod
def translate_list(cls, val):
- escaped = ", ".join([cls.translate(v) for v in val])
- return f"list({escaped})"
+ escaped = ', '.join([cls.translate(v) for v in val])
+ return f'list({escaped})'
@classmethod
def comment(cls, cmt_str):
- return f"# {cmt_str}".strip()
+ return f'# {cmt_str}'.strip()
@classmethod
def assign(cls, name, str_val):
# Leading '_' aren't legal R variable names -- so we drop them when injecting
- while name.startswith("_"):
+ while name.startswith('_'):
name = name[1:]
- return f"{name} = {str_val}"
+ return f'{name} = {str_val}'
class ScalaTranslator(Translator):
@classmethod
def translate_int(cls, val):
strval = cls.translate_raw_str(val)
- return strval + "L" if (val > 2147483647 or val < -2147483648) else strval
+ return strval + 'L' if (val > 2147483647 or val < -2147483648) else strval
@classmethod
def translate_dict(cls, val):
"""Translate dicts to scala Maps"""
- escaped = ", ".join(
- [f"{cls.translate_str(k)} -> {cls.translate(v)}" for k, v in val.items()]
- )
- return f"Map({escaped})"
+ escaped = ', '.join([f'{cls.translate_str(k)} -> {cls.translate(v)}' for k, v in val.items()])
+ return f'Map({escaped})'
@classmethod
def translate_list(cls, val):
"""Translate list to scala Seq"""
- escaped = ", ".join([cls.translate(v) for v in val])
- return f"Seq({escaped})"
+ escaped = ', '.join([cls.translate(v) for v in val])
+ return f'Seq({escaped})'
@classmethod
def comment(cls, cmt_str):
- return f"// {cmt_str}".strip()
+ return f'// {cmt_str}'.strip()
@classmethod
def assign(cls, name, str_val):
- return f"val {name} = {str_val}"
+ return f'val {name} = {str_val}'
class JuliaTranslator(Translator):
@classmethod
def translate_none(cls, val):
- return "nothing"
+ return 'nothing'
@classmethod
def translate_dict(cls, val):
- escaped = ", ".join(
- [f"{cls.translate_str(k)} => {cls.translate(v)}" for k, v in val.items()]
- )
- return f"Dict({escaped})"
+ escaped = ', '.join([f'{cls.translate_str(k)} => {cls.translate(v)}' for k, v in val.items()])
+ return f'Dict({escaped})'
@classmethod
def translate_list(cls, val):
- escaped = ", ".join([cls.translate(v) for v in val])
- return f"[{escaped}]"
+ escaped = ', '.join([cls.translate(v) for v in val])
+ return f'[{escaped}]'
@classmethod
def comment(cls, cmt_str):
- return f"# {cmt_str}".strip()
+ return f'# {cmt_str}'.strip()
class MatlabTranslator(Translator):
@@ -370,8 +359,8 @@ class MatlabTranslator(Translator):
def translate_escaped_str(cls, str_val):
"""Translate a string to an escaped Matlab string"""
if isinstance(str_val, str):
- str_val = str_val.encode("unicode_escape")
- str_val = str_val.decode("utf-8")
+ str_val = str_val.encode('unicode_escape')
+ str_val = str_val.decode('utf-8')
str_val = str_val.replace('"', '""')
return f'"{str_val}"'
@@ -379,35 +368,35 @@ def translate_escaped_str(cls, str_val):
def __translate_char_array(str_val):
"""Translates a string to a Matlab char array"""
if isinstance(str_val, str):
- str_val = str_val.encode("unicode_escape")
- str_val = str_val.decode("utf-8")
+ str_val = str_val.encode('unicode_escape')
+ str_val = str_val.decode('utf-8')
str_val = str_val.replace("'", "''")
return f"'{str_val}'"
@classmethod
def translate_none(cls, val):
- return "NaN"
+ return 'NaN'
@classmethod
def translate_dict(cls, val):
- keys = ", ".join([f"{cls.__translate_char_array(k)}" for k, v in val.items()])
- vals = ", ".join([f"{cls.translate(v)}" for k, v in val.items()])
- return f"containers.Map({{{keys}}}, {{{vals}}})"
+ keys = ', '.join([f'{cls.__translate_char_array(k)}' for k, v in val.items()])
+ vals = ', '.join([f'{cls.translate(v)}' for k, v in val.items()])
+ return f'containers.Map({{{keys}}}, {{{vals}}})'
@classmethod
def translate_list(cls, val):
- escaped = ", ".join([cls.translate(v) for v in val])
- return f"{{{escaped}}}"
+ escaped = ', '.join([cls.translate(v) for v in val])
+ return f'{{{escaped}}}'
@classmethod
def comment(cls, cmt_str):
- return f"% {cmt_str}".strip()
+ return f'% {cmt_str}'.strip()
@classmethod
- def codify(cls, parameters, comment="Parameters"):
- content = f"{cls.comment(comment)}\n"
+ def codify(cls, parameters, comment='Parameters'):
+ content = f'{cls.comment(comment)}\n'
for name, val in parameters.items():
- content += f"{cls.assign(name, cls.translate(val))};\n"
+ content += f'{cls.assign(name, cls.translate(val))};\n'
return content
@@ -415,80 +404,70 @@ class CSharpTranslator(Translator):
@classmethod
def translate_none(cls, val):
# Can't figure out how to do this as nullable
- raise NotImplementedError("Option type not implemented for C#.")
+ raise NotImplementedError('Option type not implemented for C#.')
@classmethod
def translate_bool(cls, val):
- return "true" if val else "false"
+ return 'true' if val else 'false'
@classmethod
def translate_int(cls, val):
strval = cls.translate_raw_str(val)
- return strval + "L" if (val > 2147483647 or val < -2147483648) else strval
+ return strval + 'L' if (val > 2147483647 or val < -2147483648) else strval
@classmethod
def translate_dict(cls, val):
"""Translate dicts to nontyped dictionary"""
- kvps = ", ".join(
- [
- f"{{ {cls.translate_str(k)} , {cls.translate(v)} }}"
- for k, v in val.items()
- ]
- )
- return f"new Dictionary{{ {kvps} }}"
+ kvps = ', '.join([f'{{ {cls.translate_str(k)} , {cls.translate(v)} }}' for k, v in val.items()])
+ return f'new Dictionary{{ {kvps} }}'
@classmethod
def translate_list(cls, val):
"""Translate list to array"""
- escaped = ", ".join([cls.translate(v) for v in val])
- return f"new [] {{ {escaped} }}"
+ escaped = ', '.join([cls.translate(v) for v in val])
+ return f'new [] {{ {escaped} }}'
@classmethod
def comment(cls, cmt_str):
- return f"// {cmt_str}".strip()
+ return f'// {cmt_str}'.strip()
@classmethod
def assign(cls, name, str_val):
- return f"var {name} = {str_val};"
+ return f'var {name} = {str_val};'
class FSharpTranslator(Translator):
@classmethod
def translate_none(cls, val):
- return "None"
+ return 'None'
@classmethod
def translate_bool(cls, val):
- return "true" if val else "false"
+ return 'true' if val else 'false'
@classmethod
def translate_int(cls, val):
strval = cls.translate_raw_str(val)
- return strval + "L" if (val > 2147483647 or val < -2147483648) else strval
+ return strval + 'L' if (val > 2147483647 or val < -2147483648) else strval
@classmethod
def translate_dict(cls, val):
- tuples = "; ".join(
- [
- f"({cls.translate_str(k)}, {cls.translate(v)} :> IComparable)"
- for k, v in val.items()
- ]
- )
- return f"[ {tuples} ] |> Map.ofList"
+ tuples = '; '.join([f'({cls.translate_str(k)}, {cls.translate(v)} :> IComparable)' for k, v in val.items()])
+ return f'[ {tuples} ] |> Map.ofList'
@classmethod
def translate_list(cls, val):
- escaped = "; ".join([cls.translate(v) for v in val])
- return f"[ {escaped} ]"
+ escaped = '; '.join([cls.translate(v) for v in val])
+ return f'[ {escaped} ]'
@classmethod
def comment(cls, cmt_str):
- return f"(* {cmt_str} *)".strip()
+ return f'(* {cmt_str} *)'.strip()
@classmethod
def assign(cls, name, str_val):
- return f"let {name} = {str_val}"
+ return f'let {name} = {str_val}'
class PowershellTranslator(Translator):
@@ -496,8 +475,8 @@ class PowershellTranslator(Translator):
def translate_escaped_str(cls, str_val):
"""Translate a string to an escaped Matlab string"""
if isinstance(str_val, str):
- str_val = str_val.encode("unicode_escape")
- str_val = str_val.decode("utf-8")
+ str_val = str_val.encode('unicode_escape')
+ str_val = str_val.decode('utf-8')
str_val = str_val.replace('"', '`"')
return f'"{str_val}"'
@@ -506,49 +485,47 @@ def translate_float(cls, val):
if math.isfinite(val):
return cls.translate_raw_str(val)
elif math.isnan(val):
- return "[double]::NaN"
+ return '[double]::NaN'
elif val < 0:
- return "[double]::NegativeInfinity"
+ return '[double]::NegativeInfinity'
else:
- return "[double]::PositiveInfinity"
+ return '[double]::PositiveInfinity'
@classmethod
def translate_none(cls, val):
- return "$Null"
+ return '$Null'
@classmethod
def translate_bool(cls, val):
- return "$True" if val else "$False"
+ return '$True' if val else '$False'
@classmethod
def translate_dict(cls, val):
- kvps = "\n ".join(
- [f"{cls.translate_str(k)} = {cls.translate(v)}" for k, v in val.items()]
- )
- return f"@{{{kvps}}}"
+ kvps = '\n '.join([f'{cls.translate_str(k)} = {cls.translate(v)}' for k, v in val.items()])
+ return f'@{{{kvps}}}'
@classmethod
def translate_list(cls, val):
- escaped = ", ".join([cls.translate(v) for v in val])
- return f"@({escaped})"
+ escaped = ', '.join([cls.translate(v) for v in val])
+ return f'@({escaped})'
@classmethod
def comment(cls, cmt_str):
- return f"# {cmt_str}".strip()
+ return f'# {cmt_str}'.strip()
@classmethod
def assign(cls, name, str_val):
- return f"${name} = {str_val}"
+ return f'${name} = {str_val}'
class BashTranslator(Translator):
@classmethod
def translate_none(cls, val):
- return ""
+ return ''
@classmethod
def translate_bool(cls, val):
- return "true" if val else "false"
+ return 'true' if val else 'false'
@classmethod
def translate_escaped_str(cls, str_val):
@@ -556,35 +533,33 @@ def translate_escaped_str(cls, str_val):
@classmethod
def translate_list(cls, val):
- escaped = " ".join([cls.translate(v) for v in val])
- return f"({escaped})"
+ escaped = ' '.join([cls.translate(v) for v in val])
+ return f'({escaped})'
@classmethod
def comment(cls, cmt_str):
- return f"# {cmt_str}".strip()
+ return f'# {cmt_str}'.strip()
@classmethod
def assign(cls, name, str_val):
- return f"{name}={str_val}"
+ return f'{name}={str_val}'
# Instantiate a PapermillIO instance and register Handlers.
papermill_translators = PapermillTranslators()
-papermill_translators.register("python", PythonTranslator)
-papermill_translators.register("R", RTranslator)
-papermill_translators.register("scala", ScalaTranslator)
-papermill_translators.register("julia", JuliaTranslator)
-papermill_translators.register("matlab", MatlabTranslator)
-papermill_translators.register(".net-csharp", CSharpTranslator)
-papermill_translators.register(".net-fsharp", FSharpTranslator)
-papermill_translators.register(".net-powershell", PowershellTranslator)
-papermill_translators.register("pysparkkernel", PythonTranslator)
-papermill_translators.register("sparkkernel", ScalaTranslator)
-papermill_translators.register("sparkrkernel", RTranslator)
-papermill_translators.register("bash", BashTranslator)
-
-
-def translate_parameters(kernel_name, language, parameters, comment="Parameters"):
- return papermill_translators.find_translator(kernel_name, language).codify(
- parameters, comment
- )
+papermill_translators.register('python', PythonTranslator)
+papermill_translators.register('R', RTranslator)
+papermill_translators.register('scala', ScalaTranslator)
+papermill_translators.register('julia', JuliaTranslator)
+papermill_translators.register('matlab', MatlabTranslator)
+papermill_translators.register('.net-csharp', CSharpTranslator)
+papermill_translators.register('.net-fsharp', FSharpTranslator)
+papermill_translators.register('.net-powershell', PowershellTranslator)
+papermill_translators.register('pysparkkernel', PythonTranslator)
+papermill_translators.register('sparkkernel', ScalaTranslator)
+papermill_translators.register('sparkrkernel', RTranslator)
+papermill_translators.register('bash', BashTranslator)
+
+
+def translate_parameters(kernel_name, language, parameters, comment='Parameters'):
+ return papermill_translators.find_translator(kernel_name, language).codify(parameters, comment)
diff --git a/papermill/utils.py b/papermill/utils.py
index 532a5a43..e69b710a 100644
--- a/papermill/utils.py
+++ b/papermill/utils.py
@@ -1,13 +1,12 @@
-import os
import logging
+import os
import warnings
-
from contextlib import contextmanager
from functools import wraps
from .exceptions import PapermillParameterOverwriteWarning
-logger = logging.getLogger("papermill.utils")
+logger = logging.getLogger('papermill.utils')
def any_tagged_cell(nb, tag):
@@ -48,9 +47,9 @@ def nb_kernel_name(nb, name=None):
ValueError
If no kernel name is found or provided
"""
- name = name or nb.metadata.get("kernelspec", {}).get("name")
+ name = name or nb.metadata.get('kernelspec', {}).get('name')
if not name:
- raise ValueError("No kernel name found in notebook and no override provided.")
+ raise ValueError('No kernel name found in notebook and no override provided.')
return name
@@ -74,12 +73,12 @@ def nb_language(nb, language=None):
ValueError
If no notebook language is found or provided
"""
- language = language or nb.metadata.get("language_info", {}).get("name")
+ language = language or nb.metadata.get('language_info', {}).get('name')
if not language:
# v3 language path for old notebooks that didn't convert cleanly
- language = language or nb.metadata.get("kernelspec", {}).get("language")
+ language = language or nb.metadata.get('kernelspec', {}).get('language')
if not language:
- raise ValueError("No language found in notebook and no override provided.")
+ raise ValueError('No language found in notebook and no override provided.')
return language
@@ -128,9 +127,7 @@ def merge_kwargs(caller_args, **callee_args):
"""
conflicts = set(caller_args) & set(callee_args)
if conflicts:
- args = format(
- "; ".join([f"{key}={value}" for key, value in callee_args.items()])
- )
+ args = format('; '.join([f'{key}={value}' for key, value in callee_args.items()]))
msg = f"Callee will overwrite caller's argument(s): {args}"
warnings.warn(msg, PapermillParameterOverwriteWarning)
return dict(caller_args, **callee_args)
@@ -167,7 +164,7 @@ def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
- logger.debug(f"Retrying after: {e}")
+ logger.debug(f'Retrying after: {e}')
exception = e
else:
raise exception
diff --git a/papermill/version.py b/papermill/version.py
index 824cbf24..3d98bc1d 100644
--- a/papermill/version.py
+++ b/papermill/version.py
@@ -1 +1 @@
-version = "2.5.0"
+version = '2.5.0'