diff options
| author | Alexander M Pickering <amp215@pitt.edu> | 2025-02-01 02:24:13 -0600 |
|---|---|---|
| committer | Alexander M Pickering <amp215@pitt.edu> | 2025-02-01 02:24:13 -0600 |
| commit | 61bdb4fef88c1e83787dbb023b51d8d200844e3a (patch) | |
| tree | 6d905b6f61a0e932b1ace9771c714a80e0388af0 | |
| download | mscbio2046-61bdb4fef88c1e83787dbb023b51d8d200844e3a.tar.gz mscbio2046-61bdb4fef88c1e83787dbb023b51d8d200844e3a.tar.bz2 mscbio2046-61bdb4fef88c1e83787dbb023b51d8d200844e3a.zip | |
| -rw-r--r-- | a1.scala | 99 | ||||
| -rw-r--r-- | a2.scala | 55 | ||||
| -rw-r--r-- | mod.py | 62 | ||||
| -rw-r--r-- | task.py | 236 | ||||
| -rw-r--r-- | task_o.py | 258 |
5 files changed, 710 insertions, 0 deletions
diff --git a/a1.scala b/a1.scala new file mode 100644 index 0000000..3817c64 --- /dev/null +++ b/a1.scala @@ -0,0 +1,99 @@ +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ +import org.apache.spark.SparkConf +import org.apache.spark.rdd._ +import org.bdgenomics.adam.rdd.ADAMContext._ +import collection.JavaConverters._ +import scala.io.Source +import scala.util.Random + +object Assign1 { + def main(args: Array[String]) + { + val panelfile = args(0) + val adamfile = args(1) + val conf = new SparkConf().setAppName("Assign1") + val sc = new SparkContext(conf) + + //... means "put something here" + val biggroups = Source.fromFile(panelfile).getLines().drop(1).map((str)=>str.drop(8).take(3)).toList.groupBy(identity).mapValues(_.size).filter(_._2>90) + val individualpeople = Source.fromFile(panelfile).getLines().drop(1).map((str)=>(str.take(7),str.drop(8).take(3))).toMap.filter(biggroups isDefinedAt _._2).keySet + + println("Populations with more than 90 individuals: " +biggroups.size ) + println("Individuals from these populations: " + individualpeople.size) + + val data = sc.loadGenotypes(adamfile) + + def convertAlleles + (x: java.util.List[org.bdgenomics.formats.avro.GenotypeAllele])={ + x.asScala.map(_.toString) + } + + def distance(a: Iterable[Double], b: Iterable[Double])={ + Math.sqrt(a.zip(b).map(r=>(r._1-r._2)*(r._1-r._2)).fold(0.0)(_+_)) + } + + def sumof(a: Iterable[Double], b: Iterable[Double]) = { + a.zip(b).map(r=>(r._1+r._2)).fold(0.0)(_+_) + } + + def addlists(a: Iterable[Double], b: Iterable[Double]) = { + a.zip(b).map(r=>r._1+r._2)} + + def centeroid(a: Iterable[Iterable[Double]]): Iterable[Double] = { + val numelements = a.size + a.tail.fold(a.head)((c,d)=>addlists(c,d)).map(r=>r/numelements) + } + + val cdata = data.rdd.map(r=>(r.contigName,r.start,r.end,r.sampleId,convertAlleles(r.alleles))) + + val varients = data.rdd.map(r=>(r.contigName,r.start,r.end)) + + val ids = data.rdd.map(r=>(r.sampleId,convertAlleles(r.alleles))) + + val copies = varients.zip(ids).groupBy(_._1).map(r=>(r._1->r._2.size)) + + val tpeople = data.rdd.map(r=>(r.sampleId)).distinct() + + val npeople = tpeople.count + + val gvarients = copies.filter(_._2 == npeople) + + val indbypeople = cdata.map(r=>((r._4)->(r._5->(r._1,r._2,r._3)))) + + val dcc = gvarients.count() + + println("Total variants: " + varients.distinct().count()) + println("Variants with right number of samples: " + dcc) + + val idsg = ids.groupBy(_._1) + + val people = idsg.map(r=>(r._1->r._2.map(c=>c._2.count(_=="Alt").toDouble))) + + var partitions = people.takeSample(false,21,0).map(r=>r._2) + + var i = 0 + var distancetoclusters = people.map(r=>r->partitions.map(t=>t->distance(r._2,t))) + var closestclusterto = distancetoclusters.map(r=>r._1->r._2.minBy(_._2)) + var npartitions = closestclusterto.map(r=>r._1->r._2._1).groupBy(_._2).map(r=>r._1->r._2.map(c=>c._2)).map(r=>centeroid(r._2)) + + while(i < 10){ + var ndistancetoclusters = people.map(r=>r->npartitions.map(t=>t->distance(r._2,t))) + closestclusterto = ndistancetoclusters.map(r=>r._1->r._2.reduce((a,b)=>if(a._2 < b._2) a else b)) + npartitions = closestclusterto.map(r=>r._1->r._2._1).groupBy(_._2).map(r=>r._1->r._2.map(c=>c._2)).map(r=>centeroid(r._2)) + i = i + 1 + } + + //One last clustering to put things in their final place + val finaldistancetoclusters = people.map(r=>r->npartitions.map(t=>t->distance(r._2,t))) + val finalclosestclusterto = distancetoclusters.map(r=>r._1->r._2.reduce((l,r)=>if(l._2 < r._2) l else r)) + val finalclusters = finalclosestclusterto.map(r=>r._2._1->r._1._1).groupBy(_._1).map(r=>r._1->r._2.map(t=>t._2)) + + println("Clusters:") + finalclusters.foreach(r=>println(r._2.fold("")(_+" "+_))) + + println("Number of final clusters:"+finalclusters.count()) + + System.exit(0) //you shouldn't need this, but sometimes spark hangs without it + } +} diff --git a/a2.scala b/a2.scala new file mode 100644 index 0000000..8d724a6 --- /dev/null +++ b/a2.scala @@ -0,0 +1,55 @@ +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ +import org.apache.spark.SparkConf +import scala.io.Source +import org.apache.spark.rdd._ +import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.linalg.distributed._ +import java.io._ + +object Assign2 { + + def main(args: Array[String]) + { + val conf = new SparkConf().setAppName("proj2") + val sc = new SparkContext(conf) + val datafile = args(0) + val missingfile = args(1) + val outfile = args(2) + + val ofile = new File(outfile) + val output = new BufferedWriter(new FileWriter(ofile)) + + val file = sc.textFile(datafile).cache() + val data = file.map(x=>(x.split(","))).map(x=>MatrixEntry(x(0).toLong,x(1).toLong,x(2).toDouble)) + + val missingfiletext = sc.textFile(missingfile).cache() + val missingdata = missingfiletext.map(x=>x.split(",")).map(x=>MatrixEntry(x(0).toLong,x(1).toLong,0)) + + val cm = new CoordinateMatrix(data) + val rowmatrix = cm.toRowMatrix + val numrows = rowmatrix.numRows + val numcols = rowmatrix.numCols + val indexedMatrix = rowmatrix.rows.zipWithIndex.map(_.swap) + + val svd = rowmatrix.computeSVD(10,true) + val features = svd.s.size + + val s = org.apache.spark.mllib.linalg.Matrices.diag(svd.s) + val A = svd.U.multiply(s).multiply(svd.V.transpose) + val idA = A.rows.zipWithIndex.map(_.swap) + val idA2 = sc.broadcast(idA.collect()) + + val odata = missingdata.map(x=>(x.i,x.j,idA2.value.apply(x.i.toInt)._2.apply(x.j.toInt))) + //val output = new BufferedWriter(new FileWriter(new File(outfile))) + odata.collect().foreach(x=>output.write(x._1+","+x._2+","+x._3+"\n")) + output.flush() + //distributed matrix factorization + //The cluster we run on uses 26 quad-core machines, so split the svd up into 26 peices. + + //output.write(x._1+","+x._2+","+x._3+"\n") //need to write out values to missing coordinates + + output.close() + System.exit(0) + } +} @@ -0,0 +1,62 @@ +import csv +import sys +from datetime import date, timedelta +from google.cloud import bigquery as bq + +class Fetcher: + '''Provides batches of images''' + #TODO TODO - you probably want to modify this to implement data augmentation + def __init__(self,stockfile): + self.startyear = 1974 + self.nextyear = 1975 + self.current = date(self.startyear,12,10) + self.curend = date(self.nextyear,12,10) + self.cache = {} + self.stocks = None + self.qclient = bq.Client() + #Load stock data, it's small enough to keep it all in memory + with open(stockfile) as csvfile: + dialect = csv.Sniffer().sniff(csvfile.read(1024)) + csvfile.seek(0) + reader = csv.reader(csvfile, dialect) + first = True + for row in reader: + if first: + first = False + continue + tdate = row[0] + tdate = int(date.replace("-","")) + diff = float(row[4]) - float(row[1]) + self.stocks[tdate] = diff + print("Loaded " + stockfile + ".") + + + def load_next(self): + #Load current event data 1 year at a time + print("I want to get stocks[" + str(self.current) + "]") + start_date = date(1974, 12, 10) + for n in range(364): + delt = start_date + timedelta(n) + rep = str(delt).replace("-","") + + #Implement a cache for mysql + events = [] + stockchange = 0 + sys.exit(0); + x_batch = [] + y_batch = [] + for i in xrange(batchsize): + label, files = self.examples[(self.current+i) % len(self.examples)] + label = label.flatten() + # If you are getting an error reading the image, you probably have + # the legacy PIL library installed instead of Pillow + # You need Pillow + channels = [ misc.imread(file_io.FileIO(f,'r')) for f in files] + x_batch.append(np.dstack(channels)) + y_batch.append(label) + + self.current = (self.current + batchsize) % len(self.examples) + return np.array(x_batch), np.array(y_batch) + +f = Fetcher("DOW.csv") +f.load_next() @@ -0,0 +1,236 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Trains and Evaluates the MNIST network using a feed dictionary.""" +# pylint: disable=missing-docstring +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path +import tempfile +import time + +from six.moves import xrange # pylint: disable=redefined-builtin +import tensorflow as tf + +from tensorflow.examples.tutorials.mnist import input_data +from tensorflow.examples.tutorials.mnist import mnist + + +# Basic model parameters as external flags. +flags = tf.app.flags +FLAGS = flags.FLAGS +flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.') +flags.DEFINE_integer('max_steps', 2000, 'Number of steps to run trainer.') + +flags.DEFINE_integer('hidden1', 1024, 'Number of units in hidden layer 1.') +flags.DEFINE_integer('hidden2', 1024, 'Number of units in hidden layer 2.') +flags.DEFINE_integer('hidden3', 1024, 'Number of units in hidden layer 3.') +flags.DEFINE_integer('hidden4', 1024, 'Number of units in hidden layer 4.') +flags.DEFINE_integer('hidden5', 1024, 'Number of units in hidden layer 5.') +flags.DEFINE_integer('hidden6', 1024, 'Number of units in hidden layer 6.') +flags.DEFINE_integer('hidden7', 1024, 'Number of units in hidden layer 7.') +flags.DEFINE_integer('hidden8', 1024, 'Number of units in hidden layer 8.') +flags.DEFINE_integer('hidden9', 1024, 'Number of units in hidden layer 9.') +flags.DEFINE_integer('hidden10', 1024, 'Number of units in hidden layer 10.') + +flags.DEFINE_integer('batch_size', 100, 'Batch size. ' + 'Must divide evenly into the dataset sizes.') +flags.DEFINE_string('train_dir', 'data', 'Directory to put the training data.') +flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data ' + 'for unit testing.') + + +def placeholder_inputs(batch_size): + """Generate placeholder variables to represent the input tensors. + These placeholders are used as inputs by the rest of the model building + code and will be fed from the downloaded data in the .run() loop, below. + Args: + batch_size: The batch size will be baked into both placeholders. + Returns: + images_placeholder: Images placeholder. + labels_placeholder: Labels placeholder. + """ + # Note that the shapes of the placeholders match the shapes of the full + # image and label tensors, except the first dimension is now batch_size + # rather than the full size of the train or test data sets. + events_placeholder = tf.placeholder(tf.sring, shape=(batch_size, + 21)) + stock_placeholder = tf.placeholder(tf.int32, shape=(batch_size)) + return events_placeholder, stock_placeholder + + +def fill_feed_dict(data_set, images_pl, labels_pl): + """Fills the feed_dict for training the given step. + A feed_dict takes the form of: + feed_dict = { + <placeholder>: <tensor of values to be passed for placeholder>, + .... + } + Args: + data_set: The set of images and labels, from input_data.read_data_sets() + images_pl: The images placeholder, from placeholder_inputs(). + labels_pl: The labels placeholder, from placeholder_inputs(). + Returns: + feed_dict: The feed dictionary mapping from placeholders to values. + """ + # Create the feed_dict for the placeholders filled with the next + # `batch size` examples. + events_feed, stock_feed = data_set.next_batch(FLAGS.batch_size, + FLAGS.fake_data) + feed_dict = { + images_pl: events_feed, + labels_pl: stock_feed, + } + return feed_dict + + +def do_eval(sess, + eval_correct, + events_placeholder, + stock_placeholder, + data_set): + """Runs one evaluation against the full epoch of data. + Args: + sess: The session in which the model has been trained. + eval_correct: The Tensor that returns the number of correct predictions. + images_placeholder: The images placeholder. + labels_placeholder: The labels placeholder. + data_set: The set of images and labels to evaluate, from + input_data.read_data_sets(). + """ + # And run one epoch of eval. + true_count = 0 # Counts the number of correct predictions. + steps_per_epoch = data_set.num_examples // FLAGS.batch_size + num_examples = steps_per_epoch * FLAGS.batch_size + for step in xrange(steps_per_epoch): + feed_dict = fill_feed_dict(data_set, + events_placeholder, + stock_placeholder) + true_count += sess.run(eval_correct, feed_dict=feed_dict) + precision = true_count / num_examples + print(' Num examples: %d Num correct: %d Precision @ 1: %0.04f' % + (num_examples, true_count, precision)) + + +def run_training(): + """Train MNIST for a number of steps.""" + # Get the sets of images and labels for training, validation, and + # test on MNIST. + data_sets = input_data.read_data_sets(tempfile.mkdtemp(), FLAGS.fake_data) + + # Tell TensorFlow that the model will be built into the default Graph. + with tf.Graph().as_default(): + # Generate placeholders for the images and labels. + images_placeholder, labels_placeholder = placeholder_inputs( + FLAGS.batch_size) + + # Build a Graph that computes predictions from the inference model. + logits = mnist.inference(images_placeholder, + FLAGS.hidden1, + FLAGS.hidden2) + + # Add to the Graph the Ops for loss calculation. + loss = mnist.loss(logits, labels_placeholder) + + # Add to the Graph the Ops that calculate and apply gradients. + train_op = mnist.training(loss, FLAGS.learning_rate) + + # Add the Op to compare the logits to the labels during evaluation. + eval_correct = mnist.evaluation(logits, labels_placeholder) + + # Build the summary operation based on the TF collection of Summaries. + summary_op = tf.merge_all_summaries() + + # Add the variable initializer Op. + init = tf.initialize_all_variables() + + # Create a saver for writing training checkpoints. + saver = tf.train.Saver() + + # Create a session for running Ops on the Graph. + sess = tf.Session() + + # Instantiate a SummaryWriter to output summaries and the Graph. + summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph) + + # And then after everything is built: + + # Run the Op to initialize the variables. + sess.run(init) + + # Start the training loop. + for step in xrange(FLAGS.max_steps): + start_time = time.time() + + # Fill a feed dictionary with the actual set of images and labels + # for this particular training step. + feed_dict = fill_feed_dict(data_sets.train, + images_placeholder, + labels_placeholder) + + # Run one step of the model. The return values are the activations + # from the `train_op` (which is discarded) and the `loss` Op. To + # inspect the values of your Ops or variables, you may include them + # in the list passed to sess.run() and the value tensors will be + # returned in the tuple from the call. + _, loss_value = sess.run([train_op, loss], + feed_dict=feed_dict) + + duration = time.time() - start_time + + # Write the summaries and print an overview fairly often. + if step % 100 == 0: + # Print status to stdout. + print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)) + # Update the events file. + summary_str = sess.run(summary_op, feed_dict=feed_dict) + summary_writer.add_summary(summary_str, step) + summary_writer.flush() + + # Save a checkpoint and evaluate the model periodically. + if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps: + checkpoint_file = os.path.join(FLAGS.train_dir, 'checkpoint') + saver.save(sess, checkpoint_file, global_step=step) + # Evaluate against the training set. + print('Training Data Eval:') + do_eval(sess, + eval_correct, + images_placeholder, + labels_placeholder, + data_sets.train) + # Evaluate against the validation set. + print('Validation Data Eval:') + do_eval(sess, + eval_correct, + images_placeholder, + labels_placeholder, + data_sets.validation) + # Evaluate against the test set. + print('Test Data Eval:') + do_eval(sess, + eval_correct, + images_placeholder, + labels_placeholder, + data_sets.test) + + +def main(_): + run_training() + + +if __name__ == '__main__': + tf.app.run() diff --git a/task_o.py b/task_o.py new file mode 100644 index 0000000..108f3d8 --- /dev/null +++ b/task_o.py @@ -0,0 +1,258 @@ +#!/usr/bin/env python +# Copyright 2016 Google Inc. All Rights Reserved. +# Modifcations by dkoes. +# More modifications by Alex P. + +"""This is based on: + +https://github.com/GoogleCloudPlatform/cloudml-samples/blob/master/mnist/deployable/trainer/task.py +It includes support for training and prediction on the Google Cloud ML service. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import os.path +import subprocess +import tempfile +import time +import sys +import csv +from google.cloud import bigquery as bq +from sklearn import preprocessing +from scipy import misc + +import numpy as np + +from six.moves import xrange # pylint: disable=redefined-builtin +import tensorflow as tf +import tensorflow.contrib.slim as slim +from tensorflow.python.lib.io import file_io + +query_client = bq.Client() + +# Basic model parameters as external flags. + +flags = tf.app.flags +FLAGS = flags.FLAGS +flags.DEFINE_integer('max_steps', 1000, 'Number of steps to run trainer.') +flags.DEFINE_integer('batch_size', 20, 'Batch size.') +flags.DEFINE_string('train_data_db', '[mscbiofin:eventdata.datamore4]', 'Directory containing training data') +flags.DEFINE_string('start_date',19741210,'The Start time for training') +flags.DEFINE_string('end_date',20161210,'The end date for training') + +flags.DEFINE_integer('hidden1', 1024, 'Number of units in hidden layer 1.') +flags.DEFINE_integer('hidden2', 1024, 'Number of units in hidden layer 2.') +flags.DEFINE_integer('hidden3', 1024, 'Number of units in hidden layer 3.') +flags.DEFINE_integer('hidden4', 1024, 'Number of units in hidden layer 4.') +flags.DEFINE_integer('hidden5', 1024, 'Number of units in hidden layer 5.') +flags.DEFINE_integer('hidden6', 1024, 'Number of units in hidden layer 6.') +flags.DEFINE_integer('hidden7', 1024, 'Number of units in hidden layer 7.') +flags.DEFINE_integer('hidden8', 1024, 'Number of units in hidden layer 8.') +flags.DEFINE_integer('hidden9', 1024, 'Number of units in hidden layer 9.') +flags.DEFINE_integer('hidden10', 1024, 'Number of units in hidden layer 10.') + +flags.DEFINE_string('train_output_dir', 'data', 'Directory to put the training data.') +flags.DEFINE_string('model_dir', 'model', 'Directory to put the model into.') + +# Feel free to add additional flags to assist in setting hyper parameters + +# Get labels by running sql queries. + +# Open the financial data and hold it in memory. + +def read_training_list(): + """ + Read <train_data_dir>/TRAIN which containing paths and labels in + the format label, channel1 file, channel2 file, channel3 + Returns: + List with all filenames in file image_list_file + """ + image_list_file = FLAGS.train_data_dir + '/TRAIN' + f = file_io.FileIO(image_list_file, 'r') #this can read files from the cloud + filenames = [] + labels = [] + n_classes = len(labelmap) + for line in f: + label, c1, c2, c3 = line.rstrip().split(' ') + #convert labels into onehot encoding + onehot = np.zeros(n_classes) + onehot[labelmap[label]] = 1.0 + labels.append(onehot) + #create absolute paths for image files + filenames.append([ FLAGS.train_data_dir + '/' + c for c in (c1,c2,c3)]) + + return zip( labels,filenames),n_classes + + +class Fetcher: + '''Provides batches of images''' + #TODO TODO - you probably want to modify this to implement data augmentation + def __init__(self,stockfile): + self.current = 0 + self.cache = {} + self.stocks = {} + for row in csv.reader(stockfile,delimeter=','): + date = row[0] + date = int(date.replace("-","")) + diff = float(row[4]) - float(row[1]) + self.stocks[date] = diff + + def load_next(self): + print("I want to get stocks[" + current + "]") + #Implement a cache for mysql + events = [] + stockchange = 0 + sys.exit(0); + x_batch = [] + y_batch = [] + for i in xrange(batchsize): + label, files = self.examples[(self.current+i) % len(self.examples)] + label = label.flatten() + # If you are getting an error reading the image, you probably have + # the legacy PIL library installed instead of Pillow + # You need Pillow + channels = [ misc.imread(file_io.FileIO(f,'r')) for f in files] + x_batch.append(np.dstack(channels)) + y_batch.append(label) + + self.current = (self.current + batchsize) % len(self.examples) + return np.array(x_batch), np.array(y_batch) + + +def network(inputs): + '''Define the network''' + with slim.arg_scope([slim.conv2d, slim.fully_connected], + activation_fn=tf.nn.relu, + weights_initializer=tf.truncated_normal_initializer(0.0, 0.01), + weights_regularizer=slim.l2_regularizer(0.0005)): + net = tf.reshape(inputs,[-1, 512,512,3]) + net = slim.conv2d(net, 32, [3,3], scope='conv1') + net = slim.max_pool2d(net, [4,4], scope = 'conv1') + net = slim.conv2d(net,64,[3,3], scope = 'conv2') + net = slim.max_pool2d(net,[4,4], scope = 'pool2') + net = slim.flatten(net) + net = slim.fully_connected(net,64, scope = 'fc') + net = slim.fully_connected(net, 13, activation_fn = None, scope = 'output') + return net + +def run_training(): + + #Read the training data + examples, n_classes = read_training_list() #TODO:Replace this + np.random.seed(42) #shuffle the same way each time for consistency + np.random.shuffle(examples) + + fetcher = Fetcher() + + # Tell TensorFlow that the model will be built into the default Graph. + with tf.Graph().as_default(): + # Generate placeholders for the images and labels and mark as input. + + x = tf.placeholder(tf.float32, shape=(None, 512,512,3)) + y_ = tf.placeholder(tf.float32, shape=(None, n_classes)) + + # See "Using instance keys": https://cloud.google.com/ml/docs/how-tos/preparing-models + # for why we have keys_placeholder + keys_placeholder = tf.placeholder(tf.int64, shape=(None,)) + + # IMPORTANT: Do not change the input map + inputs = {'key': keys_placeholder.name, 'image': x.name} + tf.add_to_collection('inputs', json.dumps(inputs)) + + # Build a the network + net = network(x) + + # Add to the Graph the Ops for loss calculation. + loss = slim.losses.softmax_cross_entropy(net, y_) + tf.scalar_summary(loss.op.name, loss) # keep track of value for TensorBoard + + # To be able to extract the id, we need to add the identity function. + keys = tf.identity(keys_placeholder) + + # The prediction will be the index in logits with the highest score. + # We also use a softmax operation to produce a probability distribution + # over all possible digits. + # DO NOT REMOVE OR CHANGE VARIABLE NAMES - used when predicting with a model + prediction = tf.argmax(net, 1) + scores = tf.nn.softmax(net) + + # Mark the outputs. + outputs = {'key': keys.name, + 'prediction': prediction.name, + 'scores': scores.name} + tf.add_to_collection('outputs', json.dumps(outputs)) + + # Add to the Graph the Ops that calculate and apply gradients. + train_op = tf.train.AdamOptimizer(1e-4).minimize(loss) + + + # Build the summary operation based on the TF collection of Summaries. + summary_op = tf.merge_all_summaries() + + # Add the variable initializer Op. + init = tf.initialize_all_variables() + + # Create a saver for writing training checkpoints. + saver = tf.train.Saver() + + # Create a session for running Ops on the Graph. + sess = tf.Session() + + # Instantiate a SummaryWriter to output summaries and the Graph. + summary_writer = tf.train.SummaryWriter(FLAGS.train_output_dir, sess.graph) + + # And then after everything is built: + + # Run the Op to initialize the variables. + sess.run(init) + + # Start the training loop. + for step in xrange(FLAGS.max_steps): + start_time = time.time() + + # Fill a feed dictionary with the actual set of images and labels + # for this particular training step. + images, labels = fetcher.load_batch(FLAGS.batch_size) + feed_dict = {x: images, y_: labels} + + # Run one step of the model. The return values are the activations + # from the `train_op` (which is discarded) and the `loss` Op. To + # inspect the values of your Ops or variables, you may include them + # in the list passed to sess.run() and the value tensors will be + # returned in the tuple from the call. + _, loss_value = sess.run([train_op, loss], + feed_dict=feed_dict) + + duration = time.time() - start_time + + # Write the summaries and print an overview fairly often. + if step % 1 == 0: + # Print status to stdout. + print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)) + sys.stdout.flush() + # Update the events file. + summary_str = sess.run(summary_op, feed_dict=feed_dict) + summary_writer.add_summary(summary_str, step) + summary_writer.flush() + + + # Export the model so that it can be loaded and used later for predictions. + file_io.create_dir(FLAGS.model_dir) + saver.save(sess, os.path.join(FLAGS.model_dir, 'export')) + + #make world readable for submission to evaluation server + if FLAGS.model_dir.startswith('gs://'): + subprocess.call(['gsutil', 'acl','ch','-u','AllUsers:R', FLAGS.model_dir]) + + #You probably want to implement some sort of model evaluation here + #TODO TODO TODO + +def main(_): + run_training() + + +if __name__ == '__main__': + tf.app.run() |
