Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
msm_builder.hpp
Go to the documentation of this file.
1// === AUDIT STATUS ===
2// internal: { status: not started, auditors: [], date: YYYY-MM-DD }
3// external_1: { status: not started, auditors: [], date: YYYY-MM-DD }
4// external_2: { status: not started, auditors: [], date: YYYY-MM-DD }
5// =====================
6
7#pragma once
8
9#include <cstddef>
10
15
16namespace bb {
17
19 public:
22 using Element = typename CycleGroup::element;
23 using AffineElement = typename CycleGroup::affine_element;
25
26 static constexpr size_t ADDITIONS_PER_ROW = bb::eccvm::ADDITIONS_PER_ROW;
27 static constexpr size_t NUM_WNAF_DIGITS_PER_SCALAR = bb::eccvm::NUM_WNAF_DIGITS_PER_SCALAR;
28
29 struct alignas(64) MSMRow {
30 uint32_t pc = 0; // decreasing point-counter, over all half-length (128 bit) scalar muls used to compute
31 // the required MSMs. however, this value is _constant_ on a given MSM and more precisely
32 // refers to the number of half-length scalar muls completed up until we have started
33 // the current MSM.
34 uint32_t msm_size = 0; // the number of points in (a.k.a. the length of) the MSM in whose computation
35 // this VM row participates
36 uint32_t msm_count = 0; // number of multiplications processed so far (not including this row) in current MSM
37 // round (a.k.a. wNAF digit slot). this specifically refers to the number of wNAF-digit
38 // * point scalar products we have looked up and accumulated.
39 uint32_t msm_round = 0; // current "round" of MSM, in {0, ..., 32 = `NUM_WNAF_DIGITS_PER_SCALAR`}. With the
40 // Straus algorithm, we proceed wNAF digit by wNAF digit, from left to right. (final
41 // round deals with the `skew` bit.)
42 bool msm_transition = false; // is 1 if the current row *starts* the processing of a different MSM, else 0.
43 bool q_add = false;
44 bool q_double = false;
45 bool q_skew = false;
46
47 // Each row in the MSM portion of the ECCVM can handle (up to) 4 point-additions.
48 // For each row in the VM we represent the point addition data via a size-4 array of
49 // AddState objects.
50 struct AddState {
51 bool add = false; // are we adding a point at this location in the VM?
52 // e.g if the MSM is of size-2 then the 3rd and 4th AddState objects will have this set
53 // to `false`.
54 int slice = 0; // wNAF slice value. This has values in {0, ..., 15} and corresponds to an odd number in the
55 // range {-15, -13, ..., 15} via the monotonic bijection.
56 AffineElement point{ 0, 0 }; // point being added into the accumulator. (This is of the form nP,
57 // where n is in {-15, -13, ..., 15}.)
58 FF lambda = 0; // when adding `point` into the accumulator via Affine point addition, the value of `lambda`
59 // (i.e., the slope of the line). (we need this as a witness in the circuit.)
60 FF collision_inverse = 0; // `collision_inverse` is used to validate we are not hitting point addition edge
61 // case exceptions, i.e., we want the VM proof to fail if we're doing a point
62 // addition where (x1 == x2). to do this, we simply provide an inverse to x1 - x2.
63 };
64 std::array<AddState, 4> add_state{ AddState{ false, 0, { 0, 0 }, 0, 0 },
65 AddState{ false, 0, { 0, 0 }, 0, 0 },
66 AddState{ false, 0, { 0, 0 }, 0, 0 },
67 AddState{ false, 0, { 0, 0 }, 0, 0 } };
68 // The accumulator here is, in general, the result of four EC additions: A + Q_1 + Q_2 + Q_3 + Q_4.
69 // We do not explicitly store the intermediate values A + Q_1, A + Q_1 + Q_2, and A + Q_1 + Q_2 + Q_3, although
70 // these values are implicitly used in the values of `AddState.lambda` and `AddState.collision_inverse`.
71
72 FF accumulator_x = 0; // `(accumulator_x, accumulator_y)` is the accumulator to which I potentially want to add
73 // the points in `add_state`.
74 FF accumulator_y = 0; // `(accumulator_x, accumulator_y)` is the accumulator to which I potentially want to add
75 // the points in `add_state`.
76 };
77
91 const std::vector<MSM>& msms, const uint32_t total_number_of_muls, const size_t num_msm_rows)
92 {
93 // To perform a scalar multiplication of a point P by a scalar x, we precompute a table of points
94 // -15P, -13P, ..., -3P, -P, P, 3P, ..., 15P
95 // When we perform a scalar multiplication, we decompose x into base-16 wNAF digits then look these precomputed
96 // values up with digit-by-digit. As we are performing lookups with the log-derivative argument, we have to
97 // record read counts. We record read counts in a table with the following structure:
98 // 1st write column = positive wNAF digits
99 // 2nd write column = negative wNAF digits
100 // the row number is a function of pc and wnaf digit:
101 // point_idx = total_number_of_muls - pc
102 // row = point_idx * rows_per_point_table + (some function of the slice value)
103 //
104 // Illustration:
105 // Block Structure:
106 // | 0 | 1 |
107 // | - | - |
108 // 1 | # | # | -1
109 // 3 | # | # | -3
110 // 5 | # | # | -5
111 // 7 | # | # | -7
112 // 9 | # | # | -9
113 // 11 | # | # | -11
114 // 13 | # | # | -13
115 // 15 | # | # | -15
116 //
117 // Table structure:
118 // | Block_{0} | <-- pc = total_number_of_muls
119 // | Block_{1} | <-- pc = total_number_of_muls-(num muls in msm 0)
120 // | ... | ...
121 // | Block_{total_number_of_muls-1} | <-- pc = num muls in last msm
122
123 const size_t num_rows_in_read_counts_table =
124 static_cast<size_t>(total_number_of_muls) *
125 (eccvm::POINT_TABLE_SIZE >> 1); // `POINT_TABLE_SIZE` is 2ʷ, where in our case w = 4. As noted above, with
126 // respect to *read counts*, we are record looking up the positive and
127 // negative odd multiples of [P] in two separate columns, each of size 2ʷ⁻¹.
128 std::array<std::vector<size_t>, 2> point_table_read_counts;
129 point_table_read_counts[0].reserve(num_rows_in_read_counts_table);
130 point_table_read_counts[1].reserve(num_rows_in_read_counts_table);
131 for (size_t i = 0; i < num_rows_in_read_counts_table; ++i) {
132 point_table_read_counts[0].emplace_back(0);
133 point_table_read_counts[1].emplace_back(0);
134 }
135
136 const auto update_read_count = [&point_table_read_counts](const size_t point_idx, const int slice) {
145 const size_t row_index_offset = point_idx * 8;
146 const bool digit_is_negative = slice < 0;
147 const auto relative_row_idx = static_cast<size_t>((slice + 15) / 2);
148 const size_t column_index = digit_is_negative ? 1 : 0;
149
150 if (digit_is_negative) {
151 point_table_read_counts[column_index][row_index_offset + relative_row_idx]++;
152 } else {
153 point_table_read_counts[column_index][row_index_offset + 15 - relative_row_idx]++;
154 }
155 };
156
157 // compute which row index each multiscalar multiplication will start at.
158 std::vector<size_t> msm_row_counts;
159 msm_row_counts.reserve(msms.size() + 1);
160 msm_row_counts.push_back(1);
161 // compute the point counter (i.e. the index among all single scalar muls) that each multiscalar
162 // multiplication will start at.
163 std::vector<size_t> pc_values;
164 pc_values.reserve(msms.size() + 1);
165 pc_values.push_back(total_number_of_muls);
166 for (const auto& msm : msms) {
167 const size_t num_rows_required = EccvmRowTracker::num_eccvm_msm_rows(msm.size());
168 msm_row_counts.push_back(msm_row_counts.back() + num_rows_required);
169 pc_values.push_back(pc_values.back() - msm.size());
170 }
171 BB_ASSERT_EQ(pc_values.back(), 0U);
172
173 // compute the MSM rows
174
175 std::vector<MSMRow> msm_rows(num_msm_rows);
176 // start with empty row (shiftable polynomials must have 0 as first coefficient)
177 msm_rows[0] = (MSMRow{});
178 // compute "read counts" so that we can determine the number of times entries in our log-derivative lookup
179 // tables are called.
180 // Note: this part is single-threaded. The amount of compute is low, however, so this is likely not a big
181 // concern.
182 for (size_t msm_idx = 0; msm_idx < msms.size(); ++msm_idx) {
183 for (size_t digit_idx = 0; digit_idx < NUM_WNAF_DIGITS_PER_SCALAR; ++digit_idx) {
184 auto pc = static_cast<uint32_t>(pc_values[msm_idx]);
185 const auto& msm = msms[msm_idx];
186 const size_t msm_size = msm.size();
187 const size_t num_rows_per_digit =
188 (msm_size / ADDITIONS_PER_ROW) + ((msm_size % ADDITIONS_PER_ROW != 0) ? 1 : 0);
189
190 for (size_t relative_row_idx = 0; relative_row_idx < num_rows_per_digit; ++relative_row_idx) {
191 const size_t num_points_in_row = (relative_row_idx + 1) * ADDITIONS_PER_ROW > msm_size
192 ? (msm_size % ADDITIONS_PER_ROW)
194 const size_t offset = relative_row_idx * ADDITIONS_PER_ROW;
195 for (size_t relative_point_idx = 0; relative_point_idx < ADDITIONS_PER_ROW; ++relative_point_idx) {
196 const size_t point_idx = offset + relative_point_idx;
197 const bool add = num_points_in_row > relative_point_idx;
198 if (add) {
199 int slice = msm[point_idx].wnaf_digits[digit_idx];
200 // pc starts at total_number_of_muls and decreses non-uniformly to 0
201 update_read_count((total_number_of_muls - pc) + point_idx, slice);
202 }
203 }
204 }
205
206 // update the log-derivative read count for the lookup associated with WNAF skew
207 if (digit_idx == NUM_WNAF_DIGITS_PER_SCALAR - 1) {
208 for (size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) {
209 const size_t num_points_in_row = (row_idx + 1) * ADDITIONS_PER_ROW > msm_size
210 ? (msm_size % ADDITIONS_PER_ROW)
212 const size_t offset = row_idx * ADDITIONS_PER_ROW;
213 for (size_t relative_point_idx = 0; relative_point_idx < ADDITIONS_PER_ROW;
214 ++relative_point_idx) {
215 bool add = num_points_in_row > relative_point_idx;
216 const size_t point_idx = offset + relative_point_idx;
217 if (add) {
218 // `pc` starts at total_number_of_muls and decreases non-uniformly to 0.
219 // -15 maps to the 1st point in the lookup table (array element 0)
220 // -1 maps to the point in the lookup table that corresponds to the negation of the
221 // original input point (i.e. the point we need to add into the accumulator if wnaf_skew
222 // is positive)
223 int slice = msm[point_idx].wnaf_skew ? -1 : -15;
224 update_read_count((total_number_of_muls - pc) + point_idx, slice);
225 }
226 }
227 }
228 }
229 }
230 }
231
232 // The execution trace data for the MSM columns requires knowledge of intermediate values from *affine* point
233 // addition. The naive solution to compute this data requires 2 field inversions per in-circuit group addition
234 // evaluation. This is bad! To avoid this, we split the witness computation algorithm into 3 steps.
235 // Step 1: compute the execution trace group operations in *projective* coordinates. (these will be stored in
236 // `p1_trace`, `p2_trace`, and `p3_trace`)
237 // Step 2: use batch inversion trick to convert all points into affine coordinates
238 // Step 3: populate the full execution trace, including the intermediate values from affine group
239 // operations
240 // This section sets up the data structures we need to store all intermediate ECC operations in projective form
241
242 const size_t num_point_adds_and_doubles =
243 (num_msm_rows - 2) * 4; // `num_msm_rows - 2` is the actual number of rows in the table required to compute
244 // the MSM; the msm table itself has a dummy row at the beginning and an extra row
245 // with the `x` and `y` coordinates of the accumulator at the end. (In general, the
246 // output of the accumulator from the computation at row `i` is present on row
247 // `i+1`. We multiply by 4 because each "row" of the VM processes 4 point-additions
248 // (and the fact that w = 4 means we must interleave with 4 doublings). This
249 // "corresponds" to the fact that `MSMROW.add_state` has 4 entries.
250 const size_t num_accumulators = num_msm_rows - 1; // for every row after the first row, we have an accumulator.
251 // In what follows, either p1 + p2 = p3, or p1.dbl() = p3
252 // We create 1 vector to store the entire point trace. We split into multiple containers using std::span
253 // (we want 1 vector object to more efficiently batch-normalize points)
254 static constexpr size_t NUM_POINTS_IN_ADDITION_RELATION = 3;
255 const size_t num_points_to_normalize =
256 (num_point_adds_and_doubles * NUM_POINTS_IN_ADDITION_RELATION) + num_accumulators;
257 std::vector<Element> points_to_normalize(num_points_to_normalize);
258 std::span<Element> p1_trace(&points_to_normalize[0], num_point_adds_and_doubles);
259 std::span<Element> p2_trace(&points_to_normalize[num_point_adds_and_doubles], num_point_adds_and_doubles);
260 std::span<Element> p3_trace(&points_to_normalize[num_point_adds_and_doubles * 2], num_point_adds_and_doubles);
261 // `is_double_or_add` records whether an entry in the p1/p2/p3 trace represents a point addition or
262 // doubling. if it is `true`, then we are doubling (i.e., the condition is that `p3 = p1.dbl()`), else we are
263 // adding (i.e., the condition is that `p3 = p1 + p2`).
264 std::vector<bool> is_double_or_add(num_point_adds_and_doubles);
265 // accumulator_trace tracks the value of the ECCVM accumulator for each row
266 std::span<Element> accumulator_trace(&points_to_normalize[num_point_adds_and_doubles * 3], num_accumulators);
267
268 // we start the accumulator at the offset generator point
269 constexpr auto offset_generator = get_precomputed_generators<g1, "ECCVM_OFFSET_GENERATOR", 1>()[0];
270 accumulator_trace[0] = offset_generator;
271
272 // TODO(https://github.com/AztecProtocol/barretenberg/issues/973): Reinstate multitreading?
273 // populate point trace, and the components of the MSM execution trace that do not relate to affine point
274 // operations
275 for (size_t msm_idx = 0; msm_idx < msms.size(); msm_idx++) {
276 Element accumulator = offset_generator; // for every MSM, we start with the same `offset_generator`
277 const auto& msm = msms[msm_idx]; // which MSM we are processing. This is of type `std::vector<ScalarMul>`.
278 size_t msm_row_index = msm_row_counts[msm_idx]; // the row where the given MSM starts
279 const size_t msm_size = msm.size();
280 const size_t num_rows_per_digit =
281 (msm_size / ADDITIONS_PER_ROW) +
282 (msm_size % ADDITIONS_PER_ROW !=
283 0); // the Straus algorithm proceeds by incrementing through the digit-slots and doing
284 // computations *across* the `ScalarMul`s that make up our MSM. Each digit-slot therefore
285 // contributes the *ceiling* of `msm_size`/`ADDITIONS_PER_ROW`.
286 size_t trace_index =
287 (msm_row_counts[msm_idx] - 1) * 4; // tracks the index in the traces of `p1`, `p2`, `p3`, and
288 // `accumulator_trace` that we are filling out
289
290 // for each digit-slot (`digit_idx`), and then for each row of the VM (which does `ADDITIONS_PER_ROW` point
291 // additions), we either enter in/process (`ADDITIONS_PER_ROW`) `AddState` objects, and then if necessary
292 // (i.e., if not at the last wNAF digit), process the four doublings.
293 for (size_t digit_idx = 0; digit_idx < NUM_WNAF_DIGITS_PER_SCALAR; ++digit_idx) {
294 const auto pc = static_cast<uint32_t>(pc_values[msm_idx]); // pc that our msm starts at
295
296 for (size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) {
297 const size_t num_points_in_row = (row_idx + 1) * ADDITIONS_PER_ROW > msm_size
298 ? (msm_size % ADDITIONS_PER_ROW)
300 auto& row = msm_rows[msm_row_index]; // actual `MSMRow` we will fill out in the body of this loop
301 const size_t offset = row_idx * ADDITIONS_PER_ROW;
302 row.msm_transition = (digit_idx == 0) && (row_idx == 0);
303 // each iteration of this loop process/enters in one of the `AddState` objects in `row.add_state`.
304 for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) {
305 auto& add_state = row.add_state[point_idx];
306 add_state.add = num_points_in_row > point_idx;
307 int slice = add_state.add ? msm[offset + point_idx].wnaf_digits[digit_idx] : 0;
308 // In the MSM columns in the ECCVM circuit, we can add up to 4 points per row.
309 // if `row.add_state[point_idx].add = 1`, this indicates that we want to add the
310 // `point_idx`'th point in the MSM columns into the MSM accumulator `add_state.slice` = A
311 // 4-bit WNAF slice of the scalar multiplier associated with the point we are adding (the
312 // specific slice chosen depends on the value of msm_round) (WNAF = our version of
313 // windowed-non-adjacent-form. Value range is `-15, -13,..., 15`)
314 // If `add_state.add = 1`, we want `add_state.slice` to be the *compressed*
315 // form of the WNAF slice value. (compressed = no gaps in the value range. i.e. -15,
316 // -13, ..., 15 maps to 0, ... , 15)
317 add_state.slice = add_state.add ? (slice + 15) / 2 : 0;
318 add_state.point =
319 add_state.add
320 ? msm[offset + point_idx].precomputed_table[static_cast<size_t>(add_state.slice)]
321 : AffineElement{ 0, 0 };
322
323 Element p1(accumulator);
324 Element p2(add_state.point);
325 accumulator = add_state.add ? (accumulator + add_state.point) : Element(p1);
326 p1_trace[trace_index] = p1;
327 p2_trace[trace_index] = p2;
328 p3_trace[trace_index] = accumulator;
329 is_double_or_add[trace_index] = false;
330 trace_index++;
331 }
332 // Now, `row.add_state` has been fully processed and we fill in the rest of the members of `row`.
333 accumulator_trace[msm_row_index] = accumulator;
334 row.q_add = true;
335 row.q_double = false;
336 row.q_skew = false;
337 row.msm_round = static_cast<uint32_t>(digit_idx);
338 row.msm_size = static_cast<uint32_t>(msm_size);
339 row.msm_count = static_cast<uint32_t>(offset);
340 row.pc = pc;
341 msm_row_index++;
342 }
343 // after processing each digit-slot, we now take care of doubling (as long as we are not at the last
344 // digit). We add an `MSMRow`, `row`, whose four `AddState` objects in `row.add_state`
345 // are null, but we also populate `p1_trace`, `p2_trace`, `p3_trace`, and `is_double_or_add` for four
346 // indices, corresponding to the w=4 doubling operations we need to perform. This embodies the numerical
347 // "coincidence" that `ADDITIONS_PER_ROW == NUM_WNAF_DIGIT_BITS`
348 if (digit_idx < NUM_WNAF_DIGITS_PER_SCALAR - 1) {
349 auto& row = msm_rows[msm_row_index];
350 row.msm_transition = false;
351 row.msm_round = static_cast<uint32_t>(digit_idx + 1);
352 row.msm_size = static_cast<uint32_t>(msm_size);
353 row.msm_count = static_cast<uint32_t>(0);
354 row.q_add = false;
355 row.q_double = true;
356 row.q_skew = false;
357 for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) {
358 auto& add_state = row.add_state[point_idx];
359 add_state.add = false;
360 add_state.slice = 0;
361 add_state.point = { 0, 0 };
362 add_state.collision_inverse = 0;
363
364 p1_trace[trace_index] = accumulator;
365 p2_trace[trace_index] = accumulator; // dummy
366 accumulator = accumulator.dbl();
367 p3_trace[trace_index] = accumulator;
368 is_double_or_add[trace_index] = true;
369 trace_index++;
370 }
371 accumulator_trace[msm_row_index] = accumulator;
372 msm_row_index++;
373 } else // process `wnaf_skew`, i.e., the skew digit.
374 {
375 for (size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) {
376 auto& row = msm_rows[msm_row_index];
377
378 const size_t num_points_in_row = (row_idx + 1) * ADDITIONS_PER_ROW > msm_size
379 ? msm_size % ADDITIONS_PER_ROW
381 const size_t offset = row_idx * ADDITIONS_PER_ROW;
382 row.msm_transition = false;
383 Element acc_expected = accumulator;
384 for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) {
385 auto& add_state = row.add_state[point_idx];
386 add_state.add = num_points_in_row > point_idx;
387 add_state.slice = add_state.add ? msm[offset + point_idx].wnaf_skew ? 7 : 0 : 0;
388
389 add_state.point =
390 add_state.add
391 ? msm[offset + point_idx].precomputed_table[static_cast<size_t>(add_state.slice)]
393 0, 0
394 }; // if the skew_bit is on, `slice == 7`. Then `precomputed_table[7] == -[P]`, as
395 // required for the skew logic.
396 bool add_predicate = add_state.add ? msm[offset + point_idx].wnaf_skew : false;
397 auto p1 = accumulator;
398 accumulator = add_predicate ? accumulator + add_state.point : accumulator;
399 p1_trace[trace_index] = p1;
400 p2_trace[trace_index] = add_state.point;
401 p3_trace[trace_index] = accumulator;
402 is_double_or_add[trace_index] = false;
403 trace_index++;
404 }
405 row.q_add = false;
406 row.q_double = false;
407 row.q_skew = true;
408 row.msm_round = static_cast<uint32_t>(digit_idx + 1);
409 row.msm_size = static_cast<uint32_t>(msm_size);
410 row.msm_count = static_cast<uint32_t>(offset);
411 row.pc = pc;
412 accumulator_trace[msm_row_index] = accumulator;
413 msm_row_index++;
414 }
415 }
416 }
417 }
418
419 // Normalize the points in the point trace
420 parallel_for_range(points_to_normalize.size(), [&](size_t start, size_t end) {
421 Element::batch_normalize(&points_to_normalize[start], end - start);
422 });
423
424 // inverse_trace is used to compute the value of the `collision_inverse` column in the ECCVM.
425 std::vector<FF> inverse_trace(num_point_adds_and_doubles);
426 parallel_for_range(num_point_adds_and_doubles, [&](size_t start, size_t end) {
427 for (size_t operation_idx = start; operation_idx < end; ++operation_idx) {
428 if (is_double_or_add[operation_idx]) {
429 inverse_trace[operation_idx] = (p1_trace[operation_idx].y + p1_trace[operation_idx].y);
430 } else {
431 inverse_trace[operation_idx] = (p2_trace[operation_idx].x - p1_trace[operation_idx].x);
432 }
433 }
434 FF::batch_invert(&inverse_trace[start], end - start);
435 });
436
437 // complete the computation of the ECCVM execution trace, by adding the affine intermediate point data
438 // i.e. row.accumulator_x, row.accumulator_y, row.add_state[0...3].collision_inverse,
439 // row.add_state[0...3].lambda
440 for (size_t msm_idx = 0; msm_idx < msms.size(); msm_idx++) {
441 const auto& msm = msms[msm_idx];
442 size_t trace_index = ((msm_row_counts[msm_idx] - 1) * ADDITIONS_PER_ROW);
443 size_t msm_row_index = msm_row_counts[msm_idx];
444 // 1st MSM row will have accumulator equal to the previous MSM output (or point at infinity for first MSM)
445 size_t accumulator_index = msm_row_counts[msm_idx] - 1;
446 const size_t msm_size = msm.size();
447 const size_t num_rows_per_digit =
448 (msm_size / ADDITIONS_PER_ROW) + ((msm_size % ADDITIONS_PER_ROW != 0) ? 1 : 0);
449
450 for (size_t digit_idx = 0; digit_idx < NUM_WNAF_DIGITS_PER_SCALAR; ++digit_idx) {
451 for (size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) {
452 auto& row = msm_rows[msm_row_index];
453 // note that we do not store the "intermediate accumulators" that are implicit *within* a row (i.e.,
454 // within a given `add_state` object). This is the reason why accumulator_index only increments once
455 // per `row_idx`.
456 const Element& normalized_accumulator = accumulator_trace[accumulator_index];
457 BB_ASSERT_EQ(normalized_accumulator.is_point_at_infinity(), 0);
458 row.accumulator_x = normalized_accumulator.x;
459 row.accumulator_y = normalized_accumulator.y;
460 for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) {
461 auto& add_state = row.add_state[point_idx];
462
463 const auto& inverse = inverse_trace[trace_index];
464 const auto& p1 = p1_trace[trace_index];
465 const auto& p2 = p2_trace[trace_index];
466 add_state.collision_inverse = add_state.add ? inverse : 0;
467 add_state.lambda = add_state.add ? (p2.y - p1.y) * inverse : 0;
468 trace_index++;
469 }
470 accumulator_index++;
471 msm_row_index++;
472 }
473
474 // if digit_idx < NUM_WNAF_DIGITS_PER_SCALAR - 1 we have to fill out our doubling row (which in fact
475 // amounts to 4 doublings)
476 if (digit_idx < NUM_WNAF_DIGITS_PER_SCALAR - 1) {
477 MSMRow& row = msm_rows[msm_row_index];
478 const Element& normalized_accumulator = accumulator_trace[accumulator_index];
479 const FF& acc_x = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.x;
480 const FF& acc_y = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.y;
481 row.accumulator_x = acc_x;
482 row.accumulator_y = acc_y;
483 for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) {
484 auto& add_state = row.add_state[point_idx];
485 add_state.collision_inverse = 0; // no notion of "different x values" for a point doubling
486 const FF& dx = p1_trace[trace_index].x;
487 const FF& inverse = inverse_trace[trace_index]; // here, 2y
488 add_state.lambda = ((dx + dx + dx) * dx) * inverse;
489 trace_index++;
490 }
491 accumulator_index++;
492 msm_row_index++;
493 } else // this row corresponds to performing point additions to handle WNAF skew
494 // i.e. iterate over all the points in the MSM - if for a given point, `wnaf_skew == 1`,
495 // subtract the original point from the accumulator. if `digit_idx == NUM_WNAF_DIGITS_PER_SCALAR
496 // - 1` we have finished executing our double-and-add algorithm.
497 {
498 for (size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) {
499 MSMRow& row = msm_rows[msm_row_index];
500 const Element& normalized_accumulator = accumulator_trace[accumulator_index];
501 BB_ASSERT_EQ(normalized_accumulator.is_point_at_infinity(), 0);
502 const size_t offset = row_idx * ADDITIONS_PER_ROW;
503 row.accumulator_x = normalized_accumulator.x;
504 row.accumulator_y = normalized_accumulator.y;
505 for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) {
506 auto& add_state = row.add_state[point_idx];
507 bool add_predicate = add_state.add ? msm[offset + point_idx].wnaf_skew : false;
508
509 const auto& inverse = inverse_trace[trace_index];
510 const auto& p1 = p1_trace[trace_index];
511 const auto& p2 = p2_trace[trace_index];
512 add_state.collision_inverse = add_predicate ? inverse : 0;
513 add_state.lambda = add_predicate ? (p2.y - p1.y) * inverse : 0;
514 trace_index++;
515 }
516 accumulator_index++;
517 msm_row_index++;
518 }
519 }
520 }
521 }
522
523 // populate the final row in the MSM execution trace.
524 // we always require 1 extra row at the end of the trace, because the x and y coordinates of the accumulator for
525 // row `i` are present at row `i+1`
526 Element final_accumulator(accumulator_trace.back());
527 MSMRow& final_row = msm_rows.back();
528 final_row.pc = static_cast<uint32_t>(pc_values.back());
529 final_row.msm_transition = true;
530 final_row.accumulator_x = final_accumulator.is_point_at_infinity() ? 0 : final_accumulator.x;
531 final_row.accumulator_y = final_accumulator.is_point_at_infinity() ? 0 : final_accumulator.y;
532 final_row.msm_size = 0;
533 final_row.msm_count = 0;
534 final_row.q_add = false;
535 final_row.q_double = false;
536 final_row.q_skew = false;
537 final_row.add_state = { typename MSMRow::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 },
538 typename MSMRow::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 },
539 typename MSMRow::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 },
540 typename MSMRow::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 } };
541
542 return { msm_rows, point_table_read_counts };
543 }
544};
545} // namespace bb
#define BB_ASSERT_EQ(actual, expected,...)
Definition assert.hpp:88
static constexpr size_t ADDITIONS_PER_ROW
static constexpr size_t NUM_WNAF_DIGITS_PER_SCALAR
static std::tuple< std::vector< MSMRow >, std::array< std::vector< size_t >, 2 > > compute_rows(const std::vector< MSM > &msms, const uint32_t total_number_of_muls, const size_t num_msm_rows)
Computes the row values for the Straus MSM columns of the ECCVM.
curve::BN254::Group CycleGroup
typename CycleGroup::affine_element AffineElement
bb::eccvm::MSM< CycleGroup > MSM
typename CycleGroup::element Element
static uint32_t num_eccvm_msm_rows(const size_t msm_size)
Get the number of rows in the 'msm' column section of the ECCVM associated with a single multiscalar ...
typename bb::g1 Group
Definition bn254.hpp:20
ssize_t offset
Definition engine.cpp:36
std::vector< ScalarMul< CycleGroup > > MSM
Entry point for Barretenberg command-line interface.
group< fq, fr, Bn254G1Params > g1
Definition g1.hpp:33
C slice(C const &container, size_t start)
Definition container.hpp:9
constexpr std::span< const typename Group::affine_element > get_precomputed_generators()
void parallel_for_range(size_t num_points, const std::function< void(size_t, size_t)> &func, size_t no_multhreading_if_less_or_equal)
Split a loop into several loops running in parallel.
Definition thread.cpp:141
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
std::array< AddState, 4 > add_state
static void batch_invert(C &coeffs) noexcept