Note
Go to the end to download the full example code.
torch.export Tutorial#
Created On: Oct 02, 2023 | Last Updated: Jan 27, 2025 | Last Verified: Nov 05, 2024
Author: William Wen, Zhengxu Chen, Angela Yi, Pian Pawakapan
Warning
torch.export
and its related features are in prototype status and are subject to backwards compatibility
breaking changes. This tutorial provides a snapshot of torch.export
usage as of PyTorch 2.5.
torch.export()
is the PyTorch 2.X way to export PyTorch models into
standardized model representations, intended
to be run on different (i.e. Python-less) environments. The official
documentation can be found here.
In this tutorial, you will learn how to use torch.export()
to extract
ExportedProgram
’s (i.e. single-graph representations) from PyTorch programs.
We also detail some considerations/modifications that you may need
to make in order to make your model compatible with torch.export
.
Contents
Basic Usage#
torch.export
extracts single-graph representations from PyTorch programs
by tracing the target function, given example inputs.
torch.export.export()
is the main entry point for torch.export
.
In this tutorial, torch.export
and torch.export.export()
are practically synonymous,
though torch.export
generally refers to the PyTorch 2.X export process, and torch.export.export()
generally refers to the actual function call.
The signature of torch.export.export()
is:
export(
mod: torch.nn.Module,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
*,
dynamic_shapes: Optional[Dict[str, Dict[int, Dim]]] = None
) -> ExportedProgram
torch.export.export()
traces the tensor computation graph from calling mod(*args, **kwargs)
and wraps it in an ExportedProgram
, which can be serialized or executed later with
different inputs. To execute the ExportedProgram
we can call .module()
on it to return a torch.nn.Module
which is callable, just like the
original program.
We will detail the dynamic_shapes
argument later in the tutorial.
import torch
from torch.export import export
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(100, 10)
def forward(self, x, y):
return torch.nn.functional.relu(self.lin(x + y), inplace=True)
mod = MyModule()
exported_mod = export(mod, (torch.randn(8, 100), torch.randn(8, 100)))
print(type(exported_mod))
print(exported_mod.module()(torch.randn(8, 100), torch.randn(8, 100)))
<class 'torch.export.exported_program.ExportedProgram'>
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 1.1344, 0.0000, 0.5624, 0.3435, 0.6017,
2.7909],
[0.0000, 0.0000, 0.0000, 1.0477, 0.0000, 0.0000, 2.0461, 0.0000, 0.7889,
0.0000],
[0.0000, 0.0000, 0.9877, 0.0000, 0.5176, 0.5654, 0.0000, 1.5727, 0.6437,
0.0000],
[0.5942, 0.7034, 0.0564, 0.0000, 0.4003, 0.0000, 0.0000, 0.2568, 0.0000,
0.0000],
[0.0000, 0.2937, 0.3767, 0.5728, 0.0000, 1.2748, 1.0475, 1.4762, 1.1964,
0.0000],
[0.0032, 0.0920, 0.3563, 0.0403, 0.2645, 0.0000, 0.7189, 0.7865, 0.0000,
0.0000],
[0.9222, 0.0000, 0.0495, 0.7542, 1.1252, 0.8088, 0.0000, 0.0000, 0.7890,
0.0707],
[0.0000, 0.5030, 0.0000, 0.2587, 0.8563, 0.6805, 1.0002, 0.9677, 0.0000,
1.3133]], grad_fn=<ReluBackward0>)
Let’s review some attributes of ExportedProgram
that are of interest.
The graph
attribute is an FX graph
traced from the function we exported, that is, the computation graph of all PyTorch operations.
The FX graph is in “ATen IR” meaning that it contains only “ATen-level” operations.
The graph_signature
attribute gives a more detailed description of the
input and output nodes in the exported graph, describing which ones are
parameters, buffers, user inputs, or user outputs.
The range_constraints
attributes will be covered later.
print(exported_mod)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_lin_weight: "f32[10, 100]", p_lin_bias: "f32[10]", x: "f32[8, 100]", y: "f32[8, 100]"):
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:71 in forward, code: return torch.nn.functional.relu(self.lin(x + y), inplace=True)
add: "f32[8, 100]" = torch.ops.aten.add.Tensor(x, y); x = y = None
# File: /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[8, 10]" = torch.ops.aten.linear.default(add, p_lin_weight, p_lin_bias); add = p_lin_weight = p_lin_bias = None
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:71 in forward, code: return torch.nn.functional.relu(self.lin(x + y), inplace=True)
relu_: "f32[8, 10]" = torch.ops.aten.relu_.default(linear); linear = None
return (relu_,)
Graph signature:
# inputs
p_lin_weight: PARAMETER target='lin.weight'
p_lin_bias: PARAMETER target='lin.bias'
x: USER_INPUT
y: USER_INPUT
# outputs
relu_: USER_OUTPUT
Range constraints: {}
See the torch.export
documentation
for more details.
Graph Breaks#
Although torch.export
shares components with torch.compile
,
the key limitation of torch.export
, especially when compared to
torch.compile
, is that it does not support graph breaks. This is because
handling graph breaks involves interpreting the unsupported operation with
default Python evaluation, which is incompatible with the export use case.
Therefore, in order to make your model code compatible with torch.export
,
you will need to modify your code to remove graph breaks.
A graph break is necessary in cases such as:
data-dependent control flow
class Bad1(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return torch.sin(x)
return torch.cos(x)
import traceback as tb
try:
export(Bad1(), (torch.randn(3, 3),))
except Exception:
tb.print_exc()
def forward(self, arg0_1: "f32[3, 3]"):
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:116 in forward, code: if x.sum() > 0:
sum_1: "f32[]" = torch.ops.aten.sum.default(arg0_1); arg0_1 = None
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
ne: "b8[]" = torch.ops.aten.ne.Scalar(gt, 0); gt = None
item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne); ne = item = None
def forward(self, arg0_1: "f32[3, 3]"):
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:116 in forward, code: if x.sum() > 0:
sum_1: "f32[]" = torch.ops.aten.sum.default(arg0_1); arg0_1 = None
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
ne: "b8[]" = torch.ops.aten.ne.Scalar(gt, 0); gt = None
item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne); ne = item = None
Traceback (most recent call last):
File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 122, in <module>
export(Bad1(), (torch.randn(3, 3),))
File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 311, in export
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 277, in export
return _export(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1163, in wrapper
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1129, in wrapper
ep = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2255, in _export
ep = _export_for_training(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1163, in wrapper
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1129, in wrapper
ep = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2071, in _export_for_training
export_artifact = export_func(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2002, in _non_strict_export
aten_export_artifact = _to_aten_func( # type: ignore[operator]
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1793, in _export_to_aten_ir_make_fx
gm, graph_signature = transform(_make_fx_helper)(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1922, in _aot_export_non_strict
gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1706, in _make_fx_helper
gm = make_fx(
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2429, in wrapped
return make_fx_tracer.trace(f, *args)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2356, in trace
return self._trace_inner(f, *args)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2318, in _trace_inner
t = dispatch_trace(
File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 53, in inner
return disable_fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1303, in dispatch_trace
graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1908, in trace
res = super().trace(root, concrete_args)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 868, in trace
(self.create_arg(fn(*args)),),
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1361, in wrapped
out = f(*tensors) # type:ignore[call-arg]
File "<string>", line 1, in <lambda>
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1593, in wrapped_fn
return tuple(flat_fn(*args))
File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 187, in flat_fn
tree_out = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1354, in functional_call
out = mod(*args[params_len:], **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 843, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1997, in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 560, in call_module
ret_val = forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 836, in forward
return _orig_module_call(mod, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1906, in forward
tree_out = mod(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 843, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1997, in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 560, in call_module
ret_val = forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 836, in forward
return _orig_module_call(mod, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 116, in forward
if x.sum() > 0:
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1409, in __torch_function__
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1479, in __torch_function__
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 1066, in __torch_function__
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 538, in guard_bool
r = self.evaluate()
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 512, in evaluate
return self.shape_env.evaluate_sym_node(self, size_oblivious)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7233, in evaluate_sym_node
return self.evaluate_expr(
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7333, in evaluate_expr
return self._inner_evaluate_expr(
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 272, in wrapper
return retlog(fn(*args, **kwargs))
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7356, in _inner_evaluate_expr
return self._evaluate_expr(
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7574, in _evaluate_expr
raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)). (Size-like symbols: none)
consider using data-dependent friendly APIs such as guard_or_false, guard_or_true and statically_known_trueCaused by: (_export/non_strict_utils.py:1066 in __torch_function__)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
The following call raised this error:
File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 116, in forward
if x.sum() > 0:
The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.
accessing tensor data with
.data
class Bad2(torch.nn.Module):
def forward(self, x):
x.data[0, 0] = 3
return x
try:
export(Bad2(), (torch.randn(3, 3),))
except Exception:
tb.print_exc()
calling unsupported functions (such as many built-in functions)
class Bad3(torch.nn.Module):
def forward(self, x):
x = x + 1
return x + id(x)
try:
export(Bad3(), (torch.randn(3, 3),))
except Exception:
tb.print_exc()
Non-Strict Export#
To trace the program, torch.export
uses TorchDynamo by default, a byte
code analysis engine, to symbolically analyze the Python code and build a
graph based on the results. This analysis allows torch.export
to provide
stronger guarantees about safety, but not all Python code is supported,
causing these graph breaks.
To address this issue, in PyTorch 2.3, we introduced a new mode of
exporting called non-strict mode, where we trace through the program using the
Python interpreter executing it exactly as it would in eager mode, allowing us
to skip over unsupported Python features. This is done through adding a
strict=False
flag.
Looking at some of the previous examples which resulted in graph breaks:
Calling unsupported functions (such as many built-in functions) traces
through, but in this case, id(x)
gets specialized as a constant integer in
the graph. This is because id(x)
is not a tensor operation, so the
operation is not recorded in the graph.
class Bad3(torch.nn.Module):
def forward(self, x):
x = x + 1
return x + id(x)
bad3_nonstrict = export(Bad3(), (torch.randn(3, 3),), strict=False)
print(bad3_nonstrict)
print(bad3_nonstrict.module()(torch.ones(3, 3)))
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 3]"):
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:179 in forward, code: x = x + 1
add: "f32[3, 3]" = torch.ops.aten.add.Tensor(x, 1); x = None
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:180 in forward, code: return x + id(x)
add_1: "f32[3, 3]" = torch.ops.aten.add.Tensor(add, 140226861055920); add = None
return (add_1,)
Graph signature:
# inputs
x: USER_INPUT
# outputs
add_1: USER_OUTPUT
Range constraints: {}
tensor([[1.4023e+14, 1.4023e+14, 1.4023e+14],
[1.4023e+14, 1.4023e+14, 1.4023e+14],
[1.4023e+14, 1.4023e+14, 1.4023e+14]])
However, there are still some features that require rewrites to the original module:
Control Flow Ops#
torch.export
actually does support data-dependent control flow.
But these need to be expressed using control flow ops. For example,
we can fix the control flow example above using the cond
op, like so:
class Bad1Fixed(torch.nn.Module):
def forward(self, x):
def true_fn(x):
return torch.sin(x)
def false_fn(x):
return torch.cos(x)
return torch.cond(x.sum() > 0, true_fn, false_fn, [x])
exported_bad1_fixed = export(Bad1Fixed(), (torch.randn(3, 3),))
print(exported_bad1_fixed)
print(exported_bad1_fixed.module()(torch.ones(3, 3)))
print(exported_bad1_fixed.module()(-torch.ones(3, 3)))
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 3]"):
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:205 in forward, code: return torch.cond(x.sum() > 0, true_fn, false_fn, [x])
sum_1: "f32[]" = torch.ops.aten.sum.default(x)
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
# File: <eval_with_key>.35:9 in forward, code: cond = torch.ops.higher_order.cond(l_args_0_, cond_true_0, cond_false_0, (l_args_3_0_,)); l_args_0_ = cond_true_0 = cond_false_0 = l_args_3_0_ = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x,)); gt = true_graph_0 = false_graph_0 = x = None
getitem: "f32[3, 3]" = cond[0]; cond = None
return (getitem,)
class true_graph_0(torch.nn.Module):
def forward(self, x: "f32[3, 3]"):
# File: <eval_with_key>.32:6 in forward, code: sin = torch.sin(l_args_3_0__1); l_args_3_0__1 = None
sin: "f32[3, 3]" = torch.ops.aten.sin.default(x); x = None
return (sin,)
class false_graph_0(torch.nn.Module):
def forward(self, x: "f32[3, 3]"):
# File: <eval_with_key>.33:6 in forward, code: cos = torch.cos(l_args_3_0__1); l_args_3_0__1 = None
cos: "f32[3, 3]" = torch.ops.aten.cos.default(x); x = None
return (cos,)
Graph signature:
# inputs
x: USER_INPUT
# outputs
getitem: USER_OUTPUT
Range constraints: {}
tensor([[0.8415, 0.8415, 0.8415],
[0.8415, 0.8415, 0.8415],
[0.8415, 0.8415, 0.8415]])
tensor([[0.5403, 0.5403, 0.5403],
[0.5403, 0.5403, 0.5403],
[0.5403, 0.5403, 0.5403]])
There are limitations to cond
that one should be aware of:
The predicate (i.e.
x.sum() > 0
) must result in a boolean or a single-element tensor.The operands (i.e.
[x]
) must be tensors.The branch function (i.e.
true_fn
andfalse_fn
) signature must match with the operands and they must both return a single tensor with the same metadata (for example,dtype
,shape
, etc.).Branch functions cannot mutate input or global variables.
Branch functions cannot access closure variables, except for
self
if the function is defined in the scope of a method.
For more details about cond
, check out the cond documentation.
We can also use map
, which applies a function across the first dimension
of the first tensor argument.
from torch._higher_order_ops.map import map as torch_map
class MapModule(torch.nn.Module):
def forward(self, xs, y, z):
def body(x, y, z):
return x + y + z
return torch_map(body, xs, y, z)
inps = (torch.ones(6, 4), torch.tensor(5), torch.tensor(4))
exported_map_example = export(MapModule(), inps)
print(exported_map_example)
print(exported_map_example.module()(*inps))
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, xs: "f32[6, 4]", y: "i64[]", z: "i64[]"):
# File: <eval_with_key>.64:9 in forward, code: map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_flat_xs_0_], [l_flat_args_0_, l_flat_args_1_]); map_body_0 = l_flat_xs_0_ = l_flat_args_0_ = l_flat_args_1_ = None
body_graph_0 = self.body_graph_0
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [xs], [y, z]); body_graph_0 = xs = y = z = None
getitem: "f32[6, 4]" = map_impl[0]; map_impl = None
return (getitem,)
class body_graph_0(torch.nn.Module):
def forward(self, xs: "f32[4]", y: "i64[]", z: "i64[]"):
# File: <eval_with_key>.62:5 in forward, code: add = child + l_flat_args_0_; child = l_flat_args_0_ = None
add: "f32[4]" = torch.ops.aten.add.Tensor(xs, y); xs = y = None
# File: <eval_with_key>.62:6 in forward, code: add_1 = add + l_flat_args_1_; add = l_flat_args_1_ = None
add_1: "f32[4]" = torch.ops.aten.add.Tensor(add, z); add = z = None
return (add_1,)
Graph signature:
# inputs
xs: USER_INPUT
y: USER_INPUT
z: USER_INPUT
# outputs
getitem: USER_OUTPUT
Range constraints: {}
tensor([[10., 10., 10., 10.],
[10., 10., 10., 10.],
[10., 10., 10., 10.],
[10., 10., 10., 10.],
[10., 10., 10., 10.],
[10., 10., 10., 10.]])
Other control flow ops include while_loop
, associative_scan
, and
scan
. For more documentation on each operator, please refer to
this page.
Constraints/Dynamic Shapes#
This section covers dynamic behavior and representation of exported programs. Dynamic behavior is subjective to the particular model being exported, so for the most part of this tutorial, we’ll focus on this particular toy model (with the resulting tensor shapes annotated):
class DynamicModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.l = torch.nn.Linear(5, 3)
def forward(
self,
w: torch.Tensor, # [6, 5]
x: torch.Tensor, # [4]
y: torch.Tensor, # [8, 4]
z: torch.Tensor, # [32]
):
x0 = x + y # [8, 4]
x1 = self.l(w) # [6, 3]
x2 = x0.flatten() # [32]
x3 = x2 + z # [32]
return x1, x3
By default, torch.export
produces a static program. One consequence of this is that at runtime,
the program won’t work on inputs with different shapes, even if they’re valid in eager mode.
w = torch.randn(6, 5)
x = torch.randn(4)
y = torch.randn(8, 4)
z = torch.randn(32)
model = DynamicModel()
ep = export(model, (w, x, y, z))
model(w, x, torch.randn(3, 4), torch.randn(12))
try:
ep.module()(w, x, torch.randn(3, 4), torch.randn(12))
except Exception:
tb.print_exc()
Traceback (most recent call last):
File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 286, in <module>
ep.module()(w, x, torch.randn(3, 4), torch.randn(12))
File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 837, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 413, in __call__
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 400, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1881, in _call_impl
return inner()
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1829, in inner
result = forward_call(*args, **kwargs)
File "<eval_with_key>.95", line 8, in forward
_guards_fn = self._guards_fn(w, x, y, z); _guards_fn = None
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 209, in inner
return func(*args, **kwargs)
File "<string>", line 6, in _
File "/usr/local/lib/python3.10/dist-packages/torch/__init__.py", line 2185, in _assert
assert condition, message
AssertionError: Guard failed: y.size()[0] == 8
Basic concepts: symbols and guards#
To enable dynamism, export()
provides a dynamic_shapes
argument. The easiest way to work with
dynamic shapes is using Dim.AUTO
and looking at the program that’s returned. Dynamic behavior is specified
at a input dimension-level; for each input we can specify a tuple of values:
Before we look at the program that’s produced, let’s understand what specifying dynamic_shapes
entails,
and how that interacts with export. For every input dimension where a Dim
object is specified, a symbol is
allocated,
taking on a range of [2, inf]
(why not [0, inf]
or [1, inf]
? we’ll explain later in the
0/1 specialization section).
Export then runs model tracing, looking at each operation that’s performed by the model. Each individual operation can emit
what’s called “guards”; basically boolean condition that are required to be true for the program to be valid.
When guards involve symbols allocated for input dimensions, the program contains restrictions on what input shapes are valid;
i.e. the program’s dynamic behavior. The symbolic shapes subsystem is the part responsible for taking in all the emitted guards
and producing a final program representation that adheres to all of these guards. Before we see this “final representation” in
an ExportedProgram
, let’s look at the guards emitted by the toy model we’re tracing.
Here, each forward input tensor is annotated with the symbol allocated at the start of tracing:
class DynamicModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.l = torch.nn.Linear(5, 3)
def forward(
self,
w: torch.Tensor, # [s0, s1]
x: torch.Tensor, # [s2]
y: torch.Tensor, # [s3, s4]
z: torch.Tensor, # [s5]
):
x0 = x + y # guard: s2 == s4
x1 = self.l(w) # guard: s1 == 5
x2 = x0.flatten() # no guard added here
x3 = x2 + z # guard: s3 * s4 == s5
return x1, x3
Let’s understand each of the operations and the emitted guards:
x0 = x + y
: This is an element-wise add with broadcasting, sincex
is a 1-d tensor andy
a 2-d tensor.x
is broadcasted along the last dimension ofy
, emitting the guards2 == s4
.x1 = self.l(w)
: Callingnn.Linear()
performs a matrix multiplication with model parameters. In export, parameters, buffers, and constants are considered program state, which is considered static, and so this is a matmul between a dynamic input (w: [s0, s1]
), and a statically-shaped tensor. This emits the guards1 == 5
.x2 = x0.flatten()
: This call actually doesn’t emit any guards! (at least none relevant to input shapes)x3 = x2 + z
:x2
has shape[s3*s4]
after flattening, and this element-wise add emitss3 * s4 == s5
.
Writing all of these guards down and summarizing is almost like a mathematical proof, which is what the symbolic shapes subsystem tries to do! In summary, we can conclude that the program must have the following input shapes to be valid:
w: [s0, 5]
x: [s2]
y: [s3, s2]
z: [s2*s3]
And when we do finally print out the exported program to see our result, those shapes are what we see annotated on the corresponding inputs:
print(ep)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_l_weight: "f32[3, 5]", p_l_bias: "f32[3]", w: "f32[s15, 5]", x: "f32[s77]", y: "f32[s17, s77]", z: "f32[s17*s77]"):
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:268 in forward, code: x0 = x + y # [8, 4]
add: "f32[s17, s77]" = torch.ops.aten.add.Tensor(x, y); x = y = None
# File: /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[s15, 3]" = torch.ops.aten.linear.default(w, p_l_weight, p_l_bias); w = p_l_weight = p_l_bias = None
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:270 in forward, code: x2 = x0.flatten() # [32]
flatten: "f32[s17*s77]" = torch.ops.aten.flatten.using_ints(add); add = None
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:271 in forward, code: x3 = x2 + z # [32]
add_1: "f32[s17*s77]" = torch.ops.aten.add.Tensor(flatten, z); flatten = z = None
return (linear, add_1)
Graph signature:
# inputs
p_l_weight: PARAMETER target='l.weight'
p_l_bias: PARAMETER target='l.bias'
w: USER_INPUT
x: USER_INPUT
y: USER_INPUT
z: USER_INPUT
# outputs
linear: USER_OUTPUT
add_1: USER_OUTPUT
Range constraints: {s15: VR[2, int_oo], s77: VR[2, int_oo], s17: VR[2, int_oo], s17*s77: VR[4, int_oo]}
Another feature to notice is the range_constraints field above, which contains a valid range for each symbol. This isn’t so interesting currently, since this export call doesn’t emit any guards related to symbol bounds and each base symbol has a generic bound, but this will come up later.
So far, because we’ve been exporting this toy model, this experience has not been representative of how hard it typically is to debug dynamic shapes guards & issues. In most cases it isn’t obvious what guards are being emitted, and which operations and parts of user code are responsible. For this toy model we pinpoint the exact lines, and the guards are rather intuitive.
In more complicated cases, a helpful first step is always to enable verbose logging. This can be done either with the environment
variable TORCH_LOGS="+dynamic"
, or interactively with torch._logging.set_logs(dynamic=10)
:
I1015 19:14:15.352000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:15.354000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s15 = 6 for L['w'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s15" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.354000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s21 = 5 for L['w'].size()[1] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s21" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.356000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s77 = 4 for L['x'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s77" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.358000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s17 = 8 for L['y'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s17" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.358000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s94 = 4 for L['y'].size()[1] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s94" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.360000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s68 = 32 for L['z'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s68" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V1015 19:14:15.366000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.367000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.367000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.369000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.369000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.371000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.371000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.372000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.373000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.374000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
I1015 19:14:15.376000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s77, s94) [guard added] (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s77, s94)"
I1015 19:14:15.377000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s94 = s77 (solve) VR[2, int_oo]
V1015 19:14:15.379000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
I1015 19:14:15.385000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s21, 5) [guard added] (_meta_registrations.py:2248 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s21, 5)"
V1015 19:14:15.386000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s21 = VR[5, 5] (update)
I1015 19:14:15.386000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s21 = 5 (range_refined_to_singleton) VR[5, 5]
V1015 19:14:15.398000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.401000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
I1015 19:14:15.403000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s17*s77, s68) [guard added] (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s17*s77, s68)"
V1015 19:14:15.405000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s68 = VR[4, int_oo] (update)
I1015 19:14:15.405000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s68 = s17*s77 (solve) VR[4, int_oo]
I1015 19:14:15.410000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:15.411000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].size()[0] s15 None
V1015 19:14:15.411000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].size()[1] 5 None
V1015 19:14:15.412000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].stride()[0] 5 None
V1015 19:14:15.412000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].stride()[1] 1 None
V1015 19:14:15.412000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].storage_offset() 0 None
V1015 19:14:15.412000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[0] s77 None
V1015 19:14:15.413000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[0] 1 None
V1015 19:14:15.413000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].storage_offset() 0 None
V1015 19:14:15.413000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[0] s17 None
V1015 19:14:15.413000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[1] s77 None
V1015 19:14:15.414000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[0] s77 None
V1015 19:14:15.414000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[1] 1 None
V1015 19:14:15.414000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].storage_offset() 0 None
V1015 19:14:15.415000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['z'].size()[0] s17*s77 None
V1015 19:14:15.415000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['z'].stride()[0] 1 None
V1015 19:14:15.415000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['z'].storage_offset() 0 None
V1015 19:14:15.428000 18276 torch/fx/experimental/symbolic_shapes.py:7471] eval 5 [trivial]
This spits out quite a handful, even with this simple toy model. The log lines here have been cut short at front and end to ignore unnecessary info, but looking through the logs we can see the lines relevant to what we described above; e.g. the allocation of symbols:
"""
create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
runtime_assert True == True [statically known]
create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
"""
"\ncreate_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\nruntime_assert True == True [statically known]\ncreate_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\n"
The lines with create_symbol show when a new symbol has been allocated, and the logs also identify the tensor variable names and dimensions they’ve been allocated for. In other lines we can also see the guards emitted:
"""
runtime_assert Eq(s2, s4) [guard added] x0 = x + y # output shape: [8, 4] # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"
runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w) # [6, 3] # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z # [32] # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"
"""
'\nruntime_assert Eq(s2, s4) [guard added] x0 = x + y # output shape: [8, 4] # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"\nruntime_assert Eq(s1, 5) [guard added] x1 = self.l(w) # [6, 3] # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"\nruntime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z # [32] # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"\n'
Next to the [guard added]
messages, we also see the responsible user lines of code - luckily here the model is simple enough.
In many real-world cases it’s not so straightforward: high-level torch operations can have complicated fake-kernel implementations
or operator decompositions that complicate where and what guards are emitted. In such cases the best way to dig deeper and investigate
is to follow the logs’ suggestion, and re-run with environment variable TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="..."
, to further
attribute the guard of interest.
Dim.AUTO
is just one of the available options for interacting with dynamic_shapes
; as of writing this 2 other options are available:
Dim.DYNAMIC
, and Dim.STATIC
. Dim.STATIC
simply marks a dimension static, while Dim.DYNAMIC
is similar to Dim.AUTO
in all
ways except one: it raises an error when specializing to a constant; this is designed to maintain dynamism. See for example what happens when a
static guard is emitted on a dynamically-marked dimension:
I1015 19:14:15.432000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:15.434000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s15 = 6 for L['w'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s15" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.434000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s21 = 5 for L['w'].size()[1] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s21" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.436000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s77 = 4 for L['x'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s77" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.438000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s17 = 8 for L['y'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s17" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.438000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s94 = 4 for L['y'].size()[1] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s94" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.440000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s68 = 32 for L['z'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s68" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V1015 19:14:15.446000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.447000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.448000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.449000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.449000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.451000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.451000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.452000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.453000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.454000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
I1015 19:14:15.456000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s77, s94) [guard added] (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s77, s94)"
I1015 19:14:15.457000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s94 = s77 (solve) VR[2, int_oo]
V1015 19:14:15.459000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
I1015 19:14:15.465000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s21, 5) [guard added] (_meta_registrations.py:2248 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s21, 5)"
V1015 19:14:15.466000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s21 = VR[5, 5] (update)
I1015 19:14:15.466000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s21 = 5 (range_refined_to_singleton) VR[5, 5]
V1015 19:14:15.478000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.481000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
I1015 19:14:15.483000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s17*s77, s68) [guard added] (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s17*s77, s68)"
V1015 19:14:15.485000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s68 = VR[4, int_oo] (update)
I1015 19:14:15.485000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s68 = s17*s77 (solve) VR[4, int_oo]
I1015 19:14:15.490000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:15.491000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].size()[0] s15 None
V1015 19:14:15.491000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].size()[1] 5 RelaxedUnspecConstraint(warn_only=False)
V1015 19:14:15.491000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].stride()[0] 5 None
V1015 19:14:15.492000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].stride()[1] 1 None
V1015 19:14:15.492000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].storage_offset() 0 None
V1015 19:14:15.492000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[0] s77 None
V1015 19:14:15.493000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[0] 1 None
V1015 19:14:15.493000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].storage_offset() 0 None
V1015 19:14:15.493000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[0] s17 None
V1015 19:14:15.493000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[1] s77 None
V1015 19:14:15.494000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[0] s77 None
V1015 19:14:15.494000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[1] 1 None
V1015 19:14:15.494000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].storage_offset() 0 None
V1015 19:14:15.494000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['z'].size()[0] s17*s77 None
V1015 19:14:15.495000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['z'].stride()[0] 1 None
V1015 19:14:15.495000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['z'].storage_offset() 0 None
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1821, in _export_to_aten_ir_make_fx
produce_guards_callback(gm)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1968, in _produce_guards_callback
return produce_guards_and_solve_constraints(
File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 533, in produce_guards_and_solve_constraints
raise constraint_violation_error
File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 500, in produce_guards_and_solve_constraints
shape_env.produce_guards(
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5204, in produce_guards
return self.produce_guards_verbose(*args, **kwargs, langs=("python",))[0].exprs
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5928, in produce_guards_verbose
raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['w'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
- You marked L['w'].size()[1] as dynamic but your code specialized it to be a constant (5). If you're using mark_dynamic, either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 418, in <module>
export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 311, in export
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 277, in export
return _export(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1163, in wrapper
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1129, in wrapper
ep = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2255, in _export
ep = _export_for_training(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1163, in wrapper
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1129, in wrapper
ep = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2071, in _export_for_training
export_artifact = export_func(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2002, in _non_strict_export
aten_export_artifact = _to_aten_func( # type: ignore[operator]
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1823, in _export_to_aten_ir_make_fx
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904
torch._dynamo.exc.UserError: Constraints violated (L['w'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
- You marked L['w'].size()[1] as dynamic but your code specialized it to be a constant (5). If you're using mark_dynamic, either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.
The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.
Static guards also aren’t always inherent to the model; they can also come from user specifications. In fact, a common pitfall leading to shape
specializations is when the user specifies conflicting markers for equivalent dimensions; one dynamic and another static. The same error type is
raised when this is the case for x.shape[0]
and y.shape[1]
:
I1015 19:14:15.507000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:15.508000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s15 = 6 for L['w'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s15" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.508000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s21 = 5 for L['w'].size()[1] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s21" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.511000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s17 = 8 for L['y'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s17" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.512000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s94 = 4 for L['y'].size()[1] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s94" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.514000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s68 = 32 for L['z'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s68" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V1015 19:14:15.520000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.520000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.521000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.523000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.523000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.524000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.525000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.526000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
I1015 19:14:15.531000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s94, 4) [guard added] (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s94, 4)"
V1015 19:14:15.531000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s94 = VR[4, 4] (update)
I1015 19:14:15.532000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s94 = 4 (range_refined_to_singleton) VR[4, 4]
I1015 19:14:15.539000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s21, 5) [guard added] (_meta_registrations.py:2248 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s21, 5)"
V1015 19:14:15.540000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s21 = VR[5, 5] (update)
I1015 19:14:15.541000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s21 = 5 (range_refined_to_singleton) VR[5, 5]
I1015 19:14:15.560000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(4*s17, s68) [guard added] (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(4*s17, s68)"
V1015 19:14:15.564000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s68 = VR[8, int_oo] (update)
I1015 19:14:15.565000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s68 = 4*s17 (solve) VR[8, int_oo]
I1015 19:14:15.570000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:15.570000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].size()[0] s15 None
V1015 19:14:15.571000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].size()[1] 5 None
V1015 19:14:15.571000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].stride()[0] 5 None
V1015 19:14:15.571000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].stride()[1] 1 None
V1015 19:14:15.572000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].storage_offset() 0 None
V1015 19:14:15.572000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[0] 4 None
V1015 19:14:15.572000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[0] 1 None
V1015 19:14:15.573000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].storage_offset() 0 None
V1015 19:14:15.573000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[0] s17 None
V1015 19:14:15.573000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[1] 4 RelaxedUnspecConstraint(warn_only=False)
V1015 19:14:15.574000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[0] 4 None
V1015 19:14:15.574000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[1] 1 None
V1015 19:14:15.574000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].storage_offset() 0 None
V1015 19:14:15.574000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['z'].size()[0] 4*s17 None
V1015 19:14:15.575000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['z'].stride()[0] 1 None
V1015 19:14:15.575000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['z'].storage_offset() 0 None
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1821, in _export_to_aten_ir_make_fx
produce_guards_callback(gm)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1968, in _produce_guards_callback
return produce_guards_and_solve_constraints(
File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 533, in produce_guards_and_solve_constraints
raise constraint_violation_error
File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 500, in produce_guards_and_solve_constraints
shape_env.produce_guards(
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5204, in produce_guards
return self.produce_guards_verbose(*args, **kwargs, langs=("python",))[0].exprs
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5928, in produce_guards_verbose
raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['y'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
- You marked L['y'].size()[1] as dynamic but your code specialized it to be a constant (4). If you're using mark_dynamic, either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 431, in <module>
export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 311, in export
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 277, in export
return _export(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1163, in wrapper
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1129, in wrapper
ep = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2255, in _export
ep = _export_for_training(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1163, in wrapper
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1129, in wrapper
ep = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2071, in _export_for_training
export_artifact = export_func(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2002, in _non_strict_export
aten_export_artifact = _to_aten_func( # type: ignore[operator]
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1823, in _export_to_aten_ir_make_fx
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904
torch._dynamo.exc.UserError: Constraints violated (L['y'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
- You marked L['y'].size()[1] as dynamic but your code specialized it to be a constant (4). If you're using mark_dynamic, either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.
The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.
Here you might ask why export “specializes”, i.e. why we resolve this static/dynamic conflict by going with the static route. The answer is because
of the symbolic shapes system described above, of symbols and guards. When x.shape[0]
is marked static, we don’t allocate a symbol, and compile
treating this shape as a concrete integer 4. A symbol is allocated for y.shape[1]
, and so we finally emit the guard s3 == 4
, leading to
specialization.
One feature of export is that during tracing, statements like asserts, torch._check()
, and if/else
conditions will also emit guards.
See what happens when we augment the existing model with such statements:
class DynamicModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.l = torch.nn.Linear(5, 3)
def forward(self, w, x, y, z):
assert w.shape[0] <= 512
torch._check(x.shape[0] >= 4)
if w.shape[0] == x.shape[0] + 2:
x0 = x + y
x1 = self.l(w)
x2 = x0.flatten()
x3 = x2 + z
return x1, x3
else:
return w
dynamic_shapes = {
"w": (Dim.AUTO, Dim.AUTO),
"x": (Dim.AUTO,),
"y": (Dim.AUTO, Dim.AUTO),
"z": (Dim.AUTO,),
}
try:
ep = export(DynamicModel(), (w, x, y, z), dynamic_shapes=dynamic_shapes)
except Exception:
tb.print_exc()
I1015 19:14:15.585000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:15.587000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s15 = 6 for L['w'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s15" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.587000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s21 = 5 for L['w'].size()[1] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s21" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.589000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s77 = 4 for L['x'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s77" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.591000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s17 = 8 for L['y'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s17" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.591000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s94 = 4 for L['y'].size()[1] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s94" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.593000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s68 = 32 for L['z'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s68" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V1015 19:14:15.599000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.600000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.601000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.602000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.603000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.604000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.604000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.605000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.606000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.607000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
I1015 19:14:15.613000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval s15 <= 512 [guard added] (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:450 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="s15 <= 512"
V1015 19:14:15.613000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s15 = VR[2, 512] (update)
I1015 19:14:15.617000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval s77 >= 4 [guard added] (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:451 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="s77 >= 4"
V1015 19:14:15.617000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s77 = VR[4, int_oo] (update)
I1015 19:14:15.622000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s15, s77 + 2) [guard added] (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:452 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s15, s77 + 2)"
V1015 19:14:15.624000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s77 = VR[4, 510] (update)
V1015 19:14:15.625000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s15 = VR[6, 512] (update)
I1015 19:14:15.626000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s15 = s77 + 2 (solve) VR[6, 512]
I1015 19:14:15.630000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s77, s94) [guard added] (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s77, s94)"
V1015 19:14:15.631000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s94 = VR[4, 510] (update)
I1015 19:14:15.631000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s94 = s77 (solve) VR[4, 510]
V1015 19:14:15.635000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
I1015 19:14:15.642000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s21, 5) [guard added] (_meta_registrations.py:2248 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s21, 5)"
V1015 19:14:15.643000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s21 = VR[5, 5] (update)
I1015 19:14:15.644000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s21 = 5 (range_refined_to_singleton) VR[5, 5]
V1015 19:14:15.658000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.661000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
I1015 19:14:15.669000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s17*s77, s68) [guard added] (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s17*s77, s68)"
V1015 19:14:15.670000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s68 = VR[8, int_oo] (update)
I1015 19:14:15.671000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s68 = s17*s77 (solve) VR[8, int_oo]
I1015 19:14:15.676000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:15.677000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].size()[0] s77 + 2 None
V1015 19:14:15.677000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].size()[1] 5 None
V1015 19:14:15.678000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].stride()[0] 5 None
V1015 19:14:15.678000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].stride()[1] 1 None
V1015 19:14:15.678000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].storage_offset() 0 None
V1015 19:14:15.679000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[0] s77 None
V1015 19:14:15.679000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[0] 1 None
V1015 19:14:15.679000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].storage_offset() 0 None
V1015 19:14:15.680000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[0] s17 None
V1015 19:14:15.680000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[1] s77 None
V1015 19:14:15.680000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[0] s77 None
V1015 19:14:15.681000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[1] 1 None
V1015 19:14:15.681000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].storage_offset() 0 None
V1015 19:14:15.681000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['z'].size()[0] s17*s77 None
V1015 19:14:15.682000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['z'].stride()[0] 1 None
V1015 19:14:15.682000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['z'].storage_offset() 0 None
V1015 19:14:15.698000 18276 torch/fx/experimental/symbolic_shapes.py:7471] eval 5 [trivial]
Each of these statements emits an additional guard, and the exported program shows the changes; s0
is eliminated in favor of s2 + 2
,
and s2
now contains lower and upper bounds, reflected in range_constraints
.
For the if/else condition, you might ask why the True branch was taken, and why it wasn’t the w.shape[0] != x.shape[0] + 2
guard that
got emitted from tracing. The answer is that export is guided by the sample inputs provided by tracing, and specializes on the branches taken.
If different sample input shapes were provided that fail the if
condition, export would trace and emit guards corresponding to the else
branch.
Additionally, you might ask why we traced only the if
branch, and if it’s possible to maintain control-flow in your program and keep both branches
alive. For that, refer to rewriting your model code following the Control Flow Ops
section above.
0/1 specialization#
Since we’re talking about guards and specializations, it’s a good time to talk about the 0/1 specialization issue we brought up earlier. The bottom line is that export will specialize on sample input dimensions with value 0 or 1, because these shapes have trace-time properties that don’t generalize to other shapes. For example, size 1 tensors can broadcast while other sizes fail; and size 0 … . This just means that you should specify 0/1 sample inputs when you’d like your program to hardcode them, and non-0/1 sample inputs when dynamic behavior is desirable. See what happens at runtime when we export this linear layer:
ep = export(
torch.nn.Linear(4, 3),
(torch.randn(1, 4),),
dynamic_shapes={
"input": (Dim.AUTO, Dim.STATIC),
},
)
try:
ep.module()(torch.randn(2, 4))
except Exception:
tb.print_exc()
I1015 19:14:15.703000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:15.715000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:15.716000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['input'].size()[0] 1 None
V1015 19:14:15.716000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['input'].size()[1] 4 None
V1015 19:14:15.716000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['input'].stride()[0] 4 None
V1015 19:14:15.716000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['input'].stride()[1] 1 None
V1015 19:14:15.717000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['input'].storage_offset() 0 None
W1015 19:14:15.719000 18276 torch/_export/non_strict_utils.py:564] dimension inputs['input'].shape[0] 0/1 specialized; Dim.AUTO was specified along with a sample input with hint = 1.
Traceback (most recent call last):
File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 500, in <module>
ep.module()(torch.randn(2, 4))
File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 837, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 413, in __call__
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 400, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1881, in _call_impl
return inner()
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1829, in inner
result = forward_call(*args, **kwargs)
File "<eval_with_key>.125", line 9, in forward
_guards_fn = self._guards_fn(input_1); _guards_fn = None
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 209, in inner
return func(*args, **kwargs)
File "<string>", line 3, in _
File "/usr/local/lib/python3.10/dist-packages/torch/__init__.py", line 2185, in _assert
assert condition, message
AssertionError: Guard failed: input.size()[0] == 1
Named Dims#
So far we’ve only been talking about 3 ways to specify dynamic shapes: Dim.AUTO
, Dim.DYNAMIC
, and Dim.STATIC
. The attraction of these is the
low-friction user experience; all the guards emitted during model tracing are adhered to, and dynamic behavior like min/max ranges, relations, and static/dynamic
dimensions are automatically figured out underneath export. The dynamic shapes subsystem essentially acts as a “discovery” process, summarizing these guards
and presenting what export believes is the overall dynamic behavior of the program. The drawback of this design appears once the user has stronger expectations or
beliefs about the dynamic behavior of these models - maybe there is a strong desire on dynamism and specializations on particular dimensions are to be avoided at
all costs, or maybe we just want to catch changes in dynamic behavior with changes to the original model code, or possibly underlying decompositions or meta-kernels.
These changes won’t be detected and the export()
call will most likely succeed, unless tests are in place that check the resulting ExportedProgram
representation.
For such cases, our stance is to recommend the “traditional” way of specifying dynamic shapes, which longer-term users of export might be familiar with: named Dims
:
This style of dynamic shapes allows the user to specify what symbols are allocated for input dimensions, min/max bounds on those symbols, and places restrictions on the
dynamic behavior of the ExportedProgram
produced; ConstraintViolation
errors will be raised if model tracing emits guards that conflict with the relations or static/dynamic
specifications given. For example, in the above specification, the following is asserted:
x.shape[0]
is to have range[4, 256]
, and related toy.shape[0]
byy.shape[0] == 2 * x.shape[0]
.x.shape[1]
is static.y.shape[1]
has range[2, 512]
, and is unrelated to any other dimension.
In this design, we allow relations between dimensions to be specified with univariate linear expressions: A * dim + B
can be specified for any dimension. This allows users
to specify more complex constraints like integer divisibility for dynamic dimensions:
Constraint violations, suggested fixes#
One common issue with this specification style (before Dim.AUTO
was introduced), is that the specification would often be mismatched with what was produced by model tracing.
That would lead to ConstraintViolation
errors and export suggested fixes - see for example with this model & specification, where the model inherently requires equality between
dimensions 0 of x
and y
, and requires dimension 1 to be static.
class Foo(torch.nn.Module):
def forward(self, x, y):
w = x + y
return w + torch.ones(4)
dx, dy, d1 = torch.export.dims("dx", "dy", "d1")
try:
ep = export(
Foo(),
(torch.randn(6, 4), torch.randn(6, 4)),
dynamic_shapes={
"x": (dx, d1),
"y": (dy, d1),
},
)
except Exception:
tb.print_exc()
I1015 19:14:15.728000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:15.730000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s77 = 6 for L['x'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s77" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.732000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s27 = 4 for L['x'].size()[1] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s27" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.734000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s17 = 6 for L['y'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s17" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.735000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s94 = 4 for L['y'].size()[1] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s94" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V1015 19:14:15.741000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.743000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.744000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.745000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.746000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.747000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
I1015 19:14:15.751000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s27, s94) [guard added] (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s27, s94)"
I1015 19:14:15.752000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s94 = s27 (solve) VR[2, int_oo]
I1015 19:14:15.754000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s77, s17) [guard added] (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s77, s17)"
I1015 19:14:15.755000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s77 = s17 (solve) VR[2, int_oo]
V1015 19:14:15.757000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
I1015 19:14:15.767000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s27, 4) [guard added] (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s27, 4)"
V1015 19:14:15.768000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s27 = VR[4, 4] (update)
I1015 19:14:15.768000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s27 = 4 (range_refined_to_singleton) VR[4, 4]
I1015 19:14:15.774000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:15.774000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s94 = VR[4, 4] (update)
I1015 19:14:15.775000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s94 = 4 (find) VR[4, 4]
V1015 19:14:15.775000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[0] s17 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V1015 19:14:15.776000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[1] 4 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V1015 19:14:15.776000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[0] 4 None
V1015 19:14:15.776000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[1] 1 None
V1015 19:14:15.777000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].storage_offset() 0 None
V1015 19:14:15.777000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[0] s17 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V1015 19:14:15.777000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[1] 4 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V1015 19:14:15.778000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[0] 4 None
V1015 19:14:15.778000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[1] 1 None
V1015 19:14:15.778000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].storage_offset() 0 None
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1821, in _export_to_aten_ir_make_fx
produce_guards_callback(gm)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1968, in _produce_guards_callback
return produce_guards_and_solve_constraints(
File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 533, in produce_guards_and_solve_constraints
raise constraint_violation_error
File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 500, in produce_guards_and_solve_constraints
shape_env.produce_guards(
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5204, in produce_guards
return self.produce_guards_verbose(*args, **kwargs, langs=("python",))[0].exprs
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5928, in produce_guards_verbose
raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (d1, dy)! For more information, run with TORCH_LOGS="+dynamic".
- You marked d1 as dynamic but your code specialized it to be a constant (4). If you're using mark_dynamic, either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.
- You marked d1 as dynamic but your code specialized it to be a constant (4). If you're using mark_dynamic, either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.
- The values of dy = L['y'].size()[0] and dx = L['x'].size()[0] must always be equal.
Suggested fixes:
d1 = 4
dy = dx
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 557, in <module>
ep = export(
File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 311, in export
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 277, in export
return _export(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1163, in wrapper
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1129, in wrapper
ep = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2255, in _export
ep = _export_for_training(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1163, in wrapper
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1129, in wrapper
ep = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2071, in _export_for_training
export_artifact = export_func(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2002, in _non_strict_export
aten_export_artifact = _to_aten_func( # type: ignore[operator]
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1823, in _export_to_aten_ir_make_fx
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904
torch._dynamo.exc.UserError: Constraints violated (d1, dy)! For more information, run with TORCH_LOGS="+dynamic".
- You marked d1 as dynamic but your code specialized it to be a constant (4). If you're using mark_dynamic, either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.
- You marked d1 as dynamic but your code specialized it to be a constant (4). If you're using mark_dynamic, either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.
- The values of dy = L['y'].size()[0] and dx = L['x'].size()[0] must always be equal.
Suggested fixes:
d1 = 4
dy = dx
The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.
The expectation with suggested fixes is that the user can interactively copy-paste the changes into their dynamic shapes specification, and successfully export afterwards.
Lastly, there’s couple nice-to-knows about the options for specification:
None
is a good option for static behavior: -dynamic_shapes=None
(default) exports with the entire model being static. - specifyingNone
at an input-level exports with all tensor dimensions static, and is also required for non-tensor inputs. - specifyingNone
at a dimension-level specializes that dimension, though this is deprecated in favor ofDim.STATIC
.specifying per-dimension integer values also produces static behavior, and will additionally check that the provided sample input matches the specification.
These options are combined in the inputs & dynamic shapes spec below:
inputs = (
torch.randn(4, 4),
torch.randn(3, 3),
16,
False,
)
dynamic_shapes = {
"tensor_0": (Dim.AUTO, None),
"tensor_1": None,
"int_val": None,
"bool_val": None,
}
Data-dependent errors#
While trying to export models, you have may have encountered errors like “Could not guard on data-dependent expression”, or Could not extract specialized integer from data-dependent expression”.
These errors exist because torch.export()
compiles programs using FakeTensors, which symbolically represent their real tensor counterparts. While these have equivalent symbolic properties
(e.g. sizes, strides, dtypes), they diverge in that FakeTensors do not contain any data values. While this avoids unnecessary memory usage and expensive computation, it does mean that export may be
unable to out-of-the-box compile parts of user code where compilation relies on data values. In short, if the compiler requires a concrete, data-dependent value in order to proceed, it will error out,
complaining that the value is not available.
Data-dependent values appear in many places, and common sources are calls like item()
, tolist()
, or torch.unbind()
that extract scalar values from tensors.
How are these values represented in the exported program? In the Constraints/Dynamic Shapes
section, we talked about allocating symbols to represent dynamic input dimensions.
The same happens here: we allocate symbols for every data-dependent value that appears in the program. The important distinction is that these are “unbacked” symbols,
in contrast to the “backed” symbols allocated for input dimensions. The “backed/unbacked”
nomenclature refers to the presence/absence of a “hint” for the symbol: a concrete value backing the symbol, that can inform the compiler on how to proceed.
In the input shape symbol case (backed symbols), these hints are simply the sample input shapes provided, which explains why control-flow branching is determined by the sample input properties. For data-dependent values, the symbols are taken from FakeTensor “data” during tracing, and so the compiler doesn’t know the actual values (hints) that these symbols would take on.
Let’s see how these show up in exported programs:
class Foo(torch.nn.Module):
def forward(self, x, y):
a = x.item()
b = y.tolist()
return b + [a]
inps = (
torch.tensor(1),
torch.tensor([2, 3]),
)
ep = export(Foo(), inps)
print(ep)
I1015 19:14:15.787000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:15.794000 18276 torch/fx/experimental/symbolic_shapes.py:4780] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:651 in local_scalar_dense)
I1015 19:14:15.795000 18276 torch/fx/experimental/symbolic_shapes.py:1289] compute_unbacked_bindings [u0]
I1015 19:14:15.798000 18276 torch/fx/experimental/symbolic_shapes.py:4780] create_unbacked_symint u1 [-int_oo, int_oo] (_subclasses/fake_impls.py:651 in local_scalar_dense)
I1015 19:14:15.799000 18276 torch/fx/experimental/symbolic_shapes.py:1289] compute_unbacked_bindings [u1]
I1015 19:14:15.800000 18276 torch/fx/experimental/symbolic_shapes.py:4780] create_unbacked_symint u2 [-int_oo, int_oo] (_subclasses/fake_impls.py:651 in local_scalar_dense)
I1015 19:14:15.800000 18276 torch/fx/experimental/symbolic_shapes.py:1289] compute_unbacked_bindings [u2]
I1015 19:14:15.802000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:15.803000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].storage_offset() 0 None
V1015 19:14:15.803000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[0] 2 None
V1015 19:14:15.803000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[0] 1 None
V1015 19:14:15.804000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].storage_offset() 0 None
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "i64[]", y: "i64[2]"):
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:618 in forward, code: a = x.item()
item: "Sym(u0)" = torch.ops.aten.item.default(x); x = None
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:619 in forward, code: b = y.tolist()
unbind = torch.ops.aten.unbind.int(y); y = None
getitem: "i64[]" = unbind[0]
getitem_1: "i64[]" = unbind[1]; unbind = None
item_1: "Sym(u1)" = torch.ops.aten.item.default(getitem); getitem = None
item_2: "Sym(u2)" = torch.ops.aten.item.default(getitem_1); getitem_1 = None
return (item_1, item_2, item)
Graph signature:
# inputs
x: USER_INPUT
y: USER_INPUT
# outputs
item_1: USER_OUTPUT
item_2: USER_OUTPUT
item: USER_OUTPUT
Range constraints: {u0: VR[-int_oo, int_oo], u1: VR[-int_oo, int_oo], u2: VR[-int_oo, int_oo]}
The result is that 3 unbacked symbols (notice they’re prefixed with “u”, instead of the usual “s” for input shape/backed symbols) are allocated and returned:
1 for the item()
call, and 1 for each of the elements of y
with the tolist()
call.
Note from the range constraints field that these take on ranges of [-int_oo, int_oo]
, not the default [0, int_oo]
range allocated to input shape symbols,
since we have no information on what these values are - they don’t represent sizes, so don’t necessarily have positive values.
Guards, torch._check()#
But the case above is easy to export, because the concrete values of these symbols aren’t used in any compiler decision-making; all that’s relevant is that the return values are unbacked symbols. The data-dependent errors highlighted in this section are cases like the following, where data-dependent guards are encountered:
Here we actually need the “hint”, or the concrete value of a
for the compiler to decide whether to trace return y + 2
or return y * 5
as the output.
Because we trace with FakeTensors, we don’t know what a // 2 >= 5
actually evaluates to, and export errors out with “Could not guard on data-dependent expression u0 // 2 >= 5 (unhinted)
”.
So how do we export this toy model? Unlike torch.compile()
, export requires full graph compilation, and we can’t just graph break on this. Here are some basic options:
Manual specialization: we could intervene by selecting the branch to trace, either by removing the control-flow code to contain only the specialized branch, or using
torch.compiler.is_compiling()
to guard what’s traced at compile-time.torch.cond()
: we could rewrite the control-flow code to usetorch.cond()
so we don’t specialize on a branch.
While these options are valid, they have their pitfalls. Option 1 sometimes requires drastic, invasive rewrites of the model code to specialize, and torch.cond()
is not a comprehensive system for handling data-dependent errors.
As we will see, there are data-dependent errors that do not involve control-flow.
The generally recommended approach is to start with torch._check()
calls. While these give the impression of purely being assert statements, they are in fact a system of informing the compiler on properties of symbols.
While a torch._check()
call does act as an assertion at runtime, when traced at compile-time, the checked expression is sent to the symbolic shapes subsystem for reasoning, and any symbol properties that follow from the expression being true,
are stored as symbol properties (provided it’s smart enough to infer those properties). So even if unbacked symbols don’t have hints, if we’re able to communicate properties that are generally true for these symbols via
torch._check()
calls, we can potentially bypass data-dependent guards without rewriting the offending model code.
For example in the model above, inserting torch._check(a >= 10)
would tell the compiler that y + 2
can always be returned, and torch._check(a == 4)
tells it to return y * 5
.
See what happens when we re-export this model.
class Foo(torch.nn.Module):
def forward(self, x, y):
a = x.item()
torch._check(a >= 10)
torch._check(a <= 60)
if a // 2 >= 5:
return y + 2
else:
return y * 5
inps = (
torch.tensor(32),
torch.randn(4),
)
ep = export(Foo(), inps)
print(ep)
I1015 19:14:15.811000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:15.816000 18276 torch/fx/experimental/symbolic_shapes.py:4780] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:651 in local_scalar_dense)
I1015 19:14:15.817000 18276 torch/fx/experimental/symbolic_shapes.py:1289] compute_unbacked_bindings [u0]
I1015 19:14:15.819000 18276 torch/fx/experimental/symbolic_shapes.py:7207] runtime_assert u0 >= 10 [guard added] (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:673 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= 10"
V1015 19:14:15.819000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range u0 = VR[10, int_oo] (update)
I1015 19:14:15.824000 18276 torch/fx/experimental/symbolic_shapes.py:7207] runtime_assert u0 <= 60 [guard added] (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:674 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 <= 60"
V1015 19:14:15.825000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range u0 = VR[10, 60] (update)
V1015 19:14:15.830000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == True [statically known]
I1015 19:14:15.834000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:15.834000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].storage_offset() 0 None
V1015 19:14:15.834000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[0] 4 None
V1015 19:14:15.835000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[0] 1 None
V1015 19:14:15.835000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].storage_offset() 0 None
V1015 19:14:15.837000 18276 torch/fx/experimental/symbolic_shapes.py:7700] runtime_assert u0 >= 10 == True [statically known]
V1015 19:14:15.838000 18276 torch/fx/experimental/symbolic_shapes.py:7700] runtime_assert u0 <= 60 == True [statically known]
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "i64[]", y: "f32[4]"):
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:672 in forward, code: a = x.item()
item: "Sym(u0)" = torch.ops.aten.item.default(x); x = None
ge_2: "Sym(u0 >= 10)" = item >= 10
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 10 on node 'ge_2'"); ge_2 = _assert_scalar_default = None
le_1: "Sym(u0 <= 60)" = item <= 60; item = None
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 60 on node 'le_1'"); le_1 = _assert_scalar_default_1 = None
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:676 in forward, code: return y + 2
add: "f32[4]" = torch.ops.aten.add.Tensor(y, 2); y = None
return (add,)
Graph signature:
# inputs
x: USER_INPUT
y: USER_INPUT
# outputs
add: USER_OUTPUT
Range constraints: {u0: VR[10, 60]}
Export succeeds, and note from the range constraints field that u0
takes on a range of [10, 60]
.
So what information do torch._check()
calls actually communicate? This varies as the symbolic shapes subsystem gets smarter, but at a fundamental level, these are generally true:
Equality with non-data-dependent expressions:
torch._check()
calls that communicate equalities likeu0 == s0 + 4
oru0 == 5
.Range refinement: calls that provide lower or upper bounds for symbols, like the above.
Some basic reasoning around more complicated expressions: inserting
torch._check(a < 4)
will typically tell the compiler thata >= 4
is false. Checks on complex expressions liketorch._check(a ** 2 - 3 * a <= 10)
will typically get you past identical guards.
As mentioned previously, torch._check()
calls have applicability outside of data-dependent control flow. For example, here’s a model where torch._check()
insertion
prevails while manual specialization & torch.cond()
do not:
class Foo(torch.nn.Module):
def forward(self, x, y):
a = x.item()
return y[a]
inps = (
torch.tensor(32),
torch.randn(60),
)
try:
export(Foo(), inps)
except Exception:
tb.print_exc()
I1015 19:14:15.844000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:15.850000 18276 torch/fx/experimental/symbolic_shapes.py:4780] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:651 in local_scalar_dense)
I1015 19:14:15.850000 18276 torch/fx/experimental/symbolic_shapes.py:1289] compute_unbacked_bindings [u0]
I1015 19:14:15.852000 18276 torch/fx/experimental/symbolic_shapes.py:7379] could not evaluate u0 >= 0 due to data dependency, it was assumed to be False with no runtime assertions (_subclasses/fake_impls.py:388 in meta_select)
I1015 19:14:15.852000 18276 torch/fx/experimental/symbolic_shapes.py:7379] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
I1015 19:14:15.853000 18276 torch/fx/experimental/symbolic_shapes.py:7379] could not evaluate u0 < 0 due to data dependency, it was assumed to be False with no runtime assertions (_subclasses/fake_impls.py:390 in meta_select)
I1015 19:14:15.853000 18276 torch/fx/experimental/symbolic_shapes.py:7379] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
I1015 19:14:15.854000 18276 torch/fx/experimental/symbolic_shapes.py:4780] create_unbacked_symint u1 [-int_oo, int_oo] (_subclasses/fake_impls.py:402 in meta_select)
I1015 19:14:15.855000 18276 torch/fx/experimental/symbolic_shapes.py:7379] could not evaluate u1 >= 0 due to data dependency, it was assumed to be True with no runtime assertions (utils/_stats.py:28 in wrapper)
I1015 19:14:15.855000 18276 torch/fx/experimental/symbolic_shapes.py:7379] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
I1015 19:14:15.856000 18276 torch/fx/experimental/symbolic_shapes.py:1289] compute_unbacked_bindings [u1]
I1015 19:14:15.858000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:15.859000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].storage_offset() 0 None
V1015 19:14:15.859000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[0] 60 None
V1015 19:14:15.859000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[0] 1 None
V1015 19:14:15.859000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].storage_offset() 0 None
Here is a scenario where torch._check()
insertion is required simply to prevent an operation from failing. The export call will fail with
“Could not guard on data-dependent expression -u0 > 60
”, implying that the compiler doesn’t know if this is a valid indexing operation -
if the value of x
is out-of-bounds for y
or not. Here, manual specialization is too prohibitive, and torch.cond()
has no place.
Instead, informing the compiler of u0
’s range is sufficient:
class Foo(torch.nn.Module):
def forward(self, x, y):
a = x.item()
torch._check(a >= 0)
torch._check(a < y.shape[0])
return y[a]
inps = (
torch.tensor(32),
torch.randn(60),
)
ep = export(Foo(), inps)
print(ep)
I1015 19:14:15.863000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:15.869000 18276 torch/fx/experimental/symbolic_shapes.py:4780] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:651 in local_scalar_dense)
I1015 19:14:15.869000 18276 torch/fx/experimental/symbolic_shapes.py:1289] compute_unbacked_bindings [u0]
I1015 19:14:15.871000 18276 torch/fx/experimental/symbolic_shapes.py:7207] runtime_assert u0 >= 0 [guard added] (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:722 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= 0"
V1015 19:14:15.871000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range u0 = VR[0, int_oo] (update)
I1015 19:14:15.874000 18276 torch/fx/experimental/symbolic_shapes.py:7207] runtime_assert u0 < 60 [guard added] (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:723 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 < 60"
V1015 19:14:15.875000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range u0 = VR[0, 59] (update)
V1015 19:14:15.877000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == True [statically known]
V1015 19:14:15.878000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == True [statically known]
I1015 19:14:15.881000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:15.881000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].storage_offset() 0 None
V1015 19:14:15.882000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[0] 60 None
V1015 19:14:15.882000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[0] 1 None
V1015 19:14:15.882000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].storage_offset() 0 None
V1015 19:14:15.884000 18276 torch/fx/experimental/symbolic_shapes.py:7700] runtime_assert u0 >= 0 == True [statically known]
V1015 19:14:15.886000 18276 torch/fx/experimental/symbolic_shapes.py:7700] runtime_assert u0 <= 59 == True [statically known]
V1015 19:14:15.887000 18276 torch/fx/experimental/symbolic_shapes.py:7700] runtime_assert u0 < 60 == True [statically known]
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "i64[]", y: "f32[60]"):
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:721 in forward, code: a = x.item()
item: "Sym(u0)" = torch.ops.aten.item.default(x); x = None
ge_1: "Sym(u0 >= 0)" = item >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default = None
le: "Sym(u0 <= 59)" = item <= 59
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 59 on node 'le'"); le = _assert_scalar_default_1 = None
#
lt_1: "Sym(u0 < 60)" = item < 60
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(lt_1, "Runtime assertion failed for expression u0 < 60 on node 'lt_1'"); lt_1 = _assert_scalar_default_2 = None
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:724 in forward, code: return y[a]
select: "f32[]" = torch.ops.aten.select.int(y, 0, item); y = item = None
return (select,)
Graph signature:
# inputs
x: USER_INPUT
y: USER_INPUT
# outputs
select: USER_OUTPUT
Range constraints: {u0: VR[0, 59]}
Specialized values#
Another category of data-dependent error happens when the program attempts to extract a concrete data-dependent integer/float value while tracing. This looks something like “Could not extract specialized integer from data-dependent expression”, and is analogous to the previous class of errors - if these occur when attempting to evaluate concrete integer/float values, data-dependent guard errors arise with evaluating concrete boolean values.
This error typically occurs when there is an explicit or implicit int()
cast on a data-dependent expression. For example, this list comprehension
has a range() call that implicitly does an int()
cast on the size of the list:
class Foo(torch.nn.Module):
def forward(self, x, y):
a = x.item()
b = torch.cat([y for y in range(a)], dim=0)
return b + int(a)
inps = (
torch.tensor(32),
torch.randn(60),
)
try:
export(Foo(), inps, strict=False)
except Exception:
tb.print_exc()
I1015 19:14:15.893000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:15.899000 18276 torch/fx/experimental/symbolic_shapes.py:4780] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:651 in local_scalar_dense)
I1015 19:14:15.899000 18276 torch/fx/experimental/symbolic_shapes.py:1289] compute_unbacked_bindings [u0]
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] Data dependent variable 'u0' allocated at:
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/bin/sphinx-build", line 7, in <module>
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] sys.exit(main())
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 339, in main
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return make_main(argv)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 213, in make_main
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return make_mode.run_make_mode(argv[1:])
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 181, in run_make_mode
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return make.run_generic_build(args[0])
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 169, in run_generic_build
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return build_main(args + opts)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 293, in build_main
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] app = Sphinx(args.sourcedir, args.confdir, args.outputdir,
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 272, in __init__
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] self._init_builder()
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 343, in _init_builder
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] self.events.emit('builder-inited')
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/sphinx/events.py", line 97, in emit
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] results.append(listener.handler(self.app, *args))
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 757, in generate_gallery_rst
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] ) = generate_dir_rst(
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 606, in generate_dir_rst
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] results = parallel(
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 607, in <genexpr>
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] p_fun(fname, target_dir, src_dir, gallery_conf) for fname in iterator
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/var/lib/workspace/conf.py", line 85, in wrapper
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] p.start()
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/lib/python3.10/multiprocessing/process.py", line 121, in start
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] self._popen = self._Popen(self)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/lib/python3.10/multiprocessing/context.py", line 224, in _Popen
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return _default_context.get_context().Process._Popen(process_obj)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/lib/python3.10/multiprocessing/context.py", line 281, in _Popen
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return Popen(process_obj)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 19, in __init__
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] self._launch(process_obj)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 71, in _launch
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] code = process_obj._bootstrap(parent_sentinel=child_r)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] self.run()
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] self._target(*self._args, **self._kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/var/lib/workspace/conf.py", line 73, in call_fn
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] result = func(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1374, in generate_file_rst
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] output_blocks, time_elapsed = execute_script(
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1192, in execute_script
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] execute_code_block(
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1048, in execute_code_block
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] is_last_expr, mem_max = _exec_and_get_memory(
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 876, in _exec_and_get_memory
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] mem_max, _ = call_memory(
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1725, in _sg_call_memory_noop
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return 0.0, func()
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 794, in __call__
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] exec(self.code, self.fake_main.__dict__)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 756, in <module>
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] export(Foo(), inps, strict=False)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 277, in export
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return _export(
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1129, in wrapper
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] ep = fn(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return fn(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2255, in _export
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] ep = _export_for_training(
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1129, in wrapper
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] ep = fn(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return fn(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2071, in _export_for_training
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] export_artifact = export_func(
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2002, in _non_strict_export
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] aten_export_artifact = _to_aten_func( # type: ignore[operator]
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1793, in _export_to_aten_ir_make_fx
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] gm, graph_signature = transform(_make_fx_helper)(
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1922, in _aot_export_non_strict
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1706, in _make_fx_helper
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] gm = make_fx(
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2429, in wrapped
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return make_fx_tracer.trace(f, *args)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2356, in trace
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return self._trace_inner(f, *args)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2318, in _trace_inner
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] t = dispatch_trace(
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 53, in inner
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return disable_fn(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return fn(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1303, in dispatch_trace
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1908, in trace
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] res = super().trace(root, concrete_args)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return fn(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 868, in trace
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] (self.create_arg(fn(*args)),),
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1361, in wrapped
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] out = f(*tensors) # type:ignore[call-arg]
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "<string>", line 1, in <lambda>
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1593, in wrapped_fn
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return tuple(flat_fn(*args))
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 187, in flat_fn
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] tree_out = fn(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1354, in functional_call
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] out = mod(*args[params_len:], **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 843, in module_call_wrapper
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return self.call_module(mod, forward, args, kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1997, in call_module
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return Tracer.call_module(self, m, forward, args, kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 560, in call_module
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] ret_val = forward(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 836, in forward
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return _orig_module_call(mod, *args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return self._call_impl(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return forward_call(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1906, in forward
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] tree_out = mod(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 843, in module_call_wrapper
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return self.call_module(mod, forward, args, kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1997, in call_module
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return Tracer.call_module(self, m, forward, args, kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 560, in call_module
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] ret_val = forward(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 836, in forward
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return _orig_module_call(mod, *args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return self._call_impl(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return forward_call(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 747, in forward
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] a = x.item()
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1409, in __torch_function__
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return func(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1479, in __torch_function__
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return func(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 1066, in __torch_function__
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return func(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 962, in handler
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return torch._library.utils.handle_dispatch_mode(
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/_library/utils.py", line 286, in handle_dispatch_mode
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 28, in wrapper
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return fn(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1534, in __torch_dispatch__
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return proxy_call(self, func, self.pre_dispatch, args, kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 994, in proxy_call
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] out = func(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 841, in __call__
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return self._op(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 28, in wrapper
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return fn(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1376, in __torch_dispatch__
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return self.dispatch(func, types, args, kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2096, in dispatch
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return self._cached_dispatch_impl(func, types, args, kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1498, in _cached_dispatch_impl
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return self._dispatch_impl(func, types, args, kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2725, in _dispatch_impl
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] op_impl_out = op_impl(self, func, *args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_impls.py", line 169, in dispatch_to_op_implementations_dict
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_impls.py", line 651, in local_scalar_dense
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] r = fake_mode.shape_env.create_unbacked_symint()
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 272, in wrapper
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] return retlog(fn(*args, **kwargs))
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]
def forward(self, arg0_1: "i64[]", arg1_1: "f32[60]"):
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:747 in forward, code: a = x.item()
item: "Sym(u0)" = torch.ops.aten.item.default(arg0_1); arg0_1 = item = None
def forward(self, arg0_1: "i64[]", arg1_1: "f32[60]"):
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:747 in forward, code: a = x.item()
item: "Sym(u0)" = torch.ops.aten.item.default(arg0_1); arg0_1 = item = None
Traceback (most recent call last):
File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 756, in <module>
export(Foo(), inps, strict=False)
File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 311, in export
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 277, in export
return _export(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1163, in wrapper
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1129, in wrapper
ep = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2255, in _export
ep = _export_for_training(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1163, in wrapper
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1129, in wrapper
ep = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2071, in _export_for_training
export_artifact = export_func(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2002, in _non_strict_export
aten_export_artifact = _to_aten_func( # type: ignore[operator]
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1793, in _export_to_aten_ir_make_fx
gm, graph_signature = transform(_make_fx_helper)(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1922, in _aot_export_non_strict
gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1706, in _make_fx_helper
gm = make_fx(
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2429, in wrapped
return make_fx_tracer.trace(f, *args)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2356, in trace
return self._trace_inner(f, *args)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2318, in _trace_inner
t = dispatch_trace(
File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 53, in inner
return disable_fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1303, in dispatch_trace
graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1908, in trace
res = super().trace(root, concrete_args)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 868, in trace
(self.create_arg(fn(*args)),),
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1361, in wrapped
out = f(*tensors) # type:ignore[call-arg]
File "<string>", line 1, in <lambda>
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1593, in wrapped_fn
return tuple(flat_fn(*args))
File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 187, in flat_fn
tree_out = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1354, in functional_call
out = mod(*args[params_len:], **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 843, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1997, in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 560, in call_module
ret_val = forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 836, in forward
return _orig_module_call(mod, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1906, in forward
tree_out = mod(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 843, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1997, in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 560, in call_module
ret_val = forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 836, in forward
return _orig_module_call(mod, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 748, in forward
b = torch.cat([y for y in range(a)], dim=0)
File "/usr/local/lib/python3.10/dist-packages/torch/__init__.py", line 449, in __index__
return self.node.int_()
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 468, in int_
return self.guard_int("", 0) # NB: uses Python backtrace
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 518, in guard_int
r = self.evaluate()
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 512, in evaluate
return self.shape_env.evaluate_sym_node(self, size_oblivious)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7233, in evaluate_sym_node
return self.evaluate_expr(
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7333, in evaluate_expr
return self._inner_evaluate_expr(
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 272, in wrapper
return retlog(fn(*args, **kwargs))
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7356, in _inner_evaluate_expr
return self._evaluate_expr(
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7574, in _evaluate_expr
raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0). (Size-like symbols: none)
Caused by: (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:748 in forward)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.
For these errors, some basic options you have are:
Avoid unnecessary
int()
cast calls, in this case theint(a)
in the return statement.Use
torch._check()
calls; unfortunately all you may be able to do in this case is specialize (withtorch._check(a == 60)
).Rewrite the offending code at a higher level. For example, the list comprehension is semantically a
repeat()
op, which doesn’t involve anint()
cast. The following rewrite avoids data-dependent errors:
class Foo(torch.nn.Module):
def forward(self, x, y):
a = x.item()
b = y.unsqueeze(0).repeat(a, 1)
return b + a
inps = (
torch.tensor(32),
torch.randn(60),
)
ep = export(Foo(), inps, strict=False)
print(ep)
I1015 19:14:15.918000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:15.924000 18276 torch/fx/experimental/symbolic_shapes.py:4780] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:651 in local_scalar_dense)
I1015 19:14:15.924000 18276 torch/fx/experimental/symbolic_shapes.py:1289] compute_unbacked_bindings [u0]
I1015 19:14:15.927000 18276 torch/fx/experimental/symbolic_shapes.py:7207] runtime_assert u0 >= 0 [guard added] (_meta_registrations.py:4109 in meta_repeat), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= 0"
V1015 19:14:15.928000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range u0 = VR[0, int_oo] (update)
V1015 19:14:15.929000 18276 torch/fx/experimental/symbolic_shapes.py:7700] runtime_assert u0 >= 0 == True [statically known]
I1015 19:14:15.932000 18276 torch/fx/experimental/symbolic_shapes.py:7379] could not evaluate Eq(u0, 0) due to data dependency, it was assumed to be False with no runtime assertions (utils/_stats.py:28 in wrapper)
I1015 19:14:15.932000 18276 torch/fx/experimental/symbolic_shapes.py:7379] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
I1015 19:14:15.938000 18276 torch/fx/experimental/symbolic_shapes.py:7379] could not evaluate 60*u0 < 2 due to data dependency, it was assumed to be False with no runtime assertions (_prims_common/__init__.py:310 in is_contiguous)
I1015 19:14:15.938000 18276 torch/fx/experimental/symbolic_shapes.py:7379] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
I1015 19:14:15.939000 18276 torch/fx/experimental/symbolic_shapes.py:7379] could not evaluate Eq(u0, 1) due to data dependency, it was assumed to be False with no runtime assertions (_prims_common/__init__.py:276 in check_contiguous_sizes_strides)
I1015 19:14:15.939000 18276 torch/fx/experimental/symbolic_shapes.py:7379] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
I1015 19:14:15.946000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:15.947000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].storage_offset() 0 None
V1015 19:14:15.947000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[0] 60 None
V1015 19:14:15.947000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[0] 1 None
V1015 19:14:15.947000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].storage_offset() 0 None
V1015 19:14:15.949000 18276 torch/fx/experimental/symbolic_shapes.py:7700] runtime_assert u0 >= 0 == True [statically known]
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "i64[]", y: "f32[60]"):
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:769 in forward, code: a = x.item()
item: "Sym(u0)" = torch.ops.aten.item.default(x); x = None
#
sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item); sym_constrain_range_for_size_default = None
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:769 in forward, code: a = x.item()
ge: "Sym(u0 >= 0)" = item >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:770 in forward, code: b = y.unsqueeze(0).repeat(a, 1)
unsqueeze: "f32[1, 60]" = torch.ops.aten.unsqueeze.default(y, 0); y = None
repeat: "f32[u0, 60]" = torch.ops.aten.repeat.default(unsqueeze, [item, 1]); unsqueeze = None
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:771 in forward, code: return b + a
add: "f32[u0, 60]" = torch.ops.aten.add.Tensor(repeat, item); repeat = item = None
return (add,)
Graph signature:
# inputs
x: USER_INPUT
y: USER_INPUT
# outputs
add: USER_OUTPUT
Range constraints: {u0: VR[0, int_oo]}
Data-dependent errors can be much more involved, and there are many more options in your toolkit to deal with them: torch._check_is_size()
, guard_size_oblivious()
, or real-tensor tracing, as starters.
For more in-depth guides, please refer to the Export Programming Model,
or Dealing with GuardOnDataDependentSymNode errors.
Custom Ops#
torch.export
can export PyTorch programs with custom operators. Please
refer to this page
on how to author a custom operator in either C++ or Python.
The following is an example of registering a custom operator in python to be
used by torch.export
. The important thing to note is that the custom op
must have a FakeTensor kernel.
@torch.library.custom_op("my_custom_library::custom_op", mutates_args={})
def custom_op(x: torch.Tensor) -> torch.Tensor:
print("custom_op called!")
return torch.relu(x)
@custom_op.register_fake
def custom_op_meta(x):
# Returns an empty tensor with the same shape as the expected output
return torch.empty_like(x)
Here is an example of exporting a program with the custom op.
class CustomOpExample(torch.nn.Module):
def forward(self, x):
x = torch.sin(x)
x = torch.ops.my_custom_library.custom_op(x)
x = torch.cos(x)
return x
exported_custom_op_example = export(CustomOpExample(), (torch.randn(3, 3),))
print(exported_custom_op_example)
print(exported_custom_op_example.module()(torch.randn(3, 3)))
I1015 19:14:16.027000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:16.038000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:16.038000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[0] 3 None
V1015 19:14:16.038000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[1] 3 None
V1015 19:14:16.039000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[0] 3 None
V1015 19:14:16.039000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[1] 1 None
V1015 19:14:16.039000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].storage_offset() 0 None
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 3]"):
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:812 in forward, code: x = torch.sin(x)
sin: "f32[3, 3]" = torch.ops.aten.sin.default(x); x = None
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:813 in forward, code: x = torch.ops.my_custom_library.custom_op(x)
custom_op: "f32[3, 3]" = torch.ops.my_custom_library.custom_op.default(sin); sin = None
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:814 in forward, code: x = torch.cos(x)
cos: "f32[3, 3]" = torch.ops.aten.cos.default(custom_op); custom_op = None
return (cos,)
Graph signature:
# inputs
x: USER_INPUT
# outputs
cos: USER_OUTPUT
Range constraints: {}
custom_op called!
tensor([[1.0000, 0.5618, 0.9935],
[0.9409, 0.9454, 0.8364],
[0.5766, 1.0000, 1.0000]])
Note that in the ExportedProgram
, the custom operator is included in the graph.
IR/Decompositions#
The graph produced by torch.export
returns a graph containing only
ATen operators, which are the
basic unit of computation in PyTorch. As there are over 3000 ATen operators,
export provides a way to narrow down the operator set used in the graph based
on certain characteristics, creating different IRs.
By default, export produces the most generic IR which contains all ATen
operators, including both functional and non-functional operators. A functional
operator is one that does not contain any mutations or aliasing of the inputs.
You can find a list of all ATen operators
here
and you can inspect if an operator is functional by checking
op._schema.is_mutable
, for example:
print(torch.ops.aten.add.Tensor._schema.is_mutable)
print(torch.ops.aten.add_.Tensor._schema.is_mutable)
False
True
This generic IR can be used to train in eager PyTorch Autograd. This IR can be
more explicitly reached through the API torch.export.export_for_training
,
which was introduced in PyTorch 2.5, but calling torch.export.export
should produce the same graph as of PyTorch 2.6.
class DecompExample(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return (x,)
ep_for_training = torch.export.export_for_training(DecompExample(), (torch.randn(1, 1, 3, 3),))
print(ep_for_training.graph)
/var/lib/workspace/intermediate_source/torch_export_tutorial.py:862: FutureWarning:
`torch.export.export_for_training` is deprecated and will be removed in PyTorch 2.10. Please use `torch.export.export` instead, which is functionally equivalent.
I1015 19:14:16.050000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:16.083000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:16.084000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[0] 1 None
V1015 19:14:16.084000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[1] 1 None
V1015 19:14:16.084000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[2] 3 None
V1015 19:14:16.085000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[3] 3 None
V1015 19:14:16.085000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[0] 9 None
V1015 19:14:16.085000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[1] 9 None
V1015 19:14:16.085000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[2] 3 None
V1015 19:14:16.086000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[3] 1 None
V1015 19:14:16.086000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].storage_offset() 0 None
graph():
%p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
%p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
%p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
%p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
%b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
%b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
%b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
%x : [num_users=1] = placeholder[target=x]
%conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv_weight, %p_conv_bias), kwargs = {})
%add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
%batch_norm : [num_users=1] = call_function[target=torch.ops.aten.batch_norm.default](args = (%conv2d, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05, True), kwargs = {})
return (batch_norm,)
We can then lower this exported program to an operator set which only contains
functional ATen operators through the API run_decompositions
, which
decomposes the ATen operators into the ones specified in the decomposition
table, and functionalizes the graph. By specifying an empty set, we’re only
performing functionalization, and does not do any additional decompositions.
This results in an IR which contains ~2000 operators (instead of the 3000
operators above), and is ideal for inference cases.
ep_for_inference = ep_for_training.run_decompositions(decomp_table={})
print(ep_for_inference.graph)
graph():
%p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
%p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
%p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
%p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
%b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
%b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
%b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
%x : [num_users=1] = placeholder[target=x]
%conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv_weight, %p_conv_bias), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
%_native_batch_norm_legit_functional : [num_users=3] = call_function[target=torch.ops.aten._native_batch_norm_legit_functional.default](args = (%conv2d, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 0), kwargs = {})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 3), kwargs = {})
%getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 4), kwargs = {})
return (getitem_3, getitem_4, add, getitem)
As we can see, the previously mutable operator,
torch.ops.aten.add_.default
has now been replaced with
torch.ops.aten.add.default
, a l operator.
We can also further lower this exported program to an operator set which only contains the Core ATen Operator Set, which is a collection of only ~180 operators. This IR is optimal for backends who do not want to reimplement all ATen operators.
from torch.export import default_decompositions
core_aten_decomp_table = default_decompositions()
core_aten_ep = ep_for_training.run_decompositions(decomp_table=core_aten_decomp_table)
print(core_aten_ep.graph)
graph():
%p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
%p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
%p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
%p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
%b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
%b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
%b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
%x : [num_users=1] = placeholder[target=x]
%convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %p_conv_weight, %p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
%_native_batch_norm_legit_functional : [num_users=3] = call_function[target=torch.ops.aten._native_batch_norm_legit_functional.default](args = (%convolution, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 0), kwargs = {})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 3), kwargs = {})
%getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 4), kwargs = {})
return (getitem_3, getitem_4, add, getitem)
We now see that torch.ops.aten.conv2d.default
has been decomposed
into torch.ops.aten.convolution.default
. This is because convolution
is a more “core” operator, as operations like conv1d
and conv2d
can be
implemented using the same op.
We can also specify our own decomposition behaviors:
my_decomp_table = torch.export.default_decompositions()
def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1):
return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups)
my_decomp_table[torch.ops.aten.conv2d.default] = my_awesome_custom_conv2d_function
my_ep = ep_for_training.run_decompositions(my_decomp_table)
print(my_ep.graph)
graph():
%p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
%p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
%p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
%p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
%b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
%b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
%b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
%x : [num_users=1] = placeholder[target=x]
%convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %p_conv_weight, %p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convolution, 2), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
%_native_batch_norm_legit_functional : [num_users=3] = call_function[target=torch.ops.aten._native_batch_norm_legit_functional.default](args = (%mul, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 0), kwargs = {})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 3), kwargs = {})
%getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 4), kwargs = {})
return (getitem_3, getitem_4, add, getitem)
Notice that instead of torch.ops.aten.conv2d.default
being decomposed
into torch.ops.aten.convolution.default
, it is now decomposed into
torch.ops.aten.convolution.default
and torch.ops.aten.mul.Tensor
,
which matches our custom decomposition rule.
ExportDB#
torch.export
will only ever export a single computation graph from a PyTorch program. Because of this requirement,
there will be Python or PyTorch features that are not compatible with torch.export
, which will require users to
rewrite parts of their model code. We have seen examples of this earlier in the tutorial – for example, rewriting
if-statements using cond
.
ExportDB is the standard reference that documents
supported and unsupported Python/PyTorch features for torch.export
. It is essentially a list a program samples, each
of which represents the usage of one particular Python/PyTorch feature and its interaction with torch.export
.
Examples are also tagged by category so that they can be more easily searched.
For example, let’s use ExportDB to get a better understanding of how the predicate works in the cond
operator.
We can look at the example called cond_predicate
, which has a torch.cond
tag. The example code looks like:
def cond_predicate(x):
"""
The conditional statement (aka predicate) passed to ``cond()`` must be one of the following:
- ``torch.Tensor`` with a single element
- boolean expression
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""
pred = x.dim() > 2 and x.shape[2] > 10
return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])
More generally, ExportDB can be used as a reference when one of the following occurs:
Before attempting
torch.export
, you know ahead of time that your model uses some tricky Python/PyTorch features and you want to know iftorch.export
covers that feature.When attempting
torch.export
, there is a failure and it’s unclear how to work around it.
ExportDB is not exhaustive, but is intended to cover all use cases found in typical PyTorch code. Feel free to reach
out if there is an important Python/PyTorch feature that should be added to ExportDB or supported by torch.export
.
Running the Exported Program#
As torch.export
is only a graph capturing mechanism, calling the artifact
produced by torch.export
eagerly will be equivalent to running the eager
module. To optimize the execution of the Exported Program, we can pass this
exported artifact to backends such as Inductor through torch.compile
,
AOTInductor,
or TensorRT.
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)
def forward(self, x):
x = self.linear(x)
return x
inp = torch.randn(2, 3, device="cuda")
m = M().to(device="cuda")
ep = torch.export.export(m, (inp,))
# Run it eagerly
res = ep.module()(inp)
print(res)
# Run it with torch.compile
res = torch.compile(ep.module(), backend="inductor")(inp)
print(res)
I1015 19:14:17.072000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:17.085000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:17.086000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[0] 2 None
V1015 19:14:17.086000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[1] 3 None
V1015 19:14:17.086000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[0] 3 None
V1015 19:14:17.087000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[1] 1 None
V1015 19:14:17.087000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].storage_offset() 0 None
tensor([[ 1.1126, 0.1263, 0.8522],
[ 0.2973, -1.2118, 1.0370]], device='cuda:0',
grad_fn=<AddmmBackward0>)
I1015 19:14:18.236000 18276 torch/fx/experimental/symbolic_shapes.py:3769] [2/0] create_env
/usr/local/lib/python3.10/dist-packages/torch/backends/cuda/__init__.py:131: UserWarning:
Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)
/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:312: UserWarning:
TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
I1015 19:14:19.222000 18276 torch/fx/experimental/symbolic_shapes.py:5242] [2/0] produce_guards
I1015 19:14:19.233000 18276 torch/fx/experimental/symbolic_shapes.py:5242] [2/0] produce_guards
V1015 19:14:19.233000 18276 torch/fx/experimental/symbolic_shapes.py:5462] [2/0] track_symint L['x'].size()[0] 2 None
V1015 19:14:19.234000 18276 torch/fx/experimental/symbolic_shapes.py:5462] [2/0] track_symint L['x'].size()[1] 3 None
V1015 19:14:19.234000 18276 torch/fx/experimental/symbolic_shapes.py:5462] [2/0] track_symint L['x'].stride()[0] 3 None
V1015 19:14:19.234000 18276 torch/fx/experimental/symbolic_shapes.py:5462] [2/0] track_symint L['x'].stride()[1] 1 None
V1015 19:14:19.234000 18276 torch/fx/experimental/symbolic_shapes.py:5462] [2/0] track_symint L['x'].storage_offset() 0 None
V1015 19:14:19.235000 18276 torch/fx/experimental/symbolic_shapes.py:5462] [2/0] track_symint L['self']._modules['linear']._parameters['weight'].size()[0] 3 None
V1015 19:14:19.235000 18276 torch/fx/experimental/symbolic_shapes.py:5462] [2/0] track_symint L['self']._modules['linear']._parameters['weight'].size()[1] 3 None
V1015 19:14:19.235000 18276 torch/fx/experimental/symbolic_shapes.py:5462] [2/0] track_symint L['self']._modules['linear']._parameters['weight'].stride()[0] 3 None
V1015 19:14:19.235000 18276 torch/fx/experimental/symbolic_shapes.py:5462] [2/0] track_symint L['self']._modules['linear']._parameters['weight'].stride()[1] 1 None
V1015 19:14:19.236000 18276 torch/fx/experimental/symbolic_shapes.py:5462] [2/0] track_symint L['self']._modules['linear']._parameters['weight'].storage_offset() 0 None
V1015 19:14:19.236000 18276 torch/fx/experimental/symbolic_shapes.py:5462] [2/0] track_symint L['self']._modules['linear']._parameters['bias'].size()[0] 3 None
V1015 19:14:19.236000 18276 torch/fx/experimental/symbolic_shapes.py:5462] [2/0] track_symint L['self']._modules['linear']._parameters['bias'].stride()[0] 1 None
V1015 19:14:19.236000 18276 torch/fx/experimental/symbolic_shapes.py:5462] [2/0] track_symint L['self']._modules['linear']._parameters['bias'].storage_offset() 0 None
V1015 19:14:19.237000 18276 torch/fx/experimental/symbolic_shapes.py:5675] [2/0] Skipping guard L['x'].size()[0] == 2
V1015 19:14:19.237000 18276 torch/fx/experimental/symbolic_shapes.py:5675] [2/0] Skipping guard L['x'].size()[1] == 3
V1015 19:14:19.237000 18276 torch/fx/experimental/symbolic_shapes.py:5675] [2/0] Skipping guard L['x'].stride()[0] == 3
V1015 19:14:19.237000 18276 torch/fx/experimental/symbolic_shapes.py:5675] [2/0] Skipping guard L['x'].stride()[1] == 1
V1015 19:14:19.238000 18276 torch/fx/experimental/symbolic_shapes.py:5675] [2/0] Skipping guard L['x'].storage_offset() == 0
V1015 19:14:19.238000 18276 torch/fx/experimental/symbolic_shapes.py:5675] [2/0] Skipping guard L['self']._modules['linear']._parameters['weight'].size()[0] == 3
V1015 19:14:19.238000 18276 torch/fx/experimental/symbolic_shapes.py:5675] [2/0] Skipping guard L['self']._modules['linear']._parameters['weight'].size()[1] == 3
V1015 19:14:19.238000 18276 torch/fx/experimental/symbolic_shapes.py:5675] [2/0] Skipping guard L['self']._modules['linear']._parameters['weight'].stride()[0] == 3
V1015 19:14:19.239000 18276 torch/fx/experimental/symbolic_shapes.py:5675] [2/0] Skipping guard L['self']._modules['linear']._parameters['weight'].stride()[1] == 1
V1015 19:14:19.239000 18276 torch/fx/experimental/symbolic_shapes.py:5675] [2/0] Skipping guard L['self']._modules['linear']._parameters['weight'].storage_offset() == 0
V1015 19:14:19.239000 18276 torch/fx/experimental/symbolic_shapes.py:5675] [2/0] Skipping guard L['self']._modules['linear']._parameters['bias'].size()[0] == 3
V1015 19:14:19.239000 18276 torch/fx/experimental/symbolic_shapes.py:5675] [2/0] Skipping guard L['self']._modules['linear']._parameters['bias'].stride()[0] == 1
V1015 19:14:19.240000 18276 torch/fx/experimental/symbolic_shapes.py:5675] [2/0] Skipping guard L['self']._modules['linear']._parameters['bias'].storage_offset() == 0
tensor([[ 1.1126, 0.1263, 0.8522],
[ 0.2973, -1.2118, 1.0370]], device='cuda:0',
grad_fn=<CompiledFunctionBackward>)
import torch._inductor
# Note: these APIs are subject to change
# Compile the exported program to a PT2 archive using ``AOTInductor``
with torch.no_grad():
pt2_path = torch._inductor.aoti_compile_and_package(ep)
# Load and run the .so file in Python.
# To load and run it in a C++ environment, see:
# https://pytorch.org/docs/main/torch.compiler_aot_inductor.html
aoti_compiled = torch._inductor.aoti_load_package(pt2_path)
res = aoti_compiled(inp)
Conclusion#
We introduced torch.export
, the new PyTorch 2.X way to export single computation
graphs from PyTorch programs. In particular, we demonstrate several code modifications
and considerations (control flow ops, constraints, etc.) that need to be made in order to export a graph.
Total running time of the script: (0 minutes 8.501 seconds)