Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
sha256.cpp
Go to the documentation of this file.
2
3#include <algorithm>
4#include <array>
5#include <cstdint>
6#include <memory>
7#include <stdexcept>
8
10
11namespace bb::avm2::simulation {
12
13namespace {
14
15// constants come from barretenberg/cpp/src/barretenberg/crypto/sha256/sha256.cpp
16constexpr std::array<uint32_t, 64> round_constants{
17 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
18 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
19 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
20 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
21 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
22 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
23 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
24 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2
25};
26
27} // namespace
28
29// Don't worry about any weird edge cases since we have fixed non-zero shifts
30MemoryValue Sha256::ror(const MemoryValue& x, uint8_t shift)
31{
32 auto val = x.as<uint32_t>();
33 // In a rotation, we decompose into a lhs and rhs (or hi and lo) part.
34 uint32_t lo = val & ((static_cast<uint32_t>(1) << shift) - 1);
35 uint32_t hi = val >> shift;
36 uint32_t result = lo << (32U - (shift & 31U)) | hi;
37
38 // Do this outside of an assert, in case this gets built without assert
39 bool lo_in_range = gt.gt(static_cast<uint32_t>(1) << shift, lo); // Ensure the lower bits are in range
40 (void)lo_in_range; // To please GCC.
41 assert(lo_in_range && "Low Value in ROR out of range");
42 return MemoryValue::from<uint32_t>(result);
43}
44
45// Don't need to worry about edge cases with shifts since we know we only shift by 3 and 10 for sha256
46MemoryValue Sha256::shr(const MemoryValue& x, uint8_t shift)
47{
48 uint32_t input = x.as<uint32_t>();
49 // Get the lower shift bits
50 uint32_t lo = input & ((static_cast<uint32_t>(1) << shift) - 1);
51 uint32_t hi = input >> shift;
52
53 // Do this outside of an assert, in case this gets built without assert
54 bool lo_in_range = gt.gt(static_cast<uint32_t>(1) << shift, lo); // Ensure the lower bits are in range
55 (void)lo_in_range; // To please GCC.
56 assert(lo_in_range && "Low Value in SHR out of range");
57
58 return MemoryValue::from<uint32_t>(hi);
59}
60
61// This function is used to sum the values in the vector and return the result modulo 2^32.
63{
64 uint64_t sum = 0;
65 for (const auto& value : values) {
66 // This is safe, since we've already checked that the values are of tag U32
67 sum += value.as<uint32_t>();
68 }
69 uint32_t lo = static_cast<uint32_t>(sum);
70 uint32_t hi = sum >> 32;
71
72 // Do these outside of an assert, in case this gets built without assert
73 bool lo_in_range =
74 gt.gt(static_cast<uint64_t>(1) << 32, static_cast<uint64_t>(lo)); // Ensure the lower bits are in range
75 bool hi_in_range =
76 gt.gt(static_cast<uint64_t>(1) << 32, static_cast<uint64_t>(hi)); // Ensure the upper bits are in range
77 (void)lo_in_range; // To please GCC.
78 (void)hi_in_range; // To please GCC.
79 assert(lo_in_range && hi_in_range && "Sum in MODULO_SUM out of range");
80 return MemoryValue::from<uint32_t>(lo);
81}
82
84 MemoryAddress state_addr,
85 MemoryAddress input_addr,
86 MemoryAddress output_addr)
87{
88 uint32_t execution_clk = execution_id_manager.get_execution_id();
89 uint16_t space_id = memory.get_space_id();
90
91 // Default values are FF(0) as that is what the circuit would expect
93 state.fill(MemoryValue::from<FF>(0));
94
96 input.reserve(16);
97
98 // Check that the maximum addresss for the state, input, and output addresses are within the valid range.
99 // (1) Read the 8 element hash state from { state_addr, state_addr + 1, ..., state_addr + 7 }
100 // (2) Read the 16 element input from { input_addr, input_addr + 1, ..., input_addr + 15 }
101 // (3) Write the 8 element output to { output_addr, output_addr + 1, ..., output_addr + 7 }
102 bool state_addr_out_of_range = gt.gt(static_cast<uint64_t>(state_addr) + 7, AVM_HIGHEST_MEM_ADDRESS);
103 bool input_addr_out_of_range = gt.gt(static_cast<uint64_t>(input_addr) + 15, AVM_HIGHEST_MEM_ADDRESS);
104 bool output_addr_out_of_range = gt.gt(static_cast<uint64_t>(output_addr) + 7, AVM_HIGHEST_MEM_ADDRESS);
105
106 try {
107 if (state_addr_out_of_range || input_addr_out_of_range || output_addr_out_of_range) {
108 throw std::runtime_error("Memory address out of range for sha256 compression.");
109 }
110
111 // Read the hash state from memory. The state needs to be loaded atomically from memory (i.e. all 8 elements are
112 // read regardless of errors)
113 for (uint32_t i = 0; i < 8; ++i) {
114 state[i] = memory.get(state_addr + i);
115 }
116
117 // If any of the state values are not of tag U32, we throw an error.
118 if (std::ranges::any_of(state, [](const MemoryValue& val) { return val.get_tag() != MemoryTag::U32; })) {
119 throw std::runtime_error("Invalid tag for sha256 state values.");
120 }
121
122 // Load 16 elements representing the hash input from memory.
123 // Since the circuit loads this per row, we throw on the first error we find.
124 for (uint32_t i = 0; i < 16; ++i) {
125 input.emplace_back(memory.get(input_addr + i));
126 if (input[i].get_tag() != MemoryTag::U32) {
127 throw std::runtime_error("Invalid tag for sha256 input values.");
128 }
129 }
130
131 // Perform sha256 compression. Taken from `vm2/simulation/lib/sha256_compression.cpp` but using
132 // the bitwise operations and MemoryValues
134
135 // Fill first 16 words with the inputs
136 for (size_t i = 0; i < 16; ++i) {
137 w[i] = input[i];
138 }
139
140 // Extend the input data into the remaining 48 words
141 for (size_t i = 16; i < 64; ++i) {
142 MemoryValue s0 = bitwise.xor_op(bitwise.xor_op(ror(w[i - 15], 7), ror(w[i - 15], 18)), shr(w[i - 15], 3));
143 MemoryValue s1 = bitwise.xor_op(bitwise.xor_op(ror(w[i - 2], 17), ror(w[i - 2], 19)), shr(w[i - 2], 10));
144 // Could be explicit with an std::initializer_list<uint32_t> here, the array overload is more readable imo.
145 // std::spans are annoying to construct from literals
146 // (https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2447r2.html)
147 w[i] = modulo_sum({ { w[i - 16], w[i - 7], s0, s1 } });
148 }
149
150 // Initialize round variables with previous block output
151 MemoryValue a = state[0];
152 MemoryValue b = state[1];
153 MemoryValue c = state[2];
154 MemoryValue d = state[3];
155 MemoryValue e = state[4];
156 MemoryValue f = state[5];
157 MemoryValue g = state[6];
158 MemoryValue h = state[7];
159
160 // Apply SHA-256 compression function to the message schedule
161 for (size_t i = 0; i < 64; ++i) {
162 MemoryValue S1 = bitwise.xor_op(bitwise.xor_op(ror(e, 6U), ror(e, 11U)), ror(e, 25U));
163 MemoryValue ch = bitwise.xor_op(bitwise.and_op(e, f), bitwise.and_op(~e, g));
164 MemoryValue S0 = bitwise.xor_op(bitwise.xor_op(ror(a, 2U), ror(a, 13U)), ror(a, 22U));
165 MemoryValue maj =
166 bitwise.xor_op(bitwise.xor_op(bitwise.and_op(a, b), bitwise.and_op(a, c)), bitwise.and_op(b, c));
167
168 auto prev_h = h; // Need to store the previous h value before updating it so we can use it in the modulo sum
169 h = g;
170 g = f;
171 f = e;
172 // e = d + temp1;
173 e = modulo_sum({ { d, prev_h, S1, ch, MemoryValue::from<uint32_t>(round_constants[i]), w[i] } });
174 d = c;
175 c = b;
176 b = a;
177 // a = temp1 + temp2;
178 a = modulo_sum({ { prev_h, S1, ch, MemoryValue::from<uint32_t>(round_constants[i]), w[i], S0, maj } });
179 }
180
181 // Add into previous block output and return
183 modulo_sum({ { a, state[0] } }), modulo_sum({ { b, state[1] } }), modulo_sum({ { c, state[2] } }),
184 modulo_sum({ { d, state[3] } }), modulo_sum({ { e, state[4] } }), modulo_sum({ { f, state[5] } }),
185 modulo_sum({ { g, state[6] } }), modulo_sum({ { h, state[7] } }),
186 };
187
188 // Write the output back to memory.
189 for (uint32_t i = 0; i < 8; ++i) {
190 memory.set(output_addr + i, output[i]);
191 }
192
193 events.emit({ .execution_clk = execution_clk,
194 .space_id = space_id,
195 .state_addr = state_addr,
196 .input_addr = input_addr,
197 .output_addr = output_addr,
198 .state = state,
199 .input = input,
200 .output = output });
201 } catch (const std::exception& e) {
202 // If any error occurs, we emit an event with the error message.
204 output.fill(MemoryValue::from<FF>(0)); // Default output in case of error
205 events.emit({ .execution_clk = execution_clk,
206 .space_id = space_id,
207 .state_addr = state_addr,
208 .input_addr = input_addr,
209 .output_addr = output_addr,
210 .state = state,
211 .input = input,
212 .output = output });
213 throw; // Re-throw the exception after emitting the event
214 }
215}
216
217} // namespace bb::avm2::simulation
#define AVM_HIGHEST_MEM_ADDRESS
ValueTag get_tag() const
virtual uint32_t get_execution_id() const =0
MemoryValue modulo_sum(std::span< const MemoryValue > values)
Definition sha256.cpp:62
EventEmitterInterface< Sha256CompressionEvent > & events
Definition sha256.hpp:43
void compression(MemoryInterface &memory, MemoryAddress state_addr, MemoryAddress input_addr, MemoryAddress output_addr) override
Definition sha256.cpp:83
MemoryValue shr(const MemoryValue &x, uint8_t shift)
Definition sha256.cpp:46
ExecutionIdGetterInterface & execution_id_manager
Definition sha256.hpp:40
MemoryValue ror(const MemoryValue &x, uint8_t shift)
Definition sha256.cpp:30
FF a
FF b
constexpr uint32_t round_constants[64]
uint32_t MemoryAddress
Inner sum(Cont< Inner, Args... > const &in)
Definition container.hpp:70
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13