Skip to content

Commit 2dc6d20

Browse files
author
PhySimdev
committed
docs: finished docstrings
1 parent f2af200 commit 2dc6d20

File tree

3 files changed

+116
-28
lines changed

3 files changed

+116
-28
lines changed

BarnesHut.py

Lines changed: 102 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,17 @@
66

77

88
class Node():
9+
"""Stores data for Octree nodes."""
10+
911
def __init__(self, middle, dimension):
12+
"""Method that sets up a node.
13+
14+
Method that sets up and declares variables for an octree node.
15+
16+
Args:
17+
middle: Position of center of the position.
18+
dimension: Length of sides of a node.
19+
"""
1020
self.particle = None
1121
self.middle = middle
1222
self.dimension = dimension
@@ -26,6 +36,15 @@ def get_subnode(self, quad):
2636
return (self.subnodes[quad[0]][quad[1]][quad[2]])
2737

2838
def create_subnode(self, quad):
39+
"""Method that creates a subnode.
40+
41+
Method that determines the middle and dimension of the subnode
42+
of a specific quadrant of the node. Initializes a subnode and adds
43+
that subnode to the nodes.
44+
45+
Args:
46+
quad: Quadrant of node.
47+
"""
2948
dimension = self.dimension / 2
3049

3150
x, y, z = 1, 1, 1
@@ -36,9 +55,9 @@ def create_subnode(self, quad):
3655
if quad[2] == 1:
3756
z = -1
3857

39-
middle = [self.middle[0] + ((dimension / 2) * x), # quad[0] == 0: value 1, right
40-
self.middle[1] + ((dimension / 2) * y), # quad[1] == 0: value 1, front
41-
self.middle[2] + ((dimension / 2) * z)] # quad[2] == 0: value 1, top
58+
middle = [self.middle[0] + ((dimension / 2) * x), # value 1, right
59+
self.middle[1] + ((dimension / 2) * y), # value 1, front
60+
self.middle[2] + ((dimension / 2) * z)] # value 1, top
4261
node = Node(middle, dimension)
4362
self.subnodes[quad[0]][quad[1]][quad[2]] = node
4463
self.nodes.append(node)
@@ -54,6 +73,11 @@ def get_quad(self, point):
5473
return [x, y, z]
5574

5675
def get_corners(self):
76+
"""Method that gets corners of a node.
77+
78+
Method that get corners of a node for visualization. Iterates through
79+
the top and bottom for front and back for right and left.
80+
"""
5781
if self.corners is None:
5882
self.corners = []
5983
for x in [1, -1]: # right or left
@@ -67,13 +91,22 @@ def get_corners(self):
6791

6892
def in_bounds(self, point):
6993
val = False
70-
if point[0] <= self.middle[0] + (self.dimension / 2) and point[0] >= self.middle[0] - (self.dimension / 2) and\
71-
point[1] <= self.middle[1] + (self.dimension / 2) and point[1] >= self.middle[1] - (self.dimension / 2) and\
72-
point[2] <= self.middle[2] + (self.dimension / 2) and point[2] >= self.middle[2] - (self.dimension / 2):
94+
if point[0] <= self.middle[0] + (self.dimension / 2) and\
95+
point[0] >= self.middle[0] - (self.dimension / 2) and\
96+
point[1] <= self.middle[1] + (self.dimension / 2) and\
97+
point[1] >= self.middle[1] - (self.dimension / 2) and\
98+
point[2] <= self.middle[2] + (self.dimension / 2) and\
99+
point[2] >= self.middle[2] - (self.dimension / 2):
73100
val = True
74101
return val
75102

76103
def compute_mass_distribution(self):
104+
"""Method that calculates the mass distribution.
105+
106+
Method that calculates the mass distribution of the node based on
107+
the mass posistions of the subnode weighted by weights of
108+
the subnodes.
109+
"""
77110
if self.particle is not None:
78111
self.center_of_mass = np.array([*self.particle.position])
79112
self.mass = self.particle.mass
@@ -91,41 +124,61 @@ def compute_mass_distribution(self):
91124

92125

93126
class Octree():
94-
def __init__(self, particles, root_node, theta):
127+
"""Handles setup and calculations of the Barnes-Hut octree."""
128+
129+
def __init__(self, particles, root_node, theta, node_type):
130+
"""Method that sets up an octree.
131+
132+
Method that sets up the variables for the octree. Calls functions
133+
for creation of the octree.
134+
135+
Args:
136+
particles: List of particles that are inserted.
137+
root_node: Root node of the octree.
138+
theta: Theta that determines the accuracy of the simulations.
139+
"""
95140
self.theta = theta
96141
self.root_node = root_node
97142
self.particles = particles
143+
self.node_type = node_type
98144
for particle in self.particles:
99145
self.insert_to_node(self.root_node, particle)
100146

