small refactor of model loading
[carveJwlIkooP6JGAAIwe30JlM.git] / network_compression.h
1 #ifndef NETWORK_COMPRESSION_H
2 #define NETWORK_COMPRESSION_H
3
4 #include "vg/vg_platform.h"
5 #include "vg/vg_m.h"
6
7 typedef struct bitpack_ctx bitpack_ctx;
8 struct bitpack_ctx {
9 enum bitpack_mode {
10 k_bitpack_compress,
11 k_bitpack_decompress
12 }
13 mode;
14
15 u8 *buffer;
16 u32 bytes, buffer_len;
17 };
18
19 static void bitpack_bytes( bitpack_ctx *ctx, u32 bytes, void *data ){
20 u8 *ext = data;
21 for( u32 i=0; i<bytes; i++ ){
22 u32 index = ctx->bytes+i;
23 if( ctx->mode == k_bitpack_compress ){
24 if( index < ctx->buffer_len )
25 ctx->buffer[index] = ext[i];
26 }
27 else{
28 if( index < ctx->buffer_len )
29 ext[i] = ctx->buffer[index];
30 else
31 ext[i] = 0x00;
32 }
33 }
34 ctx->bytes += bytes;
35 }
36
37 static u32 bitpack_qf32( bitpack_ctx *ctx, u32 bits,
38 f32 min, f32 max, f32 *v ){
39 u32 mask = (0x1 << bits) - 1;
40
41 if( ctx->mode == k_bitpack_compress ){
42 u32 a = vg_quantf( *v, bits, min, max );
43 bitpack_bytes( ctx, bits/8, &a );
44 return a;
45 }
46 else {
47 u32 a = 0;
48 bitpack_bytes( ctx, bits/8, &a );
49 *v = vg_dequantf( a, bits, min, max );
50 return a;
51 }
52 }
53
54 static void bitpack_qv2f( bitpack_ctx *ctx, u32 bits,
55 f32 min, f32 max, v2f v ){
56 for( u32 i=0; i<2; i ++ )
57 bitpack_qf32( ctx, bits, min, max, v+i );
58 }
59
60 static void bitpack_qv3f( bitpack_ctx *ctx, u32 bits,
61 f32 min, f32 max, v3f v ){
62 for( u32 i=0; i<3; i ++ )
63 bitpack_qf32( ctx, bits, min, max, v+i );
64 }
65
66 static void bitpack_qv4f( bitpack_ctx *ctx, u32 bits,
67 f32 min, f32 max, v4f v ){
68 for( u32 i=0; i<4; i ++ )
69 bitpack_qf32( ctx, bits, min, max, v+i );
70 }
71
72 static void bitpack_qquat( bitpack_ctx *ctx, v4f quat ){
73 const f32 k_domain = 0.70710678118f;
74
75 if( ctx->mode == k_bitpack_compress ){
76 v4f qabs;
77 for( u32 i=0; i<4; i++ )
78 qabs[i] = fabsf(quat[i]);
79
80 u32 lxy = qabs[1]>qabs[0],
81 lzw = (qabs[3]>qabs[2])+2,
82 l = qabs[lzw]>qabs[lxy]? lzw: lxy;
83
84 f32 sign = vg_signf(quat[l]);
85
86 u32 smallest[3];
87 for( u32 i=0, j=0; i<4; i ++ )
88 if( i != l )
89 smallest[j ++] = vg_quantf( quat[i]*sign, 10, -k_domain, k_domain );
90
91 u32 comp = (smallest[0]<<2) | (smallest[1]<<12) | (smallest[2]<<22) | l;
92 bitpack_bytes( ctx, 4, &comp );
93 }
94 else {
95 u32 comp;
96 bitpack_bytes( ctx, 4, &comp );
97
98 u32 smallest[3] = {(comp>>2 )&0x3ff,
99 (comp>>12)&0x3ff,
100 (comp>>22)&0x3ff},
101 l = comp & 0x3;
102
103 f32 m = 1.0f;
104
105 for( u32 i=0, j=0; i<4; i ++ ){
106 if( i != l ){
107 quat[i] = vg_dequantf( smallest[j ++], 10, -k_domain, k_domain );
108 m -= quat[i]*quat[i];
109 }
110 }
111
112 quat[l] = sqrtf(m);
113 q_normalize( quat );
114 }
115 }
116
117 #endif /* NETWORK_COMPRESSION_H */