/*
 * pagerank_contributions.cpp
 *
 *  Created on: Jan 17, 2014
 *      Author: peter
 */

#include "Snap.h"
#include "pqueue.h"
#include "pqueue_double.h"
#include "pagerank_algorithms.h"

//const bool USE_STATIC_SQRT_BALANCING = true;

/* Returns a hash table from node id u to an estimate of pi(u, target) .
 * The estimate is accurate to additive error epsilon. */
THash<TInt, TFlt> computeContributions(PNGraph g, TNGraph::TNodeI target,
		double teleportProb, double epsilon, long long *stepCount) {
	const int INITIAL_SIZE = 20;
	THash<TInt, TFlt> contributions;
	pqueue_t *pq = make_pqueue_double(INITIAL_SIZE);
	THash<TInt, node_t*> idToPQNode;

	node_t* pqNodeTarget = new node_t();
			idToPQNode.AddDat(target.GetId(), pqNodeTarget);
	pqNodeTarget->pri = teleportProb;
	pqNodeTarget->val = target.GetId();
	pqueue_insert(pq, pqNodeTarget);
	contributions.AddDat(target.GetId(), teleportProb);

	while(get_pri(pqueue_peek(pq)) > epsilon * teleportProb) {
		node_t* wPQNode = (node_t *) pqueue_peek(pq);
		int wId = wPQNode->val;
		double wPri = wPQNode->pri;
		if (TRACE) printf("Popping node %d with pri %f\n", wId, wPri);
		TNGraph::TNodeI w = g->GetNI(wId);
		pqueue_change_priority(pq, 0.0, wPQNode);
		for (int i = 0; i < w.GetInDeg(); i++) {
			int uId = w.GetInNId(i);
			TNGraph::TNodeI u = g->GetNI(uId);
			double change = (1.0 - teleportProb) * wPri / u.GetOutDeg();
			if (! contributions.IsKey(uId)) {
				if (TRACE) printf("Adding %d with initial contribution %f\n", uId, change);
				contributions.AddDat(uId, change);
				node_t* pqNodeU = new node_t();
				idToPQNode.AddDat(uId, pqNodeU);
				pqNodeU->pri = change;
				pqNodeU->val = uId;
				pqueue_insert(pq, pqNodeU);
			} else {
				node_t* pqNodeU = idToPQNode.GetDat(uId);
				contributions.GetDat(uId) += change;
				if (TRACE) printf("Increasing priority of %d to %f\n", uId, pqNodeU->pri + change);
				pqueue_change_priority(pq, pqNodeU->pri + change, pqNodeU);
			}
			if (stepCount != NULL)
				(*stepCount)++;
		}
	}

	// free all the node_ts
	for (THashKeyDatI<TInt, node_t *> i = idToPQNode.BegI(); !i.IsEnd(); i++) {
		delete i.GetDat();
	}
	return contributions;
}

/* Returns a hash table from node id u to an estimate of pi(u, target) .
 * The estimate is accurate to additive error epsilon. */
THash<TInt, TFlt> computeContributionsWithoutPQueue(PNGraph g, TNGraph::TNodeI target,
		double teleportProb, double epsilon, long long *stepCount) {
	double residualThreshold = teleportProb * epsilon;
	THash<TInt, TFlt> contributions;
	THash<TInt, TFlt> residuals;
	std::queue<int> highResidualNodes;

	contributions.AddDat(target.GetId(), teleportProb);
	residuals.AddDat(    target.GetId(), teleportProb);
	highResidualNodes.push(target.GetId());

	while(! highResidualNodes.empty()) {
		int wId = highResidualNodes.front();
		highResidualNodes.pop();
		double wPri = residuals.GetDat(wId);
		if (TRACE) printf("Popping node %d with pri %f\n", wId, wPri);
		TNGraph::TNodeI w = g->GetNI(wId);
		residuals.GetDat(wId) = 0.0;
		for (int i = 0; i < w.GetInDeg(); i++) {
			int uId = w.GetInNId(i);
			TNGraph::TNodeI u = g->GetNI(uId);
			double change = (1.0 - teleportProb) * wPri / u.GetOutDeg();
			if (! contributions.IsKey(uId)) {
				contributions.AddDat(uId, 0.0);
				residuals.AddDat(uId, 0.0);
			}
			if (residuals.GetDat(uId) < residualThreshold && residuals.GetDat(uId) + change >= residualThreshold) {
				highResidualNodes.push(uId);
			}

			contributions.GetDat(uId) += change;
			residuals.GetDat(uId) += change;
			if (TRACE)
				printf("Increasing priority of %d to %f\n", uId,
						(double) residuals.GetDat(uId));
			if (stepCount != NULL)
				(*stepCount)++;
		}
	}

	return contributions;
}



