Skip to content

Commit 22708f1

Browse files
authored
Update train_search.py
1 parent 4514474 commit 22708f1

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

cnn/train_search.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ def train(train_queue, valid_queue, model, architect, criterion, optimizer, lr):
165165
input_search = Variable(input_search, requires_grad=False).cuda()
166166
target_search = Variable(target_search, requires_grad=False).cuda(async=True)
167167

168+
### 优化部分:DARTS是交替优化的,第一步先优化alpha,第二步再优化w
169+
168170
# 更新架构权重alpha,unrolled为True时就是用论文的公式进行alpha的更新
169171
architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled)
170172

0 commit comments

Comments
 (0)