torch.dynamic-value#
constrain_as_size_example#
Original source code:
# mypy: allow-untyped-defs
import torch
class ConstrainAsSizeExample(torch.nn.Module):
    """
    If the value is not known at tracing time, you can provide hint so that we
    can trace further. Please look at torch._check and torch._check_is_size APIs.
    torch._check_is_size is used for values that NEED to be used for constructing
    tensor.
    """
    def forward(self, x):
        a = x.item()
        torch._check_is_size(a)
        torch._check(a <= 5)
        return torch.zeros((a, 5))
example_args = (torch.tensor(4),)
tags = {
    "torch.dynamic-value",
    "torch.escape-hatch",
}
model = ConstrainAsSizeExample()
torch.export.export(model, example_args)
Result:
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]"):
                 item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None
             #
            sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item);  sym_constrain_range_for_size_default = None
                 ge_1: "Sym(u0 >= 0)" = item >= 0
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'");  ge_1 = _assert_scalar_default = None
            le_1: "Sym(u0 <= 5)" = item <= 5
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 5 on node 'le_1'");  le_1 = _assert_scalar_default_1 = None
                 zeros: "f32[u0, 5]" = torch.ops.aten.zeros.default([item, 5], device = device(type='cpu'), pin_memory = False);  item = None
            return (zeros,)
Graph signature:
    # inputs
    x: USER_INPUT
    # outputs
    zeros: USER_OUTPUT
Range constraints: {u0: VR[0, 5], u1: VR[0, 5]}
constrain_as_value_example#
Original source code:
# mypy: allow-untyped-defs
import torch
class ConstrainAsValueExample(torch.nn.Module):
    """
    If the value is not known at tracing time, you can provide hint so that we
    can trace further. Please look at torch._check and torch._check_is_size APIs.
    torch._check is used for values that don't need to be used for constructing
    tensor.
    """
    def forward(self, x, y):
        a = x.item()
        torch._check(a >= 0)
        torch._check(a <= 5)
        if a < 6:
            return y.sin()
        return y.cos()
example_args = (torch.tensor(4), torch.randn(5, 5))
tags = {
    "torch.dynamic-value",
    "torch.escape-hatch",
}
model = ConstrainAsValueExample()
torch.export.export(model, example_args)
Result:
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]", y: "f32[5, 5]"):
                 item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None
            ge_1: "Sym(u0 >= 0)" = item >= 0
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'");  ge_1 = _assert_scalar_default = None
            le_1: "Sym(u0 <= 5)" = item <= 5;  item = None
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 5 on node 'le_1'");  le_1 = _assert_scalar_default_1 = None
                 sin: "f32[5, 5]" = torch.ops.aten.sin.default(y);  y = None
            return (sin,)
Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT
    # outputs
    sin: USER_OUTPUT
Range constraints: {u0: VR[0, 5], u1: VR[0, 5]}