Jaxite is a fully homomorphic encryption backend targeting TPUs and GPUs, written in JAX.
It implements the CGGI cryptosystem with some optimizations, and is a supported backend for Google's FHE compiler.
Homomorphic Encryption (HE) is a powerful cryptographic technique that allows computations to be performed directly on encrypted data, without ever needing to decrypt it. This means sensitive information can be processed by untrusted services (like cloud providers) while maintaining data privacy and security. Jaxite provides an efficient implementation of these complex HE schemes within the JAX framework.
Jaxite's primary goals include:
- Efficient Implementation: Providing high-performance implementations of Homomorphic Encryption algorithms.
- Hardware Acceleration: Leveraging JAX to enable seamless execution and acceleration on modern hardware like TPUs and GPUs.
- FHE Compiler Backend: Serving as a robust backend for Google's Fully Homomorphic Encryption compiler, facilitating advanced research and practical applications of FHE.
- Accessibility: Making advanced FHE capabilities more accessible to researchers and developers within the JAX ecosystem.
Install Jaxite
pip install jaxiteA program that shows how to use the jaxite_bool boolean gate API.
from jaxite.jaxite_bool import jaxite_bool
bool_params = jaxite_bool.bool_params
# Note: In real applications, a cryptographically secure seed needs to be
# used.
lwe_rng = bool_params.get_lwe_rng_for_128_bit_security(seed=1)
rlwe_rng = bool_params.get_rlwe_rng_for_128_bit_security(seed=1)
params = bool_params.get_params_for_128_bit_security()
cks = jaxite_bool.ClientKeySet(
params,
lwe_rng=lwe_rng,
rlwe_rng=rlwe_rng,
)
sks = jaxite_bool.ServerKeySet(
cks,
params,
lwe_rng=lwe_rng,
rlwe_rng=rlwe_rng,
bootstrap_callback=None,
)
ct_true = jaxite_bool.encrypt(True, cks, lwe_rng)
ct_false = jaxite_bool.encrypt(False, cks, lwe_rng)
not_false = jaxite_bool.not_(ct_false, params)
or_false = jaxite_bool.or_(not_false, ct_false, sks, params)
and_true = jaxite_bool.and_(or_false, ct_true, sks, params)
xor_true = jaxite_bool.xor_(and_true, ct_true, sks, params)
actual = jaxite_bool.decrypt(xor_true, cks)
expected = (((not False) or False) and True) != True
assert actual == expectedJaxite also supports higher-bit-width gates, which are called look-up tables (LUTs).
For example, this function is an 8-bit adder consisting entirely of lut3
gates.
def add_i8(
x: List[types.LweCiphertext],
y: List[types.LweCiphertext],
sks: jaxite_bool.ServerKeySet,
params: jaxite_bool.Parameters) -> List[types.LweCiphertext]:
temp_nodes: Dict[int, types.LweCiphertext] = {}
false = jaxite_bool.constant(False, params)
out = [None] * 8
temp_nodes[0] = jaxite_bool.lut3(x[0], y[0], false, 8, sks, params)
temp_nodes[1] = jaxite_bool.lut3(temp_nodes[0], x[1], y[1], 23, sks, params)
temp_nodes[2] = jaxite_bool.lut3(temp_nodes[1], x[2], y[2], 43, sks, params)
temp_nodes[3] = jaxite_bool.lut3(temp_nodes[2], x[3], y[3], 43, sks, params)
temp_nodes[4] = jaxite_bool.lut3(temp_nodes[3], x[4], y[4], 43, sks, params)
temp_nodes[5] = jaxite_bool.lut3(temp_nodes[4], x[5], y[5], 43, sks, params)
temp_nodes[6] = jaxite_bool.lut3(temp_nodes[5], x[6], y[6], 43, sks, params)
out[0] = jaxite_bool.lut3(x[0], y[0], false, 6, sks, params)
out[1] = jaxite_bool.lut3(temp_nodes[0], x[1], y[1], 150, sks, params)
out[2] = jaxite_bool.lut3(temp_nodes[1], x[2], y[2], 105, sks, params)
out[3] = jaxite_bool.lut3(temp_nodes[2], x[3], y[3], 105, sks, params)
out[4] = jaxite_bool.lut3(temp_nodes[3], x[4], y[4], 105, sks, params)
out[5] = jaxite_bool.lut3(temp_nodes[4], x[5], y[5], 105, sks, params)
out[6] = jaxite_bool.lut3(temp_nodes[5], x[6], y[6], 105, sks, params)
out[7] = jaxite_bool.lut3(temp_nodes[6], x[7], y[7], 105, sks, params)
return outOn a platform with parallelism, jaxite uses JAX's pmap
API to allow
parallel evaluation of gates that have no interdependencies. E.g., the last
eight gate operations of the i8 adder above could be rewritten as
inputs = [
(x[0], y[0], false, 6), # out[0]
(temp_nodes[0], x[1], y[1], 150), # out[1]
(temp_nodes[1], x[2], y[2], 105), # out[2]
(temp_nodes[2], x[3], y[3], 105), # out[3]
(temp_nodes[3], x[4], y[4], 105), # out[4]
(temp_nodes[4], x[5], y[5], 105), # out[5]
(temp_nodes[5], x[6], y[6], 105), # out[6]
(temp_nodes[6], x[7], y[7], 105), # out[7]
]
outputs = jaxite_bool.pmap_lut3(inputs, sks, params)
return outputsThese circuits were generated with the jaxite support in Google's Fully Homomorphic Encryption Transpiler project, see transpiler/jaxite in that project for instructions.
See CONTRIBUTING.md for details on how to contribute. The following are contributors to Jaxite project (sorted by last name):
- Asra Ali
- Eric Astor
- Bryant Gipson
- Shruthi Gorantala
- Miguel Guevara
- Jeremy Kun
- William Lam
- Rafael Misoczki
- Rob Springer
- Jonathan Takeshita
- Jianming Tong
- Cameron Tew
- Cathie Yun
Apache 2.0; see LICENSE for details.
This project is not an official Google project. It is not supported by Google and Google specifically disclaims all warranties as to its quality, merchantability, or fitness for a particular purpose.