#include "kdtree.h"
#include "utils.h"
#include <stack>
#include <algorithm>

KDTree::KDTree()
{
	triangleBuckets = NULL;
	triangleBucketsCount = 0;
	triangleBucketsSize = 0;
	totalNodes = 0;
	nodes = NULL;
}

KDTree::~KDTree()
{
	free(triangleBuckets);
	free(nodes);
}

int KDTree::feedTriangles(const std::vector<m::Vector3>& vertices, int* triangles, int n)
{
	int tbegin = (int)this->triangles.size();

	/* Copy vertices. */

	int vbegin = (int)this->vertices.size();
	this->vertices.reserve(this->vertices.size() + vertices.size());

	for (unsigned int i = 0; i < vertices.size(); i++)
		this->vertices.push_back(vertices[i]);

	/* Copy triangles. */

	this->triangles.reserve(this->triangles.size() + n);

	for (int i = 0; i < n; i++)
	{
		Triangle tri;
		tri.a = vbegin + triangles[i*3+0];
		tri.b = vbegin + triangles[i*3+1];
		tri.c = vbegin + triangles[i*3+2];
		this->triangles.push_back(tri);
	}

	return tbegin;
}

void KDTree::update()
{
	if (!vertices.size())
		return;

	/* Scene bounding box. */

	sceneMin = vertices[0];
	sceneMax = vertices[0];

	for (unsigned int i = 1; i < vertices.size(); i++)
	{
		sceneMin = vecmin(sceneMin, vertices[i]);
		sceneMax = vecmax(sceneMax, vertices[i]);
	}

	printf("SCENE MIN: %f %f %f\n", sceneMin.x, sceneMin.y, sceneMin.z);
	printf("SCENE MAX: %f %f %f\n", sceneMax.x, sceneMax.y, sceneMax.z);

	/* Build. */

	int* tris = new int [triangles.size()];
	for (unsigned int i = 0; i < triangles.size(); i++)
		tris[i] = (int)i;

	bboxes.reserve(triangles.size());

	for (unsigned int i = 0; i < triangles.size(); i++)
	{
		const Triangle& tri = triangles[i];
		m::Vector3 triMin = vecmin(vertices[tri.a], vecmin(vertices[tri.b], vertices[tri.c]));
		m::Vector3 triMax = vecmax(vertices[tri.a], vecmax(vertices[tri.b], vertices[tri.c]));
		bboxes.push_back(std::pair<m::Vector3, m::Vector3>(triMin, triMax));
	}

	int root = build(tris, (int)triangles.size(), sceneMin, sceneMax);
	assert(root == 0);

	printf("tree built.\n");
	bboxes.clear();

	delete [] tris;
}

int KDTree::newNode()
{
	if ((totalNodes & ((1 << 12)-1)) == 0)
	{
		//printf("%d -> %d\n", totalNodes, totalNodes + (1 << 14));
		nodes = (Node*) realloc(nodes, sizeof(Node) * (totalNodes + (1 << 14)));
		assert(nodes);
	}

	return totalNodes++;
}

void KDTree::allocateBuckets(int n)
{
	if (triangleBucketsCount + n > triangleBucketsSize)
	{
		triangleBucketsSize += n + 1024;
		triangleBuckets = (int*) realloc(triangleBuckets, triangleBucketsSize * sizeof(int));
		assert(triangleBuckets);
	}
}

