diff options
| author | Alexander M Pickering <amp215@pitt.edu> | 2025-02-01 02:24:13 -0600 |
|---|---|---|
| committer | Alexander M Pickering <amp215@pitt.edu> | 2025-02-01 02:24:13 -0600 |
| commit | 61bdb4fef88c1e83787dbb023b51d8d200844e3a (patch) | |
| tree | 6d905b6f61a0e932b1ace9771c714a80e0388af0 /a2.scala | |
| download | mscbio2046-master.tar.gz mscbio2046-master.tar.bz2 mscbio2046-master.zip | |
Diffstat (limited to 'a2.scala')
| -rw-r--r-- | a2.scala | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/a2.scala b/a2.scala new file mode 100644 index 0000000..8d724a6 --- /dev/null +++ b/a2.scala @@ -0,0 +1,55 @@ +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ +import org.apache.spark.SparkConf +import scala.io.Source +import org.apache.spark.rdd._ +import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.linalg.distributed._ +import java.io._ + +object Assign2 { + + def main(args: Array[String]) + { + val conf = new SparkConf().setAppName("proj2") + val sc = new SparkContext(conf) + val datafile = args(0) + val missingfile = args(1) + val outfile = args(2) + + val ofile = new File(outfile) + val output = new BufferedWriter(new FileWriter(ofile)) + + val file = sc.textFile(datafile).cache() + val data = file.map(x=>(x.split(","))).map(x=>MatrixEntry(x(0).toLong,x(1).toLong,x(2).toDouble)) + + val missingfiletext = sc.textFile(missingfile).cache() + val missingdata = missingfiletext.map(x=>x.split(",")).map(x=>MatrixEntry(x(0).toLong,x(1).toLong,0)) + + val cm = new CoordinateMatrix(data) + val rowmatrix = cm.toRowMatrix + val numrows = rowmatrix.numRows + val numcols = rowmatrix.numCols + val indexedMatrix = rowmatrix.rows.zipWithIndex.map(_.swap) + + val svd = rowmatrix.computeSVD(10,true) + val features = svd.s.size + + val s = org.apache.spark.mllib.linalg.Matrices.diag(svd.s) + val A = svd.U.multiply(s).multiply(svd.V.transpose) + val idA = A.rows.zipWithIndex.map(_.swap) + val idA2 = sc.broadcast(idA.collect()) + + val odata = missingdata.map(x=>(x.i,x.j,idA2.value.apply(x.i.toInt)._2.apply(x.j.toInt))) + //val output = new BufferedWriter(new FileWriter(new File(outfile))) + odata.collect().foreach(x=>output.write(x._1+","+x._2+","+x._3+"\n")) + output.flush() + //distributed matrix factorization + //The cluster we run on uses 26 quad-core machines, so split the svd up into 26 peices. + + //output.write(x._1+","+x._2+","+x._3+"\n") //need to write out values to missing coordinates + + output.close() + System.exit(0) + } +} |
