91 const std::vector<MSM>& msms,
const uint32_t total_number_of_muls,
const size_t num_msm_rows)
123 const size_t num_rows_in_read_counts_table =
124 static_cast<size_t>(total_number_of_muls) *
125 (eccvm::POINT_TABLE_SIZE >> 1);
129 point_table_read_counts[0].reserve(num_rows_in_read_counts_table);
130 point_table_read_counts[1].reserve(num_rows_in_read_counts_table);
131 for (
size_t i = 0; i < num_rows_in_read_counts_table; ++i) {
132 point_table_read_counts[0].emplace_back(0);
133 point_table_read_counts[1].emplace_back(0);
136 const auto update_read_count = [&point_table_read_counts](
const size_t point_idx,
const int slice) {
145 const size_t row_index_offset = point_idx * 8;
146 const bool digit_is_negative =
slice < 0;
147 const auto relative_row_idx =
static_cast<size_t>((
slice + 15) / 2);
148 const size_t column_index = digit_is_negative ? 1 : 0;
150 if (digit_is_negative) {
151 point_table_read_counts[column_index][row_index_offset + relative_row_idx]++;
153 point_table_read_counts[column_index][row_index_offset + 15 - relative_row_idx]++;
158 std::vector<size_t> msm_row_counts;
159 msm_row_counts.reserve(msms.size() + 1);
160 msm_row_counts.push_back(1);
163 std::vector<size_t> pc_values;
164 pc_values.reserve(msms.size() + 1);
165 pc_values.push_back(total_number_of_muls);
166 for (
const auto& msm : msms) {
168 msm_row_counts.push_back(msm_row_counts.back() + num_rows_required);
169 pc_values.push_back(pc_values.back() - msm.size());
182 for (
size_t msm_idx = 0; msm_idx < msms.size(); ++msm_idx) {
184 auto pc =
static_cast<uint32_t
>(pc_values[msm_idx]);
185 const auto& msm = msms[msm_idx];
186 const size_t msm_size = msm.size();
187 const size_t num_rows_per_digit =
190 for (
size_t relative_row_idx = 0; relative_row_idx < num_rows_per_digit; ++relative_row_idx) {
191 const size_t num_points_in_row = (relative_row_idx + 1) *
ADDITIONS_PER_ROW > msm_size
195 for (
size_t relative_point_idx = 0; relative_point_idx <
ADDITIONS_PER_ROW; ++relative_point_idx) {
196 const size_t point_idx =
offset + relative_point_idx;
197 const bool add = num_points_in_row > relative_point_idx;
199 int slice = msm[point_idx].wnaf_digits[digit_idx];
201 update_read_count((total_number_of_muls - pc) + point_idx,
slice);
208 for (
size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) {
214 ++relative_point_idx) {
215 bool add = num_points_in_row > relative_point_idx;
216 const size_t point_idx =
offset + relative_point_idx;
223 int slice = msm[point_idx].wnaf_skew ? -1 : -15;
224 update_read_count((total_number_of_muls - pc) + point_idx,
slice);
242 const size_t num_point_adds_and_doubles =
243 (num_msm_rows - 2) * 4;
250 const size_t num_accumulators = num_msm_rows - 1;
254 static constexpr size_t NUM_POINTS_IN_ADDITION_RELATION = 3;
255 const size_t num_points_to_normalize =
256 (num_point_adds_and_doubles * NUM_POINTS_IN_ADDITION_RELATION) + num_accumulators;
257 std::vector<Element> points_to_normalize(num_points_to_normalize);
259 std::span<Element> p2_trace(&points_to_normalize[num_point_adds_and_doubles], num_point_adds_and_doubles);
260 std::span<Element> p3_trace(&points_to_normalize[num_point_adds_and_doubles * 2], num_point_adds_and_doubles);
264 std::vector<bool> is_double_or_add(num_point_adds_and_doubles);
266 std::span<Element> accumulator_trace(&points_to_normalize[num_point_adds_and_doubles * 3], num_accumulators);
270 accumulator_trace[0] = offset_generator;
275 for (
size_t msm_idx = 0; msm_idx < msms.size(); msm_idx++) {
276 Element accumulator = offset_generator;
277 const auto& msm = msms[msm_idx];
278 size_t msm_row_index = msm_row_counts[msm_idx];
279 const size_t msm_size = msm.size();
280 const size_t num_rows_per_digit =
287 (msm_row_counts[msm_idx] - 1) * 4;
294 const auto pc =
static_cast<uint32_t
>(pc_values[msm_idx]);
296 for (
size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) {
300 auto& row = msm_rows[msm_row_index];
302 row.msm_transition = (digit_idx == 0) && (row_idx == 0);
305 auto& add_state = row.add_state[point_idx];
306 add_state.add = num_points_in_row > point_idx;
307 int slice = add_state.add ? msm[
offset + point_idx].wnaf_digits[digit_idx] : 0;
317 add_state.slice = add_state.add ? (
slice + 15) / 2 : 0;
320 ? msm[
offset + point_idx].precomputed_table[
static_cast<size_t>(add_state.slice)]
325 accumulator = add_state.add ? (accumulator + add_state.point) :
Element(p1);
326 p1_trace[trace_index] = p1;
327 p2_trace[trace_index] = p2;
328 p3_trace[trace_index] = accumulator;
329 is_double_or_add[trace_index] =
false;
333 accumulator_trace[msm_row_index] = accumulator;
335 row.q_double =
false;
337 row.msm_round =
static_cast<uint32_t
>(digit_idx);
338 row.msm_size =
static_cast<uint32_t
>(msm_size);
339 row.msm_count =
static_cast<uint32_t
>(
offset);
349 auto& row = msm_rows[msm_row_index];
350 row.msm_transition =
false;
351 row.msm_round =
static_cast<uint32_t
>(digit_idx + 1);
352 row.msm_size =
static_cast<uint32_t
>(msm_size);
353 row.msm_count =
static_cast<uint32_t
>(0);
358 auto& add_state = row.add_state[point_idx];
359 add_state.add =
false;
361 add_state.point = { 0, 0 };
362 add_state.collision_inverse = 0;
364 p1_trace[trace_index] = accumulator;
365 p2_trace[trace_index] = accumulator;
366 accumulator = accumulator.dbl();
367 p3_trace[trace_index] = accumulator;
368 is_double_or_add[trace_index] =
true;
371 accumulator_trace[msm_row_index] = accumulator;
375 for (
size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) {
376 auto& row = msm_rows[msm_row_index];
382 row.msm_transition =
false;
383 Element acc_expected = accumulator;
385 auto& add_state = row.add_state[point_idx];
386 add_state.add = num_points_in_row > point_idx;
387 add_state.slice = add_state.add ? msm[
offset + point_idx].wnaf_skew ? 7 : 0 : 0;
391 ? msm[
offset + point_idx].precomputed_table[
static_cast<size_t>(add_state.slice)]
396 bool add_predicate = add_state.add ? msm[
offset + point_idx].wnaf_skew :
false;
397 auto p1 = accumulator;
398 accumulator = add_predicate ? accumulator + add_state.point : accumulator;
399 p1_trace[trace_index] = p1;
400 p2_trace[trace_index] = add_state.point;
401 p3_trace[trace_index] = accumulator;
402 is_double_or_add[trace_index] =
false;
406 row.q_double =
false;
408 row.msm_round =
static_cast<uint32_t
>(digit_idx + 1);
409 row.msm_size =
static_cast<uint32_t
>(msm_size);
410 row.msm_count =
static_cast<uint32_t
>(
offset);
412 accumulator_trace[msm_row_index] = accumulator;
421 Element::batch_normalize(&points_to_normalize[start], end - start);
425 std::vector<FF> inverse_trace(num_point_adds_and_doubles);
427 for (
size_t operation_idx = start; operation_idx < end; ++operation_idx) {
428 if (is_double_or_add[operation_idx]) {
429 inverse_trace[operation_idx] = (p1_trace[operation_idx].y + p1_trace[operation_idx].y);
431 inverse_trace[operation_idx] = (p2_trace[operation_idx].x - p1_trace[operation_idx].x);
440 for (
size_t msm_idx = 0; msm_idx < msms.size(); msm_idx++) {
441 const auto& msm = msms[msm_idx];
443 size_t msm_row_index = msm_row_counts[msm_idx];
445 size_t accumulator_index = msm_row_counts[msm_idx] - 1;
446 const size_t msm_size = msm.size();
447 const size_t num_rows_per_digit =
451 for (
size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) {
452 auto& row = msm_rows[msm_row_index];
456 const Element& normalized_accumulator = accumulator_trace[accumulator_index];
457 BB_ASSERT_EQ(normalized_accumulator.is_point_at_infinity(), 0);
458 row.accumulator_x = normalized_accumulator.x;
459 row.accumulator_y = normalized_accumulator.y;
461 auto& add_state = row.add_state[point_idx];
463 const auto& inverse = inverse_trace[trace_index];
464 const auto& p1 = p1_trace[trace_index];
465 const auto& p2 = p2_trace[trace_index];
466 add_state.collision_inverse = add_state.add ? inverse : 0;
467 add_state.lambda = add_state.add ? (p2.y - p1.y) * inverse : 0;
477 MSMRow& row = msm_rows[msm_row_index];
478 const Element& normalized_accumulator = accumulator_trace[accumulator_index];
479 const FF& acc_x = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.x;
480 const FF& acc_y = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.y;
484 auto& add_state = row.
add_state[point_idx];
485 add_state.collision_inverse = 0;
486 const FF& dx = p1_trace[trace_index].x;
487 const FF& inverse = inverse_trace[trace_index];
488 add_state.lambda = ((dx + dx + dx) * dx) * inverse;
498 for (
size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) {
499 MSMRow& row = msm_rows[msm_row_index];
500 const Element& normalized_accumulator = accumulator_trace[accumulator_index];
501 BB_ASSERT_EQ(normalized_accumulator.is_point_at_infinity(), 0);
506 auto& add_state = row.
add_state[point_idx];
507 bool add_predicate = add_state.add ? msm[
offset + point_idx].wnaf_skew :
false;
509 const auto& inverse = inverse_trace[trace_index];
510 const auto& p1 = p1_trace[trace_index];
511 const auto& p2 = p2_trace[trace_index];
512 add_state.collision_inverse = add_predicate ? inverse : 0;
513 add_state.lambda = add_predicate ? (p2.y - p1.y) * inverse : 0;
526 Element final_accumulator(accumulator_trace.back());
527 MSMRow& final_row = msm_rows.back();
528 final_row.
pc =
static_cast<uint32_t
>(pc_values.back());
530 final_row.
accumulator_x = final_accumulator.is_point_at_infinity() ? 0 : final_accumulator.x;
531 final_row.
accumulator_y = final_accumulator.is_point_at_infinity() ? 0 : final_accumulator.y;
534 final_row.
q_add =
false;
542 return { msm_rows, point_table_read_counts };