int KDTree::build(int* tris, int n, const m::Vector3& min, const m::Vector3& max)
{
	if (!n)
		return -1;

	std::vector<std::pair<float, int> > sorted;

	int bestAxis = -1;
	int bestLeftN = 0;
	int bestRightN = 0;
	float bestPos = 0.0f;
	float bestCost = 1e24f;

	if (n > 1024)
	{
		m::Vector3 dim = max - min;
		if (dim.x > dim.y)
			bestAxis = (dim.x > dim.z) ? 0 : 2;
		else
			bestAxis = (dim.y > dim.z) ? 1 : 2;
		bestPos = .5f * (min[bestAxis] + max[bestAxis]);
	} else
	{
		for (int axis = 0; axis < 3; axis++)
		{
			sorted.clear();
			sorted.reserve(n);

			for (int i = 0; i <  n; i++)
			{
				sorted.push_back(std::pair<float, int>(bboxes[tris[i]].first[axis], tris[i] << 1));
				sorted.push_back(std::pair<float, int>(bboxes[tris[i]].second[axis], (tris[i] << 1) | 1));
			}

			std::sort(sorted.begin(), sorted.end());

			int nl = 0;
			int nr = n;

			float edge1 = max[(axis+1)%3] - min[(axis+1)%3];
			float edge2 = max[(axis+2)%3] - min[(axis+2)%3];
			float quadArea = 2 * edge1 * edge2;

			for (int i = 0; i < (int)sorted.size(); i++)
			{
				std::pair<float, int> p = sorted[i];

				if (p.first > min[axis] && p.first < max[axis])
				{
					float c1 = nl * (quadArea + 2 * edge1 * (p.first - min[axis]) + 2 * edge2 * (p.first - min[axis]));
					float c2 = nr * (quadArea + 2 * edge1 * (max[axis] - p.first) + 2 * edge2 * (max[axis] - p.first));

					if (c1 + c2 < bestCost && nl > 0 && nr > 0 && nl < n && nr < n)
					{
						bestAxis = axis;
						bestPos = p.first;
						bestCost = c1 + c2;
						bestLeftN = nl;
						bestRightN = nr;
					}
				}

				if (p.second & 1)
					nr--;
				else
					nl++;
			}

			assert(nl == n);
			assert(nr == 0);
		}
	}

	if (bestAxis == -1 || n <= 1)
	{
		int node = newNode();
		nodes[node].dir = 3 | (triangleBucketsCount << 4);
		nodes[node].right = -1;

		allocateBuckets(n + 1);

		for (int i = 0; i < n; i++)
			triangleBuckets[triangleBucketsCount++] = tris[i];
		triangleBuckets[triangleBucketsCount++] = -1;

		return node;
	}

	int node = newNode();
	nodes[node].dir = bestAxis;
	nodes[node].pos = bestPos;
	nodes[node].right = -1;

	int* leftTris = new int [n];
	int* rightTris = new int [n];
	int li = 0, ri = 0;

	for (int i = 0; i < n; i++)
	{
		if (bboxes[tris[i]].first[bestAxis] >= bestPos)
			rightTris[ri++] = tris[i];
		else if (bboxes[tris[i]].second[bestAxis] <= bestPos)
			leftTris[li++] = tris[i];
		else
		{
			leftTris[li++] = tris[i];
			rightTris[ri++] = tris[i];
		}

		//assert(li <= bestLeftN);
		//assert(ri <= bestRightN);
	}

	assert(li <= n);
	assert(ri <= n);

	//assert(li == bestLeftN);
	//assert(ri == bestRightN);

	m::Vector3 mid;

	mid = max;
	mid[bestAxis] = bestPos;
	int ret = build(leftTris, li, min, mid);
	assert(ret == -1 || ret == node+1);
	if (ret != -1)
		nodes[node].dir |= (1 << 2);

	delete [] leftTris;

	mid = min;
	mid[bestAxis] = bestPos;
	ret = build(rightTris, ri, mid, max);
	nodes[node].right = ret;

	delete [] rightTris;

	return node;
}

struct Item
{
	Item(int node, float minT, float maxT) : node(node), minT(minT), maxT(maxT) {}
	int node;
	float minT;
	float maxT;
};

int KDTree::intersect(const m::Vector3& p, const m::Vector3& d, float& _t, float& _u, float& _v) const
{
	if (!totalNodes)
		return -1;

	/* Inverse directions. */

	float invDir[3];

	for (int i = 0; i < 3; i++)
		if (d[i] == 0.0f)
			invDir[i] = 1e11f;
		else
			invDir[i] = 1.0f / d[i];

	/* Intersect scene bounding box. */

	float a, b;
	if (!aaboxIntersect(p, d, sceneMin, sceneMax, a, b))
		return -1;

	std::stack<Item> stack;

	stack.push(Item(0, a, b));

	int ni = -1;
	float nt = 1e10;
	float nu;
	float nv;

	while (!stack.empty())
	{
		int nodeIndex = stack.top().node;
		float minT = stack.top().minT;
		float maxT = stack.top().maxT;
		stack.pop();

repeat:
		const Node* node = &nodes[nodeIndex];
		assert(minT <= maxT);

		if ((node->dir & 3) == 3)
		{
			int i = node->dir >> 4;

			while (triangleBuckets[i] != -1)
			{
				const Triangle& tri = triangles[triangleBuckets[i]];
				float t, u, v;

				/* TODO: intersect all objects in the group. */

				if (triangleIntersect(p, d, vertices[tri.a], vertices[tri.b], vertices[tri.c], t, u, v))
				{
					assert(t >= 0.0f);
					if (t > 0.0001f && t < nt && t+0.0001f >= minT && t-0.0001f <= maxT)
					{
						ni = triangleBuckets[i];
						nt = t;
						nu = u;
						nv = v;

						_t = nt;
						_u = nu;
						_v = nv;
						return ni;
					}
				}

				i++;
			}
		}
		else
		{
			int dir = node->dir & 3;
			float t = (node->pos - p[dir]) * invDir[dir];

			int near;
			int far;

			if (p[dir] < node->pos)
			{
				if (node->dir & (1 << 2))
					near = nodeIndex+1;
				else
					near = -1;
				far = node->right;
			}
			else
			{
				near = node->right;
				if (node->dir & (1 << 2))
					far = nodeIndex+1;
				else
					far = -1;
			}

#if 0
			stack.push(Item(far, minT, maxT));
			stack.push(Item(near, minT, maxT));
			continue;
#endif

			if (t > maxT || t < 0.0f)
			{
				if (near != -1)
				{
					nodeIndex = near;
					goto repeat;
				}
			}
			else
			{
				if (t < minT)
				{
					if (far != -1)
					{
						nodeIndex = far;
						goto repeat;
					}
				}
				else
				{
					if (far != -1)
						stack.push(Item(far, t, maxT));

					if (near != -1)
					{
						nodeIndex = near;
						maxT = t;
						goto repeat;
					}
				}
			}
		}
	}

	return -1;
}

bool KDTree::shadowIntersect(const m::Vector3& p, const m::Vector3& d, float maxT) const
{
	float t, u, v;
	if (intersect(p, d, t, u, v) == -1)
		return false;
	if (t >= maxT)
		return false;
	return true;
}
