tff.learning.models.FunctionalModel
Stay organized with collections
Save and categorize content based on your preferences.
A model that parameterizes forward pass by model weights.
tff.learning.models.FunctionalModel(
*,
initial_weights: ModelWeights,
predict_on_batch_fn: Callable[[ModelWeights, Any, bool], Any],
loss_fn: Callable[[Any, Any, Any], Any],
metrics_fns: tuple[InitializeMetricsStateFn, UpdateMetricsStateFn, FinalizeMetricsFn] = (empty_metrics_state, noop_update_metrics, noop_finalize_metrics),
input_spec: Any
)
Args |
initial_weights
|
A 2-tuple (trainable, non_trainable) where the two
elements are sequences of weights. Weights must be values convertable to
tf.Tensor (e.g. numpy.ndarray , Python sequences, etc), but not
tf.Tensor values.
|
predict_on_batch_fn
|
A tf.function decorated callable that takes three
arguments, model_weights the same structure as initial_weights , x
the first element of batch_input (or input_spec ), and training a
boolean determinig whether the call is during a training pass (e.g. for
Dropout, BatchNormalization, etc). It must return either a tensor of
predictions or a structure whose first element (as determined by
tf.nest.flatten() ) is a tensor of predictions.
|
loss_fn
|
A callable that takes three arguments, output tensor(s) as
output of predict_on_batch that is interpretable by the loss function,
label the second element of batch_input , and optional
sample_weight that weights the output.
|
metrics_fns
|
A 3-tuple of callables that initialize the metrics state,
update the metrics state, and finalize the metrics values respectively.
This can be the result of tff.learning.metrics.create_functional_metric_fns or custom user written
callables.
|
input_spec
|
A 2-tuple of (x, y) where each element is a nested structure
of tf.TensorSpec . x corresponds to batched model inputs that define
the shape and dtype of x to predict_on_batch_fn , while y
corresponds to batched labels for those inputs that define the shape and
dtype of label to loss_fn .
|
Attributes |
initial_weights
|
|
input_spec
|
|
Methods
finalize_metrics
View source
@tf.function
finalize_metrics(
state: types.MetricsState
) -> collections.OrderedDict[str, Any]
initialize_metrics_state
View source
@tf.function
initialize_metrics_state() -> types.MetricsState
loss
View source
loss(
output: Any, label: Any, sample_weight: Optional[Any] = None
) -> float
Returns the loss value based on the model output and the label.
predict_on_batch
View source
@tf.function
predict_on_batch(
model_weights: ModelWeights, x: Any, training: bool = True
)
Returns tensor(s) interpretable by the loss function.
update_metrics_state
View source
@tf.function
update_metrics_state(
state: GenericMetricsState,
labels: Any,
batch_output: tff.learning.models.BatchOutput
,
sample_weight: Optional[Any] = None
) -> GenericMetricsState
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-09-20 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-09-20 UTC."],[],[],null,["# tff.learning.models.FunctionalModel\n\n\u003cbr /\u003e\n\n|-------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/federated/blob/v0.87.0 Version 2.0, January 2004 Licensed under the Apache License, Version 2.0 (the) |\n\nA model that parameterizes forward pass by model weights. \n\n tff.learning.models.FunctionalModel(\n *,\n initial_weights: ModelWeights,\n predict_on_batch_fn: Callable[[ModelWeights, Any, bool], Any],\n loss_fn: Callable[[Any, Any, Any], Any],\n metrics_fns: tuple[InitializeMetricsStateFn, UpdateMetricsStateFn, FinalizeMetricsFn] = (empty_metrics_state, noop_update_metrics, noop_finalize_metrics),\n input_spec: Any\n )\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|-----------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `initial_weights` | A 2-tuple `(trainable, non_trainable)` where the two elements are sequences of weights. Weights must be values convertable to [`tf.Tensor`](https://www.tensorflow.org/api_docs/python/tf/Tensor) (e.g. `numpy.ndarray`, Python sequences, etc), but *not* [`tf.Tensor`](https://www.tensorflow.org/api_docs/python/tf/Tensor) values. |\n| `predict_on_batch_fn` | A [`tf.function`](https://www.tensorflow.org/api_docs/python/tf/function) decorated callable that takes three arguments, `model_weights` the same structure as `initial_weights`, `x` the first element of `batch_input` (or `input_spec`), and `training` a boolean determinig whether the call is during a training pass (e.g. for Dropout, BatchNormalization, etc). It must return either a tensor of predictions or a structure whose first element (as determined by [`tf.nest.flatten()`](https://www.tensorflow.org/api_docs/python/tf/nest/flatten)) is a tensor of predictions. |\n| `loss_fn` | A callable that takes three arguments, `output` tensor(s) as output of `predict_on_batch` that is interpretable by the loss function, `label` the second element of `batch_input`, and optional `sample_weight` that weights the output. |\n| `metrics_fns` | A 3-tuple of callables that initialize the metrics state, update the metrics state, and finalize the metrics values respectively. This can be the result of `tff.learning.metrics.create_functional_metric_fns`or custom user written callables. |\n| `input_spec` | A 2-tuple of `(x, y)` where each element is a nested structure of [`tf.TensorSpec`](https://www.tensorflow.org/api_docs/python/tf/TensorSpec). `x` corresponds to batched model inputs that define the shape and dtype of `x` to `predict_on_batch_fn`, while `y` corresponds to batched labels for those inputs that define the shape and dtype of `label` to `loss_fn`. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Attributes ---------- ||\n|-------------------|---------------|\n| `initial_weights` | \u003cbr /\u003e \u003cbr /\u003e |\n| `input_spec` | \u003cbr /\u003e \u003cbr /\u003e |\n\n\u003cbr /\u003e\n\nMethods\n-------\n\n### `finalize_metrics`\n\n[View source](https://github.com/tensorflow/federated/blob/v0.87.0\nVersion 2.0, January 2004\nLicensed under the Apache License, Version 2.0 (the) \n\n @tf.function\n finalize_metrics(\n state: types.MetricsState\n ) -\u003e collections.OrderedDict[str, Any]\n\n### `initialize_metrics_state`\n\n[View source](https://github.com/tensorflow/federated/blob/v0.87.0\nVersion 2.0, January 2004\nLicensed under the Apache License, Version 2.0 (the) \n\n @tf.function\n initialize_metrics_state() -\u003e types.MetricsState\n\n### `loss`\n\n[View source](https://github.com/tensorflow/federated/blob/v0.87.0\nVersion 2.0, January 2004\nLicensed under the Apache License, Version 2.0 (the) \n\n loss(\n output: Any, label: Any, sample_weight: Optional[Any] = None\n ) -\u003e float\n\nReturns the loss value based on the model output and the label.\n\n### `predict_on_batch`\n\n[View source](https://github.com/tensorflow/federated/blob/v0.87.0\nVersion 2.0, January 2004\nLicensed under the Apache License, Version 2.0 (the) \n\n @tf.function\n predict_on_batch(\n model_weights: ModelWeights, x: Any, training: bool = True\n )\n\nReturns tensor(s) interpretable by the loss function.\n\n### `update_metrics_state`\n\n[View source](https://github.com/tensorflow/federated/blob/v0.87.0\nVersion 2.0, January 2004\nLicensed under the Apache License, Version 2.0 (the) \n\n @tf.function\n update_metrics_state(\n state: GenericMetricsState,\n labels: Any,\n batch_output: ../../../tff/learning/models/BatchOutput,\n sample_weight: Optional[Any] = None\n ) -\u003e GenericMetricsState"]]