-
Notifications
You must be signed in to change notification settings - Fork 22
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
model = TestModel() # From tests.scenarios
params = dict(model.named_parameters())
batch_inputs = torch.randn(3, 10)
batch_labels = torch.randint(2, (3,)).unsqueeze(-1)
batch_spec = {"inputs": batch_inputs, "labels": batch_labels}
def log_likelihood(params, batch):
output = torch.func.functional_call(model, params, batch["inputs"])
return -torch.nn.BCEWithLogitsLoss()(output, batch["labels"].float())
log_likelihood_per_sample = per_samplify(log_likelihood)Works fine:
log_likelihood_per_sample(params, batch_spec)Throws ValueError:
partial(log_likelihood_per_sample, batch= batch_spec)(params)
# ValueError: vmap(f_per_sample, in_dims=(None, 0), ...)(<inputs>): in_dims is not compatible with the structure of `inputs`. in_dims has structure TreeSpec(tuple, None, [*,
# *]) but inputs has structure TreeSpec(tuple, None, [TreeSpec(dict, ['linear.weight', 'linear.bias'], [*,
# *])]).Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working