Skip to content

Commit 3fb91ca

Browse files
authored
Update utils.py
1 parent f276dd3 commit 3fb91ca

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

cnn/utils.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torchvision.transforms as transforms
66
from torch.autograd import Variable
77

8-
8+
# 用于计算平均值
99
class AvgrageMeter(object):
1010

1111
def __init__(self):
@@ -21,7 +21,7 @@ def update(self, val, n=1):
2121
self.cnt += n
2222
self.avg = self.sum / self.cnt
2323

24-
24+
# 求top-k精度
2525
def accuracy(output, target, topk=(1,)):
2626
maxk = max(topk)
2727
batch_size = target.size(0)
@@ -36,7 +36,7 @@ def accuracy(output, target, topk=(1,)):
3636
res.append(correct_k.mul_(100.0/batch_size))
3737
return res
3838

39-
39+
# 数据增强:Cutout,生成一个边长为length的正方形遮掩(越过边界的话就变成矩形了)
4040
class Cutout(object):
4141
def __init__(self, length):
4242
self.length = length
@@ -58,7 +58,7 @@ def __call__(self, img):
5858
img *= mask
5959
return img
6060

61-
61+
# 用于CIFAR的数据增强操作
6262
def _data_transforms_cifar10(args):
6363
CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
6464
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
@@ -78,27 +78,27 @@ def _data_transforms_cifar10(args):
7878
])
7979
return train_transform, valid_transform
8080

81-
81+
# 统计参数量(M)
8282
def count_parameters_in_MB(model):
8383
return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6
8484

85-
85+
# 保存checkpoint,同时如果是最好模型的话也会copy一下
8686
def save_checkpoint(state, is_best, save):
8787
filename = os.path.join(save, 'checkpoint.pth.tar')
8888
torch.save(state, filename)
8989
if is_best:
9090
best_filename = os.path.join(save, 'model_best.pth.tar')
9191
shutil.copyfile(filename, best_filename)
9292

93-
93+
# 保存模型
9494
def save(model, model_path):
9595
torch.save(model.state_dict(), model_path)
9696

97-
97+
# 载入模型
9898
def load(model, model_path):
9999
model.load_state_dict(torch.load(model_path))
100100

101-
101+
# 随机丢弃路径,来自FractalNet
102102
def drop_path(x, drop_prob):
103103
if drop_prob > 0.:
104104
keep_prob = 1.-drop_prob
@@ -107,7 +107,7 @@ def drop_path(x, drop_prob):
107107
x.mul_(mask)
108108
return x
109109

110-
110+
# 创建文件夹,copy文件的一些操作
111111
def create_exp_dir(path, scripts_to_save=None):
112112
if not os.path.exists(path):
113113
os.mkdir(path)

0 commit comments

Comments
 (0)