1#include "NeighborPointQuery.h"
3#include "Topology/GridHash.h"
4#include "Topology/LinearBVH.h"
5#include "Topology/SparseOctree.h"
11 __constant__ int offset_nq[27][3] = {
41 IMPLEMENT_TCLASS(NeighborPointQuery, TDataType)
43 template<typename TDataType>
44 NeighborPointQuery<TDataType>::NeighborPointQuery()
47 this->inOther()->tagOptional(true);
49 this->varSizeLimit()->setRange(0, 100);
52 template<typename TDataType>
53 NeighborPointQuery<TDataType>::~NeighborPointQuery()
57 template<typename TDataType>
58 void NeighborPointQuery<TDataType>::compute()
60 auto sType = this->varSpatial()->currentKey();
62 if (sType == Spatial::UNIFORM)
64 if (this->varSizeLimit()->getValue() <= 0) {
65 requestDynamicNeighborIds();
68 requestFixedSizeNeighborIds();
71 else if (sType == Spatial::BVH)
73 requestNeighborIdsWithBVH();
75 else if (sType == Spatial::OCTREE)
77 requestNeighborIdsWithOctree();
81 template<typename Real, typename Coord, typename TDataType>
82 __global__ void K_CalNeighborSize(
84 DArray<Coord> position_new,
85 DArray<Coord> position,
86 GridHash<TDataType> hash,
89 int pId = threadIdx.x + (blockIdx.x * blockDim.x);
90 if (pId >= position_new.size()) return;
92 Coord pos_ijk = position_new[pId];
93 int3 gId3 = hash.getIndex3(pos_ijk);
96 for (int c = 0; c < 27; c++)
98 int cId = hash.getIndex(gId3.x + offset_nq[c][0], gId3.y + offset_nq[c][1], gId3.z + offset_nq[c][2]);
100 int totalNum = hash.getCounter(cId);
101 for (int i = 0; i < totalNum; i++) {
102 int nbId = hash.getParticleId(cId, i);
103 Real d_ij = (pos_ijk - position[nbId]).norm();
112 count[pId] = counter;
116 template<typename Real, typename Coord, typename TDataType>
117 __global__ void K_GetNeighborElements(
118 DArrayList<int> nbrIds,
119 DArray<Coord> position_new,
120 DArray<Coord> position,
121 GridHash<TDataType> hash,
124 int pId = threadIdx.x + (blockIdx.x * blockDim.x);
125 if (pId >= position_new.size()) return;
127 Coord pos_ijk = position_new[pId];
128 int3 gId3 = hash.getIndex3(pos_ijk);
130 List<int>& list_i = nbrIds[pId];
133 for (int c = 0; c < 27; c++)
135 int cId = hash.getIndex(gId3.x + offset_nq[c][0], gId3.y + offset_nq[c][1], gId3.z + offset_nq[c][2]);
137 int totalNum = hash.getCounter(cId);
138 for (int i = 0; i < totalNum; i++) {
139 int nbId = hash.getParticleId(cId, i);
140 Real d_ij = (pos_ijk - position[nbId]).norm();
151 template<typename TDataType>
152 void NeighborPointQuery<TDataType>::requestDynamicNeighborIds()
155 auto& points = this->inPosition()->constData();
156 auto& other = this->inOther()->isEmpty() ? this->inPosition()->constData() : this->inOther()->constData();
157 auto h = this->inRadius()->getValue();
160 if (this->outNeighborIds()->isEmpty())
161 this->outNeighborIds()->allocate();
163 auto& nbrIds = this->outNeighborIds()->getData();
165 // Construct hash grid
166 Reduction<Coord> reduce;
167 Coord hiBound = reduce.maximum(points.begin(), points.size());
168 Coord loBound = reduce.minimum(points.begin(), points.size());
170 // To avoid particles running out of the simulation domain
171 auto scn = this->getSceneGraph();
174 auto loLimit = scn->getLowerBound();
175 auto hiLimit = scn->getUpperBound();
177 hiBound = hiBound.minimum(hiLimit);
178 loBound = loBound.maximum(loLimit);
181 GridHash<TDataType> hashGrid;
182 hashGrid.setSpace(h, loBound - Coord(h), hiBound + Coord(h));
184 hashGrid.construct(points);
186 DArray<uint> counter(other.size());
187 cuExecute(other.size(),
195 nbrIds.resize(counter);
197 cuExecute(other.size(),
198 K_GetNeighborElements,
210 template <typename T> __device__ void inline swap_on_device(T& a, T& b) {
211 T c(a); a = b; b = c;
214 template <typename T>
215 __device__ void heapify_up(int* keys, T* vals, int child)
217 int parent = (child - 1) / 2;
220 if (vals[child] > vals[parent])
222 swap_on_device(vals[child], vals[parent]);
223 swap_on_device(keys[child], keys[parent]);
226 parent = (child - 1) / 2;
235 template <typename T>
236 __device__ void heapify_down(int* keys, T* vals, int node, int size) {
239 int left = 2 * j + 1;
240 int right = 2 * j + 2;
242 if (left<size && vals[left]>vals[largest]) {
245 if (right<size && vals[right]>vals[largest]) {
248 if (largest == j) return;
249 swap_on_device(vals[j], vals[largest]);
250 swap_on_device(keys[j], keys[largest]);
255 template <typename T>
256 __device__ void heap_sort(int* keys, T* vals, int size) {
258 swap_on_device(vals[0], vals[size - 1]);
259 swap_on_device(keys[0], keys[size - 1]);
260 heapify_down(keys, vals, 0, --size);
264 template<typename Real, typename Coord, typename TDataType>
265 __global__ void K_ComputeNeighborFixed(
266 DArrayList<int> nbrIds,
267 DArray<Coord> position_new,
268 DArray<Coord> position,
269 GridHash<TDataType> hash,
273 DArray<Real> heapDistance)
275 int pId = threadIdx.x + (blockIdx.x * blockDim.x);
276 if (pId >= position_new.size()) return;
278 //TODO: used shared memory for speedup
279 int* ids(heapIDs.begin() + pId * sizeLimit);// = new int[nbrLimit];
280 Real* distance(heapDistance.begin() + pId * sizeLimit);// = new Real[nbrLimit];
282 for (int i = 0; i < sizeLimit; i++) {
284 distance[i] = REAL_MAX;
287 Coord pos_ijk = position_new[pId];
288 int3 gId3 = hash.getIndex3(pos_ijk);
291 for (int c = 0; c < 27; c++)
293 int cId = hash.getIndex(gId3.x + offset_nq[c][0], gId3.y + offset_nq[c][1], gId3.z + offset_nq[c][2]);
295 int totalNum = hash.getCounter(cId);// min(hash.getCounter(cId), hash.npMax);
296 for (int i = 0; i < totalNum; i++) {
297 int nbId = hash.getParticleId(cId, i);
298 float d_ij = (pos_ijk - position[nbId]).norm();
301 if (counter < sizeLimit)
304 distance[counter] = d_ij;
306 heapify_up(ids, distance, counter);
311 if (d_ij < distance[0])
316 heapify_down(ids, distance, 0, counter);
325 List<int>& list_i = nbrIds[pId];
327 heap_sort(ids, distance, counter);
328 for (int bId = 0; bId < counter; bId++)
330 list_i.insert(ids[bId]);
334 template<typename TDataType>
335 void NeighborPointQuery<TDataType>::requestFixedSizeNeighborIds()
338 auto& points = this->inPosition()->constData();
339 auto& other = this->inOther()->isEmpty() ? this->inPosition()->constData() : this->inOther()->constData();
340 auto h = this->inRadius()->getValue();
343 if (this->outNeighborIds()->isEmpty())
344 this->outNeighborIds()->allocate();
346 auto& nbrIds = this->outNeighborIds()->getData();
348 uint numPt = this->inPosition()->getDataPtr()->size();
349 uint sizeLimit = this->varSizeLimit()->getValue();
351 nbrIds.resize(numPt, sizeLimit);
353 // Construct hash grid
354 Reduction<Coord> reduce;
355 Coord hiBound = reduce.maximum(points.begin(), points.size());
356 Coord loBound = reduce.minimum(points.begin(), points.size());
358 // To avoid particles running out of the simulation domain
359 auto scn = this->getSceneGraph();
362 auto loLimit = scn->getLowerBound();
363 auto hiLimit = scn->getUpperBound();
365 hiBound = hiBound.minimum(hiLimit);
366 loBound = loBound.maximum(loLimit);
369 GridHash<TDataType> hashGrid;
370 hashGrid.setSpace(h, loBound - Coord(h), hiBound + Coord(h));
372 hashGrid.construct(points);
374 DArray<int> ids(numPt * sizeLimit);
375 DArray<Real> distance(numPt * sizeLimit);
377 K_ComputeNeighborFixed,
393 template<typename Real, typename Coord>
394 __global__ void NPQ_SetupAABB(
395 DArray<AABB> boundingBox,
396 DArray<Coord> position,
399 int pId = threadIdx.x + (blockIdx.x * blockDim.x);
400 if (pId >= position.size()) return;
403 Coord p = position[pId];
407 boundingBox[pId] = box;
410 template<typename Coord, typename TDataType>
411 __global__ void NPQ_RequestNeighborNumberBVH(
412 DArray<uint> counter,
413 DArray<Coord> position,
414 LinearBVH<TDataType> bvh)
416 int tId = threadIdx.x + (blockIdx.x * blockDim.x);
417 if (tId >= position.size()) return;
419 Coord p = position[tId];
421 typename LinearBVH<TDataType>::AABB aabb;
422 aabb.v0 = p - EPSILON;
423 aabb.v1 = p + EPSILON;
425 counter[tId] = bvh.requestIntersectionNumber(aabb);
428 //TODO: sort ids according to their distance to the center
429 template<typename Coord, typename TDataType>
430 __global__ void NPQ_RequestNeighborIdsBVH(
431 DArrayList<int> idLists,
432 DArray<Coord> position,
433 LinearBVH<TDataType> bvh)
435 int tId = threadIdx.x + (blockIdx.x * blockDim.x);
436 if (tId >= position.size()) return;
438 Coord p = position[tId];
440 typename LinearBVH<TDataType>::AABB aabb;
441 aabb.v0 = p - EPSILON;
442 aabb.v1 = p + EPSILON;
444 bvh.requestIntersectionIds(idLists[tId], aabb);
447 template<typename TDataType>
448 void NeighborPointQuery<TDataType>::requestNeighborIdsWithBVH()
451 auto& points = this->inPosition()->constData();
452 auto& other = this->inOther()->isEmpty() ? this->inPosition()->constData() : this->inOther()->constData();
453 auto h = this->inRadius()->getValue();
455 uint numSrc = points.size();
456 uint numTar = other.size();
458 if (this->outNeighborIds()->isEmpty()) {
459 this->outNeighborIds()->allocate();
462 auto& neighborLists = this->outNeighborIds()->getData();
464 DArray<AABB> aabbs(numTar);
472 LinearBVH<TDataType> bvh;
473 bvh.construct(aabbs);
475 DArray<uint> counter(numSrc);
478 NPQ_RequestNeighborNumberBVH,
483 neighborLists.resize(counter);
486 NPQ_RequestNeighborIdsBVH,
497 template<typename Coord, typename TDataType>
498 __global__ void CDBP_RequestIntersectionNumber(
500 DArray<Coord> points,
502 SparseOctree<TDataType> octree)
504 int tId = threadIdx.x + (blockIdx.x * blockDim.x);
505 if (tId >= count.size()) return;
507 Coord p = points[tId];
510 aabb.v0 = p - radius;
511 aabb.v1 = p + radius;
513 count[tId] = octree.requestIntersectionNumberFromBottom(aabb);
516 template<typename Coord, typename TDataType>
517 __global__ void CDBP_RequestIntersectionIds(
518 DArrayList<int> lists,
521 DArray<Coord> points,
523 SparseOctree<TDataType> octree)
525 int tId = threadIdx.x + (blockIdx.x * blockDim.x);
526 if (tId >= count.size()) return;
528 Coord p = points[tId];
529 int total_num = count.size();
532 aabb.v0 = p - radius;
533 aabb.v1 = p + radius;
535 octree.reqeustIntersectionIdsFromBottom(ids.begin() + count[tId], aabb);
537 int n = tId == total_num - 1 ? ids.size() - count[total_num - 1] : count[tId + 1] - count[tId];
539 List<int>& list_i = lists[tId];
541 for (int t = 0; t < n; t++)
543 list_i.insert(ids[count[tId] + t]);
547 template<typename Real, typename Coord>
548 __global__ void CDBP_RequestNeighborSize(
549 DArray<uint> counter,
550 DArray<Coord> srcPoints,
551 DArray<Coord> tarPoints,
552 DArrayList<int> lists,
555 int tId = threadIdx.x + (blockIdx.x * blockDim.x);
556 if (tId >= counter.size()) return;
558 Coord p_i = srcPoints[tId];
560 List<int>& list_i = lists[tId];
561 int nbSize = list_i.size();
563 for (int ne = 0; ne < nbSize; ne++)
566 Real r = (p_i - tarPoints[j]).norm();
575 //TODO: sort ids according to their distance to the center
576 template<typename Real, typename Coord>
577 __global__ void CDBP_RequestNeighborIds(
578 DArrayList<int> neighbors,
579 DArray<Coord> srcPoints,
580 DArray<Coord> tarPoints,
581 DArrayList<int> lists,
584 int tId = threadIdx.x + (blockIdx.x * blockDim.x);
585 if (tId >= neighbors.size()) return;
587 Coord p_i = srcPoints[tId];
589 List<int>& list_i = lists[tId];
590 int nbSize = list_i.size();
592 List<int>& neList_i = neighbors[tId];
594 for (int ne = 0; ne < nbSize; ne++)
597 Real r = (p_i - tarPoints[j]).norm();
606 template<typename TDataType>
607 void NeighborPointQuery<TDataType>::requestNeighborIdsWithOctree()
610 auto& points = this->inPosition()->constData();
611 auto& other = this->inOther()->isEmpty() ? this->inPosition()->constData() : this->inOther()->constData();
612 auto h = this->inRadius()->getValue();
614 uint numSrc = points.size();
615 uint numTar = other.size();
617 if (this->outNeighborIds()->isEmpty()) {
618 this->outNeighborIds()->allocate();
621 auto& neighborLists = this->outNeighborIds()->getData();
623 Reduction<Coord> m_reduce_coord;
624 auto min_v0 = m_reduce_coord.minimum(other.begin(), other.size());
625 auto max_v1 = m_reduce_coord.maximum(other.begin(), other.size());
627 SparseOctree<TDataType> octree;
628 octree.setSpace(min_v0, h, maximum(max_v1[0] - min_v0[0], maximum(max_v1[1] - min_v0[1], max_v1[2] - min_v0[2])));
629 octree.construct(other, 0);
631 DArray<uint> counter(numSrc);
634 CDBP_RequestIntersectionNumber,
640 DArrayList<int> lists;
641 lists.resize(counter);
643 Reduction<uint> reduce;
644 uint total_num = reduce.accumulate(counter.begin(), counter.size());
647 scan.exclusive(counter.begin(), counter.size());
649 DArray<int> ids(total_num);
651 CDBP_RequestIntersectionIds,
659 DArray<uint> neighbor_counter(numSrc);
661 CDBP_RequestNeighborSize,
668 neighborLists.resize(neighbor_counter);
671 CDBP_RequestNeighborIds,
681 neighbor_counter.clear();
685 DEFINE_CLASS(NeighborPointQuery);