Skip to main content

mecab_ko_dict/matrix/
mod.rs

1//! # 연접 비용 행렬 (Connection Cost Matrix)
2//!
3//! 형태소 간 연접 비용을 저장하고 조회하는 모듈입니다.
4//!
5//! ## 포맷 지원
6//!
7//! - **텍스트 포맷** (`matrix.def`): `MeCab` 표준 형식
8//! - **바이너리 포맷** (`matrix.bin`): 고정 크기 i16 배열
9//! - **압축 포맷** (`matrix.bin.zst`): Zstd 압축 바이너리
10//!
11//! ## 예제
12//!
13//! ```rust,ignore
14//! use mecab_ko_dict::matrix::ConnectionMatrix;
15//!
16//! // 텍스트 파일에서 로드
17//! let matrix = ConnectionMatrix::from_def_file("matrix.def").unwrap();
18//!
19//! // 연접 비용 조회 (left_id=0, right_id=0)
20//! let cost = matrix.get(0, 0);
21//! ```
22//!
23//! ## 행렬 구조
24//!
25//! 연접 비용 행렬은 `lsize x rsize` 크기의 2차원 배열입니다.
26//! - `lsize`: 좌문맥 ID 개수
27//! - `rsize`: 우문맥 ID 개수
28//! - 접근: `matrix[right_id + lsize * left_id]`
29
30use std::io::{self};
31use std::path::Path;
32
33use byteorder::{LittleEndian, ReadBytesExt};
34
35use crate::error::{DictError, Result};
36
37mod dense;
38mod mmap;
39mod sparse;
40
41pub use dense::DenseMatrix;
42pub use mmap::MmapMatrix;
43pub use sparse::SparseMatrix;
44
45// SIMD 최적화 모듈
46#[cfg(feature = "simd")]
47pub mod simd;
48
49#[cfg(feature = "simd")]
50pub use simd::SimdMatrix;
51
52pub(super) const MATRIX_HEADER_SIZE: usize = 4;
53
54pub(super) const MKM3_MAGIC: &[u8; 4] = b"MKM3";
55pub(super) const MKM3_HEADER_SIZE: usize = 16;
56
57/// 행렬 헤더 정보
58pub(super) struct MatrixHeader {
59    /// 좌문맥 크기
60    pub(super) lsize: usize,
61    /// 우문맥 크기
62    pub(super) rsize: usize,
63    /// 헤더 크기 (v2: 4, v3: 16)
64    pub(super) header_size: usize,
65}
66
67/// 행렬 헤더를 파싱하는 내부 함수
68///
69/// v2/v3 포맷을 자동 감지하고 헤더 정보를 추출합니다.
70///
71/// # Arguments
72///
73/// * `data` - 파싱할 바이트 데이터 (헤더 크기 이상)
74///
75/// # Returns
76///
77/// 성공 시 `MatrixHeader`, 형식 오류 시 에러
78pub(super) fn parse_matrix_header(data: &[u8]) -> Result<MatrixHeader> {
79    let is_v3 = data.len() >= 4 && &data[..4] == MKM3_MAGIC;
80    let header_size = if is_v3 {
81        MKM3_HEADER_SIZE
82    } else {
83        MATRIX_HEADER_SIZE
84    };
85
86    if data.len() < header_size {
87        return Err(DictError::Format(
88            "Matrix binary too short for header".to_string(),
89        ));
90    }
91
92    let mut cursor = io::Cursor::new(data);
93
94    let (lsize, rsize) = if is_v3 {
95        cursor.set_position(4);
96        let _version = cursor.read_u8().map_err(DictError::Io)?;
97        let _flags = cursor.read_u8().map_err(DictError::Io)?;
98        let _reserved = cursor.read_u16::<LittleEndian>().map_err(DictError::Io)?;
99        let l = cursor.read_u32::<LittleEndian>().map_err(DictError::Io)? as usize;
100        let r = cursor.read_u32::<LittleEndian>().map_err(DictError::Io)? as usize;
101        (l, r)
102    } else {
103        let l = cursor.read_u16::<LittleEndian>().map_err(DictError::Io)? as usize;
104        let r = cursor.read_u16::<LittleEndian>().map_err(DictError::Io)? as usize;
105        (l, r)
106    };
107
108    Ok(MatrixHeader {
109        lsize,
110        rsize,
111        header_size,
112    })
113}
114
115/// 기본 비용 (연결 불가능한 경우)
116pub const INVALID_CONNECTION_COST: i32 = i32::MAX;
117
118/// 연접 비용 행렬 인터페이스
119///
120/// 형태소 간 연접 비용을 조회하는 인터페이스입니다.
121/// mecab-ko-core의 `ConnectionCost` trait과 호환됩니다.
122pub trait Matrix {
123    /// 연접 비용 조회
124    ///
125    /// # Arguments
126    ///
127    /// * `right_id` - 이전 노드의 우문맥 ID (right context ID)
128    /// * `left_id` - 현재 노드의 좌문맥 ID (left context ID)
129    ///
130    /// # Returns
131    ///
132    /// 연접 비용 (i32). 연결 불가능한 경우 `INVALID_CONNECTION_COST` 반환
133    fn get(&self, right_id: u16, left_id: u16) -> i32;
134
135    /// 좌문맥 크기
136    fn left_size(&self) -> usize;
137
138    /// 우문맥 크기
139    fn right_size(&self) -> usize;
140
141    /// 전체 엔트리 수
142    fn entry_count(&self) -> usize {
143        self.left_size() * self.right_size()
144    }
145}
146
147/// 연접 비용 행렬 로더
148///
149/// 다양한 포맷에서 연접 비용 행렬을 로드합니다.
150pub struct MatrixLoader;
151
152impl MatrixLoader {
153    /// 자동 포맷 감지 로드
154    ///
155    /// 파일 확장자에 따라 적절한 로더를 선택합니다.
156    /// - `.def`: 텍스트 포맷
157    /// - `.bin`: 바이너리 포맷
158    /// - `.bin.zst`, `.zst`: 압축 바이너리 포맷
159    ///
160    /// # Errors
161    ///
162    /// 파일을 읽거나 파싱할 수 없는 경우 에러를 반환합니다.
163    pub fn load<P: AsRef<Path>>(path: P) -> Result<DenseMatrix> {
164        let path = path.as_ref();
165        let path_str = path.to_string_lossy();
166
167        if path_str.ends_with(".def") {
168            DenseMatrix::from_def_file(path)
169        } else if path_str.ends_with(".zst") || path_str.ends_with(".bin.zst") {
170            DenseMatrix::from_compressed_file(path)
171        } else if path_str.ends_with(".bin") {
172            DenseMatrix::from_bin_file(path)
173        } else {
174            // 기본: 바이너리 시도 후 텍스트 시도
175            DenseMatrix::from_bin_file(path).or_else(|_| DenseMatrix::from_def_file(path))
176        }
177    }
178
179    /// 메모리 맵으로 로드 (바이너리 파일만 지원)
180    ///
181    /// # Errors
182    ///
183    /// 파일을 읽거나 메모리 맵을 생성할 수 없는 경우 에러를 반환합니다.
184    pub fn load_mmap<P: AsRef<Path>>(path: P) -> Result<MmapMatrix> {
185        MmapMatrix::from_file(path)
186    }
187}
188
189/// 연접 비용 행렬을 위한 통합 타입
190///
191/// 다양한 행렬 구현을 하나의 타입으로 사용할 수 있습니다.
192pub enum ConnectionMatrix {
193    /// 밀집 행렬
194    Dense(DenseMatrix),
195    /// 희소 행렬
196    Sparse(SparseMatrix),
197    /// 메모리 맵 행렬
198    Mmap(MmapMatrix),
199}
200
201impl ConnectionMatrix {
202    /// 텍스트 파일에서 로드
203    ///
204    /// # Errors
205    ///
206    /// 파일을 읽거나 파싱할 수 없는 경우 에러를 반환합니다.
207    pub fn from_def_file<P: AsRef<Path>>(path: P) -> Result<Self> {
208        Ok(Self::Dense(DenseMatrix::from_def_file(path)?))
209    }
210
211    /// 바이너리 파일에서 로드
212    ///
213    /// # Errors
214    ///
215    /// 파일을 읽거나 파싱할 수 없는 경우 에러를 반환합니다.
216    pub fn from_bin_file<P: AsRef<Path>>(path: P) -> Result<Self> {
217        Ok(Self::Dense(DenseMatrix::from_bin_file(path)?))
218    }
219
220    /// 메모리 맵으로 로드
221    ///
222    /// # Errors
223    ///
224    /// 파일을 읽거나 메모리 맵을 생성할 수 없는 경우 에러를 반환합니다.
225    pub fn from_mmap_file<P: AsRef<Path>>(path: P) -> Result<Self> {
226        Ok(Self::Mmap(MmapMatrix::from_file(path)?))
227    }
228
229    /// 압축된 바이너리 파일에서 로드 (.zst)
230    ///
231    /// # Errors
232    ///
233    /// 파일을 읽거나 압축 해제/파싱할 수 없는 경우 에러를 반환합니다.
234    pub fn from_compressed_file<P: AsRef<Path>>(path: P) -> Result<Self> {
235        Ok(Self::Dense(DenseMatrix::from_compressed_file(path)?))
236    }
237
238    /// 자동 포맷 감지 로드
239    ///
240    /// # Errors
241    ///
242    /// 파일을 읽거나 파싱할 수 없는 경우 에러를 반환합니다.
243    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
244        Ok(Self::Dense(MatrixLoader::load(path)?))
245    }
246}
247
248impl Matrix for ConnectionMatrix {
249    #[inline(always)]
250    fn get(&self, right_id: u16, left_id: u16) -> i32 {
251        match self {
252            Self::Dense(m) => m.get(right_id, left_id),
253            Self::Sparse(m) => m.get(right_id, left_id),
254            Self::Mmap(m) => m.get(right_id, left_id),
255        }
256    }
257
258    fn left_size(&self) -> usize {
259        match self {
260            Self::Dense(m) => m.left_size(),
261            Self::Sparse(m) => m.left_size(),
262            Self::Mmap(m) => m.left_size(),
263        }
264    }
265
266    fn right_size(&self) -> usize {
267        match self {
268            Self::Dense(m) => m.right_size(),
269            Self::Sparse(m) => m.right_size(),
270            Self::Mmap(m) => m.right_size(),
271        }
272    }
273}
274
275#[cfg(test)]
276#[allow(clippy::expect_used, clippy::unwrap_used, clippy::cast_lossless)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn test_dense_matrix_new() {
282        let matrix = DenseMatrix::new(10, 10, 0);
283        assert_eq!(matrix.left_size(), 10);
284        assert_eq!(matrix.right_size(), 10);
285        assert_eq!(matrix.entry_count(), 100);
286        assert_eq!(matrix.get(0, 0), 0);
287    }
288
289    #[test]
290    fn test_dense_matrix_set_get() {
291        let mut matrix = DenseMatrix::new(10, 10, 0);
292        matrix.set(3, 5, 100);
293        assert_eq!(matrix.get(3, 5), 100);
294        assert_eq!(matrix.get(5, 3), 0);
295    }
296
297    #[test]
298    fn test_dense_matrix_from_vec() {
299        let costs = vec![1, 2, 3, 4, 5, 6];
300        let matrix = DenseMatrix::from_vec(2, 3, costs).unwrap();
301        // costs[right_id + lsize * left_id]
302        // (0,0) = costs[0] = 1
303        // (1,0) = costs[1] = 2
304        // (0,1) = costs[2] = 3
305        // (1,1) = costs[3] = 4
306        // (0,2) = costs[4] = 5
307        // (1,2) = costs[5] = 6
308        assert_eq!(matrix.get(0, 0), 1);
309        assert_eq!(matrix.get(1, 0), 2);
310        assert_eq!(matrix.get(0, 1), 3);
311        assert_eq!(matrix.get(1, 1), 4);
312        assert_eq!(matrix.get(0, 2), 5);
313        assert_eq!(matrix.get(1, 2), 6);
314    }
315
316    #[test]
317    fn test_dense_matrix_from_vec_size_mismatch() {
318        let costs = vec![1, 2, 3];
319        let result = DenseMatrix::from_vec(2, 3, costs);
320        assert!(result.is_err());
321    }
322
323    #[test]
324    fn test_dense_matrix_boundary() {
325        let matrix = DenseMatrix::new(10, 10, 0);
326        // 경계 외 접근
327        assert_eq!(matrix.get(100, 100), INVALID_CONNECTION_COST);
328    }
329
330    #[test]
331    fn test_dense_matrix_def_reader() {
332        let data = "3 3\n0 0 100\n1 1 200\n2 2 300\n";
333        let reader = std::io::Cursor::new(data);
334        let matrix = DenseMatrix::from_def_reader(reader).unwrap();
335
336        assert_eq!(matrix.left_size(), 3);
337        assert_eq!(matrix.right_size(), 3);
338        assert_eq!(matrix.get(0, 0), 100);
339        assert_eq!(matrix.get(1, 1), 200);
340        assert_eq!(matrix.get(2, 2), 300);
341        // 설정되지 않은 값은 i16::MAX
342        assert_eq!(matrix.get(0, 1), i16::MAX as i32);
343    }
344
345    #[test]
346    fn test_dense_matrix_binary_roundtrip() {
347        let mut matrix = DenseMatrix::new(5, 5, 0);
348        matrix.set(0, 0, 100);
349        matrix.set(1, 2, -500);
350        matrix.set(4, 4, 32767);
351
352        let bytes = matrix.to_bin_bytes();
353        let loaded = DenseMatrix::from_bin_bytes(&bytes).unwrap();
354
355        assert_eq!(loaded.left_size(), 5);
356        assert_eq!(loaded.right_size(), 5);
357        assert_eq!(loaded.get(0, 0), 100);
358        assert_eq!(loaded.get(1, 2), -500);
359        assert_eq!(loaded.get(4, 4), 32767);
360    }
361
362    #[test]
363    fn test_sparse_matrix() {
364        let mut sparse = SparseMatrix::new(100, 100, 0);
365        sparse.set(10, 20, 500);
366        sparse.set(50, 50, -100);
367
368        assert_eq!(sparse.get(10, 20), 500);
369        assert_eq!(sparse.get(50, 50), -100);
370        assert_eq!(sparse.get(0, 0), 0); // 기본값
371
372        assert_eq!(sparse.entry_count_stored(), 2);
373        assert!(sparse.sparsity() > 0.99); // 거의 희소
374    }
375
376    #[test]
377    fn test_sparse_dense_conversion() {
378        let mut dense = DenseMatrix::new(10, 10, 0);
379        dense.set(3, 3, 100);
380        dense.set(5, 7, 200);
381
382        let sparse = SparseMatrix::from_dense(&dense, 0);
383        assert_eq!(sparse.entry_count_stored(), 2);
384        assert_eq!(sparse.get(3, 3), 100);
385        assert_eq!(sparse.get(5, 7), 200);
386
387        let converted = sparse.to_dense();
388        assert_eq!(converted.get(3, 3), 100);
389        assert_eq!(converted.get(5, 7), 200);
390        assert_eq!(converted.get(0, 0), 0);
391    }
392
393    #[test]
394    fn test_memory_size() {
395        let dense = DenseMatrix::new(100, 100, 0);
396        let mem_size = dense.memory_size();
397        // 최소 20000 바이트 (100*100*2)
398        assert!(mem_size >= 20000);
399
400        let sparse = SparseMatrix::new(100, 100, 0);
401        let sparse_size = sparse.memory_size();
402        // 희소 행렬은 훨씬 작음
403        assert!(sparse_size < mem_size);
404    }
405
406    #[test]
407    fn test_connection_matrix_enum() {
408        let dense = DenseMatrix::new(5, 5, 100);
409        let matrix = ConnectionMatrix::Dense(dense);
410
411        assert_eq!(matrix.left_size(), 5);
412        assert_eq!(matrix.right_size(), 5);
413        assert_eq!(matrix.get(0, 0), 100);
414    }
415
416    #[test]
417    fn test_large_matrix() {
418        // mecab-ko-dic의 실제 크기 (약 2800 x 2800)
419        let matrix = DenseMatrix::new(178, 178, 0);
420        assert_eq!(matrix.entry_count(), 178 * 178);
421        assert_eq!(
422            matrix.memory_size(),
423            std::mem::size_of::<DenseMatrix>() + 178 * 178 * 2
424        );
425    }
426
427    #[test]
428    fn test_def_with_comments_and_empty_lines() {
429        let data = "2 2\n# This is a comment\n\n0 0 10\n0 1 20\n\n1 0 30\n1 1 40\n";
430        let reader = std::io::Cursor::new(data);
431        let matrix = DenseMatrix::from_def_reader(reader).unwrap();
432
433        assert_eq!(matrix.get(0, 0), 10);
434        assert_eq!(matrix.get(0, 1), 20);
435        assert_eq!(matrix.get(1, 0), 30);
436        assert_eq!(matrix.get(1, 1), 40);
437    }
438
439    #[test]
440    fn test_v3_header_roundtrip() {
441        let mut matrix = DenseMatrix::new(5, 5, 0);
442        matrix.set(0, 0, 42);
443        matrix.set(2, 3, -999);
444        matrix.set(4, 4, 32767);
445
446        let bytes = matrix.to_bin_bytes_v3();
447        assert_eq!(&bytes[..4], b"MKM3");
448        assert_eq!(bytes[4], 1);
449        assert_eq!(bytes[5], 0);
450
451        let loaded = DenseMatrix::from_bin_bytes(&bytes).unwrap();
452        assert_eq!(loaded.left_size(), 5);
453        assert_eq!(loaded.right_size(), 5);
454        assert_eq!(loaded.get(0, 0), 42);
455        assert_eq!(loaded.get(2, 3), -999);
456        assert_eq!(loaded.get(4, 4), 32767);
457    }
458
459    #[test]
460    fn test_v2_backward_compat() {
461        let mut matrix = DenseMatrix::new(4, 4, 0);
462        matrix.set(1, 2, 777);
463
464        let bytes = matrix.to_bin_bytes();
465        assert_ne!(&bytes[..4], b"MKM3");
466
467        let loaded = DenseMatrix::from_bin_bytes(&bytes).unwrap();
468        assert_eq!(loaded.left_size(), 4);
469        assert_eq!(loaded.right_size(), 4);
470        assert_eq!(loaded.get(1, 2), 777);
471    }
472
473    #[test]
474    fn test_v3_large_dimensions() {
475        let lsize = (u16::MAX as usize) + 1;
476        let rsize = 1;
477        let costs = vec![0i16; lsize * rsize];
478        let matrix = DenseMatrix::from_vec(lsize, rsize, costs).unwrap();
479
480        let bytes = matrix.to_bin_bytes_v3();
481        let loaded = DenseMatrix::from_bin_bytes(&bytes).unwrap();
482        assert_eq!(loaded.left_size(), lsize);
483        assert_eq!(loaded.right_size(), rsize);
484    }
485}