/* Returns a hash table from node id v to an estimate of Pr(walk from start
 * terminates at v).
 * A walk terminates after each step with probability alpha, or if the walk
 *  reaches terminationSet. */
THash<TInt, TFlt> monteCarloPPR(PNGraph g, TNGraph::TNodeI start,
		double teleportProb, long long int nWalks, TRnd& rnd, const THashSet<TInt>* terminationSet, long long *stepCount) {
	THash<TInt, TFlt> estimates;


	for (long long int i = 0; i < nWalks; i++) {
		//Current location is v
		 TNGraph::TNodeI v = start;
		 while (rnd.GetUniDev() > teleportProb && (terminationSet == NULL || !terminationSet->IsKey(v.GetId()))) {
			 if (TRACE) printf("Walked to %d\n", v.GetId());
			 if (v.GetOutDeg() > 0) {
				 int newId = v.GetOutNId(rnd.GetUniDevInt(v.GetOutDeg()));
				 v = g->GetNI(newId);
			 } else {} // use imaginary self loop for sink node
			 if (stepCount != NULL)
				 (*stepCount)++;
		 }
		 if (TRACE) printf("Completed walk at %d\n", v.GetId());
		 if (!estimates.IsKey(v.GetId())) {
			estimates.AddDat(v.GetId(), 0.0);
		 }
		 estimates.GetDat(v.GetId()) += 1.0 / nWalks;
	}
	return estimates;
}

