1use std::io::{self, Read, Write};
2use std::mem::size_of;
3
4#[derive(Debug, Default)]
5pub struct BitWriter {
6 buffer: Vec<u8>,
7 scratch: u64,
8 scratch_bits: usize,
9}
10
11impl BitWriter {
12 pub fn with_capacity(capacity: usize) -> Self {
13 Self {
14 buffer: Vec::with_capacity(capacity),
15 scratch: 0,
16 scratch_bits: 0,
17 }
18 }
19
20 pub fn consume(mut self) -> Result<Vec<u8>, io::Error> {
21 self.flush_bits()?;
22 Ok(self.buffer)
23 }
24
25 pub fn write_bits(&mut self, value: u32, bits: usize) -> Result<(), io::Error> {
26 assert!(bits <= 32);
27
28 self.scratch |= (value as u64) << self.scratch_bits;
29 self.scratch_bits += bits;
30
31 if self.scratch_bits >= 32 {
32 let bytes = (self.scratch as u32).to_le_bytes();
33 self.buffer.write_all(&bytes)?;
34 self.scratch >>= 32;
35 self.scratch_bits -= 32;
36 }
37
38 Ok(())
39 }
40
41 pub fn align(&mut self) -> Result<(), io::Error> {
42 let remainder_bits = self.scratch_bits % 8;
43 if remainder_bits != 0 {
44 self.write_bits(0, 8 - remainder_bits)?;
45 assert!(self.scratch_bits % 8 == 0);
46 }
47 Ok(())
48 }
49
50 pub fn flush_bits(&mut self) -> Result<(), io::Error> {
51 if self.scratch_bits != 0 {
52 let bytes = (self.scratch as u32).to_le_bytes();
53 self.buffer.write_all(&bytes)?;
54 self.scratch = 0;
55 self.scratch_bits = 0;
56 }
57 Ok(())
58 }
59
60 pub fn bits_written(&self) -> usize {
61 self.buffer.len() * 8 + self.scratch_bits
62 }
63
64 fn align_bits(&self) -> usize {
65 (8 - (self.scratch_bits % 8)) % 8
66 }
67
68 pub fn write_bool(&mut self, value: bool) -> Result<(), io::Error> {
69 self.write_bits(value as u32, 1)
70 }
71
72 pub fn write_u8(&mut self, byte: u8) -> Result<(), io::Error> {
73 self.write_bits(byte as u32, 8)
74 }
75
76 pub fn write_u16(&mut self, value: u16) -> Result<(), io::Error> {
77 self.write_bits(value as u32, 16)
78 }
79
80 pub fn write_u32(&mut self, value: u32) -> Result<(), io::Error> {
81 self.write_bits(value, 32)
82 }
83
84 pub fn write_u64(&mut self, value: u64) -> Result<(), io::Error> {
85 let low_bits = value as u32;
86 let high_bits = (value >> 32) as u32;
87 self.write_bits(low_bits, 32)?;
88 self.write_bits(high_bits, 32)
89 }
90
91 pub fn write_i16(&mut self, value: i16) -> Result<(), io::Error> {
92 self.write_bits(value as u32, 16)
93 }
94
95 pub fn write_i32(&mut self, value: i32) -> Result<(), io::Error> {
96 self.write_bits(value as u32, 32)
97 }
98
99 pub fn write_i64(&mut self, value: i64) -> Result<(), io::Error> {
100 self.write_u64(value as u64)
101 }
102
103 pub fn write_varint_u16(&mut self, value: u16) -> Result<(), io::Error> {
104 self.write_varint_u64(value as u64)
105 }
106
107 pub fn write_varint_u32(&mut self, value: u32) -> Result<(), io::Error> {
108 self.write_varint_u64(value as u64)
109 }
110
111 pub fn write_varint_u64(&mut self, mut value: u64) -> Result<(), io::Error> {
115 for _ in 0..8 {
116 let mut t = value as u8;
117 t &= 0b011111111u8;
119 value >>= 7;
120
121 let more_to_write = value != 0;
123 if more_to_write {
124 t |= 0b10000000u8;
125 }
126
127 self.write_u8(t)?;
128
129 if !more_to_write {
130 return Ok(());
131 }
132 }
133
134 self.write_u8(value as u8)
137 }
138
139 pub fn write_varint_i16(&mut self, value: i16) -> Result<(), io::Error> {
140 let value = zig_zag_encode(value as i64);
141 self.write_varint_u64(value)
142 }
143
144 pub fn write_varint_i32(&mut self, value: i32) -> Result<(), io::Error> {
145 let value = zig_zag_encode(value as i64);
146 self.write_varint_u64(value)
147 }
148
149 pub fn write_varint_i64(&mut self, value: i64) -> Result<(), io::Error> {
150 let value = zig_zag_encode(value);
151 self.write_varint_u64(value)
152 }
153
154 pub fn write_f32(&mut self, value: f32) -> Result<(), io::Error> {
155 self.write_u32(value.to_bits())
156 }
157
158 pub fn write_f64(&mut self, value: f64) -> Result<(), io::Error> {
159 self.write_u64(value.to_bits())
160 }
161}
162
163impl Write for BitWriter {
164 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
165 self.align()?;
167
168 let mut head_bytes = ((32 - self.scratch_bits) / 8) % 4;
170
171 if head_bytes > buf.len() {
172 head_bytes = buf.len();
174 }
175
176 for &value in buf.iter().take(head_bytes) {
177 self.write_bits(value as u32, 8)?;
178 }
179
180 if head_bytes == buf.len() {
181 return Ok(buf.len());
182 }
183
184 self.flush_bits()?;
186 assert_eq!(self.align_bits(), 0);
187
188 const U32_SIZE: usize = size_of::<u32>();
190 let num_words = (buf.len() - head_bytes) / U32_SIZE;
191 if num_words > 0 {
192 self.buffer
193 .extend_from_slice(&buf[head_bytes..head_bytes + num_words * U32_SIZE]);
194 }
195
196 let tail_start = head_bytes + num_words * U32_SIZE;
198 let tail_bytes = buf.len() - tail_start;
199 assert!(tail_bytes < 4);
200
201 for i in 0..tail_bytes {
202 self.write_bits(buf[tail_start + i] as u32, 8)?;
203 }
204
205 Ok(buf.len())
206 }
207
208 fn flush(&mut self) -> io::Result<()> {
209 self.flush_bits()
210 }
211}
212
213#[derive(Debug)]
214pub struct BitReader<'a> {
215 buffer: &'a [u8],
216 scratch: u64,
217 scratch_bits: usize,
218 bits_read: usize,
219}
220
221impl<'a> BitReader<'a> {
222 pub fn new(buffer: &'a [u8]) -> Result<Self, io::Error> {
223 if buffer.len() % 4 != 0 {
224 return Err(io::Error::new(
225 io::ErrorKind::InvalidInput,
226 "BitReader buffer must have the length as a multiple of 4",
227 ));
228 }
229 Ok(Self {
230 buffer,
231 scratch: 0,
232 scratch_bits: 0,
233 bits_read: 0,
234 })
235 }
236
237 pub fn read_bits(&mut self, bits: usize) -> Result<u32, io::Error> {
238 assert!(bits <= 32);
239
240 if self.scratch_bits < bits {
241 let mut word = [0u8; 4];
242 self.buffer.read_exact(&mut word)?;
243 let word = u32::from_le_bytes(word);
244 self.scratch |= (word as u64) << self.scratch_bits;
245 self.scratch_bits += 32;
246 }
247
248 assert!(self.scratch_bits >= bits);
249
250 let output = (self.scratch & ((1u64 << bits) - 1)) as u32;
251 self.scratch >>= bits;
252 self.scratch_bits -= bits;
253 self.bits_read += bits;
254
255 Ok(output)
256 }
257
258 pub fn align(&mut self) -> Result<(), io::Error> {
259 let remainder_bits = self.bits_read % 8;
260 if remainder_bits != 0 {
261 let value = self.read_bits(8 - remainder_bits)?;
262 assert_eq!(self.bits_read % 8, 0);
263 if value != 0 {
266 return Err(io::Error::new(
267 io::ErrorKind::InvalidData,
268 "Invalid padding, alignment bits must all be 0",
269 ));
270 }
271 }
272
273 Ok(())
274 }
275
276 pub fn read_bool(&mut self) -> Result<bool, io::Error> {
277 Ok(self.read_bits(1)? == 1)
278 }
279
280 pub fn read_u8(&mut self) -> Result<u8, io::Error> {
281 Ok(self.read_bits(8)? as u8)
282 }
283
284 pub fn read_u16(&mut self) -> Result<u16, io::Error> {
285 Ok(self.read_bits(16)? as u16)
286 }
287
288 pub fn read_u32(&mut self) -> Result<u32, io::Error> {
289 self.read_bits(32)
290 }
291
292 pub fn read_u64(&mut self) -> Result<u64, io::Error> {
293 let low_bits = self.read_bits(32)?;
294 let high_bits = self.read_bits(32)?;
295
296 let value = low_bits as u64 | ((high_bits as u64) << 32);
297 Ok(value)
298 }
299
300 pub fn read_i16(&mut self) -> Result<i16, io::Error> {
301 Ok(self.read_bits(16)? as i16)
302 }
303
304 pub fn read_i32(&mut self) -> Result<i32, io::Error> {
305 Ok(self.read_bits(32)? as i32)
306 }
307
308 pub fn read_i64(&mut self) -> Result<i64, io::Error> {
309 Ok(self.read_u64()? as i64)
310 }
311
312 pub fn read_varint_u16(&mut self) -> Result<u16, io::Error> {
313 let value = self.read_varint_u64()?;
314 Ok(value as u16)
315 }
316
317 pub fn read_varint_u32(&mut self) -> Result<u32, io::Error> {
318 let value = self.read_varint_u64()?;
319 Ok(value as u32)
320 }
321
322 pub fn read_varint_u64(&mut self) -> Result<u64, io::Error> {
323 let mut result: u64 = 0;
324 for i in 0..8 {
325 let byte = self.read_u8()?;
326 let stop_reading = (byte & 0b10000000u8) == 0;
328
329 let value = (byte & 0b01111111u8) as u64;
331 result |= value << (i * 7);
332
333 if stop_reading {
334 return Ok(result);
335 }
336 }
337
338 let value = self.read_u8()? as u64;
341 result |= value << 56;
342
343 Ok(result)
344 }
345
346 pub fn read_varint_i16(&mut self) -> Result<i16, io::Error> {
347 let value = self.read_varint_u64()?;
348 let value = zig_zag_decode(value);
349 Ok(value as i16)
350 }
351
352 pub fn read_varint_i32(&mut self) -> Result<i32, io::Error> {
353 let value = self.read_varint_u64()?;
354 let value = zig_zag_decode(value);
355 Ok(value as i32)
356 }
357
358 pub fn read_varint_i64(&mut self) -> Result<i64, io::Error> {
359 let value = self.read_varint_u64()?;
360 let value = zig_zag_decode(value);
361 Ok(value)
362 }
363
364 pub fn read_f32(&mut self) -> Result<f32, io::Error> {
365 let value = self.read_u32()?;
366 Ok(f32::from_bits(value))
367 }
368
369 pub fn read_f64(&mut self) -> Result<f64, io::Error> {
370 let value = self.read_u64()?;
371 Ok(f64::from_bits(value))
372 }
373}
374
375impl<'a> Read for BitReader<'a> {
376 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
377 self.align()?;
379
380 let mut head_bytes = (self.scratch_bits / 8) % 4;
382
383 if head_bytes > buf.len() {
385 head_bytes = buf.len();
386 }
387
388 for value in buf.iter_mut().take(head_bytes) {
389 *value = self.read_bits(8)? as u8;
390 }
391
392 if head_bytes == buf.len() {
393 return Ok(buf.len());
394 }
395
396 const U32_SIZE: usize = size_of::<u32>();
398 let num_words = (buf.len() - head_bytes) / U32_SIZE;
399 if num_words > 0 {
400 self.buffer
401 .read_exact(&mut buf[head_bytes..head_bytes + (num_words * U32_SIZE)])?;
402 self.bits_read += num_words * 32;
403 }
404
405 let tail_start = head_bytes + num_words * U32_SIZE;
407 let tail_bytes = buf.len() - tail_start;
408 assert!(tail_bytes < 4);
409
410 for i in 0..tail_bytes {
411 buf[tail_start + i] = self.read_bits(8)? as u8;
412 }
413
414 Ok(buf.len())
415 }
416}
417
418#[inline(always)]
422fn zig_zag_encode(value: i64) -> u64 {
423 if value < 0 {
424 !(value as u64) * 2 + 1
425 } else {
426 (value as u64) * 2
427 }
428}
429
430#[inline(always)]
434fn zig_zag_decode(value: u64) -> i64 {
435 if value % 2 == 0 {
436 (value / 2) as i64
437 } else {
438 !(value / 2) as i64
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445 use serde::{Deserialize, Serialize};
446
447 #[test]
448 fn usage() {
449 let mut writer = BitWriter::default();
450
451 writer.write_bool(true).unwrap();
453
454 let value: u32 = 3;
457 writer.write_bits(value, 2).unwrap();
458
459 writer.write_u8(0).unwrap(); writer.write_u16(1).unwrap(); writer.write_u32(2).unwrap(); writer.write_u64(3).unwrap(); writer.write_i64(-1).unwrap(); writer.write_i32(-2).unwrap(); writer.write_i64(-3).unwrap(); writer.write_varint_u16(1).unwrap(); writer.write_varint_u32(2).unwrap(); writer.write_varint_u64(3).unwrap(); writer.write_varint_i16(-1).unwrap(); writer.write_varint_i32(-2).unwrap(); writer.write_varint_i64(-3).unwrap(); writer.write_f32(1.0).unwrap(); writer.write_f64(2.0).unwrap(); let bytes = vec![7u8; 20];
486 writer.write_all(&bytes).unwrap();
487
488 #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
489 struct SimpleStruct {
490 value: u64,
491 string: String,
492 array: [u16; 12],
493 }
494
495 let message = SimpleStruct {
496 value: 999999999999,
497 string: "some text to serialize".to_owned(),
498 array: [500; 12],
499 };
500
501 bincode::serialize_into(&mut writer, &message).unwrap();
503
504 let writer_bytes = writer.consume().unwrap();
506 let mut reader = BitReader::new(&writer_bytes).unwrap();
507
508 assert!(reader.read_bool().unwrap());
510 assert_eq!(reader.read_bits(2).unwrap(), 3);
511
512 assert_eq!(reader.read_u8().unwrap(), 0);
513 assert_eq!(reader.read_u16().unwrap(), 1);
514 assert_eq!(reader.read_u32().unwrap(), 2);
515 assert_eq!(reader.read_u64().unwrap(), 3);
516 assert_eq!(reader.read_i64().unwrap(), -1);
517 assert_eq!(reader.read_i32().unwrap(), -2);
518 assert_eq!(reader.read_i64().unwrap(), -3);
519
520 assert_eq!(reader.read_varint_u16().unwrap(), 1);
521 assert_eq!(reader.read_varint_u32().unwrap(), 2);
522 assert_eq!(reader.read_varint_u64().unwrap(), 3);
523 assert_eq!(reader.read_varint_i16().unwrap(), -1);
524 assert_eq!(reader.read_varint_i32().unwrap(), -2);
525 assert_eq!(reader.read_varint_i64().unwrap(), -3);
526
527 assert_eq!(reader.read_f32().unwrap(), 1.0);
528 assert_eq!(reader.read_f64().unwrap(), 2.0);
529
530 let mut new_bytes = vec![0u8; bytes.len()];
531 reader.read_exact(&mut new_bytes).unwrap();
532 assert_eq!(bytes, new_bytes);
533
534 let de_message: SimpleStruct = bincode::deserialize_from(&mut reader).unwrap();
535 assert_eq!(message, de_message);
536 }
537
538 #[test]
539 fn bit_writer_reader_test() {
540 let mut writer = BitWriter::default();
541
542 writer.write_bits(3, 2).unwrap();
543 writer.write_bits(5, 5).unwrap();
544 let bytes = vec![0, 1, 2, 3, 4, 5, 6, 7];
547 writer.write_all(&bytes).unwrap();
548
549 writer.write_bits(7, 12).unwrap();
550 writer.write_bits(1, 1).unwrap();
551
552 let writer_bytes = writer.consume().unwrap();
553 let mut reader = BitReader::new(&writer_bytes).unwrap();
554
555 assert_eq!(reader.read_bits(2).unwrap(), 3);
556 assert_eq!(reader.read_bits(5).unwrap(), 5);
557 let mut new_bytes = vec![0u8; bytes.len()];
558 reader.read_exact(&mut new_bytes).unwrap();
559 assert_eq!(new_bytes, bytes);
560 assert_eq!(reader.read_bits(12).unwrap(), 7);
561 assert_eq!(reader.read_bits(1).unwrap(), 1);
562 }
563
564 #[test]
565 fn bit_read_write_aligned() {
566 let mut writer = BitWriter::default();
567
568 let bytes = vec![0, 1, 2, 3, 4, 5, 6, 7];
569 writer.write_all(&bytes).unwrap();
570
571 let writer_bytes = writer.consume().unwrap();
572 let mut reader = BitReader::new(&writer_bytes).unwrap();
573
574 let mut new_bytes = vec![0u8; bytes.len()];
575 reader.read_exact(&mut new_bytes).unwrap();
576 assert_eq!(new_bytes, bytes);
577 }
578
579 #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
580 struct TestMessage {
581 value: u64,
582 array: [u16; 12],
583 string: String,
584 }
585
586 #[test]
587 fn bincode_aligned() {
588 let mut writer = BitWriter::default();
589
590 let message = TestMessage {
591 value: 999999999999,
592 array: [500; 12],
593 string: "just a test string".to_owned(),
594 };
595
596 bincode::serialize_into(&mut writer, &message).unwrap();
597
598 let writer_bytes = writer.consume().unwrap();
599 let mut reader = BitReader::new(&writer_bytes).unwrap();
600
601 let de_message: TestMessage = bincode::deserialize_from(&mut reader).unwrap();
602
603 assert_eq!(message, de_message);
604 }
605
606 #[test]
607 fn bincode_not_aligned() {
608 let mut writer = BitWriter::default();
609
610 let message = TestMessage {
611 value: 999999999999,
612 array: [500; 12],
613 string: "just a test string".to_owned(),
614 };
615
616 writer.write_bits(3, 5).unwrap();
617
618 bincode::serialize_into(&mut writer, &message).unwrap();
619
620 writer.write_bits(1, 2).unwrap();
621
622 let writer_bytes = writer.consume().unwrap();
623 let mut reader = BitReader::new(&writer_bytes).unwrap();
624
625 assert_eq!(reader.read_bits(5).unwrap(), 3);
626 let de_message: TestMessage = bincode::deserialize_from(&mut reader).unwrap();
627 assert_eq!(reader.read_bits(2).unwrap(), 1);
628
629 assert_eq!(message, de_message);
630 }
631
632 #[test]
633 fn varint_aligned() {
634 let mut writer = BitWriter::default();
635
636 writer.write_varint_u64(5).unwrap();
637 assert_eq!(writer.bits_written(), 8);
638
639 let high_number = 0xffa0000000000000u64;
640 writer.write_varint_u64(high_number).unwrap();
641 assert_eq!(writer.bits_written(), 8 + (9 * 8));
642
643 writer.write_varint_u32(320000).unwrap();
644 writer.write_varint_u16(16000).unwrap();
645
646 let high_negative_number = -0xffa000000000000i64;
647 writer.write_varint_i64(high_negative_number).unwrap();
648 writer.write_varint_i32(-320000).unwrap();
649 writer.write_varint_i16(-16000).unwrap();
650
651 let writer_bytes = writer.consume().unwrap();
652 let mut reader = BitReader::new(&writer_bytes).unwrap();
653
654 assert_eq!(reader.read_varint_u64().unwrap(), 5);
655 assert_eq!(reader.read_varint_u64().unwrap(), high_number);
656 assert_eq!(reader.read_varint_u32().unwrap(), 320000);
657 assert_eq!(reader.read_varint_u16().unwrap(), 16000);
658
659 assert_eq!(reader.read_varint_i64().unwrap(), high_negative_number);
660 assert_eq!(reader.read_varint_i32().unwrap(), -320000);
661 assert_eq!(reader.read_varint_i16().unwrap(), -16000);
662 }
663
664 #[test]
665 fn varint_not_aligned() {
666 let mut writer = BitWriter::default();
667
668 writer.write_bits(3, 5).unwrap();
669
670 writer.write_varint_u64(5).unwrap();
671
672 let high_number = 0xffa0000000000000u64;
673 writer.write_varint_u64(high_number).unwrap();
674
675 writer.write_varint_u32(320000).unwrap();
676 writer.write_varint_u16(16000).unwrap();
677
678 let high_negative_number = -0xffa000000000000i64;
679 writer.write_varint_i64(high_negative_number).unwrap();
680 writer.write_varint_i32(-320000).unwrap();
681 writer.write_varint_i16(-16000).unwrap();
682
683 let writer_bytes = writer.consume().unwrap();
684 let mut reader = BitReader::new(&writer_bytes).unwrap();
685
686 assert_eq!(reader.read_bits(5).unwrap(), 3);
687
688 assert_eq!(reader.read_varint_u64().unwrap(), 5);
689 assert_eq!(reader.read_varint_u64().unwrap(), high_number);
690 assert_eq!(reader.read_varint_u32().unwrap(), 320000);
691 assert_eq!(reader.read_varint_u16().unwrap(), 16000);
692
693 assert_eq!(reader.read_varint_i64().unwrap(), high_negative_number);
694 assert_eq!(reader.read_varint_i32().unwrap(), -320000);
695 assert_eq!(reader.read_varint_i16().unwrap(), -16000);
696 }
697
698 #[test]
699 fn bool() {
700 let mut writer = BitWriter::default();
701 writer.write_bool(true).unwrap();
702 writer.write_bool(false).unwrap();
703 writer.write_bool(true).unwrap();
704 writer.write_bool(true).unwrap();
705 writer.write_bool(false).unwrap();
706
707 let writer_bytes = writer.consume().unwrap();
708 let mut reader = BitReader::new(&writer_bytes).unwrap();
709
710 assert!(reader.read_bool().unwrap());
711 assert!(!reader.read_bool().unwrap());
712 assert!(reader.read_bool().unwrap());
713 assert!(reader.read_bool().unwrap());
714 assert!(!reader.read_bool().unwrap());
715 }
716
717 #[test]
718 fn float() {
719 let mut writer = BitWriter::default();
720 writer.write_bool(true).unwrap();
721
722 writer.write_f32(1234.5678).unwrap();
723 writer.write_f64(12345.6789).unwrap();
724
725 let writer_bytes = writer.consume().unwrap();
726 let mut reader = BitReader::new(&writer_bytes).unwrap();
727
728 assert!(reader.read_bool().unwrap());
729 assert_eq!(reader.read_f32().unwrap(), 1234.5678);
730 assert_eq!(reader.read_f64().unwrap(), 12345.6789);
731 }
732}