101147
def insert_to_node(self, node, particle):
148+
"""Recursive method that inserts particles into the octree.
149+
150+
Recursive method that inserts bodies into the octree.
151+
Checks if particle is in the current node to prevent bounds issues.
152+
Determines the appropriate child node and gets that subnode.
153+
If that subnode is empty insert the particle and stop.
154+
If the child node point is a point node (one particle) turn it into
155+
a regional node by inserting both particles into it.
156+
If the child node is a regional node insert the particle.
157+
158+
Args:
159+
node: Quadrant of node.
160+
particle: Simulation body.
161+
"""
102162
# check if point is in cuboid of present node
103163
if not node.in_bounds(particle.position) and not np.array_equal(particle.position, self.root_node.middle):
104164
print("error particle not in bounds")
105165
print(f"middle: {node.middle}, dimension: {node.dimension}, particle position: {particle.position}, type: {type(particle)}")
106166
return
107167

108-
# determine the appropriate child node
109168
quad = node.get_quad(particle.position)
110169
if node.get_subnode(quad) is None:
111170
node.create_subnode(quad)
112171
subnode = node.get_subnode(quad)
113172

114-
# if subnode is empty, insert point, stop insertion
115173
if subnode.particle is None and len(subnode.nodes) == 0: # case empty node
116174
subnode.insert_particle(particle)
117175

118-
# If the child node is a point node, replace it with a region node.
119-
# Call insert for the point that just got replaced.
120-
# Set current node as the newly formed regionnode.
121176
elif subnode.particle is not None: # case point node
122177
old_particle = subnode.particle
123178
subnode.insert_particle(None)
124179
self.insert_to_node(subnode, old_particle)
125180
self.insert_to_node(subnode, particle)
126181

127-
# If the child node is a point node, replace it with a region node.
128-
# Call insert for the point that just got replaced. Set current node as the newly formed region node.
129182
elif subnode.particle is None and len(subnode.nodes) >= 1: # case region node
130183
self.insert_to_node(subnode, particle)
131184

@@ -143,15 +196,28 @@ def update_forces_collisions(self):
143196
del self.collision_dic[particle]
144197

145198
def calc_forces(self, node, particle):
146-
# Gravitational force and collision between two particles
199+
"""Method that calculates the force on an octree particle.
200+
201+
Method that calculates the force on an octree particle by iterating
202+
through the octree.
203+
If the node is a point node that doesnt hold the body itelf, directly
204+
calculate the forces.
205+
If its a regional node and the dimension/distance ratio is smaller
206+
than theta and the center of mass position is not the same as the
207+
particle position, calculate the force between the node and
208+
the particle.
209+
210+
Args:
211+
node: Quadrant of node.
212+
particle: Simulation body.
213+
"""
147214
if node.particle is not None and node.particle != particle:
148215
force, e_pot, distance = self.gravitational_force(particle, node, np.array([]), np.array([]))
149216
particle.force -= force
150217
particle.e_pot -= e_pot
151218
if distance < particle.radius + node.particle.radius and particle.mass > node.particle.mass:
152219
self.collision_dic[particle].append(node.particle)
153220

154-
# find regional node were particle is not the center of mass (subnodes particle is particle)
155221
elif node.particle is None and not np.array_equal(particle.position, node.center_of_mass):
156222
distance = np.array([*particle.position]) - np.array([*node.center_of_mass])
157223
r = math.sqrt(np.dot(distance, distance))
@@ -165,6 +231,22 @@ def calc_forces(self, node, particle):
165231
self.calc_forces(subnode, particle)
166232

167233
def gravitational_force(self, particle, node, distance_vec, distance_val): # can be ragional or point node
234+
"""Method that calculates the force between two particles.
235+
236+
Method that calculates the force acted on the particle by
237+
another particle or a node. Only calculates the distance and vector
238+
did not have to be calculated for theta.
239+
240+
Args:
241+
particle: Simulation body.
242+
node: Node of the octree.
243+
distance_vec: Vector of distance between the bodies.
244+
distance_val: Magnitude of distance betweent the bodies
245+
246+
Returns:
247+
The force, potential energy and distance between two bodies or
248+
the body and the node.
249+
"""
168250
force = np.array([0., 0., 0.])
169251
if len(distance_vec) == 0 and len(distance_val) == 0:
170252
distance = np.array([*particle.position]) - np.array([*node.center_of_mass])
@@ -179,14 +261,15 @@ def gravitational_force(self, particle, node, distance_vec, distance_val): # ca
179261
force = (distance / distance_mag) * force_mag
180262
return force, e_pot, distance_mag
181263

182-
def get_all_nodes(self, node, lst): # all point nodes, could include regional nodes
264+
def get_all_nodes(self, node, lst):
183265

