summaryrefslogtreecommitdiff
path: root/a2.scala
diff options
context:
space:
mode:
authorAlexander M Pickering <amp215@pitt.edu>2025-02-01 02:24:13 -0600
committerAlexander M Pickering <amp215@pitt.edu>2025-02-01 02:24:13 -0600
commit61bdb4fef88c1e83787dbb023b51d8d200844e3a (patch)
tree6d905b6f61a0e932b1ace9771c714a80e0388af0 /a2.scala
downloadmscbio2046-master.tar.gz
mscbio2046-master.tar.bz2
mscbio2046-master.zip
Inital commitHEADmaster
Diffstat (limited to 'a2.scala')
-rw-r--r--a2.scala55
1 files changed, 55 insertions, 0 deletions
diff --git a/a2.scala b/a2.scala
new file mode 100644
index 0000000..8d724a6
--- /dev/null
+++ b/a2.scala
@@ -0,0 +1,55 @@
+import org.apache.spark.SparkContext
+import org.apache.spark.SparkContext._
+import org.apache.spark.SparkConf
+import scala.io.Source
+import org.apache.spark.rdd._
+import org.apache.spark.mllib.linalg._
+import org.apache.spark.mllib.linalg.distributed._
+import java.io._
+
+object Assign2 {
+
+ def main(args: Array[String])
+ {
+ val conf = new SparkConf().setAppName("proj2")
+ val sc = new SparkContext(conf)
+ val datafile = args(0)
+ val missingfile = args(1)
+ val outfile = args(2)
+
+ val ofile = new File(outfile)
+ val output = new BufferedWriter(new FileWriter(ofile))
+
+ val file = sc.textFile(datafile).cache()
+ val data = file.map(x=>(x.split(","))).map(x=>MatrixEntry(x(0).toLong,x(1).toLong,x(2).toDouble))
+
+ val missingfiletext = sc.textFile(missingfile).cache()
+ val missingdata = missingfiletext.map(x=>x.split(",")).map(x=>MatrixEntry(x(0).toLong,x(1).toLong,0))
+
+ val cm = new CoordinateMatrix(data)
+ val rowmatrix = cm.toRowMatrix
+ val numrows = rowmatrix.numRows
+ val numcols = rowmatrix.numCols
+ val indexedMatrix = rowmatrix.rows.zipWithIndex.map(_.swap)
+
+ val svd = rowmatrix.computeSVD(10,true)
+ val features = svd.s.size
+
+ val s = org.apache.spark.mllib.linalg.Matrices.diag(svd.s)
+ val A = svd.U.multiply(s).multiply(svd.V.transpose)
+ val idA = A.rows.zipWithIndex.map(_.swap)
+ val idA2 = sc.broadcast(idA.collect())
+
+ val odata = missingdata.map(x=>(x.i,x.j,idA2.value.apply(x.i.toInt)._2.apply(x.j.toInt)))
+ //val output = new BufferedWriter(new FileWriter(new File(outfile)))
+ odata.collect().foreach(x=>output.write(x._1+","+x._2+","+x._3+"\n"))
+ output.flush()
+ //distributed matrix factorization
+ //The cluster we run on uses 26 quad-core machines, so split the svd up into 26 peices.
+
+ //output.write(x._1+","+x._2+","+x._3+"\n") //need to write out values to missing coordinates
+
+ output.close()
+ System.exit(0)
+ }
+}