Skip to content

Problem composing per_samplify with functools.partial #78

@SamDuffield

Description

@SamDuffield
model = TestModel() # From tests.scenarios
params = dict(model.named_parameters())
batch_inputs = torch.randn(3, 10)
batch_labels = torch.randint(2, (3,)).unsqueeze(-1)
batch_spec = {"inputs": batch_inputs, "labels": batch_labels}

def log_likelihood(params, batch):
    output = torch.func.functional_call(model, params, batch["inputs"])
    return -torch.nn.BCEWithLogitsLoss()(output, batch["labels"].float())

log_likelihood_per_sample = per_samplify(log_likelihood)

Works fine:

log_likelihood_per_sample(params, batch_spec)

Throws ValueError:

partial(log_likelihood_per_sample, batch= batch_spec)(params)
# ValueError: vmap(f_per_sample, in_dims=(None, 0), ...)(<inputs>): in_dims is not compatible with the structure of `inputs`. in_dims has structure TreeSpec(tuple, None, [*,
#  *]) but inputs has structure TreeSpec(tuple, None, [TreeSpec(dict, ['linear.weight', 'linear.bias'], [*,
#    *])]).

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions