class documentation
class TensorFlowModelDataset(AbstractVersionedDataSet[
TensorflowModelDataset loads and saves TensorFlow models. The underlying functionality is supported by, and passes input arguments through to, TensorFlow 2.X load_model and save_model methods.
Example usage for the YAML API:
tensorflow_model: type: tensorflow.TensorFlowModelDataset filepath: data/06_models/tensorflow_model.h5 load_args: compile: False save_args: overwrite: True include_optimizer: False credentials: tf_creds
Example usage for the Python API:
>>> from kedro.extras.datasets.tensorflow import TensorFlowModelDataset >>> import tensorflow as tf >>> import numpy as np >>> >>> data_set = TensorFlowModelDataset("data/06_models/tensorflow_model.h5") >>> model = tf.keras.Model() >>> predictions = model.predict([...]) >>> >>> data_set.save(model) >>> loaded_model = data_set.load() >>> new_predictions = loaded_model.predict([...]) >>> np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6)
Method | __init__ |
Creates a new instance of TensorFlowModelDataset. |
Constant | DEFAULT |
Undocumented |
Constant | DEFAULT |
Undocumented |
Method | _describe |
Undocumented |
Method | _exists |
Undocumented |
Method | _invalidate |
Invalidate underlying filesystem caches. |
Method | _load |
Undocumented |
Method | _release |
Undocumented |
Method | _save |
Undocumented |
Instance Variable | _fs |
Undocumented |
Instance Variable | _is |
Undocumented |
Instance Variable | _load |
Undocumented |
Instance Variable | _protocol |
Undocumented |
Instance Variable | _save |
Undocumented |
Instance Variable | _tmp |
Undocumented |
Inherited from AbstractVersionedDataSet
:
Method | exists |
Checks whether a data set's output already exists by calling the provided _exists() method. |
Method | load |
Loads data by delegation to the provided load method. |
Method | resolve |
Compute the version the dataset should be loaded with. |
Method | resolve |
Compute the version the dataset should be saved with. |
Method | save |
Saves data by delegation to the provided save method. |
Method | _fetch |
Undocumented |
Method | _fetch |
Generate and cache the current save version |
Method | _get |
Undocumented |
Method | _get |
Undocumented |
Method | _get |
Undocumented |
Instance Variable | _exists |
Undocumented |
Instance Variable | _filepath |
Undocumented |
Instance Variable | _glob |
Undocumented |
Instance Variable | _version |
Undocumented |
Instance Variable | _version |
Undocumented |
Inherited from AbstractDataSet
(via AbstractVersionedDataSet
):
Class Method | from |
Create a data set instance using the configuration provided. |
Method | __str__ |
Undocumented |
Method | release |
Release any cached data. |
Method | _copy |
Undocumented |
Property | _logger |
Undocumented |
def __init__(self, filepath:
str
, load_args: Dict[ str, Any]
= None, save_args: Dict[ str, Any]
= None, version: Version
= None, credentials: Dict[ str, Any]
= None, fs_args: Dict[ str, Any]
= None):
(source)
¶
Creates a new instance of TensorFlowModelDataset.
Parameters | |
filepath:str | Filepath in POSIX format to a TensorFlow model directory prefixed with a
protocol like s3:// . If prefix is not provided file protocol (local filesystem)
will be used. The prefix should be any protocol supported by fsspec.
Note: http(s) doesn't support versioning. |
loadDict[ | TensorFlow options for loading models. Here you can find all available arguments: https://www.tensorflow.org/api_docs/python/tf/keras/models/load_model All defaults are preserved. |
saveDict[ | TensorFlow options for saving models. Here you can find all available arguments: https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model All defaults are preserved, except for "save_format", which is set to "tf". |
version:Version | If specified, should be an instance of kedro.io.core.Version. If its load attribute is None, the latest version will be loaded. If its save attribute is None, save version will be autogenerated. |
credentials:Dict[ | Credentials required to get access to the underlying filesystem.
E.g. for GCSFileSystem it should look like {'token': None} . |
fsDict[ | Extra arguments to pass into underlying filesystem class constructor
(e.g. {"project": "my-project"} for GCSFileSystem). |