55import torchvision .transforms as transforms
66from torch .autograd import Variable
77
8-
8+ # 用于计算平均值
99class AvgrageMeter (object ):
1010
1111 def __init__ (self ):
@@ -21,7 +21,7 @@ def update(self, val, n=1):
2121 self .cnt += n
2222 self .avg = self .sum / self .cnt
2323
24-
24+ # 求top-k精度
2525def accuracy (output , target , topk = (1 ,)):
2626 maxk = max (topk )
2727 batch_size = target .size (0 )
@@ -36,7 +36,7 @@ def accuracy(output, target, topk=(1,)):
3636 res .append (correct_k .mul_ (100.0 / batch_size ))
3737 return res
3838
39-
39+ # 数据增强:Cutout,生成一个边长为length的正方形遮掩(越过边界的话就变成矩形了)
4040class Cutout (object ):
4141 def __init__ (self , length ):
4242 self .length = length
@@ -58,7 +58,7 @@ def __call__(self, img):
5858 img *= mask
5959 return img
6060
61-
61+ # 用于CIFAR的数据增强操作
6262def _data_transforms_cifar10 (args ):
6363 CIFAR_MEAN = [0.49139968 , 0.48215827 , 0.44653124 ]
6464 CIFAR_STD = [0.24703233 , 0.24348505 , 0.26158768 ]
@@ -78,27 +78,27 @@ def _data_transforms_cifar10(args):
7878 ])
7979 return train_transform , valid_transform
8080
81-
81+ # 统计参数量(M)
8282def count_parameters_in_MB (model ):
8383 return np .sum (np .prod (v .size ()) for name , v in model .named_parameters () if "auxiliary" not in name )/ 1e6
8484
85-
85+ # 保存checkpoint,同时如果是最好模型的话也会copy一下
8686def save_checkpoint (state , is_best , save ):
8787 filename = os .path .join (save , 'checkpoint.pth.tar' )
8888 torch .save (state , filename )
8989 if is_best :
9090 best_filename = os .path .join (save , 'model_best.pth.tar' )
9191 shutil .copyfile (filename , best_filename )
9292
93-
93+ # 保存模型
9494def save (model , model_path ):
9595 torch .save (model .state_dict (), model_path )
9696
97-
97+ # 载入模型
9898def load (model , model_path ):
9999 model .load_state_dict (torch .load (model_path ))
100100
101-
101+ # 随机丢弃路径,来自FractalNet
102102def drop_path (x , drop_prob ):
103103 if drop_prob > 0. :
104104 keep_prob = 1. - drop_prob
@@ -107,7 +107,7 @@ def drop_path(x, drop_prob):
107107 x .mul_ (mask )
108108 return x
109109
110-
110+ # 创建文件夹,copy文件的一些操作
111111def create_exp_dir (path , scripts_to_save = None ):
112112 if not os .path .exists (path ):
113113 os .mkdir (path )
0 commit comments