Page 467 - 2020학년도 MDP과제발표회 자료집 (통신과) (3)
P. 467
train_writer = tf.summary.FileWriter(tensorboardDIR + '/train', sess.graph)
test_writer = tf.summary.FileWriter(tensorboardDIR + '/test')
‘’‘
tf.train.Saver 를 이용하여 학습중인 모델을 저장할 수도 있고 미리 학습된 모델을 가져와서 사용할 수,
도 있습니다.
‘’‘
# tf.train.Saver 를 이용해서 모델과 파라미터를 저장합니다.
SAVER_DIR = "model"
saver = tf.train.Saver()
checkpoint_path = os.path.join(SAVER_DIR, "model")
ckpt = tf.train.get_checkpoint_state(SAVER_DIR)
# 만약 저장된 모델과 파라미터가 있으면 이를 불러오고 (Restore)
# Restored 모델을 이용해서 테스트 데이터에 대한 정확도를 출력하고 프로그램을 종료합니다.
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
print('Accuracy:', sess.run(accuracy, feed_dict={X: test_input, Y: test_label, keep_prob:
1.0}))
sess.close()
exit()
‘’‘
training 을 시작합니다 한 스텝당 . epochs 에서 정했던 만큼 학습을 진행합니다 이후 한 스텝이 끝날 .
때 마다 saver 에 모델을 저장합니다.
‘’‘
# Training 시작
print('Learning started.')
for epoch in range(training_epochs):
avg_cost = 0
total_batch = int(len(train_input) / batch_size)
for i in range(total_batch):
start = ((i + 1) * batch_size) - batch_size
end = ((i + 1) * batch_size)
batch_xs = train_input[start:end]
batch_ys = train_label[start:end]
feed_dict = {X: batch_xs, Y: batch_ys, keep_prob: 0.5}
c, _ = sess.run([cost, optimizer], feed_dict=feed_dict)
avg_cost += c / total_batch
if epoch % 1 == 0:
saver.save(sess, checkpoint_path, global_step=epoch)
인천전자마이스터고등학교
- 487 - 정보통신기기과 487