1use std::borrow::Cow;
49
50pub type NodeId = u32;
52
53pub const BOS_NODE_ID: NodeId = 0;
55pub const INVALID_NODE_ID: NodeId = u32::MAX;
57
58pub const BOS_CONTEXT_ID: u16 = 0;
60
61pub const EOS_CONTEXT_ID: u16 = 0;
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
66pub enum NodeType {
67 Bos,
69 Eos,
71 #[default]
73 Known,
74 Unknown,
76 User,
78}
79
80#[derive(Debug, Clone)]
84pub struct Node {
85 pub id: NodeId,
87
88 pub surface: Cow<'static, str>,
90
91 pub start_pos: usize,
93
94 pub end_pos: usize,
96
97 pub start_byte: usize,
99
100 pub end_byte: usize,
102
103 pub left_id: u16,
105
106 pub right_id: u16,
108
109 pub word_cost: i32,
111
112 pub total_cost: i32,
114
115 pub prev_node_id: NodeId,
117
118 pub node_type: NodeType,
120
121 pub feature: Cow<'static, str>,
123
124 pub has_space_before: bool,
126}
127
128impl Node {
129 #[must_use]
131 pub const fn bos() -> Self {
132 Self {
133 id: BOS_NODE_ID,
134 surface: Cow::Borrowed("BOS"),
135 start_pos: 0,
136 end_pos: 0,
137 start_byte: 0,
138 end_byte: 0,
139 left_id: BOS_CONTEXT_ID,
140 right_id: BOS_CONTEXT_ID,
141 word_cost: 0,
142 total_cost: 0,
143 prev_node_id: INVALID_NODE_ID,
144 node_type: NodeType::Bos,
145 feature: Cow::Borrowed("BOS/EOS,*,*,*,*,*,*,*"),
146 has_space_before: false,
147 }
148 }
149
150 #[must_use]
152 pub const fn eos(id: NodeId, char_len: usize, byte_len: usize) -> Self {
153 Self {
154 id,
155 surface: Cow::Borrowed("EOS"),
156 start_pos: char_len,
157 end_pos: char_len,
158 start_byte: byte_len,
159 end_byte: byte_len,
160 left_id: EOS_CONTEXT_ID,
161 right_id: EOS_CONTEXT_ID,
162 word_cost: 0,
163 total_cost: i32::MAX,
164 prev_node_id: INVALID_NODE_ID,
165 node_type: NodeType::Eos,
166 feature: Cow::Borrowed("BOS/EOS,*,*,*,*,*,*,*"),
167 has_space_before: false,
168 }
169 }
170
171 #[inline]
173 #[must_use]
174 pub fn is_bos(&self) -> bool {
175 self.node_type == NodeType::Bos
176 }
177
178 #[inline]
180 #[must_use]
181 pub fn is_eos(&self) -> bool {
182 self.node_type == NodeType::Eos
183 }
184
185 #[inline]
187 #[must_use]
188 pub const fn char_len(&self) -> usize {
189 self.end_pos - self.start_pos
190 }
191
192 #[inline]
194 #[must_use]
195 pub const fn byte_len(&self) -> usize {
196 self.end_byte - self.start_byte
197 }
198}
199
200#[derive(Debug, Clone)]
202pub struct NodeBuilder {
203 surface: String,
204 start_pos: usize,
205 end_pos: usize,
206 start_byte: usize,
207 end_byte: usize,
208 left_id: u16,
209 right_id: u16,
210 word_cost: i32,
211 node_type: NodeType,
212 feature: String,
213 has_space_before: bool,
214}
215
216impl NodeBuilder {
217 #[must_use]
225 pub fn new(surface: &str, start_pos: usize, end_pos: usize) -> Self {
226 Self {
227 surface: surface.to_string(),
228 start_pos,
229 end_pos,
230 start_byte: 0,
231 end_byte: 0,
232 left_id: 0,
233 right_id: 0,
234 word_cost: 0,
235 node_type: NodeType::Known,
236 feature: String::new(),
237 has_space_before: false,
238 }
239 }
240
241 #[must_use]
243 pub const fn byte_positions(mut self, start: usize, end: usize) -> Self {
244 self.start_byte = start;
245 self.end_byte = end;
246 self
247 }
248
249 #[must_use]
251 pub const fn left_id(mut self, id: u16) -> Self {
252 self.left_id = id;
253 self
254 }
255
256 #[must_use]
258 pub const fn right_id(mut self, id: u16) -> Self {
259 self.right_id = id;
260 self
261 }
262
263 #[must_use]
265 pub const fn word_cost(mut self, cost: i32) -> Self {
266 self.word_cost = cost;
267 self
268 }
269
270 #[must_use]
272 pub const fn node_type(mut self, node_type: NodeType) -> Self {
273 self.node_type = node_type;
274 self
275 }
276
277 #[must_use]
279 pub fn feature(mut self, feature: &str) -> Self {
280 self.feature = feature.to_string();
281 self
282 }
283
284 #[must_use]
286 pub const fn has_space_before(mut self, value: bool) -> Self {
287 self.has_space_before = value;
288 self
289 }
290
291 #[must_use]
293 pub const fn build(self) -> Self {
294 self
295 }
296}
297
298#[derive(Debug, Clone)]
302pub struct CharPositions {
303 char_to_byte: Vec<usize>,
305 total_bytes: usize,
307}
308
309impl CharPositions {
310 #[must_use]
312 pub fn new(text: &str) -> Self {
313 let mut char_to_byte = Vec::with_capacity(text.chars().count() + 1);
314 let mut byte_pos = 0;
315
316 for c in text.chars() {
317 char_to_byte.push(byte_pos);
318 byte_pos += c.len_utf8();
319 }
320 char_to_byte.push(byte_pos); Self {
323 char_to_byte,
324 total_bytes: byte_pos,
325 }
326 }
327
328 #[inline]
330 #[must_use]
331 pub fn char_to_byte(&self, char_pos: usize) -> usize {
332 self.char_to_byte
333 .get(char_pos)
334 .copied()
335 .unwrap_or(self.total_bytes)
336 }
337
338 #[inline]
340 #[must_use]
341 pub fn char_count(&self) -> usize {
342 if self.char_to_byte.is_empty() {
343 0
344 } else {
345 self.char_to_byte.len() - 1
346 }
347 }
348
349 #[inline]
354 #[must_use]
355 pub fn byte_to_char(&self, byte_pos: usize) -> usize {
356 self.char_to_byte
357 .binary_search(&byte_pos)
358 .unwrap_or_else(|_| self.char_count())
359 }
360
361 #[inline]
363 #[must_use]
364 pub const fn byte_count(&self) -> usize {
365 self.total_bytes
366 }
367}
368
369#[derive(Debug, Clone, Default)]
375pub struct SpacePositions {
376 positions: Vec<usize>,
378}
379
380impl SpacePositions {
381 #[must_use]
383 pub fn new(text: &str) -> Self {
384 let mut positions = Vec::new();
385 let mut char_pos = 0;
386 let mut prev_is_space = false;
387
388 for c in text.chars() {
389 if prev_is_space && !c.is_whitespace() {
390 positions.push(char_pos);
391 }
392 prev_is_space = c.is_whitespace();
393 if !c.is_whitespace() {
394 char_pos += 1;
395 }
396 }
397
398 Self { positions }
400 }
401
402 #[inline]
404 #[must_use]
405 pub fn has_space_before(&self, char_pos: usize) -> bool {
406 self.positions.binary_search(&char_pos).is_ok()
407 }
408}
409
410#[derive(Debug)]
419pub struct Lattice {
420 original_text: String,
422
423 text: String,
425
426 char_positions: CharPositions,
428
429 space_positions: SpacePositions,
431
432 stripped_to_original_byte: Vec<usize>,
435
436 nodes: Vec<Node>,
439
440 ends_at: Vec<Vec<NodeId>>,
443
444 starts_at: Vec<Vec<NodeId>>,
447
448 bos_id: NodeId,
450
451 eos_id: NodeId,
453}
454
455impl Lattice {
456 #[must_use]
471 pub fn new(text: &str) -> Self {
472 let original_text = text.to_string();
474 let text_no_space: String = text.chars().filter(|c| !c.is_whitespace()).collect();
475
476 let char_positions = CharPositions::new(&text_no_space);
477 let space_positions = SpacePositions::new(text);
478 let stripped_to_original_byte = Self::build_stripped_to_original(text);
479
480 let char_len = char_positions.char_count();
481 let byte_len = char_positions.byte_count();
482
483 let bos = Node::bos();
485 let bos_id = bos.id;
486
487 let eos_id = 1;
489 let eos = Node::eos(eos_id, char_len, byte_len);
490
491 let nodes = vec![bos, eos];
493
494 let mut ends_at = vec![Vec::new(); char_len + 1];
496 let mut starts_at = vec![Vec::new(); char_len + 1];
497
498 ends_at[0].push(bos_id);
500 starts_at[char_len].push(eos_id);
502
503 Self {
504 original_text,
505 text: text_no_space,
506 char_positions,
507 space_positions,
508 stripped_to_original_byte,
509 nodes,
510 ends_at,
511 starts_at,
512 bos_id,
513 eos_id,
514 }
515 }
516
517 #[inline]
519 #[must_use]
520 pub fn text(&self) -> &str {
521 &self.text
522 }
523
524 #[inline]
526 #[must_use]
527 pub fn original_text(&self) -> &str {
528 &self.original_text
529 }
530
531 #[inline]
534 #[must_use]
535 pub fn original_byte_pos(&self, stripped_char_pos: usize) -> usize {
536 self.stripped_to_original_byte
537 .get(stripped_char_pos)
538 .copied()
539 .unwrap_or(self.original_text.len())
540 }
541
542 fn build_stripped_to_original(original: &str) -> Vec<usize> {
543 let mut map = Vec::new();
544 let mut byte_offset = 0;
545 for c in original.chars() {
546 if !c.is_whitespace() {
547 map.push(byte_offset);
548 }
549 byte_offset += c.len_utf8();
550 }
551 map.push(byte_offset);
552 map
553 }
554
555 #[inline]
557 #[must_use]
558 pub fn char_len(&self) -> usize {
559 self.char_positions.char_count()
560 }
561
562 #[inline]
569 #[must_use]
570 pub fn char_pos_from_start_and_byte_len(&self, start_pos: usize, byte_len: usize) -> usize {
571 let start_byte = self.char_positions.char_to_byte(start_pos);
572 self.char_positions.byte_to_char(start_byte + byte_len)
573 }
574
575 #[inline]
577 #[must_use]
578 pub const fn byte_len(&self) -> usize {
579 self.char_positions.byte_count()
580 }
581
582 #[inline]
584 #[must_use]
585 pub fn node_count(&self) -> usize {
586 self.nodes.len()
587 }
588
589 #[inline]
591 #[must_use]
592 pub fn bos(&self) -> &Node {
593 &self.nodes[self.bos_id as usize]
594 }
595
596 #[inline]
598 #[must_use]
599 pub fn eos(&self) -> &Node {
600 &self.nodes[self.eos_id as usize]
601 }
602
603 #[inline]
605 pub fn eos_mut(&mut self) -> &mut Node {
606 let eos_id = self.eos_id as usize;
607 &mut self.nodes[eos_id]
608 }
609
610 #[inline]
612 #[must_use]
613 pub fn node(&self, id: NodeId) -> Option<&Node> {
614 self.nodes.get(id as usize)
615 }
616
617 #[inline]
619 pub fn node_mut(&mut self, id: NodeId) -> Option<&mut Node> {
620 self.nodes.get_mut(id as usize)
621 }
622
623 #[inline]
625 pub fn nodes(&self) -> impl Iterator<Item = &Node> {
626 self.nodes.iter()
627 }
628
629 #[inline]
631 pub fn nodes_ending_at(&self, pos: usize) -> impl Iterator<Item = &Node> {
632 self.ends_at
633 .get(pos)
634 .map(|ids| ids.iter())
635 .into_iter()
636 .flatten()
637 .filter_map(|&id| self.nodes.get(id as usize))
638 }
639
640 #[inline]
642 pub fn nodes_starting_at(&self, pos: usize) -> impl Iterator<Item = &Node> {
643 self.starts_at
644 .get(pos)
645 .map(|ids| ids.iter())
646 .into_iter()
647 .flatten()
648 .filter_map(|&id| self.nodes.get(id as usize))
649 }
650
651 #[allow(clippy::cast_possible_truncation)]
661 pub fn add_node(&mut self, builder: NodeBuilder) -> NodeId {
662 let id = self.nodes.len() as NodeId;
663
664 let start_byte = self.char_positions.char_to_byte(builder.start_pos);
666 let end_byte = self.char_positions.char_to_byte(builder.end_pos);
667
668 let has_space_before =
670 builder.has_space_before || self.space_positions.has_space_before(builder.start_pos);
671
672 let node = Node {
673 id,
674 surface: Cow::Owned(builder.surface),
675 start_pos: builder.start_pos,
676 end_pos: builder.end_pos,
677 start_byte,
678 end_byte,
679 left_id: builder.left_id,
680 right_id: builder.right_id,
681 word_cost: builder.word_cost,
682 total_cost: i32::MAX, prev_node_id: INVALID_NODE_ID,
684 node_type: builder.node_type,
685 feature: Cow::Owned(builder.feature),
686 has_space_before,
687 };
688
689 if builder.start_pos < self.starts_at.len() {
691 self.starts_at[builder.start_pos].push(id);
692 }
693 if builder.end_pos < self.ends_at.len() {
694 self.ends_at[builder.end_pos].push(id);
695 }
696
697 self.nodes.push(node);
698 id
699 }
700
701 #[must_use]
703 pub fn substring(&self, start: usize, end: usize) -> &str {
704 let start_byte = self.char_positions.char_to_byte(start);
705 let end_byte = self.char_positions.char_to_byte(end);
706 &self.text[start_byte..end_byte]
707 }
708
709 #[inline]
711 #[must_use]
712 pub fn has_space_at(&self, char_pos: usize) -> bool {
713 self.space_positions.has_space_before(char_pos)
714 }
715
716 pub fn clear(&mut self) {
718 self.nodes.truncate(2);
720
721 for v in &mut self.ends_at {
723 v.clear();
724 }
725 for v in &mut self.starts_at {
726 v.clear();
727 }
728
729 if !self.ends_at.is_empty() {
731 self.ends_at[0].push(self.bos_id);
732 }
733 let char_len = self.char_len();
734 if char_len < self.starts_at.len() {
735 self.starts_at[char_len].push(self.eos_id);
736 }
737
738 if let Some(eos) = self.nodes.get_mut(self.eos_id as usize) {
740 eos.total_cost = i32::MAX;
741 eos.prev_node_id = INVALID_NODE_ID;
742 }
743 }
744
745 pub fn reset(&mut self, text: &str) {
747 self.original_text.clear();
750 self.original_text.push_str(text);
751
752 self.text.clear();
753 for c in text.chars().filter(|c| !c.is_whitespace()) {
754 self.text.push(c);
755 }
756
757 self.char_positions = CharPositions::new(&self.text);
758 self.stripped_to_original_byte = Self::build_stripped_to_original(text);
759 self.space_positions = SpacePositions::new(text);
760
761 let char_len = self.char_positions.char_count();
762 let byte_len = self.char_positions.byte_count();
763
764 let new_len = char_len + 1;
768 let old_ends_len = self.ends_at.len();
769 let old_starts_len = self.starts_at.len();
770
771 for v in self.ends_at.iter_mut().take(new_len.min(old_ends_len)) {
773 v.clear();
774 }
775 for v in self.starts_at.iter_mut().take(new_len.min(old_starts_len)) {
776 v.clear();
777 }
778
779 self.ends_at.truncate(new_len);
781 self.starts_at.truncate(new_len);
782 while self.ends_at.len() < new_len {
783 self.ends_at.push(Vec::new());
784 }
785 while self.starts_at.len() < new_len {
786 self.starts_at.push(Vec::new());
787 }
788
789 self.nodes.truncate(2);
791
792 if let Some(eos) = self.nodes.get_mut(self.eos_id as usize) {
794 eos.start_pos = char_len;
795 eos.end_pos = char_len;
796 eos.start_byte = byte_len;
797 eos.end_byte = byte_len;
798 eos.total_cost = i32::MAX;
799 eos.prev_node_id = INVALID_NODE_ID;
800 }
801
802 self.ends_at[0].push(self.bos_id);
804 self.starts_at[char_len].push(self.eos_id);
805 }
806
807 #[must_use]
811 pub fn best_path(&self) -> Vec<&Node> {
812 let mut path = Vec::new();
813 let mut current_id = self.eos_id;
814
815 while current_id != INVALID_NODE_ID {
816 if let Some(node) = self.nodes.get(current_id as usize) {
817 if node.node_type != NodeType::Bos && node.node_type != NodeType::Eos {
818 path.push(node);
819 }
820 current_id = node.prev_node_id;
821 } else {
822 break;
823 }
824 }
825
826 path.reverse();
827 path
828 }
829
830 #[cfg(test)]
832 #[must_use]
833 #[allow(clippy::format_push_string, clippy::uninlined_format_args)]
834 pub fn visualize(&self) -> String {
835 let mut output = String::new();
836 output.push_str(&format!("Lattice for: \"{}\"\n", self.text));
837 output.push_str(&format!("Nodes: {}\n", self.node_count()));
838
839 for pos in 0..=self.char_len() {
840 let ending: Vec<_> = self.nodes_ending_at(pos).collect();
841 if !ending.is_empty() {
842 output.push_str(&format!("\nPosition {}: ", pos));
843 for node in ending {
844 output.push_str(&format!(
845 "[{}: {} ({}-{})]",
846 node.id, node.surface, node.start_pos, node.end_pos
847 ));
848 }
849 }
850 }
851
852 output
853 }
854}
855
856#[derive(Debug, Clone, Default)]
858pub struct LatticeStats {
859 pub total_nodes: usize,
861 pub known_nodes: usize,
863 pub unknown_nodes: usize,
865 pub user_nodes: usize,
867 pub char_length: usize,
869}
870
871impl Lattice {
872 #[must_use]
874 pub fn stats(&self) -> LatticeStats {
875 let mut stats = LatticeStats {
876 total_nodes: self.nodes.len(),
877 char_length: self.char_len(),
878 ..Default::default()
879 };
880
881 for node in &self.nodes {
882 match node.node_type {
883 NodeType::Known => stats.known_nodes += 1,
884 NodeType::Unknown => stats.unknown_nodes += 1,
885 NodeType::User => stats.user_nodes += 1,
886 _ => {}
887 }
888 }
889
890 stats
891 }
892
893 #[must_use]
897 pub fn memory_usage(&self) -> usize {
898 let text_bytes = self.text.len() + self.original_text.len();
900
901 let nodes_bytes = self.nodes.capacity() * std::mem::size_of::<Node>();
903
904 let index_bytes = self.starts_at.capacity() * std::mem::size_of::<Vec<u32>>()
906 + self.ends_at.capacity() * std::mem::size_of::<Vec<u32>>()
907 + self
908 .starts_at
909 .iter()
910 .map(|v| v.capacity() * 4)
911 .sum::<usize>()
912 + self.ends_at.iter().map(|v| v.capacity() * 4).sum::<usize>();
913
914 let pos_bytes = (self.char_positions.char_count() + 1) * std::mem::size_of::<usize>();
916
917 let space_bytes = self.char_len() * std::mem::size_of::<usize>() / 10; let node_strings: usize = self
922 .nodes
923 .iter()
924 .map(|n| n.surface.len() + n.feature.len())
925 .sum();
926
927 text_bytes + nodes_bytes + index_bytes + pos_bytes + space_bytes + node_strings
928 }
929}
930
931#[cfg(test)]
932#[allow(clippy::unwrap_used, clippy::needless_collect)]
933mod tests {
934 use super::*;
935
936 #[test]
937 fn test_lattice_creation() {
938 let lattice = Lattice::new("안녕하세요");
939
940 assert_eq!(lattice.text(), "안녕하세요");
941 assert_eq!(lattice.char_len(), 5);
942 assert_eq!(lattice.node_count(), 2); }
944
945 #[test]
946 fn test_lattice_with_spaces() {
947 let lattice = Lattice::new("안녕 하세요");
948
949 assert_eq!(lattice.text(), "안녕하세요");
951 assert_eq!(lattice.original_text(), "안녕 하세요");
952 assert_eq!(lattice.char_len(), 5);
953
954 assert!(!lattice.has_space_at(0));
956 assert!(!lattice.has_space_at(1));
957 assert!(lattice.has_space_at(2)); }
959
960 #[test]
961 fn test_add_node() {
962 let mut lattice = Lattice::new("안녕하세요");
963
964 let node_id = lattice.add_node(
965 NodeBuilder::new("안녕", 0, 2)
966 .left_id(100)
967 .right_id(100)
968 .word_cost(1000)
969 .feature("NNG,*,F,안녕,*,*,*,*"),
970 );
971
972 assert_eq!(node_id, 2); assert_eq!(lattice.node_count(), 3);
974
975 let node = lattice.node(node_id).unwrap();
976 assert_eq!(node.surface.as_ref(), "안녕");
977 assert_eq!(node.start_pos, 0);
978 assert_eq!(node.end_pos, 2);
979 assert_eq!(node.left_id, 100);
980 assert_eq!(node.word_cost, 1000);
981 }
982
983 #[test]
984 fn test_nodes_at_position() {
985 let mut lattice = Lattice::new("안녕하세요");
986
987 lattice.add_node(NodeBuilder::new("안녕", 0, 2));
989 lattice.add_node(NodeBuilder::new("안", 0, 1));
991 lattice.add_node(NodeBuilder::new("녕하", 1, 3));
993
994 let starting_at_0: Vec<_> = lattice.nodes_starting_at(0).collect();
996 assert_eq!(starting_at_0.len(), 2); let ending_at_2: Vec<_> = lattice.nodes_ending_at(2).collect();
1000 assert_eq!(ending_at_2.len(), 1); }
1002
1003 #[test]
1004 fn test_char_positions() {
1005 let positions = CharPositions::new("한글test");
1006
1007 assert_eq!(positions.char_count(), 6);
1008 assert_eq!(positions.char_to_byte(0), 0); assert_eq!(positions.char_to_byte(1), 3); assert_eq!(positions.char_to_byte(2), 6); assert_eq!(positions.char_to_byte(3), 7); }
1013
1014 #[test]
1015 fn test_substring() {
1016 let lattice = Lattice::new("안녕하세요");
1017
1018 assert_eq!(lattice.substring(0, 2), "안녕");
1019 assert_eq!(lattice.substring(2, 5), "하세요");
1020 assert_eq!(lattice.substring(0, 5), "안녕하세요");
1021 }
1022
1023 #[test]
1024 fn test_bos_eos() {
1025 let lattice = Lattice::new("테스트");
1026
1027 let bos = lattice.bos();
1028 assert!(bos.is_bos());
1029 assert_eq!(bos.id, BOS_NODE_ID);
1030
1031 let eos = lattice.eos();
1032 assert!(eos.is_eos());
1033 assert_eq!(eos.start_pos, 3);
1034 }
1035
1036 #[test]
1037 fn test_lattice_reset() {
1038 let mut lattice = Lattice::new("안녕");
1039 lattice.add_node(NodeBuilder::new("안녕", 0, 2));
1040 assert_eq!(lattice.node_count(), 3);
1041
1042 lattice.reset("하세요");
1043 assert_eq!(lattice.text(), "하세요");
1044 assert_eq!(lattice.char_len(), 3);
1045 assert_eq!(lattice.node_count(), 2); }
1047
1048 #[test]
1049 fn test_space_before_detection() {
1050 let mut lattice = Lattice::new("아버지가 방에");
1051
1052 let node_id = lattice.add_node(NodeBuilder::new("방에", 4, 6));
1054
1055 let node = lattice.node(node_id).unwrap();
1056 assert!(node.has_space_before);
1057
1058 let node_id2 = lattice.add_node(NodeBuilder::new("아버지가", 0, 4));
1060 let node2 = lattice.node(node_id2).unwrap();
1061 assert!(!node2.has_space_before);
1062 }
1063}