double fastPPR(PNGraph g, TNGraph::TNodeI start, TNGraph::TNodeI target,
		double teleportProb, double threshold, double approximationRatio,
		double failureProbability, TRnd& rnd, bool verbose, bool expandInNeighbors,
		const std::vector<double>& globalPR, double edgeCount,
		double* forwardTime, double* reverseTime, long long *stepCount) {
	const bool lessBiasedEstimate = false;

	double startTime = wallClockTime();

	double toTargetConstant = 3.0; // TODO: Is 3.0 a good constant here?
	double c = approximationRatio;
	double chernoffConstant = std::max(3.0 * c / (c - 1.0), 4.0 * c / ((c - 1.0) * (c -1.0))) ;// * log(g->GetNodes());
	double adhocBalanceConstant = 3.0;// Increase this to do less forward work and more backward work

	double averageDegree = edgeCount / g->GetNodes();


	double targetPR = !globalPR.empty() ? globalPR.at(target.GetId()) : 1.0 / g->GetNodes(); // if PageRank is unknown, just use average PageRank
	double thresholdFromStart = globalPR.empty() /*USE_STATIC_SQRT_BALANCING*/ ?  sqrt(threshold) :
			adhocBalanceConstant * sqrt(threshold * chernoffConstant * log(1.0 / failureProbability) / ( c * toTargetConstant * averageDegree * g->GetNodes() * targetPR));
	long long nWalksFromStart = (long long) (chernoffConstant / thresholdFromStart * log(1.0 / failureProbability));

	double thresholdToTarget = threshold / thresholdFromStart;
	double epsilonToTarget = thresholdToTarget *(approximationRatio - 1.0) / toTargetConstant;



	THash<TInt, TFlt> pprToTarget = computeContributionsWithoutPQueue(g, target, teleportProb, epsilonToTarget, stepCount);
	// Special Case: If the start node is inside the set of high-ppr-to-target nodes, we should just return the estimate we have
	if (pprToTarget.IsKey(start.GetId()) &&
			pprToTarget.GetDat(start.GetId()) + epsilonToTarget >= thresholdToTarget) {
		printf("start is in TargetSet: No random walks needed.\n");

		if (reverseTime != NULL) {
			*reverseTime += wallClockTime() - startTime;
		}
		return pprToTarget.GetDat(start.GetId());
	}


	long long targetSetSize = 0;
	THashSet<TInt> frontier;
	for (THashKeyDatI<TInt, TFlt> i = pprToTarget.BegI(); !i.IsEnd(); i++) {
		if (i.GetDat() >= thresholdToTarget) {
			TNGraph::TNodeI t = g->GetNI(i.GetKey());
			if (expandInNeighbors) {
				// add t's in-neighbors to frontier
				for (int j = 0; j < t.GetInDeg(); j++) {
					int fId = t.GetInNId(j);
					if (!pprToTarget.IsKey(fId)
							|| pprToTarget.GetDat(fId) < thresholdToTarget) {
						frontier.AddKey(fId);
					}
					if (stepCount != NULL)
						(*stepCount)++;
				}
			} else {
				// In this case, the frontier is exactly the target set
				frontier.AddKey(t.GetId());
			}
			targetSetSize++;
		}
	}
	if (lessBiasedEstimate) {
		for (THashSet<TInt>::TIter i = frontier.BegI(); i != frontier.EndI(); i++) {

			TNGraph::TNodeI f = g->GetNI(i.GetKey());
			for (int j = 0; j < f.GetOutDeg(); j++) {
				if (pprToTarget.IsKey(f.GetOutNId(j)) &&
						pprToTarget.GetDat(f.GetOutNId(j)) > thresholdToTarget  ) {
					pprToTarget.GetDat(f.GetId()) += 1.0 / f.GetOutDeg() * epsilonToTarget / 2.0;
				}
			}
		}
	}

	if (reverseTime != NULL) {
		*reverseTime += wallClockTime() - startTime;
	}
	startTime = wallClockTime();

	if (verbose)
		printf("Taking %lld walks from start, via chernoffConstant %f\n", nWalksFromStart, chernoffConstant);

	THash<TInt, TFlt> probFromStart = monteCarloPPR(g, start, teleportProb, nWalksFromStart, rnd, &frontier, stepCount);

	if (verbose)
			printf("Node %d has frontier size %d target set size %lld\n", target.GetId(), frontier.Len(), targetSetSize);
	if (verbose)
			printf("Node %d has explored monte-carlo size %d\n", start.GetId(), probFromStart.Len());

	double estimate = 0.0;
	for (THashKeyDatI<TInt, TFlt> i = probFromStart.BegI(); !i.IsEnd(); i++) {
		if (frontier.IsKey(i.GetKey())) {
			estimate += i.GetDat() * (pprToTarget.GetDat(i.GetKey()) );
		}
	}

	if (forwardTime != NULL) {
		*forwardTime += wallClockTime() - startTime;
	}

	return estimate;
}

TNGraph::TNodeI samplePageRank(PNGraph g, double teleportProb, TRnd& rnd) {
	TNGraph::TNodeI v = g->GetRndNI(rnd);
	while (rnd.GetUniDev() > teleportProb) {
		int newId;
		if (v.GetOutDeg() > 0) {
			newId = v.GetOutNId(rnd.GetUniDevInt(v.GetOutDeg()));
		} else {
			newId = v.GetId(); // use imaginary self loop for sink nodes
		}
		v = g->GetNI(newId);
		if (TRACE)
			printf("Walked to %d\n", newId);
	}
	return v;
}

