Skip to content

Commit f951bfc

Browse files
authored
nbody: add zig implementation using vectorized sqrt (#474)
1 parent 0c1d472 commit f951bfc

File tree

2 files changed

+127
-0
lines changed

2 files changed

+127
-0
lines changed

bench/algorithm/nbody/3.zig

+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
const std = @import("std");
2+
const math = std.math;
3+
4+
const solar_mass = 4.0 * math.pi * math.pi;
5+
const year = 365.24;
6+
7+
const vec3 = @Vector(3, f64);
8+
9+
fn scale(v: anytype, f: f64) @TypeOf(v) {
10+
return v * @as(@TypeOf(v), @splat(f));
11+
}
12+
13+
fn lengthSq(v: vec3) f64 {
14+
return @reduce(.Add, v * v);
15+
}
16+
17+
fn length(v: vec3) f64 {
18+
return @sqrt(lengthSq(v));
19+
}
20+
21+
const Body = struct {
22+
pos: vec3,
23+
vel: vec3,
24+
mass: f64,
25+
};
26+
27+
fn offsetMomentum(bodies: []Body) void {
28+
@setFloatMode(.optimized);
29+
var pos: vec3 = @splat(0);
30+
for (bodies[1..]) |b| pos += scale(@as(vec3, b.vel), b.mass);
31+
bodies[0].vel = -scale(pos, 1.0 / solar_mass);
32+
}
33+
34+
fn allPairs(comptime n: usize) [n * (n - 1)/2][2]u32 {
35+
var res: [n * (n - 1)/2][2]u32 = undefined;
36+
var k: usize = 0;
37+
for (0..n - 1) |i| for (i + 1..n) |j| {
38+
res[k] = .{@intCast(i), @intCast(j)};
39+
k += 1;
40+
};
41+
return res;
42+
}
43+
44+
fn advance(comptime n: usize, bodies: *[n]Body, dt: f64) void {
45+
@setFloatMode(.optimized);
46+
const pairs = comptime allPairs(n);
47+
var dp: [pairs.len]vec3 = undefined;
48+
var distSq: @Vector(pairs.len, f64) = undefined;
49+
inline for (pairs, 0..) |p, i| {
50+
const d = bodies[p[0]].pos - bodies[p[1]].pos;
51+
dp[i] = d;
52+
distSq[i] = lengthSq(dp[i]);
53+
}
54+
const mag = @as(@Vector(pairs.len, f64), @splat(dt)) / (distSq * @sqrt(distSq));
55+
56+
inline for (pairs, 0..) |p, i| {
57+
bodies[p[0]].vel -= scale(dp[i], bodies[p[1]].mass * mag[i]);
58+
bodies[p[1]].vel += scale(dp[i], bodies[p[0]].mass * mag[i]);
59+
}
60+
61+
inline for (bodies) |*body| body.pos += scale(body.vel, dt);
62+
}
63+
64+
fn energy(bodies: []const Body) f64 {
65+
@setFloatMode(.optimized);
66+
var e: f64 = 0.0;
67+
for (bodies, 0..) |bi, i| {
68+
e += 0.5 * lengthSq(bi.vel) * bi.mass;
69+
for (bodies[i + 1 ..]) |bj| {
70+
e -= bi.mass * bj.mass / length(bi.pos - bj.pos);
71+
}
72+
}
73+
return e;
74+
}
75+
76+
var solar_bodies = [_]Body{
77+
// Sun
78+
Body{
79+
.pos = @splat(0),
80+
.vel = @splat(0),
81+
.mass = solar_mass,
82+
},
83+
// Jupiter
84+
Body{
85+
.pos = .{ 4.84143144246472090, -1.16032004402742839, -0.103622044471123109 },
86+
.vel = scale(vec3{ 1.66007664274403694e-03, 7.69901118419740425e-03, -6.90460016972063023e-05 }, year),
87+
.mass = 9.54791938424326609e-04 * solar_mass,
88+
},
89+
// Saturn
90+
Body{
91+
.pos = .{ 8.34336671824457987, 4.12479856412430479, -0.403523417114321381 },
92+
.vel = scale(vec3{ -2.76742510726862411e-03, 4.99852801234917238e-03, 2.30417297573763929e-05 }, year),
93+
.mass = 2.85885980666130812e-04 * solar_mass,
94+
},
95+
// Uranus
96+
Body{
97+
.pos = .{ 12.8943695621391310, -15.1111514016986312, -0.223307578892655734 },
98+
.vel = scale(vec3{ 2.96460137564761618e-03, 2.37847173959480950e-03, -2.96589568540237556e-05 }, year),
99+
.mass = 4.36624404335156298e-05 * solar_mass,
100+
},
101+
// Neptune
102+
Body{
103+
.pos = .{ 15.3796971148509165, -25.9193146099879641, 0.179258772950371181 },
104+
.vel = scale(vec3{ 2.68067772490389322e-03, 1.62824170038242295e-03, -9.51592254519715870e-05 }, year),
105+
.mass = 5.15138902046611451e-05 * solar_mass,
106+
},
107+
};
108+
109+
pub fn main() !void {
110+
const steps = try getSteps();
111+
112+
offsetMomentum(&solar_bodies);
113+
const initial_energy = energy(&solar_bodies);
114+
for (0..steps) |_| advance(solar_bodies.len, &solar_bodies, 0.01);
115+
const final_energy = energy(&solar_bodies);
116+
117+
const stdout = std.io.getStdOut().writer();
118+
try stdout.print("{d:.9}\n{d:.9}\n", .{ initial_energy, final_energy });
119+
}
120+
121+
fn getSteps() !usize {
122+
var arg_it = std.process.args();
123+
_ = arg_it.skip();
124+
const arg = arg_it.next() orelse return 1000;
125+
return try std.fmt.parseInt(usize, arg, 10);
126+
}

bench/bench_zig.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ problems:
1313
source:
1414
- 1.zig
1515
- 2.zig
16+
- 3.zig
1617
- name: spectral-norm
1718
source:
1819
- 1.zig

0 commit comments

Comments
 (0)