Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
sha256.cpp
Go to the documentation of this file.
2
3#include <algorithm>
4#include <array>
5#include <cstdint>
6#include <memory>
7#include <stdexcept>
8
10
11namespace bb::avm2::simulation {
12
13namespace {
14
15// constants come from barretenberg/cpp/src/barretenberg/crypto/sha256/sha256.cpp
16constexpr std::array<uint32_t, 64> round_constants{
17 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
18 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
19 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
20 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
21 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
22 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
23 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
24 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2
25};
26
27} // namespace
28
29// Don't worry about any weird edge cases since we have fixed non-zero shifts
30MemoryValue Sha256::ror(const MemoryValue& x, uint8_t shift)
31{
32 auto val = x.as<uint32_t>();
33 // In a rotation, we decompose into a lhs and rhs (or hi and lo) part.
34 uint32_t lo = val & ((static_cast<uint32_t>(1) << shift) - 1);
35 uint32_t hi = val >> shift;
36 uint32_t result = lo << (32U - (shift & 31U)) | hi;
37
38 // Do this outside of an assert, in case this gets built without assert
39 bool lo_in_range = gt.gt(static_cast<uint32_t>(1) << shift, lo); // Ensure the lower bits are in range
40 BB_ASSERT(lo_in_range, "Low Value in ROR out of range");
41 return MemoryValue::from<uint32_t>(result);
42}
43
44// Don't need to worry about edge cases with shifts since we know we only shift by 3 and 10 for sha256
45MemoryValue Sha256::shr(const MemoryValue& x, uint8_t shift)
46{
47 uint32_t input = x.as<uint32_t>();
48 // Get the lower shift bits
49 uint32_t lo = input & ((static_cast<uint32_t>(1) << shift) - 1);
50 uint32_t hi = input >> shift;
51
52 // Do this outside of an assert, in case this gets built without assert
53 bool lo_in_range = gt.gt(static_cast<uint32_t>(1) << shift, lo); // Ensure the lower bits are in range
54 BB_ASSERT(lo_in_range, "Low Value in SHR out of range");
55
56 return MemoryValue::from<uint32_t>(hi);
57}
58
59// This function is used to sum the values in the vector and return the result modulo 2^32.
61{
62 uint64_t sum = 0;
63 for (const auto& value : values) {
64 // This is safe, since we've already checked that the values are of tag U32
65 sum += value.as<uint32_t>();
66 }
67 uint32_t lo = static_cast<uint32_t>(sum);
68 uint32_t hi = sum >> 32;
69
70 // Do these outside of an assert, in case this gets built without assert
71 bool lo_in_range =
72 gt.gt(static_cast<uint64_t>(1) << 32, static_cast<uint64_t>(lo)); // Ensure the lower bits are in range
73 bool hi_in_range =
74 gt.gt(static_cast<uint64_t>(1) << 32, static_cast<uint64_t>(hi)); // Ensure the upper bits are in range
75 BB_ASSERT(lo_in_range && hi_in_range, "Sum in MODULO_SUM out of range");
76 return MemoryValue::from<uint32_t>(lo);
77}
78
80 MemoryAddress state_addr,
81 MemoryAddress input_addr,
82 MemoryAddress output_addr)
83{
84 uint32_t execution_clk = execution_id_manager.get_execution_id();
85 uint16_t space_id = memory.get_space_id();
86
87 // Default values are FF(0) as that is what the circuit would expect
89 state.fill(MemoryValue::from<FF>(0));
90
92 input.reserve(16);
93
94 // Check that the maximum addresss for the state, input, and output addresses are within the valid range.
95 // (1) Read the 8 element hash state from { state_addr, state_addr + 1, ..., state_addr + 7 }
96 // (2) Read the 16 element input from { input_addr, input_addr + 1, ..., input_addr + 15 }
97 // (3) Write the 8 element output to { output_addr, output_addr + 1, ..., output_addr + 7 }
98 bool state_addr_out_of_range = gt.gt(static_cast<uint64_t>(state_addr) + 7, AVM_HIGHEST_MEM_ADDRESS);
99 bool input_addr_out_of_range = gt.gt(static_cast<uint64_t>(input_addr) + 15, AVM_HIGHEST_MEM_ADDRESS);
100 bool output_addr_out_of_range = gt.gt(static_cast<uint64_t>(output_addr) + 7, AVM_HIGHEST_MEM_ADDRESS);
101
102 try {
103 if (state_addr_out_of_range || input_addr_out_of_range || output_addr_out_of_range) {
104 throw Sha256CompressionException("Memory address out of range for sha256 compression.");
105 }
106
107 // Read the hash state from memory. The state needs to be loaded atomically from memory (i.e. all 8 elements are
108 // read regardless of errors)
109 for (uint32_t i = 0; i < 8; ++i) {
110 state[i] = memory.get(state_addr + i);
111 }
112
113 // If any of the state values are not of tag U32, we throw an error.
114 if (std::ranges::any_of(state, [](const MemoryValue& val) { return val.get_tag() != MemoryTag::U32; })) {
115 throw Sha256CompressionException("Invalid tag for sha256 state values.");
116 }
117
118 // Load 16 elements representing the hash input from memory.
119 // Since the circuit loads this per row, we throw on the first error we find.
120 for (uint32_t i = 0; i < 16; ++i) {
121 input.emplace_back(memory.get(input_addr + i));
122 if (input[i].get_tag() != MemoryTag::U32) {
123 throw Sha256CompressionException("Invalid tag for sha256 input values.");
124 }
125 }
126
127 // Perform sha256 compression. Taken from `vm2/simulation/lib/sha256_compression.cpp` but using
128 // the bitwise operations and MemoryValues
130
131 // Fill first 16 words with the inputs
132 for (size_t i = 0; i < 16; ++i) {
133 w[i] = input[i];
134 }
135
136 // Extend the input data into the remaining 48 words
137 for (size_t i = 16; i < 64; ++i) {
138 MemoryValue s0 = bitwise.xor_op(bitwise.xor_op(ror(w[i - 15], 7), ror(w[i - 15], 18)), shr(w[i - 15], 3));
139 MemoryValue s1 = bitwise.xor_op(bitwise.xor_op(ror(w[i - 2], 17), ror(w[i - 2], 19)), shr(w[i - 2], 10));
140 // Could be explicit with an std::initializer_list<uint32_t> here, the array overload is more readable imo.
141 // std::spans are annoying to construct from literals
142 // (https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2447r2.html)
143 w[i] = modulo_sum({ { w[i - 16], w[i - 7], s0, s1 } });
144 }
145
146 // Initialize round variables with previous block output
147 MemoryValue a = state[0];
148 MemoryValue b = state[1];
149 MemoryValue c = state[2];
150 MemoryValue d = state[3];
151 MemoryValue e = state[4];
152 MemoryValue f = state[5];
153 MemoryValue g = state[6];
154 MemoryValue h = state[7];
155
156 // Apply SHA-256 compression function to the message schedule
157 for (size_t i = 0; i < 64; ++i) {
158 MemoryValue S1 = bitwise.xor_op(bitwise.xor_op(ror(e, 6U), ror(e, 11U)), ror(e, 25U));
159 MemoryValue ch = bitwise.xor_op(bitwise.and_op(e, f), bitwise.and_op(~e, g));
160 MemoryValue S0 = bitwise.xor_op(bitwise.xor_op(ror(a, 2U), ror(a, 13U)), ror(a, 22U));
161 MemoryValue maj =
162 bitwise.xor_op(bitwise.xor_op(bitwise.and_op(a, b), bitwise.and_op(a, c)), bitwise.and_op(b, c));
163
164 auto prev_h = h; // Need to store the previous h value before updating it so we can use it in the modulo sum
165 h = g;
166 g = f;
167 f = e;
168 // e = d + temp1;
169 e = modulo_sum({ { d, prev_h, S1, ch, MemoryValue::from<uint32_t>(round_constants[i]), w[i] } });
170 d = c;
171 c = b;
172 b = a;
173 // a = temp1 + temp2;
174 a = modulo_sum({ { prev_h, S1, ch, MemoryValue::from<uint32_t>(round_constants[i]), w[i], S0, maj } });
175 }
176
177 // Add into previous block output and return
179 modulo_sum({ { a, state[0] } }), modulo_sum({ { b, state[1] } }), modulo_sum({ { c, state[2] } }),
180 modulo_sum({ { d, state[3] } }), modulo_sum({ { e, state[4] } }), modulo_sum({ { f, state[5] } }),
181 modulo_sum({ { g, state[6] } }), modulo_sum({ { h, state[7] } }),
182 };
183
184 // Write the output back to memory.
185 for (uint32_t i = 0; i < 8; ++i) {
186 memory.set(output_addr + i, output[i]);
187 }
188
189 events.emit({ .execution_clk = execution_clk,
190 .space_id = space_id,
191 .state_addr = state_addr,
192 .input_addr = input_addr,
193 .output_addr = output_addr,
194 .state = state,
195 .input = input,
196 .output = output });
197 } catch (const Sha256CompressionException& e) {
198 // If any error occurs, we emit an event with the error message.
200 output.fill(MemoryValue::from<FF>(0)); // Default output in case of error
201 events.emit({ .execution_clk = execution_clk,
202 .space_id = space_id,
203 .state_addr = state_addr,
204 .input_addr = input_addr,
205 .output_addr = output_addr,
206 .state = state,
207 .input = input,
208 .output = output });
209
210 // Rethrow the exception after emitting the event
211 throw;
212 }
213}
214
215} // namespace bb::avm2::simulation
#define BB_ASSERT(expression,...)
Definition assert.hpp:80
#define AVM_HIGHEST_MEM_ADDRESS
ValueTag get_tag() const
virtual uint32_t get_execution_id() const =0
MemoryValue modulo_sum(std::span< const MemoryValue > values)
Definition sha256.cpp:60
EventEmitterInterface< Sha256CompressionEvent > & events
Definition sha256.hpp:43
void compression(MemoryInterface &memory, MemoryAddress state_addr, MemoryAddress input_addr, MemoryAddress output_addr) override
Definition sha256.cpp:79
MemoryValue shr(const MemoryValue &x, uint8_t shift)
Definition sha256.cpp:45
ExecutionIdGetterInterface & execution_id_manager
Definition sha256.hpp:40
MemoryValue ror(const MemoryValue &x, uint8_t shift)
Definition sha256.cpp:30
FF a
FF b
constexpr uint32_t round_constants[64]
uint32_t MemoryAddress
Inner sum(Cont< Inner, Args... > const &in)
Definition container.hpp:70
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13