Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
batched_affine_addition.cpp
Go to the documentation of this file.
5#include <algorithm>
6#include <execution>
7#include <set>
8
9namespace bb {
10
11template <typename Curve>
13 const std::span<G1>& points, const std::vector<size_t>& sequence_counts)
14{
15 BB_BENCH_NAME("BatchedAffineAddition::add_in_place");
16 // Instantiate scratch space for point addition denominators and their calculation
17 std::vector<Fq> scratch_space_vector(points.size());
18 std::span<Fq> scratch_space(scratch_space_vector);
19
20 // Divide the work into groups of addition sequences to be reduced by each thread
21 auto [addition_sequences_, sequence_tags] = construct_thread_data(points, sequence_counts, scratch_space);
22 auto& addition_sequences = addition_sequences_;
23
24 const size_t num_threads = addition_sequences.size();
25 parallel_for(num_threads, [&](size_t thread_idx) { batched_affine_add_in_place(addition_sequences[thread_idx]); });
26
27 // Construct a vector of the reduced points, accounting for sequences that may have been split across threads
28 std::vector<G1> reduced_points;
29 size_t prev_tag = std::numeric_limits<size_t>::max();
30 for (auto [sequences, tags] : zip_view(addition_sequences, sequence_tags)) {
31 // Extract the first num-sequence-counts many points from each add sequence
32 for (size_t i = 0; i < sequences.sequence_counts.size(); ++i) {
33 if (tags[i] == prev_tag) {
34 reduced_points.back() = reduced_points.back() + sequences.points[i];
35 } else {
36 reduced_points.emplace_back(sequences.points[i]);
37 }
38 prev_tag = tags[i];
39 }
40 }
41
42 return reduced_points;
43}
44
45template <typename Curve>
47 const std::span<G1>& points, const std::vector<size_t>& sequence_counts, const std::span<Fq>& scratch_space)
48{
49 // Compute the endpoints of the sequences within the points array from the sequence counts
50 std::vector<size_t> sequence_endpoints;
51 size_t total_count = 0;
52 for (const auto& count : sequence_counts) {
53 total_count += count;
54 sequence_endpoints.emplace_back(total_count);
55 }
56
57 if (points.size() != total_count) {
58 throw_or_abort("Number of input points does not match sequence counts!");
59 }
60
61 // Determine the optimal number of threads for parallelization
62 const size_t MIN_POINTS_PER_THREAD = 1 << 14; // heuristic; anecdotally optimal for practical cases
63 const size_t total_num_points = points.size();
64 const size_t optimal_threads = total_num_points / MIN_POINTS_PER_THREAD;
65 const size_t num_threads = std::max(1UL, std::min(get_num_cpus(), optimal_threads));
66 // Distribute the work as evenly as possible across threads
67 const size_t base_thread_size = total_num_points / num_threads;
68 const size_t leftover_size = total_num_points % num_threads;
69 std::vector<size_t> thread_sizes(num_threads, base_thread_size);
70 for (size_t i = 0; i < leftover_size; ++i) {
71 thread_sizes[i]++;
72 }
73
74 // Construct the point spans for each thread according to the distribution determined above
75 std::vector<std::span<G1>> thread_points;
76 std::vector<std::span<Fq>> thread_scratch_space;
77 std::vector<size_t> thread_endpoints;
78 size_t point_index = 0;
79 for (auto size : thread_sizes) {
80 thread_points.push_back(points.subspan(point_index, size));
81 thread_scratch_space.push_back(scratch_space.subspan(point_index, size));
82 point_index += size;
83 thread_endpoints.emplace_back(point_index);
84 }
85
86 // Construct the union of the thread and sequence endpoints by combining, sorting, then removing duplicates. This is
87 // used to break the points into sequences for each thread while tracking tags so that sequences split across one of
88 // more threads can be properly reconstructed.
89 std::vector<size_t> all_endpoints;
90 all_endpoints.reserve(thread_endpoints.size() + sequence_endpoints.size());
91 all_endpoints.insert(all_endpoints.end(), thread_endpoints.begin(), thread_endpoints.end());
92 all_endpoints.insert(all_endpoints.end(), sequence_endpoints.begin(), sequence_endpoints.end());
93 std::sort(all_endpoints.begin(), all_endpoints.end());
94 auto last = std::unique(all_endpoints.begin(), all_endpoints.end());
95 all_endpoints.erase(last, all_endpoints.end());
96
97 // Construct sequence counts and tags for each thread using the set of all thread and sequence endpoints
98 size_t prev_endpoint = 0;
99 size_t thread_idx = 0;
100 size_t sequence_idx = 0;
101 std::vector<std::vector<size_t>> thread_sequence_counts(num_threads);
102 std::vector<std::vector<size_t>> thread_sequence_tags(num_threads);
103 for (auto& endpoint : all_endpoints) {
104 size_t chunk_size = endpoint - prev_endpoint;
105 thread_sequence_counts[thread_idx].emplace_back(chunk_size);
106 thread_sequence_tags[thread_idx].emplace_back(sequence_idx);
107 if (endpoint == thread_endpoints[thread_idx]) {
108 thread_idx++;
109 }
110 if (endpoint == sequence_endpoints[sequence_idx]) {
111 sequence_idx++;
112 }
113 prev_endpoint = endpoint;
114 }
115
116 if (thread_sequence_counts.size() != thread_points.size()) {
117 throw_or_abort("Mismatch in sequence count construction!");
118 }
119
120 // Construct the addition sequences for each thread
121 std::vector<AdditionSequences> addition_sequences;
122 for (size_t i = 0; i < num_threads; ++i) {
123 addition_sequences.push_back(
124 AdditionSequences{ thread_sequence_counts[i], thread_points[i], thread_scratch_space[i] });
125 }
126
127 return { addition_sequences, thread_sequence_tags };
128}
129
130template <typename Curve>
132 Curve>::batch_compute_point_addition_slope_inverses(const AdditionSequences& add_sequences)
133{
134 auto points = add_sequences.points;
135 auto sequence_counts = add_sequences.sequence_counts;
136
137 // Count the total number of point pairs to be added across all addition sequences
138 size_t total_num_pairs{ 0 };
139 for (auto& count : sequence_counts) {
140 total_num_pairs += count >> 1;
141 }
142
143 // Define scratch space for batched inverse computations and eventual storage of denominators
144 BB_ASSERT_GTE(add_sequences.scratch_space.size(), 2 * total_num_pairs);
145 std::span<Fq> denominators = add_sequences.scratch_space.subspan(0, total_num_pairs);
146 std::span<Fq> differences = add_sequences.scratch_space.subspan(total_num_pairs, 2 * total_num_pairs);
147
148 // Compute and store successive products of differences (x_2 - x_1)
149 Fq accumulator = 1;
150 size_t point_idx = 0;
151 size_t pair_idx = 0;
152 for (auto& count : sequence_counts) {
153 const auto num_pairs = count >> 1;
154 for (size_t j = 0; j < num_pairs; ++j) {
155 BB_ASSERT_LT(pair_idx, total_num_pairs);
156 const auto& x1 = points[point_idx++].x;
157 const auto& x2 = points[point_idx++].x;
158
159 // It is assumed that the input points are random and thus w/h/p do not share an x-coordinate
160 BB_ASSERT(x1 != x2);
161
162 auto diff = x2 - x1;
163 differences[pair_idx] = diff;
164
165 // Store and update the running product of differences at each stage
166 denominators[pair_idx++] = accumulator;
167 accumulator *= diff;
168 }
169 // If number of points in the sequence is odd, we skip the last one since it has no pair
170 point_idx += (count & 0x01ULL);
171 }
172
173 // Invert the full product of differences
174 Fq inverse = accumulator.invert();
175
176 // Compute the individual point-pair addition denominators 1/(x2 - x1)
177 for (size_t i = 0; i < total_num_pairs; ++i) {
178 size_t idx = total_num_pairs - 1 - i;
179 denominators[idx] *= inverse;
180 inverse *= differences[idx];
181 }
182
183 return denominators;
184}
185
186template <typename Curve>
188{
189 const size_t num_points = add_sequences.points.size();
190 if (num_points == 0 || num_points == 1) { // nothing to do
191 return;
192 }
193
194 // Batch compute terms of the form 1/(x2 -x1) for each pair to be added in this round
195 std::span<Fq> denominators = batch_compute_point_addition_slope_inverses(add_sequences);
196
197 auto points = add_sequences.points;
198 auto sequence_counts = add_sequences.sequence_counts;
199
200 // Compute pairwise in-place additions for all sequences with more than 1 point
201 size_t point_idx = 0; // index for points to be summed
202 size_t result_point_idx = 0; // index for result points
203 size_t pair_idx = 0; // index into array of denominators for each pair
204 bool more_additions = false;
205 for (auto& count : sequence_counts) {
206 const auto num_pairs = count >> 1;
207 const bool overflow = static_cast<bool>(count & 0x01ULL);
208 // Compute the sum of all pairs in the sequence and store the result in the same points array
209 for (size_t j = 0; j < num_pairs; ++j) {
210 const auto& point_1 = points[point_idx++]; // first summand
211 const auto& point_2 = points[point_idx++]; // second summand
212 const auto& denominator = denominators[pair_idx++]; // denominator needed in add formula
213 auto& result = points[result_point_idx++]; // target for addition result
214
215 result = affine_add_with_denominator(point_1, point_2, denominator);
216 }
217 // If the sequence had an odd number of points, simply carry the unpaired point over to the next round
218 if (overflow) {
219 points[result_point_idx++] = points[point_idx++];
220 }
221
222 // Update the sequence counts in place for the next round
223 const uint32_t updated_sequence_count = static_cast<uint32_t>(num_pairs) + static_cast<uint32_t>(overflow);
224 count = updated_sequence_count;
225
226 // More additions are required if any sequence has not yet been reduced to a single point
227 more_additions = more_additions || updated_sequence_count > 1;
228 }
229
230 // Recursively perform pairwise additions until all sequences have been reduced to a single point
231 if (more_additions) {
232 const size_t updated_point_count = result_point_idx;
233 std::span<G1> updated_points(&points[0], updated_point_count);
234 return batched_affine_add_in_place(
235 AdditionSequences{ sequence_counts, updated_points, add_sequences.scratch_space });
236 }
237}
238
241} // namespace bb
#define BB_ASSERT(expression,...)
Definition assert.hpp:80
#define BB_ASSERT_GTE(left, right,...)
Definition assert.hpp:138
#define BB_ASSERT_LT(left, right,...)
Definition assert.hpp:153
#define BB_BENCH_NAME(name)
Definition bb_bench.hpp:219
Class for handling fast batched affine addition of large sets of EC points.
static std::vector< G1 > add_in_place(const std::span< G1 > &points, const std::vector< size_t > &sequence_counts)
Given a set of points and sequence counts, peform addition to reduce each sequence to a single point.
static void batched_affine_add_in_place(AdditionSequences add_sequences)
Internal method for in-place summation of a single set of addition sequences.
static ThreadData construct_thread_data(const std::span< G1 > &points, const std::vector< size_t > &sequence_counts, const std::span< Fq > &scratch_space)
Construct the set of AdditionSequences to be handled by each thread.
Entry point for Barretenberg command-line interface.
Definition api.hpp:5
size_t get_num_cpus()
Definition thread.cpp:33
void parallel_for(size_t num_iterations, const std::function< void(size_t)> &func)
Definition thread.cpp:111
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
void throw_or_abort(std::string const &err)