8 __constant__ int offset[27][3] = { 0, 0, 0,
37 template<typename TDataType>
38 GridHash<TDataType>::GridHash()
42 template<typename TDataType>
43 GridHash<TDataType>::~GridHash()
47 template<typename TDataType>
48 void GridHash<TDataType>::setSpace(Real _h, Coord _lo, Coord _hi)
54 lo = _lo - padding * ds;
56 Coord nSeg = (_hi - _lo) / ds;
58 nx = (int)ceil(nSeg[0]);
59 ny = (int)ceil(nSeg[1]);
60 nz = (int)ceil(nSeg[2]);
62 //To avoid values less or equal to zero
70 hi = lo + Coord((Real)nx, (Real)ny, (Real)nz) * ds;
76 cuSafeCall(cudaMalloc((void**)&counter, num * sizeof(int)));
77 cuSafeCall(cudaMalloc((void**)&index, num * sizeof(int)));
79 if (m_reduce != nullptr)
84 m_reduce = Reduction<int>::Create(num);
87 template<typename TDataType>
88 __global__ void K_CalculateParticleNumber(GridHash<TDataType> hash, const DArray<typename TDataType::Coord> pos)
90 int pId = threadIdx.x + (blockIdx.x * blockDim.x);
91 if (pId >= pos.size()) return;
93 int gId = hash.getIndex(pos[pId]);
96 atomicAdd(&(hash.index[gId]), 1);
99 template<typename TDataType>
100 __global__ void K_ConstructHashTable(GridHash<TDataType> hash, DArray<typename TDataType::Coord> pos)
102 int pId = threadIdx.x + (blockIdx.x * blockDim.x);
103 if (pId >= pos.size()) return;
105 int gId = hash.getIndex(pos[pId]);
109 int index = atomicAdd(&(hash.counter[gId]), 1);
110 hash.ids[hash.index[gId] + index] = pId;
113 template<typename TDataType>
114 void GridHash<TDataType>::construct(const DArray<Coord>& pos)
118 dim3 pDims = int(ceil(pos.size() / BLOCK_SIZE + 0.5f));
120 K_CalculateParticleNumber << <pDims, BLOCK_SIZE >> > (*this, pos);
121 particle_num = m_reduce->accumulate(index, num);
123 if (m_scan == nullptr)
125 m_scan = new Scan<int>();
127 m_scan->exclusive(index, num);
131 cuSafeCall(cudaFree(ids));
133 cuSafeCall(cudaMalloc((void**)&ids, particle_num * sizeof(int)));
135 // std::cout << "Particle number: " << particle_num << std::endl;
137 K_ConstructHashTable << <pDims, BLOCK_SIZE >> > (*this, pos);
141 template<typename TDataType>
142 void GridHash<TDataType>::clear()
144 if (counter != nullptr)
145 cuSafeCall(cudaMemset(counter, 0, num * sizeof(int)));
147 if (index != nullptr)
148 cuSafeCall(cudaMemset(index, 0, num * sizeof(int)));
151 template<typename TDataType>
152 void GridHash<TDataType>::release()
154 if (counter != nullptr)
155 cuSafeCall(cudaFree(counter));
158 cuSafeCall(cudaFree(ids));
160 if (index != nullptr)
161 cuSafeCall(cudaFree(index));
163 if (m_scan != nullptr)
166 if (m_reduce != nullptr)
170 DEFINE_CLASS(GridHash);