1use std::io::{self};
31use std::path::Path;
32
33use byteorder::{LittleEndian, ReadBytesExt};
34
35use crate::error::{DictError, Result};
36
37mod dense;
38mod mmap;
39mod sparse;
40
41pub use dense::DenseMatrix;
42pub use mmap::MmapMatrix;
43pub use sparse::SparseMatrix;
44
45#[cfg(feature = "simd")]
47pub mod simd;
48
49#[cfg(feature = "simd")]
50pub use simd::SimdMatrix;
51
52pub(super) const MATRIX_HEADER_SIZE: usize = 4;
53
54pub(super) const MKM3_MAGIC: &[u8; 4] = b"MKM3";
55pub(super) const MKM3_HEADER_SIZE: usize = 16;
56
57pub(super) struct MatrixHeader {
59 pub(super) lsize: usize,
61 pub(super) rsize: usize,
63 pub(super) header_size: usize,
65}
66
67pub(super) fn parse_matrix_header(data: &[u8]) -> Result<MatrixHeader> {
79 let is_v3 = data.len() >= 4 && &data[..4] == MKM3_MAGIC;
80 let header_size = if is_v3 {
81 MKM3_HEADER_SIZE
82 } else {
83 MATRIX_HEADER_SIZE
84 };
85
86 if data.len() < header_size {
87 return Err(DictError::Format(
88 "Matrix binary too short for header".to_string(),
89 ));
90 }
91
92 let mut cursor = io::Cursor::new(data);
93
94 let (lsize, rsize) = if is_v3 {
95 cursor.set_position(4);
96 let _version = cursor.read_u8().map_err(DictError::Io)?;
97 let _flags = cursor.read_u8().map_err(DictError::Io)?;
98 let _reserved = cursor.read_u16::<LittleEndian>().map_err(DictError::Io)?;
99 let l = cursor.read_u32::<LittleEndian>().map_err(DictError::Io)? as usize;
100 let r = cursor.read_u32::<LittleEndian>().map_err(DictError::Io)? as usize;
101 (l, r)
102 } else {
103 let l = cursor.read_u16::<LittleEndian>().map_err(DictError::Io)? as usize;
104 let r = cursor.read_u16::<LittleEndian>().map_err(DictError::Io)? as usize;
105 (l, r)
106 };
107
108 Ok(MatrixHeader {
109 lsize,
110 rsize,
111 header_size,
112 })
113}
114
115pub const INVALID_CONNECTION_COST: i32 = i32::MAX;
117
118pub trait Matrix {
123 fn get(&self, right_id: u16, left_id: u16) -> i32;
134
135 fn left_size(&self) -> usize;
137
138 fn right_size(&self) -> usize;
140
141 fn entry_count(&self) -> usize {
143 self.left_size() * self.right_size()
144 }
145}
146
147pub struct MatrixLoader;
151
152impl MatrixLoader {
153 pub fn load<P: AsRef<Path>>(path: P) -> Result<DenseMatrix> {
164 let path = path.as_ref();
165 let path_str = path.to_string_lossy();
166
167 if path_str.ends_with(".def") {
168 DenseMatrix::from_def_file(path)
169 } else if path_str.ends_with(".zst") || path_str.ends_with(".bin.zst") {
170 DenseMatrix::from_compressed_file(path)
171 } else if path_str.ends_with(".bin") {
172 DenseMatrix::from_bin_file(path)
173 } else {
174 DenseMatrix::from_bin_file(path).or_else(|_| DenseMatrix::from_def_file(path))
176 }
177 }
178
179 pub fn load_mmap<P: AsRef<Path>>(path: P) -> Result<MmapMatrix> {
185 MmapMatrix::from_file(path)
186 }
187}
188
189pub enum ConnectionMatrix {
193 Dense(DenseMatrix),
195 Sparse(SparseMatrix),
197 Mmap(MmapMatrix),
199}
200
201impl ConnectionMatrix {
202 pub fn from_def_file<P: AsRef<Path>>(path: P) -> Result<Self> {
208 Ok(Self::Dense(DenseMatrix::from_def_file(path)?))
209 }
210
211 pub fn from_bin_file<P: AsRef<Path>>(path: P) -> Result<Self> {
217 Ok(Self::Dense(DenseMatrix::from_bin_file(path)?))
218 }
219
220 pub fn from_mmap_file<P: AsRef<Path>>(path: P) -> Result<Self> {
226 Ok(Self::Mmap(MmapMatrix::from_file(path)?))
227 }
228
229 pub fn from_compressed_file<P: AsRef<Path>>(path: P) -> Result<Self> {
235 Ok(Self::Dense(DenseMatrix::from_compressed_file(path)?))
236 }
237
238 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
244 Ok(Self::Dense(MatrixLoader::load(path)?))
245 }
246}
247
248impl Matrix for ConnectionMatrix {
249 #[inline(always)]
250 fn get(&self, right_id: u16, left_id: u16) -> i32 {
251 match self {
252 Self::Dense(m) => m.get(right_id, left_id),
253 Self::Sparse(m) => m.get(right_id, left_id),
254 Self::Mmap(m) => m.get(right_id, left_id),
255 }
256 }
257
258 fn left_size(&self) -> usize {
259 match self {
260 Self::Dense(m) => m.left_size(),
261 Self::Sparse(m) => m.left_size(),
262 Self::Mmap(m) => m.left_size(),
263 }
264 }
265
266 fn right_size(&self) -> usize {
267 match self {
268 Self::Dense(m) => m.right_size(),
269 Self::Sparse(m) => m.right_size(),
270 Self::Mmap(m) => m.right_size(),
271 }
272 }
273}
274
275#[cfg(test)]
276#[allow(clippy::expect_used, clippy::unwrap_used, clippy::cast_lossless)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn test_dense_matrix_new() {
282 let matrix = DenseMatrix::new(10, 10, 0);
283 assert_eq!(matrix.left_size(), 10);
284 assert_eq!(matrix.right_size(), 10);
285 assert_eq!(matrix.entry_count(), 100);
286 assert_eq!(matrix.get(0, 0), 0);
287 }
288
289 #[test]
290 fn test_dense_matrix_set_get() {
291 let mut matrix = DenseMatrix::new(10, 10, 0);
292 matrix.set(3, 5, 100);
293 assert_eq!(matrix.get(3, 5), 100);
294 assert_eq!(matrix.get(5, 3), 0);
295 }
296
297 #[test]
298 fn test_dense_matrix_from_vec() {
299 let costs = vec![1, 2, 3, 4, 5, 6];
300 let matrix = DenseMatrix::from_vec(2, 3, costs).unwrap();
301 assert_eq!(matrix.get(0, 0), 1);
309 assert_eq!(matrix.get(1, 0), 2);
310 assert_eq!(matrix.get(0, 1), 3);
311 assert_eq!(matrix.get(1, 1), 4);
312 assert_eq!(matrix.get(0, 2), 5);
313 assert_eq!(matrix.get(1, 2), 6);
314 }
315
316 #[test]
317 fn test_dense_matrix_from_vec_size_mismatch() {
318 let costs = vec![1, 2, 3];
319 let result = DenseMatrix::from_vec(2, 3, costs);
320 assert!(result.is_err());
321 }
322
323 #[test]
324 fn test_dense_matrix_boundary() {
325 let matrix = DenseMatrix::new(10, 10, 0);
326 assert_eq!(matrix.get(100, 100), INVALID_CONNECTION_COST);
328 }
329
330 #[test]
331 fn test_dense_matrix_def_reader() {
332 let data = "3 3\n0 0 100\n1 1 200\n2 2 300\n";
333 let reader = std::io::Cursor::new(data);
334 let matrix = DenseMatrix::from_def_reader(reader).unwrap();
335
336 assert_eq!(matrix.left_size(), 3);
337 assert_eq!(matrix.right_size(), 3);
338 assert_eq!(matrix.get(0, 0), 100);
339 assert_eq!(matrix.get(1, 1), 200);
340 assert_eq!(matrix.get(2, 2), 300);
341 assert_eq!(matrix.get(0, 1), i16::MAX as i32);
343 }
344
345 #[test]
346 fn test_dense_matrix_binary_roundtrip() {
347 let mut matrix = DenseMatrix::new(5, 5, 0);
348 matrix.set(0, 0, 100);
349 matrix.set(1, 2, -500);
350 matrix.set(4, 4, 32767);
351
352 let bytes = matrix.to_bin_bytes();
353 let loaded = DenseMatrix::from_bin_bytes(&bytes).unwrap();
354
355 assert_eq!(loaded.left_size(), 5);
356 assert_eq!(loaded.right_size(), 5);
357 assert_eq!(loaded.get(0, 0), 100);
358 assert_eq!(loaded.get(1, 2), -500);
359 assert_eq!(loaded.get(4, 4), 32767);
360 }
361
362 #[test]
363 fn test_sparse_matrix() {
364 let mut sparse = SparseMatrix::new(100, 100, 0);
365 sparse.set(10, 20, 500);
366 sparse.set(50, 50, -100);
367
368 assert_eq!(sparse.get(10, 20), 500);
369 assert_eq!(sparse.get(50, 50), -100);
370 assert_eq!(sparse.get(0, 0), 0); assert_eq!(sparse.entry_count_stored(), 2);
373 assert!(sparse.sparsity() > 0.99); }
375
376 #[test]
377 fn test_sparse_dense_conversion() {
378 let mut dense = DenseMatrix::new(10, 10, 0);
379 dense.set(3, 3, 100);
380 dense.set(5, 7, 200);
381
382 let sparse = SparseMatrix::from_dense(&dense, 0);
383 assert_eq!(sparse.entry_count_stored(), 2);
384 assert_eq!(sparse.get(3, 3), 100);
385 assert_eq!(sparse.get(5, 7), 200);
386
387 let converted = sparse.to_dense();
388 assert_eq!(converted.get(3, 3), 100);
389 assert_eq!(converted.get(5, 7), 200);
390 assert_eq!(converted.get(0, 0), 0);
391 }
392
393 #[test]
394 fn test_memory_size() {
395 let dense = DenseMatrix::new(100, 100, 0);
396 let mem_size = dense.memory_size();
397 assert!(mem_size >= 20000);
399
400 let sparse = SparseMatrix::new(100, 100, 0);
401 let sparse_size = sparse.memory_size();
402 assert!(sparse_size < mem_size);
404 }
405
406 #[test]
407 fn test_connection_matrix_enum() {
408 let dense = DenseMatrix::new(5, 5, 100);
409 let matrix = ConnectionMatrix::Dense(dense);
410
411 assert_eq!(matrix.left_size(), 5);
412 assert_eq!(matrix.right_size(), 5);
413 assert_eq!(matrix.get(0, 0), 100);
414 }
415
416 #[test]
417 fn test_large_matrix() {
418 let matrix = DenseMatrix::new(178, 178, 0);
420 assert_eq!(matrix.entry_count(), 178 * 178);
421 assert_eq!(
422 matrix.memory_size(),
423 std::mem::size_of::<DenseMatrix>() + 178 * 178 * 2
424 );
425 }
426
427 #[test]
428 fn test_def_with_comments_and_empty_lines() {
429 let data = "2 2\n# This is a comment\n\n0 0 10\n0 1 20\n\n1 0 30\n1 1 40\n";
430 let reader = std::io::Cursor::new(data);
431 let matrix = DenseMatrix::from_def_reader(reader).unwrap();
432
433 assert_eq!(matrix.get(0, 0), 10);
434 assert_eq!(matrix.get(0, 1), 20);
435 assert_eq!(matrix.get(1, 0), 30);
436 assert_eq!(matrix.get(1, 1), 40);
437 }
438
439 #[test]
440 fn test_v3_header_roundtrip() {
441 let mut matrix = DenseMatrix::new(5, 5, 0);
442 matrix.set(0, 0, 42);
443 matrix.set(2, 3, -999);
444 matrix.set(4, 4, 32767);
445
446 let bytes = matrix.to_bin_bytes_v3();
447 assert_eq!(&bytes[..4], b"MKM3");
448 assert_eq!(bytes[4], 1);
449 assert_eq!(bytes[5], 0);
450
451 let loaded = DenseMatrix::from_bin_bytes(&bytes).unwrap();
452 assert_eq!(loaded.left_size(), 5);
453 assert_eq!(loaded.right_size(), 5);
454 assert_eq!(loaded.get(0, 0), 42);
455 assert_eq!(loaded.get(2, 3), -999);
456 assert_eq!(loaded.get(4, 4), 32767);
457 }
458
459 #[test]
460 fn test_v2_backward_compat() {
461 let mut matrix = DenseMatrix::new(4, 4, 0);
462 matrix.set(1, 2, 777);
463
464 let bytes = matrix.to_bin_bytes();
465 assert_ne!(&bytes[..4], b"MKM3");
466
467 let loaded = DenseMatrix::from_bin_bytes(&bytes).unwrap();
468 assert_eq!(loaded.left_size(), 4);
469 assert_eq!(loaded.right_size(), 4);
470 assert_eq!(loaded.get(1, 2), 777);
471 }
472
473 #[test]
474 fn test_v3_large_dimensions() {
475 let lsize = (u16::MAX as usize) + 1;
476 let rsize = 1;
477 let costs = vec![0i16; lsize * rsize];
478 let matrix = DenseMatrix::from_vec(lsize, rsize, costs).unwrap();
479
480 let bytes = matrix.to_bin_bytes_v3();
481 let loaded = DenseMatrix::from_bin_bytes(&bytes).unwrap();
482 assert_eq!(loaded.left_size(), lsize);
483 assert_eq!(loaded.right_size(), rsize);
484 }
485}