Page 467 - 2020학년도 MDP과제발표회 자료집 (통신과) (3)
P. 467

train_writer  =  tf.summary.FileWriter(tensorboardDIR  +  '/train',  sess.graph)
             test_writer  =  tf.summary.FileWriter(tensorboardDIR  +  '/test')


             ‘’‘
             tf.train.Saver 를  이용하여  학습중인  모델을  저장할  수도  있고 미리  학습된  모델을  가져와서  사용할  수,
             도  있습니다.
             ‘’‘
             #  tf.train.Saver 를  이용해서  모델과  파라미터를  저장합니다.
             SAVER_DIR  =  "model"
             saver  =  tf.train.Saver()
             checkpoint_path  =  os.path.join(SAVER_DIR,  "model")
             ckpt  =  tf.train.get_checkpoint_state(SAVER_DIR)


             #  만약  저장된  모델과  파라미터가  있으면  이를  불러오고  (Restore)
             #  Restored  모델을  이용해서  테스트  데이터에  대한  정확도를  출력하고  프로그램을  종료합니다.
             if  ckpt  and  ckpt.model_checkpoint_path:
                     saver.restore(sess,  ckpt.model_checkpoint_path)
                     print('Accuracy:',  sess.run(accuracy,  feed_dict={X:  test_input,  Y:  test_label,  keep_prob:
             1.0}))
                     sess.close()
                     exit()


             ‘’‘
             training 을  시작합니다 한  스텝당 .       epochs 에서  정했던  만큼  학습을  진행합니다 이후  한  스텝이  끝날 .
             때  마다  saver 에  모델을  저장합니다.
             ‘’‘
             #  Training  시작
             print('Learning  started.')


             for  epoch  in  range(training_epochs):
                     avg_cost  =  0
                     total_batch  =  int(len(train_input)  /  batch_size)


                     for  i  in  range(total_batch):
                             start  =  ((i  +  1)  *  batch_size)  -  batch_size
                             end  =  ((i  +  1)  *  batch_size)
                             batch_xs  =  train_input[start:end]
                             batch_ys  =  train_label[start:end]
                             feed_dict  =  {X:  batch_xs,  Y:  batch_ys,  keep_prob:  0.5}
                             c,  _  =  sess.run([cost,  optimizer],  feed_dict=feed_dict)
                             avg_cost  +=  c  /  total_batch


                     if  epoch  %  1  ==  0:
                             saver.save(sess,  checkpoint_path,  global_step=epoch)



                                                                                      인천전자마이스터고등학교
                                                         -  487  -                       정보통신기기과         487
   462   463   464   465   466   467   468   469   470   471   472