Skip to content

When I execute the function pwdist_exact, I get an error #30

@xlcbingo1999

Description

@xlcbingo1999

When I try to compare the distance of two subsets, which randomly sampled from the EMNIST dataset, I use the 'exact' method and follow the format of example.py, but I always enter except at the function pwdist_exact.

from otdd.pytorch.datasets import load_torchvision_data
from otdd.pytorch.distance import DatasetDistance
from torch.utils.data import DataLoader, SubsetRandomSampler, Dataset
from torchvision.datasets import EMNIST
from torchvision.transforms import Compose, ToTensor, Normalize, Grayscale
import json

class CustomDataset(Dataset):
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = [int(i) for i in indices]
        self.targets = dataset.targets # 保留targets属性
        self.classes = dataset.classes # 保留classes属性
        
    def __len__(self):
        return len(self.indices)

    def __getitem__(self, item):
        x, y = self.dataset[self.indices[item]]
        return x, y
    
    def get_class_distribution(self):
        sub_targets = self.targets[self.indices]
        return sub_targets.unique(return_counts=True)

raw_data_path = '/mnt/linuxidc_client/dataset/Amazon_Review_split/EMNIST'
sub_train_config_path = '/mnt/linuxidc_client/dataset/Amazon_Review_split/sub_train_datasets_config.json'
sub_test_config_path = '/mnt/linuxidc_client/dataset/Amazon_Review_split/test_dataset_config.json'
train_id = 0
test_id = 0
dataset_name = "EMNIST"
sub_train_key = 'train_sub_{}'.format(train_id)
sub_test_key = 'test_sub_{}'.format(test_id)
BATCH_SIZE = 2048
with open(sub_train_config_path, 'r+') as f:
    current_subtrain_config = json.load(f)
    f.close()
with open(sub_test_config_path, 'r+') as f:
    current_subtest_config = json.load(f)
    f.close()
real_train_index = sorted(list(current_subtrain_config[dataset_name][sub_train_key]["indexes"]))
print("check last real_train_index: ", real_train_index[-1])
print(len(real_train_index))
real_test_index = sorted(list(current_subtest_config[dataset_name][sub_test_key]["indexes"])) 
print("check last real_test_index: ", real_test_index[-1])
print(len(real_test_index))

transform = Compose([
    Grayscale(3),
    ToTensor(),
    Normalize((0.1307,), (0.3081,))
])
train_dataset = EMNIST(
    root=raw_data_path,
    split="bymerge",
    download=False,
    train=True,
    transform=transform
)
test_dataset = EMNIST(
    root=raw_data_path,
    split="bymerge",
    download=False,
    train=False,
    transform=transform
)
print("begin train: {} test: {}".format(train_id, test_id))
print("check all size: train[{}] and test[{}]".format(len(train_dataset), len(test_dataset)))
train_dataset = CustomDataset(train_dataset, real_train_index)
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE)
test_dataset = CustomDataset(test_dataset, real_test_index)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE)
print("Finished split datasets!")
print("check train_loader: {}".format(len(train_loader) * BATCH_SIZE))
print("check test_loader: {}".format(len(test_loader) * BATCH_SIZE))


# Instantiate distance
dist = DatasetDistance(train_loader, test_loader,
                          inner_ot_method = 'exact',
                          debiased_loss = True,
                          p = 2, entreg = 1e-1,
                          device='cuda:3')

d = dist.distance(maxsamples = 1000)
print(f'OTDD-EMNIST(train,test)={d:8.2f}')

It is worth noting that the label distribution of the MNIST subset is not the same as that of the entire EMNIST dataset. In the subset, the number of instances of some labels is 0.

The function pwdist_exact seems to return the correct result when I take evenly spaced samples. Here is the code.

# real_train_index = sorted(list(current_subtrain_config[dataset_name][sub_train_key]["indexes"]))
real_train_index = list(range(0, 697932, 10))
# real_test_index = sorted(list(current_subtest_config[dataset_name][sub_test_key]["indexes"])) 
real_test_index = list(range(0, 116323, 3))

You can download my sub_train_datasets_config.json and test_dataset_config.json in Google Drive. Link: https://drive.google.com/drive/folders/1r_vyLJ-RmuuNZqneBP3meexrEZvgc_Ce?usp=sharing

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions