Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
ecc_msm_relation_impl.hpp
Go to the documentation of this file.
1// === AUDIT STATUS ===
2// internal: { status: not started, auditors: [], date: YYYY-MM-DD }
3// external_1: { status: not started, auditors: [], date: YYYY-MM-DD }
4// external_2: { status: not started, auditors: [], date: YYYY-MM-DD }
5// =====================
6
7#pragma once
10#include "ecc_msm_relation.hpp"
11
12namespace bb {
13
46template <typename FF>
47template <typename ContainerOverSubrelations, typename AllEntities, typename Parameters>
48void ECCVMMSMRelationImpl<FF>::accumulate(ContainerOverSubrelations& accumulator,
49 const AllEntities& in,
50 const Parameters& /*unused*/,
51 const FF& scaling_factor)
52{
53 using Accumulator = typename std::tuple_element_t<0, ContainerOverSubrelations>;
54 using View = typename Accumulator::View;
55
56 const auto& x1 = View(in.msm_x1);
57 const auto& y1 = View(in.msm_y1);
58 const auto& x2 = View(in.msm_x2);
59 const auto& y2 = View(in.msm_y2);
60 const auto& x3 = View(in.msm_x3);
61 const auto& y3 = View(in.msm_y3);
62 const auto& x4 = View(in.msm_x4);
63 const auto& y4 = View(in.msm_y4);
64 const auto& collision_inverse1 = View(in.msm_collision_x1);
65 const auto& collision_inverse2 = View(in.msm_collision_x2);
66 const auto& collision_inverse3 = View(in.msm_collision_x3);
67 const auto& collision_inverse4 = View(in.msm_collision_x4);
68 const auto& lambda1 = View(in.msm_lambda1);
69 const auto& lambda2 = View(in.msm_lambda2);
70 const auto& lambda3 = View(in.msm_lambda3);
71 const auto& lambda4 = View(in.msm_lambda4);
72 const auto& lagrange_first = View(in.lagrange_first);
73 const auto& add1 = View(in.msm_add1);
74 const auto& add1_shift = View(in.msm_add1_shift);
75 const auto& add2 = View(in.msm_add2);
76 const auto& add3 = View(in.msm_add3);
77 const auto& add4 = View(in.msm_add4);
78 const auto& acc_x = View(in.msm_accumulator_x);
79 const auto& acc_y = View(in.msm_accumulator_y);
80 const auto& acc_x_shift = View(in.msm_accumulator_x_shift);
81 const auto& acc_y_shift = View(in.msm_accumulator_y_shift);
82 const auto& slice1 = View(in.msm_slice1);
83 const auto& slice2 = View(in.msm_slice2);
84 const auto& slice3 = View(in.msm_slice3);
85 const auto& slice4 = View(in.msm_slice4);
86 const auto& msm_transition = View(in.msm_transition);
87 const auto& msm_transition_shift = View(in.msm_transition_shift);
88 const auto& round = View(in.msm_round);
89 const auto& round_shift = View(in.msm_round_shift);
90 const auto& q_add = View(in.msm_add); // is 1 iff we are at an ADD row in Straus algorithm
91 const auto& q_add_shift = View(in.msm_add_shift);
92 const auto& q_skew = View(in.msm_skew);
93 const auto& q_skew_shift = View(in.msm_skew_shift);
94 const auto& q_double = View(in.msm_double); // is 1 iff we are at an DOUBLE row in Straus algorithm
95 const auto& q_double_shift = View(in.msm_double_shift);
96 const auto& msm_size = View(in.msm_size_of_msm);
97 const auto& pc = View(in.msm_pc); // pc stands for `point-counter`.
98 const auto& pc_shift = View(in.msm_pc_shift);
99 const auto& count = View(in.msm_count);
100 const auto& count_shift = View(in.msm_count_shift);
101 auto is_not_first_row = (-lagrange_first + 1);
102
267 auto add = [&](auto& xb,
268 auto& yb,
269 auto& xa,
270 auto& ya,
271 auto& lambda,
272 auto& selector,
273 auto& relation,
274 auto& collision_relation) {
275 // computation of lambda is valid: if q == 1, then L == (yb - ya) / (xb - xa)
276 // if q == 0, then L == 0. combining these into a single constraint yields:
277 // q * (L * (xb - xa - 1) - (yb - ya)) + L = 0
278 relation += selector * (lambda * (xb - xa - 1) - (yb - ya)) + lambda;
279 collision_relation += selector * (xb - xa);
280 // x_out = L.L + (-xb - xa) * q + (1 - q) xa
281 // deg L = 1, deg q = 1, min(deg(xa), deg(xb))≥ 1.
282 // hence deg(x_out) = 1 + max(deg(xa, xb))
283 auto x_out = lambda.sqr() + (-xb - xa - xa) * selector + xa;
284
285 // y_out = L . (xa - x_out) - ya * q + (1 - q) ya
286 // hence deg(y_out) = max(1 + deg(x_out), 1 + deg(ya))
287 auto y_out = lambda * (xa - x_out) + (-ya - ya) * selector + ya;
288 return std::array<Accumulator, 2>{ x_out, y_out };
289 };
290
305 auto first_add = [&](auto& xb,
306 auto& yb,
307 auto& xa,
308 auto& ya,
309 auto& lambda,
310 auto& selector,
311 auto& relation,
312 auto& collision_relation) {
313 // N.B. this is brittle - should be curve agnostic but we don't propagate the curve parameter into relations!
314 constexpr auto offset_generator = get_precomputed_generators<g1, "ECCVM_OFFSET_GENERATOR", 1>()[0];
315 constexpr uint256_t oxu = offset_generator.x;
316 constexpr uint256_t oyu = offset_generator.y;
317 const Accumulator xo(oxu);
318 const Accumulator yo(oyu);
319 // set (x, y) to be either accumulator if `selector == 0` or OFFSET if `selector == 1`.
320 auto x = xo * selector + xb * (-selector + 1);
321 auto y = yo * selector + yb * (-selector + 1);
322 relation += lambda * (x - xa) - (y - ya); // degree 3
323 collision_relation += (xa - x);
324 auto x_out = lambda * lambda + (-x - xa);
325 auto y_out = lambda * (xa - x_out) - ya;
326 return std::array<Accumulator, 2>{ x_out, y_out };
327 };
328
329 // ADD operations (if row represents ADD round, not SKEW or DOUBLE)
330 Accumulator add_relation(0); // validates the correctness of all elliptic curve additions.
331 Accumulator x1_collision_relation(0);
332 Accumulator x2_collision_relation(0);
333 Accumulator x3_collision_relation(0);
334 Accumulator x4_collision_relation(0);
335 // If `msm_transition == 1`, we have started a new MSM. We need to treat the current value of [Acc] as the point at
336 // infinity!
337 auto [x_t1, y_t1] =
338 first_add(acc_x, acc_y, x1, y1, lambda1, msm_transition, add_relation, x1_collision_relation); // [deg 2, deg 3]
339 auto [x_t2, y_t2] = add(x2, y2, x_t1, y_t1, lambda2, add2, add_relation, x2_collision_relation); // [deg 3, deg 4]
340 auto [x_t3, y_t3] = add(x3, y3, x_t2, y_t2, lambda3, add3, add_relation, x3_collision_relation); // [deg 4, deg 5]
341 auto [x_t4, y_t4] = add(x4, y4, x_t3, y_t3, lambda4, add4, add_relation, x4_collision_relation); // [deg 5, deg 6]
342
343 // Validate accumulator output matches ADD output if q_add = 1
344 std::get<0>(accumulator) += q_add * (acc_x_shift - x_t4) * scaling_factor;
345 std::get<1>(accumulator) += q_add * (acc_y_shift - y_t4) * scaling_factor;
346 std::get<2>(accumulator) += q_add * add_relation * scaling_factor;
347
355 auto dbl = [&](auto& x, auto& y, auto& lambda, auto& relation) {
356 auto two_x = x + x;
357 relation += lambda * (y + y) - (two_x + x) * x;
358 auto x_out = lambda.sqr() - two_x;
359 auto y_out = lambda * (x - x_out) - y;
360 return std::array<Accumulator, 2>{ x_out, y_out };
361 };
362
378 Accumulator double_relation(0);
379 auto [x_d1, y_d1] = dbl(acc_x, acc_y, lambda1, double_relation);
380 auto [x_d2, y_d2] = dbl(x_d1, y_d1, lambda2, double_relation);
381 auto [x_d3, y_d3] = dbl(x_d2, y_d2, lambda3, double_relation);
382 auto [x_d4, y_d4] = dbl(x_d3, y_d3, lambda4, double_relation);
383 std::get<10>(accumulator) += q_double * (acc_x_shift - x_d4) * scaling_factor;
384 std::get<11>(accumulator) += q_double * (acc_y_shift - y_d4) * scaling_factor;
385 std::get<12>(accumulator) += q_double * double_relation * scaling_factor;
386
400 Accumulator skew_relation(0);
401 static FF inverse_seven = FF(7).invert();
402 auto skew1_select = slice1 * inverse_seven;
403 auto skew2_select = slice2 * inverse_seven;
404 auto skew3_select = slice3 * inverse_seven;
405 auto skew4_select = slice4 * inverse_seven;
406 Accumulator x1_skew_collision_relation(0);
407 Accumulator x2_skew_collision_relation(0);
408 Accumulator x3_skew_collision_relation(0);
409 Accumulator x4_skew_collision_relation(0);
410 // add skew points iff row is a SKEW row AND slice = 7 (point_table[7] maps to -[P])
411 // N.B. while it would be nice to have one `add` relation for both ADD and SKEW rounds,
412 // this would increase degree of sumcheck identity vs evaluating them separately.
413 // This is because, for add rounds, the result of adding [P1], [Acc] is [P1 + Acc] or [P1]
414 // but for skew rounds, the result of adding [P1], [Acc] is [P1 + Acc] or [Acc]
415 auto [x_s1, y_s1] = add(x1, y1, acc_x, acc_y, lambda1, skew1_select, skew_relation, x1_skew_collision_relation);
416 auto [x_s2, y_s2] = add(x2, y2, x_s1, y_s1, lambda2, skew2_select, skew_relation, x2_skew_collision_relation);
417 auto [x_s3, y_s3] = add(x3, y3, x_s2, y_s2, lambda3, skew3_select, skew_relation, x3_skew_collision_relation);
418 auto [x_s4, y_s4] = add(x4, y4, x_s3, y_s3, lambda4, skew4_select, skew_relation, x4_skew_collision_relation);
419
420 // Validate accumulator output matches SKEW output if q_skew = 1
421 std::get<3>(accumulator) += q_skew * (acc_x_shift - x_s4) * scaling_factor;
422 std::get<4>(accumulator) += q_skew * (acc_y_shift - y_s4) * scaling_factor;
423 std::get<5>(accumulator) += q_skew * skew_relation * scaling_factor;
424
425 // Check x-coordinates do not collide if row is an ADD row or a SKEW row
426 // if either q_add or q_skew = 1, an inverse should exist for each computed relation
427 // Step 1: construct boolean selectors that describe whether we added a point at the current row
428 const auto add_first_point = add1 * q_add + q_skew * skew1_select;
429 const auto add_second_point = add2 * q_add + q_skew * skew2_select;
430 const auto add_third_point = add3 * q_add + q_skew * skew3_select;
431 const auto add_fourth_point = add4 * q_add + q_skew * skew4_select;
432 // Step 2: construct the difference a.k.a. delta between x-coordinates for each point add (depending on if row is
433 // ADD or SKEW)
434 const auto x1_delta = x1_skew_collision_relation * q_skew + x1_collision_relation * q_add;
435 const auto x2_delta = x2_skew_collision_relation * q_skew + x2_collision_relation * q_add;
436 const auto x3_delta = x3_skew_collision_relation * q_skew + x3_collision_relation * q_add;
437 const auto x4_delta = x4_skew_collision_relation * q_skew + x4_collision_relation * q_add;
438 // Step 3: x_delta * inverse - 1 = 0 if we performed a point addition (else x_delta * inverse = 0)
439 std::get<6>(accumulator) += (x1_delta * collision_inverse1 - add_first_point) * scaling_factor;
440 std::get<7>(accumulator) += (x2_delta * collision_inverse2 - add_second_point) * scaling_factor;
441 std::get<8>(accumulator) += (x3_delta * collision_inverse3 - add_third_point) * scaling_factor;
442 std::get<9>(accumulator) += (x4_delta * collision_inverse4 - add_fourth_point) * scaling_factor;
443
444 // When add_i = 0, force slice_i to ALSO be 0
445 std::get<13>(accumulator) += (-add1 + 1) * slice1 * scaling_factor;
446 std::get<14>(accumulator) += (-add2 + 1) * slice2 * scaling_factor;
447 std::get<15>(accumulator) += (-add3 + 1) * slice3 * scaling_factor;
448 std::get<16>(accumulator) += (-add4 + 1) * slice4 * scaling_factor;
449
450 // SELECTORS ARE MUTUALLY EXCLUSIVE
451 // at most one of q_skew, q_double, q_add can be nonzero.
452 // note that as we can expect our table to be zero padded, we _do not_ insist that q_add + q_double + q_skew == 1.
453 std::get<17>(accumulator) += (q_add * q_double + q_add * q_skew + q_double * q_skew) * scaling_factor;
454
455 // Validate that if q_add = 1 or q_skew = 1, add1 also is 1
456 // NOTE(#2222): could just get rid of add1 as a column, as it is a linear combination.
457 std::get<32>(accumulator) += (add1 - q_add - q_skew) * scaling_factor;
458
459 // ROUND TRANSITION LOGIC
460 // `round_transition` describes whether we are transitioning between "rounds" of the MSM according to the Straus
461 // algorithm. In particular, the `round` corresponds to the wNAF digit we are currently processing.
462
463 const auto round_delta = round_shift - round;
464 // If `msm_transition == 0` (next row) then `round_delta` is boolean; the round is internal to a given MSM and
465 // represents the wNAF digit currently being processed. `round_delta == 0` means that the current and next steps of
466 // the Straus algorithm are processing the same wNAF digit place.
467
468 // `round_transition == 0` if `round_delta == 0` or the next row is an MSM transition.
469 // if `round_transition != 1`, then `round_transition == round_delta == 1` by the following constraint.
470 // in particular, `round_transition` is boolean. (`round_delta` is not boolean precisely one step before an MSM
471 // transition, but that does not concern us here.)
472 const auto round_transition = round_delta * (-msm_transition_shift + 1);
473 std::get<18>(accumulator) += round_transition * (round_delta - 1) * scaling_factor;
474
475 // If `round_transition == 1`, then `round_delta == 1` and `msm_transition_shift == 0`. Therefore, we wish to
476 // constrain next row in the VM to either be a double (if `round != 31`) or skew (if `round == 31`). In either case,
477 // the point is that we have finished processing a wNAF digit place and need to either perform the doublings to move
478 // on to the next place _or_ we are at the last place and need to perform the skew computation to finish. These are
479 // equationally represented as:
480 // round_transition * skew_shift * (round - 31) = 0 (if round tx and skew, then round == 31);
481 // round_transition * (skew_shift + double_shift - 1) = 0 (if round tx, then skew XOR double = 1).
482 // (-round_delta + 1) * q_double_shift = 1 (if q_double_shift == 1, then round_transition = 1)
483 // together, these have the following implications: if round tx and round != 31, then double_shift = 1.
484 // conversely, if round tx and double_shift == 0, then `q_skew_shift == 1` (which then forces `round == 31`).
485 // similarly, if q_double_shift == 1, then round_transition == 0,
486 // the fact that a round_transition occurs at the first time skew_shift == 1 follows from the fact that skew == 1
487 // implies round == 32 and the above three relations, together with the _definition_ of round_transition.
488 std::get<19>(accumulator) += round_transition * q_skew_shift * (round - 31) * scaling_factor;
489 std::get<20>(accumulator) += round_transition * (q_skew_shift + q_double_shift - 1) * scaling_factor;
490 std::get<35>(accumulator) += (-round_delta + 1) * q_double_shift * scaling_factor;
491 // if the next is neither double nor skew, and we are not at an msm_transition, then round_delta = 0 and the next
492 // "row" of our VM is processing the same wNAF digit place.
493 std::get<21>(accumulator) += round_transition * (-q_double_shift + 1) * (-q_skew_shift + 1) * scaling_factor;
494
495 // CONSTRAINING Q_DOUBLE AND Q_SKEW
496 // NOTE: we have already constrained q_add, q_skew, and q_double to be mutually exclusive.
497
498 // if double, next add = 1. As q_double, q_add, and q_skew are mutually exclusive, this suffices to force
499 // q_double_shift == q_skew_shift == 0.
500 std::get<22>(accumulator) += q_double * (-q_add_shift + 1) * scaling_factor;
501 // if the current row has q_skew == 1 and the next row is _not_ an MSM transition, then q_skew_shift = 1.
502 // this forces q_skew to precisely correspond to the rows where `round == 32`. Indeed, note that the first q_skew
503 // bit is set correctly:
504 // round == 31, round_transition == 1 ==> q_skew_shift == 1. (if, to the contrary, q_double_shift == 1, then
505 // the q_add_shift_shift == 1, but we assume that we have correctly constrained the q_adds via the multiset
506 // argument. this means that q_double_shift == 0, which forces q_skew_shift == 1 because round_transition
507 // == 1.)
508 // this means that the first row with `round == 32` has q_skew == 1. then all subsequent q_skew entries must be 1,
509 // _until_ we start our new MSM.
510 std::get<33>(accumulator) += (-msm_transition_shift + 1) * q_skew * (-q_skew_shift + 1) * scaling_factor;
511 // if q_skew == 1, then round == 32. This is almost certainly redundant but psychologically useful to "constrain
512 // both ends".
513 std::get<34>(accumulator) += q_skew * (-round + 32) * scaling_factor;
514
515 // UPDATING THE COUNT
516
517 // if we are changing the `round` (i.e., starting to process a new wNAF digit or at an msm transition), the
518 // count_shift must be 0.
519 std::get<23>(accumulator) += round_delta * count_shift * scaling_factor;
520 // if msm_transition = 0 and round_transition = 0, then the next "row" of the VM is processing the same wNAF digit.
521 // this means that the count must increase: count_shift = count + add1 + add2 + add3 + add4
522 std::get<24>(accumulator) += (-msm_transition_shift + 1) * (-round_delta + 1) *
523 (count_shift - count - add1 - add2 - add3 - add4) * scaling_factor;
524
525 // at least one of the following must be true:
526 // the next step is an MSM transition;
527 // the next count is zero (meaning we are starting the processing of a new wNAF digit)
528 // the next step is processing the same wNAF digit (i.e., round_delta == 0)
529 // (note that at the start of a new MSM, the count is also zero, so the above are not mutually exclusive.)
530 std::get<25>(accumulator) +=
531 is_not_first_row * (-msm_transition_shift + 1) * round_delta * count_shift * scaling_factor;
532
533 // if msm_transition = 1, then round = 0.
534 std::get<26>(accumulator) += msm_transition * round * scaling_factor;
535
536 // if msm_transition_shift = 1, pc = pc_shift + msm_size
537 // NB: `ecc_set_relation` ensures `msm_size` maps to `transcript.msm_count` for the current value of `pc`
538 std::get<27>(accumulator) += is_not_first_row * msm_transition_shift * (msm_size + pc_shift - pc) * scaling_factor;
539
540 // Addition continuity checks
541 // We want to RULE OUT the following scenarios:
542 // Case 1: add2 = 1, add1 = 0
543 // Case 2: add3 = 1, add2 = 0
544 // Case 3: add4 = 1, add3 = 0
545 // These checks ensure that the current row does not skip points (for both ADD and SKEW ops)
546 // This is part of a wider set of checks we use to ensure that all point data is used in the assigned
547 // multiscalar multiplication operation (and not in a different MSM operation).
548 std::get<28>(accumulator) += add2 * (-add1 + 1) * scaling_factor;
549 std::get<29>(accumulator) += add3 * (-add2 + 1) * scaling_factor;
550 std::get<30>(accumulator) += add4 * (-add3 + 1) * scaling_factor;
551
552 // Final continuity check.
553 // If an addition spans two rows, we need to make sure that the following scenario is RULED OUT:
554 // add4 = 0 on the CURRENT row, add1 = 1 on the NEXT row
555 // We must apply the above for the two cases:
556 // Case 1: q_add = 1 on the CURRENT row, q_add = 1 on the NEXT row
557 // Case 2: q_skew = 1 on the CURRENT row, q_skew = 1 on the NEXT row
558 // (i.e. if q_skew = 1, q_add_shift = 1 this implies an MSM transition so we skip this continuity check)
559 std::get<31>(accumulator) +=
560 (q_add * q_add_shift + q_skew * q_skew_shift) * (-add4 + 1) * add1_shift * scaling_factor;
561
562 // remaining checks (done in ecc_set_relation.hpp, ecc_lookup_relation.hpp)
563 // when transition occurs, perform set membership lookup on (accumulator / pc / msm_size)
564 // perform set membership lookups on add_i * (pc / round / slice_i)
565 // perform lookups on (pc / slice_i / x / y)
566
567 // We look up wnaf slices by mapping round + pc -> slice
568 // We use an exact set membership check to validate that
569 // wnafs written in wnaf_relation == wnafs read in msm relation
570 // We use `add1/add2/add3/add4` to flag whether we are performing a wnaf read op
571 // We can set these to be Prover-defined as the set membership check implicitly ensures that the correct reads
572 // have occurred.
573}
574
575} // namespace bb
static void accumulate(ContainerOverSubrelations &accumulator, const AllEntities &in, const Parameters &, const FF &scaling_factor)
MSM relations that evaluate the Strauss multiscalar multiplication algorithm.
Entry point for Barretenberg command-line interface.
group< fq, fr, Bn254G1Params > g1
Definition g1.hpp:33
typename Flavor::FF FF
constexpr std::span< const typename Group::affine_element > get_precomputed_generators()
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
constexpr field invert() const noexcept