184266
if node.particle is None and len(node.nodes) >= 1 or node.particle is not None:
185267
if len(node.nodes) >= 1:
186-
# lst.append(node.get_corners()) # if regional are shown aswell
268+
if self.node_type == "regional" or self.node_type == "both":
269+
lst.append(node.get_corners())
187270
for subnode in node.nodes:
188271
self.get_all_nodes(subnode, lst)
189-
if node.particle is not None:
272+
if node.particle is not None and (self.node_type == "point" or self.node_type == "both"):
190273
lst.append(node.get_corners())
191274

192275

Simulation.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def __init__(self, theta=1, rc=0, absolute_pos=True, focus_index=0):
2121
rc: Restitution coefficient for collisions.
2222
absolute_pos: Bool value to determine type of movement.
2323
focus_index: Index of the list focus_options form 0 to 2.
24+
node_type: String that determines what nodes are displayed.
2425
"""
2526
self.restitution_coefficient = rc
2627
self.focus_options = ["none", "body", "cm"]
@@ -36,7 +37,7 @@ def __init__(self, theta=1, rc=0, absolute_pos=True, focus_index=0):
3637
self.total_e = 0
3738
self.cm_pos = np.array([0, 0, 0])
3839
self.cm_velo = None
39-
40+
4041
if focus_index >= 0 and focus_index < len(self.focus_options):
4142
self.focus_index = focus_index
4243
else:
@@ -100,7 +101,7 @@ def clear_trail(self):
100101
for body in self.bodies:
101102
body.trail = []
102103

103-
def calculate(self, timestep, draw_box):
104+
def calculate(self, timestep, draw_box, node_type):
104105
"""Method that calculates a simulation physics step.
105106
106107
Method that calls functions for physics calculations.
@@ -115,7 +116,7 @@ def calculate(self, timestep, draw_box):
115116
"""
116117
if self.first:
117118
self.first = False
118-
self.update_interactions()
119+
self.update_interactions(node_type)
119120
for body in self.bodies:
120121
body.acceleration = body.force / body.mass
121122

@@ -124,7 +125,7 @@ def calculate(self, timestep, draw_box):
124125
body.update_velocity(timestep)
125126
body.update_position(timestep)
126127

127-
self.update_interactions()
128+
self.update_interactions(node_type)
128129

129130
self.tree_nodes = []
130131
if draw_box:
@@ -185,7 +186,7 @@ def get_data(self):
185186

186187
return body_data, self.tree_nodes, system_data, self.cm_pos - default_pos
187188

188-
def update_interactions(self):
189+
def update_interactions(self, node_type):
189190
center = self.get_focus_pos()
190191

191192
largest_val = 0
@@ -200,7 +201,7 @@ def update_interactions(self):
200201

201202
dimension = math.sqrt(((furthest_body.position[largest_index] - center[largest_index]) * 2.5)**2)
202203
root = Node(center, dimension)
203-
self.tree = Octree(self.bodies, root, self.theta)
204+
self.tree = Octree(self.bodies, root, self.theta, node_type)
204205
root.compute_mass_distribution()
205206
self.tree.update_forces_collisions()
206207
self.compute_collisions(self.tree.collision_dic)

main.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(self):
6767
self.planet_size_factor = 80
6868
self.min_body_size = .5
6969
self.path_planet_color = True
70-
self.draw_box = False
70+
self.draw_box = True
7171
self.draw_trail = True
7272
self.min_trail_size = .5
7373
self.trail_length = 100000
@@ -84,6 +84,10 @@ def __init__(self):
8484
self.max_frame = 3600
8585

8686
self.draw_rot = True
87+
self.node_list_index= 1
88+
self.node_list = ["regional", "point", "both"]
89+
self.node_type = self.node_list[self.node_list_index]
90+
self.grid_thickness = 2 * 10**9
8791
self.rot_cube_pos = [.8, -.8] # zwischen -1,1, scale of screen
8892
self.rot_cube_scale = 80
8993
self.rot_cube_scolor = "grey"
@@ -484,7 +488,7 @@ def draw_cube(self, qube):
484488
pos2 = qube[line[1]][0]
485489
if pos1 and pos2 is not None:
486490
f = (qube[line[0]][1] + qube[line[1]][1]) / 2
487-
self.pointer.pensize(self.path_size * f / 2)
491+
self.pointer.pensize(self.grid_thickness * f / 2)
488492
self.pointer.goto(pos1)
489493
self.pointer.down()
490494
self.pointer.goto(pos2)
@@ -728,7 +732,7 @@ def update_program(self):
728732
self.frame_count += 1
729733
t1 = time.time()
730734
for system in self.simulations:
731-
system.calculate(self.timestep, self.draw_box)
735+
system.calculate(self.timestep, self.draw_box, self.node_type)
732736
t2 = time.time()
733737
self.physics_time = t2 - t1
734738

0 commit comments

Comments
 (0)