Skip to main content

mecab_ko_dict/matrix/
dense.rs

1//! 밀집 연접 비용 행렬 (Dense Connection Cost Matrix)
2
3use std::io::{self, BufRead, BufReader};
4#[cfg(feature = "zstd")]
5use std::io::{Read, Write as IoWrite};
6use std::path::Path;
7
8use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
9
10use crate::error::{DictError, Result};
11
12use super::{
13    parse_matrix_header, Matrix, INVALID_CONNECTION_COST, MATRIX_HEADER_SIZE, MKM3_HEADER_SIZE,
14    MKM3_MAGIC,
15};
16
17/// 밀집 연접 비용 행렬 (Dense Matrix)
18///
19/// 모든 연접 비용을 메모리에 저장하는 구현입니다.
20/// 희소 행렬이 아닌 경우에 적합합니다.
21#[derive(Debug, Clone)]
22pub struct DenseMatrix {
23    /// 좌문맥 크기
24    pub(super) lsize: usize,
25    /// 우문맥 크기
26    pub(super) rsize: usize,
27    /// 비용 배열 (row-major: costs[`right_id` + lsize * `left_id`])
28    pub(super) costs: Vec<i16>,
29}
30
31impl DenseMatrix {
32    /// 새로운 밀집 행렬 생성 (모든 값을 기본값으로 초기화)
33    ///
34    /// # Arguments
35    ///
36    /// * `lsize` - 좌문맥 크기
37    /// * `rsize` - 우문맥 크기
38    /// * `default_cost` - 기본 비용 값
39    #[must_use]
40    pub fn new(lsize: usize, rsize: usize, default_cost: i16) -> Self {
41        let costs = vec![default_cost; lsize * rsize];
42        Self {
43            lsize,
44            rsize,
45            costs,
46        }
47    }
48
49    /// 기존 비용 벡터로 밀집 행렬 생성
50    ///
51    /// # Arguments
52    ///
53    /// * `lsize` - 좌문맥 크기
54    /// * `rsize` - 우문맥 크기
55    /// * `costs` - 비용 배열
56    ///
57    /// # Returns
58    ///
59    /// 성공 시 `DenseMatrix`, 크기 불일치 시 에러
60    ///
61    /// # Errors
62    ///
63    /// 비용 배열의 길이가 `lsize * rsize`와 일치하지 않으면 에러를 반환합니다.
64    pub fn from_vec(lsize: usize, rsize: usize, costs: Vec<i16>) -> Result<Self> {
65        let expected_size = lsize * rsize;
66        if costs.len() != expected_size {
67            return Err(DictError::Format(format!(
68                "Matrix size mismatch: expected {} entries, got {}",
69                expected_size,
70                costs.len()
71            )));
72        }
73        Ok(Self {
74            lsize,
75            rsize,
76            costs,
77        })
78    }
79
80    /// 비용 설정
81    ///
82    /// # Arguments
83    ///
84    /// * `right_id` - 우문맥 ID
85    /// * `left_id` - 좌문맥 ID
86    /// * `cost` - 비용 값
87    pub fn set(&mut self, right_id: u16, left_id: u16, cost: i16) {
88        let index = right_id as usize + self.lsize * left_id as usize;
89        if index < self.costs.len() {
90            self.costs[index] = cost;
91        }
92    }
93
94    /// 텍스트 파일(matrix.def)에서 로드
95    ///
96    /// # Format
97    ///
98    /// ```text
99    /// <lsize> <rsize>
100    /// <right_id> <left_id> <cost>
101    /// ...
102    /// ```
103    ///
104    /// # Arguments
105    ///
106    /// * `path` - matrix.def 파일 경로
107    ///
108    /// # Errors
109    ///
110    /// 파일을 읽을 수 없거나 형식이 잘못된 경우 에러를 반환합니다.
111    pub fn from_def_file<P: AsRef<Path>>(path: P) -> Result<Self> {
112        let file = std::fs::File::open(path.as_ref()).map_err(DictError::Io)?;
113        let reader = BufReader::new(file);
114        Self::from_def_reader(reader)
115    }
116
117    /// 텍스트 리더에서 로드
118    ///
119    /// # Errors
120    ///
121    /// 리더에서 데이터를 읽을 수 없거나 형식이 잘못된 경우 에러를 반환합니다.
122    pub fn from_def_reader<R: BufRead>(mut reader: R) -> Result<Self> {
123        // 첫 줄: 크기 정보
124        let mut first_line = String::new();
125        reader.read_line(&mut first_line).map_err(DictError::Io)?;
126
127        let sizes: Vec<usize> = first_line
128            .split_whitespace()
129            .filter_map(|s| s.parse().ok())
130            .collect();
131
132        if sizes.len() != 2 {
133            return Err(DictError::Format(
134                "Invalid matrix header: expected 'lsize rsize'".to_string(),
135            ));
136        }
137
138        let lsize = sizes[0];
139        let rsize = sizes[1];
140
141        // 기본값으로 초기화 (i16::MAX는 연결 불가능을 의미)
142        let mut matrix = Self::new(lsize, rsize, i16::MAX);
143
144        // 나머지 줄: 연접 비용
145        for line in reader.lines() {
146            let line = line.map_err(DictError::Io)?;
147            let line = line.trim();
148
149            if line.is_empty() || line.starts_with('#') {
150                continue;
151            }
152
153            let parts: Vec<&str> = line.split_whitespace().collect();
154            if parts.len() != 3 {
155                continue;
156            }
157
158            let right_id: u16 = parts[0]
159                .parse()
160                .map_err(|_| DictError::Format(format!("Invalid right_id: {}", parts[0])))?;
161            let left_id: u16 = parts[1]
162                .parse()
163                .map_err(|_| DictError::Format(format!("Invalid left_id: {}", parts[1])))?;
164            let cost: i16 = parts[2]
165                .parse()
166                .map_err(|_| DictError::Format(format!("Invalid cost: {}", parts[2])))?;
167
168            matrix.set(right_id, left_id, cost);
169        }
170
171        Ok(matrix)
172    }
173
174    /// 바이너리 파일(matrix.bin)에서 로드
175    ///
176    /// # Format
177    ///
178    /// ```text
179    /// [2 bytes] lsize (little-endian u16)
180    /// [2 bytes] rsize (little-endian u16)
181    /// [lsize * rsize * 2 bytes] costs (little-endian i16 array)
182    /// ```
183    ///
184    /// # Errors
185    ///
186    /// 파일을 읽을 수 없거나 형식이 잘못된 경우 에러를 반환합니다.
187    pub fn from_bin_file<P: AsRef<Path>>(path: P) -> Result<Self> {
188        let data = std::fs::read(path.as_ref()).map_err(DictError::Io)?;
189        Self::from_bin_bytes(&data)
190    }
191
192    /// 바이너리 바이트에서 로드
193    ///
194    /// # Errors
195    ///
196    /// 데이터가 유효한 바이너리 형식이 아닌 경우 에러를 반환합니다.
197    pub fn from_bin_bytes(data: &[u8]) -> Result<Self> {
198        let header = parse_matrix_header(data)?;
199
200        let expected_size = header.lsize * header.rsize * 2;
201        let data_size = data.len() - header.header_size;
202
203        if data_size != expected_size {
204            return Err(DictError::Format(format!(
205                "Matrix data size mismatch: expected {expected_size} bytes, got {data_size}"
206            )));
207        }
208
209        let mut cursor = io::Cursor::new(data);
210        cursor.set_position(header.header_size as u64);
211
212        let mut costs = Vec::with_capacity(header.lsize * header.rsize);
213        for _ in 0..(header.lsize * header.rsize) {
214            costs.push(cursor.read_i16::<LittleEndian>().map_err(DictError::Io)?);
215        }
216
217        Ok(Self {
218            lsize: header.lsize,
219            rsize: header.rsize,
220            costs,
221        })
222    }
223
224    /// v3 포맷(MKM3)으로 직렬화
225    #[must_use]
226    pub fn to_bin_bytes_v3(&self) -> Vec<u8> {
227        let mut buf = Vec::with_capacity(MKM3_HEADER_SIZE + self.costs.len() * 2);
228
229        buf.extend_from_slice(MKM3_MAGIC);
230        buf.push(1);
231        buf.push(0);
232        buf.write_u16::<LittleEndian>(0).ok();
233        #[allow(clippy::cast_possible_truncation)]
234        buf.write_u32::<LittleEndian>(self.lsize as u32).ok();
235        #[allow(clippy::cast_possible_truncation)]
236        buf.write_u32::<LittleEndian>(self.rsize as u32).ok();
237
238        for &cost in &self.costs {
239            buf.write_i16::<LittleEndian>(cost).ok();
240        }
241
242        buf
243    }
244
245    /// 압축된 바이너리 파일(matrix.bin.zst)에서 로드
246    ///
247    /// # Errors
248    ///
249    /// 파일을 읽거나 압축 해제할 수 없는 경우 에러를 반환합니다.
250    #[cfg(feature = "zstd")]
251    pub fn from_compressed_file<P: AsRef<Path>>(path: P) -> Result<Self> {
252        let file = std::fs::File::open(path.as_ref()).map_err(DictError::Io)?;
253        let decoder = zstd::Decoder::new(file).map_err(DictError::Io)?;
254        let mut data = Vec::new();
255        BufReader::new(decoder)
256            .read_to_end(&mut data)
257            .map_err(DictError::Io)?;
258        Self::from_bin_bytes(&data)
259    }
260
261    /// 압축된 바이너리 파일에서 로드 (zstd feature 비활성화 시)
262    ///
263    /// # Errors
264    ///
265    /// zstd feature가 비활성화된 경우 항상 에러를 반환합니다.
266    #[cfg(not(feature = "zstd"))]
267    pub fn from_compressed_file<P: AsRef<Path>>(_path: P) -> Result<Self> {
268        Err(DictError::Format(
269            "zstd feature is not enabled. Use uncompressed files or enable the 'zstd' feature."
270                .to_string(),
271        ))
272    }
273
274    /// 바이너리 형식으로 저장
275    #[must_use]
276    pub fn to_bin_bytes(&self) -> Vec<u8> {
277        let mut buf = Vec::with_capacity(MATRIX_HEADER_SIZE + self.costs.len() * 2);
278
279        // 헤더
280        #[allow(clippy::cast_possible_truncation)]
281        buf.write_u16::<LittleEndian>(self.lsize as u16).ok();
282        #[allow(clippy::cast_possible_truncation)]
283        buf.write_u16::<LittleEndian>(self.rsize as u16).ok();
284
285        // 데이터
286        for &cost in &self.costs {
287            buf.write_i16::<LittleEndian>(cost).ok();
288        }
289
290        buf
291    }
292
293    /// 바이너리 파일로 저장
294    ///
295    /// # Errors
296    ///
297    /// 파일을 쓸 수 없는 경우 에러를 반환합니다.
298    pub fn to_bin_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
299        let data = self.to_bin_bytes();
300        std::fs::write(path.as_ref(), data).map_err(DictError::Io)
301    }
302
303    /// 압축된 바이너리 파일로 저장
304    ///
305    /// # Errors
306    ///
307    /// 파일을 쓰거나 압축할 수 없는 경우 에러를 반환합니다.
308    #[cfg(feature = "zstd")]
309    pub fn to_compressed_file<P: AsRef<Path>>(&self, path: P, level: i32) -> Result<()> {
310        let data = self.to_bin_bytes();
311        let file = std::fs::File::create(path.as_ref()).map_err(DictError::Io)?;
312        let mut encoder = zstd::Encoder::new(file, level).map_err(DictError::Io)?;
313        encoder.write_all(&data).map_err(DictError::Io)?;
314        encoder.finish().map_err(DictError::Io)?;
315        Ok(())
316    }
317
318    /// 압축된 바이너리 파일로 저장 (zstd feature 비활성화 시)
319    ///
320    /// # Errors
321    ///
322    /// zstd feature가 비활성화된 경우 항상 에러를 반환합니다.
323    #[cfg(not(feature = "zstd"))]
324    pub fn to_compressed_file<P: AsRef<Path>>(&self, _path: P, _level: i32) -> Result<()> {
325        Err(DictError::Format(
326            "zstd feature is not enabled. Use uncompressed files or enable the 'zstd' feature."
327                .to_string(),
328        ))
329    }
330
331    /// 원본 비용 배열 참조
332    #[must_use]
333    pub fn costs(&self) -> &[i16] {
334        &self.costs
335    }
336
337    /// 메모리 사용량 (바이트)
338    #[must_use]
339    pub fn memory_size(&self) -> usize {
340        std::mem::size_of::<Self>() + self.costs.len() * std::mem::size_of::<i16>()
341    }
342}
343
344impl Matrix for DenseMatrix {
345    #[inline(always)]
346    fn get(&self, right_id: u16, left_id: u16) -> i32 {
347        let index = right_id as usize + self.lsize * left_id as usize;
348        if index < self.costs.len() {
349            i32::from(self.costs[index])
350        } else {
351            INVALID_CONNECTION_COST
352        }
353    }
354
355    fn left_size(&self) -> usize {
356        self.lsize
357    }
358
359    fn right_size(&self) -> usize {
360        self.rsize
361    }
362}