diff --git a/src/merge/mod.rs b/src/merge/mod.rs new file mode 100644 index 0000000..30fabe1 --- /dev/null +++ b/src/merge/mod.rs @@ -0,0 +1,331 @@ +use crate::cli::MergeStrategy; +use anyhow::{Context, Result}; +use dialoguer::{Select, Editor}; +use std::fs; +use std::path::{Path, PathBuf}; +use std::io::{self, Write, Read}; + +#[derive(Debug, Clone)] +pub struct Conflict { + pub path: PathBuf, + pub base_content: Option, + pub local_content: Option, + pub remote_content: Option, +} + +#[derive(Debug, Clone)] +pub enum MergeResult { + Success, + Conflict(Vec), + AutoMerged, +} + +pub struct Merger { + conflicts: Vec, +} + +impl Merger { + pub fn new() -> Self { + Self { + conflicts: Vec::new(), + } + } + + pub fn diff(&self, local: &PathBuf, remote: &PathBuf) -> Result> { + let local_content = self.read_file(local)?; + let remote_content = self.read_file(remote)?; + + self.compute_diff(&local_content, &remote_content) + } + + pub fn three_way_merge( + &mut self, + base: &PathBuf, + local: &PathBuf, + remote: &PathBuf, + output: &PathBuf, + strategy: Option, + ) -> Result { + let base_content = if base.exists() { Some(self.read_file(base)?) } else { None }; + let local_content = self.read_file(local)?; + let remote_content = self.read_file(remote)?; + + let conflicts = self.find_conflicts(&base_content, &local_content, &remote_content); + + if !conflicts.is_empty() { + self.conflicts = conflicts; + + match strategy { + Some(MergeStrategy::Ours) => { + fs::write(output, local_content)?; + Ok(MergeResult::Conflict(self.conflicts.clone())) + } + Some(MergeStrategy::Theirs) => { + fs::write(output, remote_content)?; + Ok(MergeResult::Conflict(self.conflicts.clone())) + } + Some(MergeStrategy::Ask) => { + self.resolve_conflicts_interactively(output, &local_content, &remote_content) + } + Some(MergeStrategy::Diff3) | None => { + self.three_way_merge_content(base_content, local_content, remote_content, output) + } + } + } else { + self.auto_merge(base_content, local_content, remote_content, output) + } + } + + fn three_way_merge_content( + &mut self, + base_content: Option, + local_content: String, + remote_content: String, + output: &PathBuf, + ) -> Result { + let mut lines: Vec = Vec::new(); + + let base_lines = base_content.as_ref() + .map(|s| s.lines().map(String::from).collect::>()) + .unwrap_or_default(); + let local_lines: Vec<&str> = local_content.lines().collect(); + let remote_lines: Vec<&str> = remote_content.lines().collect(); + + let base_len = base_lines.len(); + let local_len = local_lines.len(); + let remote_len = remote_lines.len(); + + let mut i = 0; + let mut j = 0; + let mut k = 0; + + while i < base_len || j < local_len || k < remote_len { + let base_line = base_lines.get(i).map(String::as_str); + let local_line = local_lines.get(j).copied(); + let remote_line = remote_lines.get(k).copied(); + + if local_line == remote_line { + if let Some(line) = local_line { + lines.push(line.to_string()); + } + i += 1; + j += 1; + k += 1; + } else if base_line == local_line { + if let Some(line) = remote_line { + lines.push(format!("+ {}", line)); + } + i += 1; + k += 1; + } else if base_line == remote_line { + if let Some(line) = local_line { + lines.push(format!("+ {}", line)); + } + i += 1; + j += 1; + } else { + if let Some(line) = local_line { + lines.push(format!("+ {}", line)); + } + if let Some(line) = remote_line { + lines.push(format!("+ {}", line)); + } + self.conflicts.push(Conflict { + path: output.clone(), + base_content: base_line.map(String::from), + local_content: local_line.map(String::from), + remote_content: remote_line.map(String::from), + }); + i += 1; + j += 1; + k += 1; + } + } + + fs::write(output, lines.join("\n"))?; + + if self.conflicts.is_empty() { + Ok(MergeResult::AutoMerged) + } else { + Ok(MergeResult::Conflict(self.conflicts.clone())) + } + } + + fn auto_merge( + &mut self, + base_content: Option, + local_content: String, + remote_content: String, + output: &PathBuf, + ) -> Result { + let base_lines = base_content.as_ref() + .map(|s| s.lines().map(String::from).collect::>()) + .unwrap_or_default(); + + let merged = self.merge_lines(&base_lines, &local_content, &remote_content); + + fs::write(output, merged)?; + + Ok(MergeResult::AutoMerged) + } + + fn merge_lines(&self, base: &[String], local: &str, remote: &str) -> String { + let mut result = String::new(); + + let base_lines: Vec<&str> = base.iter().map(|s| s.as_str()).collect(); + let local_lines: Vec<&str> = local.lines().collect(); + let remote_lines: Vec<&str> = remote.lines().collect(); + + let mut local_set: std::collections::HashSet<&str> = local_lines.iter().copied().collect(); + let mut remote_set: std::collections::HashSet<&str> = remote_lines.iter().copied().collect(); + + for line in &local_set { + if !remote_set.contains(line) && !base_lines.contains(line) { + result.push_str(line); + result.push('\n'); + } + } + + for line in &remote_set { + if !local_set.contains(line) && !base_lines.contains(line) { + result.push_str(line); + result.push('\n'); + } + } + + result + } + + fn find_conflicts( + &mut self, + base: &Option, + local: &String, + remote: &String, + ) -> Vec { + let mut conflicts = Vec::new(); + + if base.is_none() { + return conflicts; + } + + let base_lines: Vec<&str> = base.as_ref().unwrap().lines().collect(); + let local_lines: Vec<&str> = local.lines().collect(); + let remote_lines: Vec<&str> = remote.lines().collect(); + + for i in 0..std::cmp::min(std::cmp::min(base_lines.len(), local_lines.len()), remote_lines.len()) { + if base_lines[i] != local_lines[i] && base_lines[i] != remote_lines[i] && local_lines[i] != remote_lines[i] { + conflicts.push(Conflict { + path: PathBuf::from("conflict"), + base_content: Some(base_lines[i].to_string()), + local_content: Some(local_lines[i].to_string()), + remote_content: Some(remote_lines[i].to_string()), + }); + } + } + + conflicts + } + + fn resolve_conflicts_interactively( + &mut self, + output: &PathBuf, + local_content: &String, + remote_content: &String, + ) -> Result { + let conflict_text = format!( + "<<<<<<< LOCAL\n{}\n=======\n{}\n>>>>>>> REMOTE\n", + local_content, remote_content + ); + + println!("\nConflict detected in {:?}", output); + + let options = &[ + "Keep local version", + "Keep remote version", + "Edit merge manually", + "Abort merge", + ]; + + let selection = Select::new() + .with_prompt("How would you like to resolve this conflict?") + .items(options) + .default(0) + .interact()?; + + match selection { + 0 => { + fs::write(output, local_content)?; + } + 1 => { + fs::write(output, remote_content)?; + } + 2 => { + if Editor::new() + .edit(&conflict_text) + .unwrap_or(Some(conflict_text.clone())) + .map(|content| fs::write(output, &content)) + .is_err() + { + fs::write(output, &conflict_text)?; + } + } + 3 => { + anyhow::bail!("Merge aborted by user"); + } + _ => { + fs::write(output, &conflict_text)?; + } + } + + if self.conflicts.is_empty() { + Ok(MergeResult::Success) + } else { + Ok(MergeResult::Conflict(self.conflicts.clone())) + } + } + + fn compute_diff(&self, a: &str, b: &str) -> Result> { + let mut diff = Vec::new(); + + let a_lines: Vec<&str> = a.lines().collect(); + let b_lines: Vec<&str> = b.lines().collect(); + + let mut i = 0; + let mut j = 0; + + while i < a_lines.len() || j < b_lines.len() { + if i >= a_lines.len() { + diff.push(format!("+ {}", b_lines[j])); + j += 1; + } else if j >= b_lines.len() { + diff.push(format!("- {}", a_lines[i])); + i += 1; + } else if a_lines[i] == b_lines[j] { + diff.push(format!(" {}", a_lines[i])); + i += 1; + j += 1; + } else { + diff.push(format!("- {}", a_lines[i])); + diff.push(format!("+ {}", b_lines[j])); + i += 1; + j += 1; + } + } + + Ok(diff) + } + + fn read_file(&self, path: &PathBuf) -> Result { + let mut file = fs::File::open(path) + .with_context(|| format!("Failed to open {:?}", path))?; + let mut content = String::new(); + file.read_to_string(&mut content)?; + Ok(content) + } +} + +impl Default for Merger { + fn default() -> Self { + Self::new() + } +}