Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
safe_uint.cpp
Go to the documentation of this file.
1#include "safe_uint.hpp"
2#include "../bool/bool.hpp"
3#include "../circuit_builders/circuit_builders.hpp"
6
7namespace bb::stdlib {
8
9template <typename Builder>
10
12{
13 return safe_uint_t((value + other.value), current_max + other.current_max, IS_UNSAFE);
14}
15
16template <typename Builder> safe_uint_t<Builder> safe_uint_t<Builder>::operator*(const safe_uint_t& other) const
17{
18
19 uint512_t new_max = uint512_t(current_max) * uint512_t(other.current_max);
20 BB_ASSERT_EQ(new_max.hi, 0U);
21 return safe_uint_t((value * other.value), new_max.lo, IS_UNSAFE);
22}
23
34template <typename Builder>
36 const size_t difference_bit_size,
37 std::string const& description) const
38{
39 BB_ASSERT_LTE(difference_bit_size, MAX_BIT_NUM);
40 BB_ASSERT(!(this->value.is_constant() && other.value.is_constant()));
41
42 field_ct difference_val = this->value - other.value;
43 // Creates the range constraint that difference_val is in [0, (1<<difference_bit_size) - 1].
44 safe_uint_t<Builder> difference(difference_val, difference_bit_size, format("subtract: ", description));
45 // It is possible for underflow to happen and the range constraint to not catch it.
46 // This is when (a - b) + modulus <= (1<<difference_bit_size) - 1 (or difference.current_max)
47 // Checking that difference.current_max + max of (b - a) >= modulus will ensure that underflow is caught in all
48 // cases
49 if (difference.current_max + other.current_max > MAX_VALUE)
50 throw_or_abort("maximum value exceeded in safe_uint subtract");
51 return difference;
52}
53
68template <typename Builder> safe_uint_t<Builder> safe_uint_t<Builder>::operator-(const safe_uint_t& other) const
69{
70 // If both are constants and the operation is an underflow, throw an error since circuit itself underflows
71 BB_ASSERT(!(this->value.is_constant() && other.value.is_constant() &&
72 static_cast<uint256_t>(value.get_value()) < static_cast<uint256_t>(other.value.get_value())));
73
74 field_ct difference_val = this->value - other.value;
75
76 // safe_uint_t constructor creates a range constraint which checks that `difference_val` is within [0,
77 // current_max].
78 safe_uint_t<Builder> difference(difference_val, (size_t)(current_max.get_msb() + 1), "- operator");
79
80 // Call the two operands a and b. If this operations is underflow and the range constraint fails to catch it,
81 // this means that (a-b) + modulus is IN the range [0, a.current_max].
82 // This is equivalent to the condition that (a - b) + modulus <= a.current_max.
83 // IF b.current_max >= modulus - a.current_max, then it is possible for this condition to be true
84 // because we can let a be 0, and b be b.current_max -> (0 - b.current_max) + modulus <= a.current_max is true.
85 // IF b.current_max < modulus - a.current_max, it is impossible for underflow to happen, no matter how you set a and
86 // b. Therefore, we check that b.current_max >= modulus - a.current_max, which is equivalent to
87 // difference.current_max + other.current_max > MAX_VALUE Note that we will throw an error sometimes even if a-b is
88 // not an underflow but we cannot distinguish it from a case that underflows, so we must throw an error.
89 if (difference.current_max + other.current_max > MAX_VALUE)
90 throw_or_abort("maximum value exceeded in safe_uint minus operator");
91 return difference;
92}
93
105template <typename Builder>
107 const safe_uint_t& other,
108 const size_t quotient_bit_size,
109 const size_t remainder_bit_size,
110 std::string const& description,
111 const std::function<std::pair<uint256_t, uint256_t>(uint256_t, uint256_t)>& get_quotient) const
112{
113 BB_ASSERT_EQ(this->value.is_constant(), false);
114 BB_ASSERT_LTE(quotient_bit_size, MAX_BIT_NUM);
115 BB_ASSERT_LTE(remainder_bit_size, MAX_BIT_NUM);
116 uint256_t val = this->value.get_value();
117 auto [quotient_val, remainder_val] = get_quotient(val, (uint256_t)other.value.get_value());
118 field_ct quotient_field(witness_t(value.context, quotient_val));
119 field_ct remainder_field(witness_t(value.context, remainder_val));
120 safe_uint_t<Builder> quotient(quotient_field, quotient_bit_size, format("divide method quotient: ", description));
121 safe_uint_t<Builder> remainder(
122 remainder_field, remainder_bit_size, format("divide method remainder: ", description));
123
124 const auto merged_tag = OriginTag(get_origin_tag(), other.get_origin_tag());
125 quotient.set_origin_tag(merged_tag);
126 remainder.set_origin_tag(merged_tag);
127
128 // This line implicitly checks we are not overflowing
129 safe_uint_t int_val = quotient * other + remainder;
130
131 // We constrain divisor - remainder - 1 to be non-negative to ensure that remainder < divisor.
132 // Define remainder_plus_one to avoid multiple subtractions
133 const safe_uint_t<Builder> remainder_plus_one = remainder + 1;
134 // Subtraction of safe_uint_t's imposes the desired range constraint
135 other - remainder_plus_one;
136
137 this->assert_equal(int_val, "divide method quotient and/or remainder incorrect");
138
139 return quotient;
140}
141
149template <typename Builder> safe_uint_t<Builder> safe_uint_t<Builder>::operator/(const safe_uint_t& other) const
150{
151 BB_ASSERT_EQ(this->value.is_constant(), false);
152
153 uint256_t val = this->value.get_value();
154 auto [quotient_val, remainder_val] = val.divmod((uint256_t)other.value.get_value());
155 field_ct quotient_field(witness_t(value.context, quotient_val));
156 field_ct remainder_field(witness_t(value.context, remainder_val));
157 safe_uint_t<Builder> quotient(quotient_field, (size_t)(current_max.get_msb() + 1), format("/ operator quotient"));
158 safe_uint_t<Builder> remainder(
159 remainder_field, (size_t)(other.current_max.get_msb() + 1), format("/ operator remainder"));
160
161 const auto merged_tag = OriginTag(get_origin_tag(), other.get_origin_tag());
162 quotient.set_origin_tag(merged_tag);
163 remainder.set_origin_tag(merged_tag);
164
165 // This line implicitly checks we are not overflowing
166 safe_uint_t int_val = quotient * other + remainder;
167
168 // We constrain divisor - remainder - 1 to be non-negative to ensure that remainder < divisor.
169 // // define remainder_plus_one to avoid multiple subtractions
170 const safe_uint_t<Builder> remainder_plus_one = remainder + 1;
171 // // subtraction of safe_uint_t's imposes the desired range constraint
172 other - remainder_plus_one;
173
174 this->assert_equal(int_val, "/ operator quotient and/or remainder incorrect");
175
176 return quotient;
177}
178
179template <typename Builder> safe_uint_t<Builder> safe_uint_t<Builder>::normalize() const
180{
181 auto norm_value = value.normalize();
182 return safe_uint_t(norm_value, current_max, IS_UNSAFE);
183}
184
185template <typename Builder> void safe_uint_t<Builder>::assert_is_zero(std::string const& msg) const
186{
187 value.assert_is_zero(msg);
188}
189
190template <typename Builder> void safe_uint_t<Builder>::assert_is_not_zero(std::string const& msg) const
191{
192 value.assert_is_not_zero(msg);
193}
194
195template <typename Builder> bool_t<Builder> safe_uint_t<Builder>::is_zero() const
196{
197 return value.is_zero();
198}
199
200template <typename Builder> bb::fr safe_uint_t<Builder>::get_value() const
201{
202 return value.get_value();
203}
204
205template <typename Builder> bool_t<Builder> safe_uint_t<Builder>::operator==(const safe_uint_t& other) const
206{
207 return value == other.value;
208}
209
210template <typename Builder> bool_t<Builder> safe_uint_t<Builder>::operator!=(const safe_uint_t& other) const
211{
212 return !operator==(other);
213}
214template <typename Builder>
215std::array<safe_uint_t<Builder>, 3> safe_uint_t<Builder>::slice(const uint8_t msb, const uint8_t lsb) const
216{
217 BB_ASSERT_GTE(msb, lsb);
219 const safe_uint_t lhs = *this;
220 Builder* ctx = lhs.get_context();
221
222 const uint256_t value = uint256_t(get_value());
223 // This should be caught by the proof itself, but the circuit creator will have now way of knowing where the issue
224 // is
226 const auto msb_plus_one = uint32_t(msb) + 1;
227 const auto hi_mask = ((uint256_t(1) << (256 - uint32_t(msb))) - 1);
228 const auto hi = (value >> msb_plus_one) & hi_mask;
229
230 const auto lo_mask = (uint256_t(1) << lsb) - 1;
231 const auto lo = value & lo_mask;
232
233 const auto slice_mask = ((uint256_t(1) << (uint32_t(msb - lsb) + 1)) - 1);
234 const auto slice = (value >> lsb) & slice_mask;
235 safe_uint_t lo_wit, slice_wit, hi_wit;
236 if (this->value.is_constant()) {
237 hi_wit = safe_uint_t(hi);
238 lo_wit = safe_uint_t(lo);
239 slice_wit = safe_uint_t(slice);
240
241 } else {
242 hi_wit = safe_uint_t(witness_t(ctx, hi), grumpkin::MAX_NO_WRAP_INTEGER_BIT_LENGTH - uint32_t(msb), "hi_wit");
243 lo_wit = safe_uint_t(witness_t(ctx, lo), lsb, "lo_wit");
244 slice_wit = safe_uint_t(witness_t(ctx, slice), msb_plus_one - lsb, "slice_wit");
245 }
246 assert_equal(((hi_wit * safe_uint_t(uint256_t(1) << msb_plus_one)) + lo_wit +
247 (slice_wit * safe_uint_t(uint256_t(1) << lsb))));
248
249 std::array<safe_uint_t, 3> result = { lo_wit, slice_wit, hi_wit };
250 OriginTag tag = get_origin_tag();
251 for (auto& element : result) {
252 element.set_origin_tag(tag);
253 }
254 return result;
255}
256
259
260} // namespace bb::stdlib
#define BB_ASSERT(expression,...)
Definition assert.hpp:80
#define BB_ASSERT_GTE(left, right,...)
Definition assert.hpp:138
#define BB_ASSERT_EQ(actual, expected,...)
Definition assert.hpp:93
#define BB_ASSERT_LTE(left, right,...)
Definition assert.hpp:168
#define BB_ASSERT_LT(left, right,...)
Definition assert.hpp:153
constexpr std::pair< uint256_t, uint256_t > divmod(const uint256_t &b) const
constexpr uint64_t get_msb() const
Implements boolean logic in-circuit.
Definition bool.hpp:59
bb::fr get_value() const
Given a := *this, compute its value given by a.v * a.mul + a.add.
Definition field.cpp:828
bool is_constant() const
Definition field.hpp:429
void assert_is_zero(std::string const &msg="safe_uint_t::assert_is_zero") const
safe_uint_t subtract(const safe_uint_t &other, const size_t difference_bit_size, std::string const &description="") const
Subtraction when you have a pre-determined bound on the difference size.
Definition safe_uint.cpp:35
safe_uint_t operator/(const safe_uint_t &other) const
Potentially less efficient than divide function - bounds remainder and quotient by max of this.
bool_ct is_zero() const
OriginTag get_origin_tag() const
safe_uint_t normalize() const
std::array< safe_uint_t< Builder >, 3 > slice(const uint8_t msb, const uint8_t lsb) const
bool_ct operator==(const safe_uint_t &other) const
safe_uint_t operator-(const safe_uint_t &other) const
Subtraction on two safe_uint_t objects.
Definition safe_uint.cpp:68
void set_origin_tag(OriginTag tag) const
Builder * get_context() const
safe_uint_t operator*(const safe_uint_t &other) const
Definition safe_uint.cpp:16
safe_uint_t operator+(const safe_uint_t &other) const
Definition safe_uint.cpp:11
bool_ct operator!=(const safe_uint_t &other) const
safe_uint_t divide(const safe_uint_t &other, const size_t quotient_bit_size, const size_t remainder_bit_size, std::string const &description="", const std::function< std::pair< uint256_t, uint256_t >(uint256_t, uint256_t)> &get_quotient=[](uint256_t val, uint256_t divisor) { return std::make_pair((uint256_t)(val/(uint256_t) divisor),(uint256_t)(val %(uint256_t) divisor));}) const
division when you have a pre-determined bound on the sizes of the quotient and remainder
bb::fr get_value() const
void assert_is_not_zero(std::string const &msg="safe_uint_t::assert_is_not_zero") const
std::string format(Args... args)
Definition log.hpp:24
constexpr size_t MAX_NO_WRAP_INTEGER_BIT_LENGTH
Definition grumpkin.hpp:15
uintx< uint256_t > uint512_t
Definition uintx.hpp:307
std::conditional_t< IsGoblinBigGroup< C, Fq, Fr, G >, element_goblin::goblin_element< C, goblin_field< C >, Fr, G >, element_default::element< C, Fq, Fr, G > > element
element wraps either element_default::element or element_goblin::goblin_element depending on parametr...
Definition biggroup.hpp:995
C slice(C const &container, size_t start)
Definition container.hpp:9
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
BB_INLINE constexpr bool is_zero() const noexcept
void throw_or_abort(std::string const &err)