tf.train.experimental.PythonState
Stay organized with collections
Save and categorize content based on your preferences.
A mixin for putting Python state in an object-based checkpoint.
This is an abstract class which allows extensions to TensorFlow's object-based
checkpointing (see tf.train.Checkpoint
). For example a wrapper for NumPy
arrays:
import io
import numpy
class NumpyWrapper(tf.train.experimental.PythonState):
def __init__(self, array):
self.array = array
def serialize(self):
string_file = io.BytesIO()
try:
numpy.save(string_file, self.array, allow_pickle=False)
serialized = string_file.getvalue()
finally:
string_file.close()
return serialized
def deserialize(self, string_value):
string_file = io.BytesIO(string_value)
try:
self.array = numpy.load(string_file, allow_pickle=False)
finally:
string_file.close()
Instances of NumpyWrapper
are checkpointable objects, and will be saved and
restored from checkpoints along with TensorFlow state like variables.
root = tf.train.Checkpoint(numpy=NumpyWrapper(numpy.array([1.])))
save_path = root.save(prefix)
root.numpy.array *= 2.
assert [2.] == root.numpy.array
root.restore(save_path)
assert [1.] == root.numpy.array
Methods
deserialize
View source
@abc.abstractmethod
deserialize(
string_value
)
Callback to deserialize the object.
serialize
View source
@abc.abstractmethod
serialize()
Callback to serialize the object. Returns a string.
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2024-04-26 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-04-26 UTC."],[],[],null,["# tf.train.experimental.PythonState\n\n\u003cbr /\u003e\n\n|------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/tensorflow/blob/v2.16.1/tensorflow/python/trackable/python_state.py#L28-L87) |\n\nA mixin for putting Python state in an object-based checkpoint.\n\n#### View aliases\n\n\n**Compat aliases for migration**\n\nSee\n[Migration guide](https://www.tensorflow.org/guide/migrate) for\nmore details.\n\n[`tf.compat.v1.train.experimental.PythonState`](https://www.tensorflow.org/api_docs/python/tf/train/experimental/PythonState)\n\n\u003cbr /\u003e\n\nThis is an abstract class which allows extensions to TensorFlow's object-based\ncheckpointing (see [`tf.train.Checkpoint`](../../../tf/train/Checkpoint)). For example a wrapper for NumPy\narrays: \n\n import io\n import numpy\n\n class NumpyWrapper(tf.train.experimental.PythonState):\n\n def __init__(self, array):\n self.array = array\n\n def serialize(self):\n string_file = io.BytesIO()\n try:\n numpy.save(string_file, self.array, allow_pickle=False)\n serialized = string_file.getvalue()\n finally:\n string_file.close()\n return serialized\n\n def deserialize(self, string_value):\n string_file = io.BytesIO(string_value)\n try:\n self.array = numpy.load(string_file, allow_pickle=False)\n finally:\n string_file.close()\n\nInstances of `NumpyWrapper` are checkpointable objects, and will be saved and\nrestored from checkpoints along with TensorFlow state like variables. \n\n root = tf.train.Checkpoint(numpy=NumpyWrapper(numpy.array([1.])))\n save_path = root.save(prefix)\n root.numpy.array *= 2.\n assert [2.] == root.numpy.array\n root.restore(save_path)\n assert [1.] == root.numpy.array\n\nMethods\n-------\n\n### `deserialize`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.16.1/tensorflow/python/trackable/python_state.py#L79-L81) \n\n @abc.abstractmethod\n deserialize(\n string_value\n )\n\nCallback to deserialize the object.\n\n### `serialize`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.16.1/tensorflow/python/trackable/python_state.py#L75-L77) \n\n @abc.abstractmethod\n serialize()\n\nCallback to serialize the object. Returns a string."]]