5#include "Math/SimpleMath.h"
7#include <thrust/sort.h>
13 template<typename TDataType>
14 LinearBVH<TDataType>::LinearBVH()
18 template<typename TDataType>
19 LinearBVH<TDataType>::~LinearBVH()
23 template<typename TDataType>
24 void LinearBVH<TDataType>::release()
29 mSortedObjectIds.clear();
30 mFlags.clear(); //Flags used for calculating bounding box
34 template<typename Coord, typename AABB>
35 __global__ void LBVH_CalculateCenter(
39 int tId = threadIdx.x + (blockIdx.x * blockDim.x);
41 if (tId >= aabb.size()) return;
43 center[tId] = Real(0.5) * (aabb[tId].v0 + aabb[tId].v1);
46 // Expands a 10-bit integer into 30 bits
47 // by inserting 2 zeros after each bit.
48 __device__ uint expandBits(uint v)
50 v = (v * 0x00010001u) & 0xFF0000FFu;
51 v = (v * 0x00000101u) & 0x0F00F00Fu;
52 v = (v * 0x00000011u) & 0xC30C30C3u;
53 v = (v * 0x00000005u) & 0x49249249u;
57 // Calculates a 30-bit Morton code for the
58 // given 3D point located within the unit cube [0,1].
59 template<typename Real>
60 __device__ uint morton3D(Real x, Real y, Real z)
62 x = min(max(x * Real(1024), Real(0)), Real(1023));
63 y = min(max(y * Real(1024), Real(0)), Real(1023));
64 z = min(max(z * Real(1024), Real(0)), Real(1023));
65 uint xx = expandBits((uint)x);
66 uint yy = expandBits((uint)y);
67 uint zz = expandBits((uint)z);
68 return xx * 4 + yy * 2 + zz;
71 template<typename Real, typename Coord>
72 __global__ void LBVH_CalculateMortonCode(
73 DArray<uint64> morton,
74 DArray<uint> objectId,
79 uint tId = threadIdx.x + (blockIdx.x * blockDim.x);
81 if (tId >= center.size()) return;
83 Coord scaled = (center[tId] - orgin) / L;
85 uint64 m64 = morton3D(scaled.x, scaled.y, scaled.z);
89 //printf("morton %u: %u; %llu \n", tId, morton3D(scaled.x, scaled.y, scaled.z), m64);
95 __device__ int findSplit(uint* sortedMortonCodes,
99 // Identical Morton codes => split the range in the middle.
101 uint firstCode = sortedMortonCodes[first];
102 uint lastCode = sortedMortonCodes[last];
104 if (firstCode == lastCode)
105 return (first + last) >> 1;
107 // Calculate the number of highest bits that are the same
108 // for all objects, using the count-leading-zeros intrinsic.
110 int commonPrefix = __clz(firstCode ^ lastCode);
112 // Use binary search to find where the next bit differs.
113 // Specifically, we are looking for the highest object that
114 // shares more than commonPrefix bits with the first one.
116 int split = first; // initial guess
117 int step = last - first;
121 step = (step + 1) >> 1; // exponential decrease
122 int newSplit = split + step; // proposed new position
126 uint splitCode = sortedMortonCodes[newSplit];
127 int splitPrefix = __clz(firstCode ^ splitCode);
128 if (splitPrefix > commonPrefix)
129 split = newSplit; // accept proposal
136 template<typename Node, typename AABB>
137 __global__ void LBVH_ConstructBinaryRadixTree(
138 DArray<Node> bvhNodes,
139 DArray<AABB> sortedAABBs,
141 DArray<uint64> mortonCodes,
142 DArray<uint> sortedObjectIds)
144 int i = threadIdx.x + (blockIdx.x * blockDim.x);
145 int N = sortedObjectIds.size();
149// printf("Num: %d \n", N);
151// printf("sorted morton %d: %llu \n", i, mortonCodes[i]);
153 sortedAABBs[i + N - 1] = aabbs[sortedObjectIds[i]];
155 if (i >= N - 1) return;
157 //Calculate the length of the longest common prefix between i and j, note i should be in the range of [0, N-1]
158 auto delta = [&](int _i, int _j) -> int {
159 if (_j < 0 || _j >= N) return -1;
160 return __clzll(mortonCodes[_i] ^ mortonCodes[_j]);
163// printf("Test CLZ: %d \n", __clzll(mortonCodes[1]));
165 int d = delta(i, i + 1) - delta(i, i - 1) > 0 ? 1 : -1;
167// printf("%u %d \n", i, d);
169// printf("delta: %d \n", delta(0, 1));
171 // Compute upper bound for the length of the range
172 int delta_min = delta(i, i - d);
174// printf("delta_min %d %d: %d \n", i, i - d, delta_min);
177 while (delta(i, i + len_max * d) > delta_min)
182 // Find the other end using binary search
184 for (int t = len_max / 2; t > 0; t = t / 2)
186 if (delta(i, i + (len + t) * d) > delta_min)
194 // Find the split position using binary search
195 int delta_node = delta(i, j);
200// printf("len: %d \n", len);
203 for (int t = (len + 1) / 2; t > 0; t = t == 1 ? 0 : (t + 1) / 2)
205 if (delta(i, i + (s + t) * d) > delta_node)
210// printf("s: %d; t: %d \n", s, t);
214 int gamma = i + s * d + minimum(d, (int)0);
216// printf("i-j: %d %d; Gamma: %d \n", i, j, gamma);
220// printf("21 22 23 24 dir: %d; %llu; %llu; %llu; %llu \n", d, mortonCodes[21], mortonCodes[22], mortonCodes[23], mortonCodes[24]);
221// printf("0 41 42 dir: %llu; %llu; %llu \n", mortonCodes[0], mortonCodes[41], mortonCodes[42]);
224 //printf("Gamma: %u \n", gamma);
226 //Output child pointers
227 int left_idx = minimum(i, j) == gamma ? gamma + N - 1 : gamma;
228 int right_idx = maximum(i, j) == gamma + 1 ? gamma + N : gamma + 1;
230// printf("i: %d, j: %d Left: %d; Right: %d; \n", i, j, left_idx, right_idx);
232 bvhNodes[i].left = left_idx;
233 bvhNodes[i].right = right_idx;
235 bvhNodes[left_idx].parent = i;
236 bvhNodes[right_idx].parent = i;
239 template<typename Node, typename AABB>
240 __global__ void LBVH_CalculateBoundingBox(
241 DArray<AABB> sortedAABBs,
242 DArray<Node> bvhNodes,
245 uint i = threadIdx.x + (blockIdx.x * blockDim.x);
246 uint N = flags.size();
250 //Output AABBs of leaf nodes
251// auto v0 = sortedAABBs[i + N - 1].v0;
252// auto v1 = sortedAABBs[i + N - 1].v1;
253// printf("%d: idx, %f %f %f; %f %f %f \n", i + N - 1, v0.x, v0.y, v0.z, v1.x, v1.y, v1.z);
255 int idx = bvhNodes[i + N - 1].parent;
256 while (idx != EMPTY) // means idx == 0
258 //printf("Left: %u; Right: %u, \n", idx->left->idx, idx->right->idx);
259 const int old = atomicCAS(flags.begin() + idx, 0, 1);
262 // this is the first thread entered here.
263 // wait the other thread from the other child node.
267 // here, the flag has already been 1. it means that this
268 // thread is the 2nd thread. merge AABB of both childlen.
270 const int l_idx = bvhNodes[idx].left;
271 const int r_idx = bvhNodes[idx].right;
272 const AABB l_aabb = sortedAABBs[l_idx];
273 const AABB r_aabb = sortedAABBs[r_idx];
274 sortedAABBs[idx] = l_aabb.merge(r_aabb);
276 //Output AABBs of internal nodes
277// auto v0 = sortedAABBs[idx].v0;
278// auto v1 = sortedAABBs[idx].v1;
279// printf("%d: idx, %f %f %f; %f %f %f \n", idx, v0.x, v0.y, v0.z, v1.x, v1.y, v1.z);
281 // look the next parent...
282 idx = bvhNodes[idx].parent;
284 //printf("BB %d, \n", idx);
288 template<typename Node>
289 __global__ void LBVH_InitialAllNodes(
290 DArray<Node> bvhNodes)
292 uint i = threadIdx.x + (blockIdx.x * blockDim.x);
293 if (i >= bvhNodes.size()) return;
295 bvhNodes[i] = Node();
299 template<typename TDataType>
300 void LinearBVH<TDataType>::construct(const DArray<AABB>& aabb)
302 uint num = aabb.size();
304 if (mCenters.size() != num){
305 mCenters.resize(num);
306 mMortonCodes.resize(num);
307 mSortedObjectIds.resize(num);
310 mSortedAABBs.resize(2 * num - 1);
311 mAllNodes.resize(2 * num - 1);
315 LBVH_CalculateCenter,
319 Reduction<Coord> mReduce;
320 Coord v_min = mReduce.minimum(mCenters.begin(), mCenters.size());
321 Coord v_max = mReduce.maximum(mCenters.begin(), mCenters.size());
323 Real L = std::max(v_max[0] - v_min[0], std::max(v_max[1] - v_min[1], v_max[2] - v_min[2]));
324 L = L < REAL_EPSILON ? Real(1) : L; //To avoid being divided by zero
326 Coord origin = Real(0.5) * (v_min + v_max) - Real(0.5) * L;
329 LBVH_CalculateMortonCode,
338 thrust::sort_by_key(thrust::device, mMortonCodes.begin(), mMortonCodes.begin() + mMortonCodes.size(), mSortedObjectIds.begin());
340// std::cout << "Sort: " << timer.getElapsedTime() << std::endl;
342 cuExecute(mAllNodes.size(),
343 LBVH_InitialAllNodes,
348 LBVH_ConstructBinaryRadixTree,
355// std::cout << "Construct: " << timer.getElapsedTime() << std::endl;
357// CArray<Node> hArray;
358// hArray.assign(mAllNodes);
363 LBVH_CalculateBoundingBox,
368// std::cout << "BoundingBox: " << timer.getElapsedTime() << std::endl;
371 template<typename TDataType>
372 GPU_FUNC uint LinearBVH<TDataType>::requestIntersectionNumber(const AABB& queryAABB, const int queryId) const
374 // Allocate traversal stack from thread-local memory,
375 // and push NULL to indicate that there are no postponed nodes.
379 stack.reserve(buffer, 64);
381 uint N = mSortedObjectIds.size();
383 // Traverse nodes starting from the root.
388 // Check each child node for overlap.
389 int idxL = mAllNodes[idx].left;
390 int idxR = mAllNodes[idx].right;
391 bool overlapL = queryAABB.checkOverlap(getAABB(idxL));
392 bool overlapR = queryAABB.checkOverlap(getAABB(idxR));
394 // Query overlaps a leaf node => report collision.
395 if (overlapL && mAllNodes[idxL].isLeaf()) {
396 int objId = mSortedObjectIds[idxL - N + 1];
397 if(objId > queryId) ret++;
400 if (overlapR && mAllNodes[idxR].isLeaf()) {
401 int objId = mSortedObjectIds[idxR - N + 1];
402 if (objId > queryId) ret++;
405 // Query overlaps an internal node => traverse.
406 bool traverseL = (overlapL && !mAllNodes[idxL].isLeaf());
407 bool traverseR = (overlapR && !mAllNodes[idxR].isLeaf());
409 if (!traverseL && !traverseR) {
410 idx = !stack.empty() ? stack.top() : EMPTY; // pop
415 idx = (traverseL) ? idxL : idxR;
416 if (traverseL && traverseR)
417 stack.push(idxR); // push
419 } while (idx != EMPTY);
424 template<typename TDataType>
425 GPU_FUNC void LinearBVH<TDataType>::requestIntersectionIds(List<int>& ids, const AABB& queryAABB, const int queryId) const
427 // Allocate traversal stack from thread-local memory,
428 // and push NULL to indicate that there are no postponed nodes.
432 stack.reserve(buffer, 64);
434 uint N = mSortedObjectIds.size();
436 // Traverse nodes starting from the root.
441 // Check each child node for overlap.
442 int idxL = mAllNodes[idx].left;
443 int idxR = mAllNodes[idx].right;
444 bool overlapL = queryAABB.checkOverlap(getAABB(idxL));
445 bool overlapR = queryAABB.checkOverlap(getAABB(idxR));
447 // Query overlaps a leaf node => report collision.
448 if (overlapL && mAllNodes[idxL].isLeaf()) {
449 int objId = mSortedObjectIds[idxL - N + 1];
454 if (overlapR && mAllNodes[idxR].isLeaf()) {
455 int objId = mSortedObjectIds[idxR - N + 1];
460 // Query overlaps an internal node => traverse.
461 bool traverseL = (overlapL && !mAllNodes[idxL].isLeaf());
462 bool traverseR = (overlapR && !mAllNodes[idxR].isLeaf());
464 if (!traverseL && !traverseR) {
465 idx = !stack.empty() ? stack.top() : EMPTY; // pop
470 idx = (traverseL) ? idxL : idxR;
471 if (traverseL && traverseR)
472 stack.push(idxR); // push
474 } while (idx != EMPTY);
477 DEFINE_CLASS(LinearBVH);