mecab_ko_dict/matrix/
dense.rs1use std::io::{self, BufRead, BufReader};
4#[cfg(feature = "zstd")]
5use std::io::{Read, Write as IoWrite};
6use std::path::Path;
7
8use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
9
10use crate::error::{DictError, Result};
11
12use super::{
13 parse_matrix_header, Matrix, INVALID_CONNECTION_COST, MATRIX_HEADER_SIZE, MKM3_HEADER_SIZE,
14 MKM3_MAGIC,
15};
16
17#[derive(Debug, Clone)]
22pub struct DenseMatrix {
23 pub(super) lsize: usize,
25 pub(super) rsize: usize,
27 pub(super) costs: Vec<i16>,
29}
30
31impl DenseMatrix {
32 #[must_use]
40 pub fn new(lsize: usize, rsize: usize, default_cost: i16) -> Self {
41 let costs = vec![default_cost; lsize * rsize];
42 Self {
43 lsize,
44 rsize,
45 costs,
46 }
47 }
48
49 pub fn from_vec(lsize: usize, rsize: usize, costs: Vec<i16>) -> Result<Self> {
65 let expected_size = lsize * rsize;
66 if costs.len() != expected_size {
67 return Err(DictError::Format(format!(
68 "Matrix size mismatch: expected {} entries, got {}",
69 expected_size,
70 costs.len()
71 )));
72 }
73 Ok(Self {
74 lsize,
75 rsize,
76 costs,
77 })
78 }
79
80 pub fn set(&mut self, right_id: u16, left_id: u16, cost: i16) {
88 let index = right_id as usize + self.lsize * left_id as usize;
89 if index < self.costs.len() {
90 self.costs[index] = cost;
91 }
92 }
93
94 pub fn from_def_file<P: AsRef<Path>>(path: P) -> Result<Self> {
112 let file = std::fs::File::open(path.as_ref()).map_err(DictError::Io)?;
113 let reader = BufReader::new(file);
114 Self::from_def_reader(reader)
115 }
116
117 pub fn from_def_reader<R: BufRead>(mut reader: R) -> Result<Self> {
123 let mut first_line = String::new();
125 reader.read_line(&mut first_line).map_err(DictError::Io)?;
126
127 let sizes: Vec<usize> = first_line
128 .split_whitespace()
129 .filter_map(|s| s.parse().ok())
130 .collect();
131
132 if sizes.len() != 2 {
133 return Err(DictError::Format(
134 "Invalid matrix header: expected 'lsize rsize'".to_string(),
135 ));
136 }
137
138 let lsize = sizes[0];
139 let rsize = sizes[1];
140
141 let mut matrix = Self::new(lsize, rsize, i16::MAX);
143
144 for line in reader.lines() {
146 let line = line.map_err(DictError::Io)?;
147 let line = line.trim();
148
149 if line.is_empty() || line.starts_with('#') {
150 continue;
151 }
152
153 let parts: Vec<&str> = line.split_whitespace().collect();
154 if parts.len() != 3 {
155 continue;
156 }
157
158 let right_id: u16 = parts[0]
159 .parse()
160 .map_err(|_| DictError::Format(format!("Invalid right_id: {}", parts[0])))?;
161 let left_id: u16 = parts[1]
162 .parse()
163 .map_err(|_| DictError::Format(format!("Invalid left_id: {}", parts[1])))?;
164 let cost: i16 = parts[2]
165 .parse()
166 .map_err(|_| DictError::Format(format!("Invalid cost: {}", parts[2])))?;
167
168 matrix.set(right_id, left_id, cost);
169 }
170
171 Ok(matrix)
172 }
173
174 pub fn from_bin_file<P: AsRef<Path>>(path: P) -> Result<Self> {
188 let data = std::fs::read(path.as_ref()).map_err(DictError::Io)?;
189 Self::from_bin_bytes(&data)
190 }
191
192 pub fn from_bin_bytes(data: &[u8]) -> Result<Self> {
198 let header = parse_matrix_header(data)?;
199
200 let expected_size = header.lsize * header.rsize * 2;
201 let data_size = data.len() - header.header_size;
202
203 if data_size != expected_size {
204 return Err(DictError::Format(format!(
205 "Matrix data size mismatch: expected {expected_size} bytes, got {data_size}"
206 )));
207 }
208
209 let mut cursor = io::Cursor::new(data);
210 cursor.set_position(header.header_size as u64);
211
212 let mut costs = Vec::with_capacity(header.lsize * header.rsize);
213 for _ in 0..(header.lsize * header.rsize) {
214 costs.push(cursor.read_i16::<LittleEndian>().map_err(DictError::Io)?);
215 }
216
217 Ok(Self {
218 lsize: header.lsize,
219 rsize: header.rsize,
220 costs,
221 })
222 }
223
224 #[must_use]
226 pub fn to_bin_bytes_v3(&self) -> Vec<u8> {
227 let mut buf = Vec::with_capacity(MKM3_HEADER_SIZE + self.costs.len() * 2);
228
229 buf.extend_from_slice(MKM3_MAGIC);
230 buf.push(1);
231 buf.push(0);
232 buf.write_u16::<LittleEndian>(0).ok();
233 #[allow(clippy::cast_possible_truncation)]
234 buf.write_u32::<LittleEndian>(self.lsize as u32).ok();
235 #[allow(clippy::cast_possible_truncation)]
236 buf.write_u32::<LittleEndian>(self.rsize as u32).ok();
237
238 for &cost in &self.costs {
239 buf.write_i16::<LittleEndian>(cost).ok();
240 }
241
242 buf
243 }
244
245 #[cfg(feature = "zstd")]
251 pub fn from_compressed_file<P: AsRef<Path>>(path: P) -> Result<Self> {
252 let file = std::fs::File::open(path.as_ref()).map_err(DictError::Io)?;
253 let decoder = zstd::Decoder::new(file).map_err(DictError::Io)?;
254 let mut data = Vec::new();
255 BufReader::new(decoder)
256 .read_to_end(&mut data)
257 .map_err(DictError::Io)?;
258 Self::from_bin_bytes(&data)
259 }
260
261 #[cfg(not(feature = "zstd"))]
267 pub fn from_compressed_file<P: AsRef<Path>>(_path: P) -> Result<Self> {
268 Err(DictError::Format(
269 "zstd feature is not enabled. Use uncompressed files or enable the 'zstd' feature."
270 .to_string(),
271 ))
272 }
273
274 #[must_use]
276 pub fn to_bin_bytes(&self) -> Vec<u8> {
277 let mut buf = Vec::with_capacity(MATRIX_HEADER_SIZE + self.costs.len() * 2);
278
279 #[allow(clippy::cast_possible_truncation)]
281 buf.write_u16::<LittleEndian>(self.lsize as u16).ok();
282 #[allow(clippy::cast_possible_truncation)]
283 buf.write_u16::<LittleEndian>(self.rsize as u16).ok();
284
285 for &cost in &self.costs {
287 buf.write_i16::<LittleEndian>(cost).ok();
288 }
289
290 buf
291 }
292
293 pub fn to_bin_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
299 let data = self.to_bin_bytes();
300 std::fs::write(path.as_ref(), data).map_err(DictError::Io)
301 }
302
303 #[cfg(feature = "zstd")]
309 pub fn to_compressed_file<P: AsRef<Path>>(&self, path: P, level: i32) -> Result<()> {
310 let data = self.to_bin_bytes();
311 let file = std::fs::File::create(path.as_ref()).map_err(DictError::Io)?;
312 let mut encoder = zstd::Encoder::new(file, level).map_err(DictError::Io)?;
313 encoder.write_all(&data).map_err(DictError::Io)?;
314 encoder.finish().map_err(DictError::Io)?;
315 Ok(())
316 }
317
318 #[cfg(not(feature = "zstd"))]
324 pub fn to_compressed_file<P: AsRef<Path>>(&self, _path: P, _level: i32) -> Result<()> {
325 Err(DictError::Format(
326 "zstd feature is not enabled. Use uncompressed files or enable the 'zstd' feature."
327 .to_string(),
328 ))
329 }
330
331 #[must_use]
333 pub fn costs(&self) -> &[i16] {
334 &self.costs
335 }
336
337 #[must_use]
339 pub fn memory_size(&self) -> usize {
340 std::mem::size_of::<Self>() + self.costs.len() * std::mem::size_of::<i16>()
341 }
342}
343
344impl Matrix for DenseMatrix {
345 #[inline(always)]
346 fn get(&self, right_id: u16, left_id: u16) -> i32 {
347 let index = right_id as usize + self.lsize * left_id as usize;
348 if index < self.costs.len() {
349 i32::from(self.costs[index])
350 } else {
351 INVALID_CONNECTION_COST
352 }
353 }
354
355 fn left_size(&self) -> usize {
356 self.lsize
357 }
358
359 fn right_size(&self) -> usize {
360 self.rsize
361 }
362}