diff options
| author | Alexander M Pickering <amp215@pitt.edu> | 2025-02-01 02:24:13 -0600 |
|---|---|---|
| committer | Alexander M Pickering <amp215@pitt.edu> | 2025-02-01 02:24:13 -0600 |
| commit | 61bdb4fef88c1e83787dbb023b51d8d200844e3a (patch) | |
| tree | 6d905b6f61a0e932b1ace9771c714a80e0388af0 /task.py | |
| download | mscbio2046-61bdb4fef88c1e83787dbb023b51d8d200844e3a.tar.gz mscbio2046-61bdb4fef88c1e83787dbb023b51d8d200844e3a.tar.bz2 mscbio2046-61bdb4fef88c1e83787dbb023b51d8d200844e3a.zip | |
Diffstat (limited to 'task.py')
| -rw-r--r-- | task.py | 236 |
1 files changed, 236 insertions, 0 deletions
@@ -0,0 +1,236 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Trains and Evaluates the MNIST network using a feed dictionary.""" +# pylint: disable=missing-docstring +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path +import tempfile +import time + +from six.moves import xrange # pylint: disable=redefined-builtin +import tensorflow as tf + +from tensorflow.examples.tutorials.mnist import input_data +from tensorflow.examples.tutorials.mnist import mnist + + +# Basic model parameters as external flags. +flags = tf.app.flags +FLAGS = flags.FLAGS +flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.') +flags.DEFINE_integer('max_steps', 2000, 'Number of steps to run trainer.') + +flags.DEFINE_integer('hidden1', 1024, 'Number of units in hidden layer 1.') +flags.DEFINE_integer('hidden2', 1024, 'Number of units in hidden layer 2.') +flags.DEFINE_integer('hidden3', 1024, 'Number of units in hidden layer 3.') +flags.DEFINE_integer('hidden4', 1024, 'Number of units in hidden layer 4.') +flags.DEFINE_integer('hidden5', 1024, 'Number of units in hidden layer 5.') +flags.DEFINE_integer('hidden6', 1024, 'Number of units in hidden layer 6.') +flags.DEFINE_integer('hidden7', 1024, 'Number of units in hidden layer 7.') +flags.DEFINE_integer('hidden8', 1024, 'Number of units in hidden layer 8.') +flags.DEFINE_integer('hidden9', 1024, 'Number of units in hidden layer 9.') +flags.DEFINE_integer('hidden10', 1024, 'Number of units in hidden layer 10.') + +flags.DEFINE_integer('batch_size', 100, 'Batch size. ' + 'Must divide evenly into the dataset sizes.') +flags.DEFINE_string('train_dir', 'data', 'Directory to put the training data.') +flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data ' + 'for unit testing.') + + +def placeholder_inputs(batch_size): + """Generate placeholder variables to represent the input tensors. + These placeholders are used as inputs by the rest of the model building + code and will be fed from the downloaded data in the .run() loop, below. + Args: + batch_size: The batch size will be baked into both placeholders. + Returns: + images_placeholder: Images placeholder. + labels_placeholder: Labels placeholder. + """ + # Note that the shapes of the placeholders match the shapes of the full + # image and label tensors, except the first dimension is now batch_size + # rather than the full size of the train or test data sets. + events_placeholder = tf.placeholder(tf.sring, shape=(batch_size, + 21)) + stock_placeholder = tf.placeholder(tf.int32, shape=(batch_size)) + return events_placeholder, stock_placeholder + + +def fill_feed_dict(data_set, images_pl, labels_pl): + """Fills the feed_dict for training the given step. + A feed_dict takes the form of: + feed_dict = { + <placeholder>: <tensor of values to be passed for placeholder>, + .... + } + Args: + data_set: The set of images and labels, from input_data.read_data_sets() + images_pl: The images placeholder, from placeholder_inputs(). + labels_pl: The labels placeholder, from placeholder_inputs(). + Returns: + feed_dict: The feed dictionary mapping from placeholders to values. + """ + # Create the feed_dict for the placeholders filled with the next + # `batch size` examples. + events_feed, stock_feed = data_set.next_batch(FLAGS.batch_size, + FLAGS.fake_data) + feed_dict = { + images_pl: events_feed, + labels_pl: stock_feed, + } + return feed_dict + + +def do_eval(sess, + eval_correct, + events_placeholder, + stock_placeholder, + data_set): + """Runs one evaluation against the full epoch of data. + Args: + sess: The session in which the model has been trained. + eval_correct: The Tensor that returns the number of correct predictions. + images_placeholder: The images placeholder. + labels_placeholder: The labels placeholder. + data_set: The set of images and labels to evaluate, from + input_data.read_data_sets(). + """ + # And run one epoch of eval. + true_count = 0 # Counts the number of correct predictions. + steps_per_epoch = data_set.num_examples // FLAGS.batch_size + num_examples = steps_per_epoch * FLAGS.batch_size + for step in xrange(steps_per_epoch): + feed_dict = fill_feed_dict(data_set, + events_placeholder, + stock_placeholder) + true_count += sess.run(eval_correct, feed_dict=feed_dict) + precision = true_count / num_examples + print(' Num examples: %d Num correct: %d Precision @ 1: %0.04f' % + (num_examples, true_count, precision)) + + +def run_training(): + """Train MNIST for a number of steps.""" + # Get the sets of images and labels for training, validation, and + # test on MNIST. + data_sets = input_data.read_data_sets(tempfile.mkdtemp(), FLAGS.fake_data) + + # Tell TensorFlow that the model will be built into the default Graph. + with tf.Graph().as_default(): + # Generate placeholders for the images and labels. + images_placeholder, labels_placeholder = placeholder_inputs( + FLAGS.batch_size) + + # Build a Graph that computes predictions from the inference model. + logits = mnist.inference(images_placeholder, + FLAGS.hidden1, + FLAGS.hidden2) + + # Add to the Graph the Ops for loss calculation. + loss = mnist.loss(logits, labels_placeholder) + + # Add to the Graph the Ops that calculate and apply gradients. + train_op = mnist.training(loss, FLAGS.learning_rate) + + # Add the Op to compare the logits to the labels during evaluation. + eval_correct = mnist.evaluation(logits, labels_placeholder) + + # Build the summary operation based on the TF collection of Summaries. + summary_op = tf.merge_all_summaries() + + # Add the variable initializer Op. + init = tf.initialize_all_variables() + + # Create a saver for writing training checkpoints. + saver = tf.train.Saver() + + # Create a session for running Ops on the Graph. + sess = tf.Session() + + # Instantiate a SummaryWriter to output summaries and the Graph. + summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph) + + # And then after everything is built: + + # Run the Op to initialize the variables. + sess.run(init) + + # Start the training loop. + for step in xrange(FLAGS.max_steps): + start_time = time.time() + + # Fill a feed dictionary with the actual set of images and labels + # for this particular training step. + feed_dict = fill_feed_dict(data_sets.train, + images_placeholder, + labels_placeholder) + + # Run one step of the model. The return values are the activations + # from the `train_op` (which is discarded) and the `loss` Op. To + # inspect the values of your Ops or variables, you may include them + # in the list passed to sess.run() and the value tensors will be + # returned in the tuple from the call. + _, loss_value = sess.run([train_op, loss], + feed_dict=feed_dict) + + duration = time.time() - start_time + + # Write the summaries and print an overview fairly often. + if step % 100 == 0: + # Print status to stdout. + print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)) + # Update the events file. + summary_str = sess.run(summary_op, feed_dict=feed_dict) + summary_writer.add_summary(summary_str, step) + summary_writer.flush() + + # Save a checkpoint and evaluate the model periodically. + if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps: + checkpoint_file = os.path.join(FLAGS.train_dir, 'checkpoint') + saver.save(sess, checkpoint_file, global_step=step) + # Evaluate against the training set. + print('Training Data Eval:') + do_eval(sess, + eval_correct, + images_placeholder, + labels_placeholder, + data_sets.train) + # Evaluate against the validation set. + print('Validation Data Eval:') + do_eval(sess, + eval_correct, + images_placeholder, + labels_placeholder, + data_sets.validation) + # Evaluate against the test set. + print('Test Data Eval:') + do_eval(sess, + eval_correct, + images_placeholder, + labels_placeholder, + data_sets.test) + + +def main(_): + run_training() + + +if __name__ == '__main__': + tf.app.run() |