void printDoubles(const TVec<double>& xs) {
	printf("[");
	for (int i = 0; i < xs.Len(); i++) {
		printf("%g%s", xs[i], i + 1 < xs.Len() ?  ", " : "");
	}
	printf("]\n");
}

void printMatrixTSV(const std::vector<std::vector<double>  >& m) {
	for (unsigned r = 0; r < m.size(); r++) {
		for (unsigned c = 0; c < m[r].size(); c++) {
			printf("%g\t", m[r][c]);
		}
		printf("\n");
	}
}

void printDoublesForPython(const std::vector<double> d) {
	printf("[");
    for (unsigned i = 0; i < d.size(); i++) {
				printf("%g,", d[i]);
	}
	printf("]");
}


void printMatrixForPython(const std::vector<std::vector<double>  >& m) {
	printf("[\n");
	for (unsigned r = 0; r < m.size(); r++) {
		printDoublesForPython(m[r]);
		printf(",\n");
	}
	printf("\n]\n");
}


std::vector<std::vector<double> > transposeMatrix(const std::vector<std::vector<double> >& m) {
	std::vector<std::vector<double> > result(m[0].size(), std::vector<double>(m.size(), 0.0));
	for (unsigned r = 0; r < m.size(); r++) {
		for (unsigned c = 0; c < m.at(r).size(); c++) {
			result.at(c).at(r) = m.at(r).at(c);
		}
	}
	return result;
}

double mean(const TVec<double>& xs) {
	double sum = 0.0;
	for (int i = 0; i < xs.Len(); i++) {
		sum += xs[i];
	}
	return sum / xs.Len();
}

double percentile(TVec<double>& xs, double percentile) {
	xs.Sort(true);
	int i = (int) (xs.Len() * percentile / 100.0000001);
	return xs[i];
}

std::vector<double> globalPageRank(PNGraph g, double teleportProb,
		double threshold) {
	int maxId = 0;
	for (TNGraph::TNodeI u = g->BegNI(); u != g->EndNI(); u++) {
		maxId = std::max(maxId, u.GetId());
	}
	std::vector<double> pi(maxId + 1, 0.0);

	for (TNGraph::TNodeI u = g->BegNI(); u != g->EndNI(); u++) {
			pi[u.GetId()] = 1.0 / g->GetNodes();
	}
	double changeAmount = 1.0e100;
	while (changeAmount > threshold) {
		changeAmount = 0.0;
		for (TNGraph::TNodeI v = g->BegNI(); v != g->EndNI(); v++) {
			double newPiV = teleportProb / g->GetNodes(); // We might get to v by teleporting
			// iterate over in-neighbors
			for (int j = 0; j < v.GetInDeg(); j++) {
				TNGraph::TNodeI u = g->GetNI(v.GetInNId(j));
				// Or we might get to v by being at in-neighbor u on the previous step
				newPiV += (1.0 - teleportProb) * pi[u.GetId()] / u.GetOutDeg();
			}
			changeAmount += fabs(newPiV - pi[v.GetId()]);
			pi[v.GetId()] = newPiV;
		}
		printf("Completed iteration. L^1 change: %g\n", changeAmount);
	}
	//sanity check
	double sum = 0.0;
	for (unsigned int i = 0; i < pi.size(); i++) {
		sum += pi[i];
	}
	printf("sum of pi (note that mass from sink nodes is destroyed): %g\n", sum);
	return pi;
}

/* For some reason, the snap library counts edges using an int, leading to overflow.  We can just count edges manually. */
double countEdges(PNGraph g) {
	double count = 0;
	for (TNGraph::TNodeI u = g->BegNI(); u != g->EndNI(); u++) {
		count += u.GetOutDeg();
	}
	return count;
}

double wallClockTime() {
	// from http://stackoverflow.com/questions/588307/c-obtaining-milliseconds-time-on-linux-clock-doesnt-seem-to-work-properl/
	struct timeval t;
	gettimeofday(&t, NULL);
	return t.tv_sec + t.tv_usec / 1.0e6;
}
