/*
* Copyright (C) 2004-2005 Jonathan Bindel
* Copyright (C) 2006-2007 Eskil Bylund
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
*/
using System;
using System.Diagnostics;
using System.IO;
using System.Text.RegularExpressions;
using DCSharp.Backend.Connections;
using DCSharp.Backend.Managers;
using DCSharp.Backend.Objects;
using DCSharp.Extras;
using DCSharp.Hashing;
using FileSystem.IO.File;
using DirectorySystem.IO.Directory;
namespace DCSharp.Backend.Protocols.Nmdc{
public class UserProtocol : Protocol, IUserProtocol
{
[Flags]
public enum Features
{
MiniSlots = 0x01,
BZList = 0x02,
XmlBZList = 0x04,
ADCGet = 0x08,
GetZBlock = 0x10,
GetTestZBlock = 0x20,
ZLIG = 0x40,
TTHL = 0x80,
TTHF = 0x100,
CDM = 0x200
}
protected enum States
{
// Handshake
Nick,
Lock,
Direction,
Key,
// Upload
Get,
Send,
// Download
FileLength
}
protected enum Errors
{
FileNotAvailable,
InvalidCommand,
InvalidKey,
NoSlots
}
private Features features;
private Incoming incoming;
private TransferDirection direction;
// Download
private DownloadFileInfo download;
private SourceInfo source;
public UserProtocol(IUserProtocolHelper helper)
{
if (helper == null)
{
throw new ArgumentNullException("helper");
}
this.helper = helper;
incoming = new Incoming(this);
// Handshake
commands.Add("$mynick", new StringEventHandler(incoming.MyNick));
commands.Add("$lock", new StringEventHandler(incoming.Lock));
commands.Add("$supports", new StringEventHandler(incoming.Supports));
commands.Add("$direction", new StringEventHandler(incoming.Direction));
commands.Add("$key", new StringEventHandler(incoming.Key));
// Upload
commands.Add("$get", new StringEventHandler(incoming.Get));
commands.Add("$send", new BlankEventHandler(incoming.Send));
commands.Add("$getlistlen", new BlankEventHandler(incoming.GetListLen));
commands.Add("$getzblock", new StringEventHandler(incoming.GetZBlock));
commands.Add("$ugetblock", new StringEventHandler(incoming.UGetBlock));
commands.Add("$ugetzblock", new StringEventHandler(incoming.UGetZBlock));
commands.Add("$adcget", new StringEventHandler(incoming.AdcGet));
// Download
commands.Add("$filelength", new StringEventHandler(incoming.FileLength));
commands.Add("$maxedout", new BlankEventHandler(incoming.MaxedOut));
commands.Add("$sending", new StringEventHandler(incoming.Sending));
commands.Add("$failed", new StringEventHandler(incoming.Failed));
commands.Add("$adcsnd", new StringEventHandler(incoming.AdcSnd));
// Other
commands.Add("$error", new StringEventHandler(incoming.Error));
}
#region Classes
protected class Incoming
{
private UserProtocol protocol;
public States state;
public string keyStr;
private int directionNumber;
private UploadFileInfo upload;
private IShareManager ShareManager
{
get { return protocol.Helper.ShareManager; }
}
public Incoming(UserProtocol protocol)
{
this.protocol = protocol;
state = States.Nick;
Random random = new Random();
directionNumber = random.Next(32767);
}
#region Methods
#region Handshake
public void MyNick(string nick)
{
CheckState(States.Nick);
HubConnection hub = protocol.Connection.Hub;
Identity identity = null;
// Get the hub and user
if (hub != null)
{
identity = protocol.GetUser(hub, nick);
}
else
{
protocol.Helper.IdentifyUser(nick, out hub, out identity);
}
if (hub == null || identity == null)
{
Debug.WriteLine("Unknown user " + nick);
protocol.Connection.Disconnect();
return;
}
// Update the connection
protocol.Connection.Hub = hub;
protocol.Connection.LocalIdentity = hub.LocalIdentity;
protocol.Connection.RemoteIdentity = identity;
// Direction
protocol.direction = protocol.Helper.GetTransferDirection(identity.User);
// Done
if (protocol.Connection.Incoming)
{
protocol.MyNick();
protocol.Lock();
}
state = States.Lock;
}
public void Lock(string argument)
{
CheckState(States.Lock);
protocol.Supports();
protocol.Direction(directionNumber);
protocol.Key(CalculateKey(argument, protocol.Encoding));
state = States.Direction;
}
public void Supports(string argument)
{
string[] supports = argument.Split(' ');
foreach(string feature in supports)
{
if(feature.Length > 0)
{
try
{
Features feat = (Features)Enum.Parse(typeof(Features),
feature);
protocol.features |= feat;
}
catch(ArgumentException)
{
Debug.WriteLine(feature + " needs to be added to features.");
}
}
}
}
public void Direction(string argument)
{
CheckState(States.Direction);
string[] args = argument.Split(' ');
if (args[0].Equals("Upload") &&
protocol.direction == TransferDirection.Up)
{
Debug.WriteLine("No one wants to download");
protocol.Connection.Disconnect();
}
else if (args[0].Equals("Download") &&
protocol.direction == TransferDirection.Down)
{
// The one with the highest number gets to download
int number = int.Parse(args[1]);
if (number > directionNumber)
{
protocol.direction = TransferDirection.Up;
}
else if (number == directionNumber)
{
Debug.WriteLine("$Direction collision");
protocol.Connection.Disconnect();
}
}
state = States.Key;
}
public void Key(string argument)
{
CheckState(States.Key);
// Get the local key in the connection encoding.
// TODO: Compare with the original incoming byte array instead?
byte[] bytes = new byte[keyStr.Length];
for(int i = 0; i < keyStr.Length; i++)
{
bytes[i] = Convert.ToByte(keyStr[i]);
}
string key = protocol.Encoding.GetString(bytes);
if(argument != key)
{
protocol.Error(Errors.InvalidKey);
protocol.Connection.Disconnect();
return;
}
// Start downloading/uploading
if (protocol.direction == TransferDirection.Down)
{
state = States.FileLength;
}
else
{
state = States.Get;
}
// Connection ready for file transfer
protocol.Connection.Direction = protocol.direction;
}
#endregion
#region Upload
public void Get(string argument)
{
CheckState(States.Get);
string[] strings = argument.Split('$');
if(strings.Length != 2)
{
return;
}
string virtualPath = strings[0].Trim();
string path = GetRealPath(virtualPath, true);
long startingByte = long.Parse(strings[1].Trim()) - 1;
if (startingByte < 0)
{
protocol.Error(Errors.InvalidCommand);
return;
}
try
{
UploadFileInfo upload = CreateUploadFileInfo(path,
startingByte, -1);
if (!protocol.Helper.CanUpload(protocol.Connection,
upload))
{
protocol.MaxedOut();
return;
}
protocol.Connection.Send("$FileLength " + upload.Size + "|");
this.upload = upload;
}
catch
{
protocol.Error(Errors.FileNotAvailable);
return;
}
state = States.Send;
}
public void Send()
{
CheckState(States.Send);
protocol.Connection.UploadFile(upload);
upload = null;
state = States.Get;
}
// $GetListLen returns the length of the DcLst style file list
public void GetListLen()
{
CheckState(States.Get);
try
{
string path = ShareManager.LegacyFileListPath;
FileInfo file = new FileInfo(path);
protocol.Connection.Send("$ListLen " + file.Length + "|");
}
catch
{
protocol.Error(Errors.FileNotAvailable);
}
}
public void GetZBlock(string argument)
{
GetBlock(argument, true);
}
public void UGetBlock(string argument)
{
UGetBlock(argument, false);
}
public void UGetZBlock(string argument)
{
UGetBlock(argument, true);
}
// Support for UGetBlock is indicated by XmlBZList in $Supports
// $Failed is sent in response to $GetZBlock, $UGetBlock or $UGetZBlock.
private void UGetBlock(string argument, bool compress)
{
EnsureUTF8(ref argument);
GetBlock(argument, compress);
}
private void GetBlock(string argument, bool compress)
{
CheckState(States.Get);
int i = argument.IndexOf(' ');
int j = argument.IndexOf(' ', i + 1);
if (i <= 0 || j <= 0)
{
protocol.Failed("Wrong UGetBlock parameters|");
return;
}
long startingByte = long.Parse(argument.Substring(0, i));
long count = long.Parse(argument.Substring(i, j - i));
string virtualPath = argument.Substring(j + 1);
string path = GetRealPath(virtualPath, true);
try
{
UploadFileInfo upload = CreateUploadFileInfo(path,
startingByte, count);
if (!protocol.Helper.CanUpload(protocol.Connection,
upload))
{
protocol.MaxedOut();
return;
}
protocol.Connection.Send("$Sending " + upload.RequestedBytes + "|");
protocol.Connection.UploadFile(upload, compress);
}
catch (Exception e)
{
Debug.WriteLine(e);
protocol.Failed(Errors.FileNotAvailable);
}
}
public void AdcGet(string argument)
{
CheckState(States.Get);
EnsureUTF8(ref argument);
// Split on spaces not escaped with a backslash.
string[] parameters = Regex.Split(argument, @"(?<!\\) ");
if (parameters.Length < 4)
{
protocol.Error("Wrong ADCGET parameters|");
return;
}
string type = parameters[0];
string path = parameters[1];
long startingByte = long.Parse(parameters[2]);
long count = long.Parse(parameters[3]);
bool compress = parameters.Length >= 5 && parameters[4] == "ZL1";
if (type == "file")
{
AdcGetFile(type, path, startingByte, count, compress);
}
else if (type == "tthl" && path.StartsWith("TTH/"))
{
string tth = path.Substring(4);
HashTree hashTree = ShareManager.GetHashTree(tth);
if (hashTree != null)
{
byte[][] leaves = hashTree.Leaves;
string command = String.Format("$ADCSND {0} {1} {2} {3}|",
type, path, startingByte,
leaves.Length * hashTree.HashSize);
protocol.Connection.Send(command, System.Text.Encoding.UTF8);
foreach (byte[] leaf in leaves)
{
protocol.Connection.Send(leaf);
}
}
else
{
protocol.Error(Errors.FileNotAvailable);
}
}
else
{
protocol.Error("Wrong ADCGET parameters|");
}
}
private void AdcGetFile(string type, string path, long startingByte,
long count, bool compress)
{
string realPath;
if (path.StartsWith("TTH/"))
{
string tth = path.Substring(4);
realPath = ShareManager.GetFullPathFromTTH(tth);
}
else
{
realPath = path.StartsWith("/") ? path.Substring(1) : path;
realPath = GetRealPath(UserProtocol.UnescapeADCParameter(realPath),
false);
}
try
{
UploadFileInfo upload = CreateUploadFileInfo(realPath,
startingByte, count);
if (!protocol.Helper.CanUpload(protocol.Connection,
upload))
{
protocol.MaxedOut();
return;
}
string command = String.Format("$ADCSND {0} {1} {2} {3}{4}|",
type, path, startingByte, upload.RequestedBytes,
compress ? " ZL1" : null);
protocol.Connection.Send(command, System.Text.Encoding.UTF8);
protocol.Connection.UploadFile(upload, compress);
}
catch (Exception e)
{
Debug.WriteLine(e);
protocol.Error(Errors.FileNotAvailable);
}
}
#endregion
#region Download
public void FileLength(string argument)
{
CheckState(States.FileLength);
protocol.download.Size = long.Parse(argument);
protocol.Connection.Send("$Send|");
protocol.Connection.BeginReceive();
}
public void MaxedOut()
{
CheckState(States.FileLength);
protocol.Connection.DownloadError(TransferError.NoSlots,
"No slots available");
}
public void Sending(string argument)
{
CheckState(States.FileLength);
if (protocol.download.Size == 0)
{
protocol.download.Size = long.Parse(argument);
}
protocol.Connection.BeginReceive();
}
public void Failed(string argument)
{
Error(argument);
}
public void AdcSnd(string argument)
{
CheckState(States.FileLength);
EnsureUTF8(ref argument);
// Split on spaces not escaped with a backslash.
string[] args = Regex.Split(argument, @"(?<!\\) ");
if (args.Length == 4)
{
long startingByte = long.Parse(args[2]);
long count = long.Parse(args[3]);
if (protocol.download.Position != startingByte)
{
// TODO: Error
return;
}
if (protocol.download.Size == 0)
{
protocol.download.Size = count;
}
protocol.Connection.BeginReceive();
}
}
#endregion
#region Other
public void Error(string argument)
{
if (state == States.FileLength && protocol.source != null)
{
string arg = argument.ToLower();
if (arg.IndexOf(" no more exists") >= 0)
{
protocol.Connection.DownloadError(TransferError.NoSlots,
"No slots available");
}
else if (arg == "file not available")
{
protocol.Connection.DownloadError(TransferError.NotAvailable,
"File not available");
}
else
{
protocol.Connection.DownloadError(TransferError.Unknown,
argument);
}
}
}
#endregion
protected void CheckState(States assumedState)
{
if (assumedState != state)
{
throw new InvalidOperationException();
}
}
protected void EnsureUTF8(ref string argument)
{
if (protocol.Encoding.CodePage != 65001)
{
byte[] data = protocol.Encoding.GetBytes(argument);
argument = System.Text.Encoding.UTF8.GetString(data);
}
}
protected UploadFileInfo CreateUploadFileInfo(string path,
long startingByte, long count)
{
FileInfo file = new FileInfo(path);
if (count == -1)
{
count = file.Length - startingByte;
}
UploadFileInfo upload = new UploadFileInfo(
protocol.Connection.RemoteIdentity, file.FullName, file.Length,
null);
upload.Position = startingByte;
upload.RequestedBytes = count;
upload.Size = file.Length;
upload.IsFileList = IsFileList(path);
return upload;
}
protected string GetRealPath(string virtualPath, bool replaceSeparators)
{
// TODO: Implement proper security check
if(virtualPath.IndexOf(@"..\") >= 0)
{
return null;
}
if (replaceSeparators)
{
// Different path separators is a mess.
virtualPath = virtualPath.Replace('\\', Path.DirectorySeparatorChar);
virtualPath = virtualPath.Replace(Path.AltDirectorySeparatorChar,
Path.DirectorySeparatorChar);
}
string path;
if (virtualPath == "files.xml.bz2")
{
ShareManager.SaveFileList();
path = ShareManager.CompressedFileListPath;
}
else if (virtualPath == "MyList.DcLst")
{
ShareManager.SaveFileList();
path = ShareManager.LegacyFileListPath;
}
else
{
path = ShareManager.GetFullPath(virtualPath);
}
return path;
}
protected bool IsFileList(string path)
{
if (path == "files.xml.bz2" || path == "MyList.DcLst" ||
path == ShareManager.CompressedFileListPath ||
path == ShareManager.LegacyFileListPath)
{
return true;
}
return false;
}
}
#endregion
#region Properties
public UserConnection Connection
{
get { return connection; }
set { connection = value; }
}
private UserConnection connection;
public IUserProtocolHelper Helper
{
get { return helper; }
}
private IUserProtocolHelper helper;
#endregion
public override void Initialize()
{
base.Initialize();
if (!connection.Incoming)
{
MyNick();
Lock();
}
}
protected override void Invoke(string command, string argument)
{
try
{
base.Invoke(command, argument);
}
catch (Exception)
{
Error(Errors.InvalidCommand);
// TODO: Hold the previous download
if (direction == TransferDirection.Down)
{
incoming.state = States.FileLength;
}
else
{
incoming.state = States.Get;
}
}
}
protected bool Supports(Features feature)
{
return features == (features | feature);
}
#region Handshake
protected void MyNick()
{
connection.Send("$MyNick " + connection.LocalIdentity.Nick + "|");
}
protected void Lock()
{
string lockStr = "EXTENDEDPROTOCOLABCABCABCABCABCABC Pk=DCPLUSPLUS0.674ABCABC";
incoming.keyStr = CalculateKey(lockStr, Encoding);
connection.Send("$Lock " + lockStr + "|");
}
protected void Supports()
{
// If we send a lock string starting with EXTENDEDPROTOCOL we have
// to support at least XmlBZList and UGetBlock
string[] features = new string[] {"MiniSlots", "XmlBZList",
"ADCGet", "TTHL", "TTHF"}; //, "GetZBlock", "ZLIG"};
string supports = "";
foreach(string feature in features)
{
supports += feature + " ";
}
connection.Send("$Supports " + supports + "|");
}
protected void Direction(int number)
{
string message = String.Format("$Direction {0} {1}|",
direction == TransferDirection.Down ? "Download" : "Upload",
number);
connection.Send(message);
}
protected void Key(string argument)
{
connection.ConnectionSend("$Key " + argument + "|");
}
#endregion
#region Download
public void DownloadFile(DownloadFileInfo download, SourceInfo source)
{
if (download == null)
{
throw new ArgumentNullException("download");
}
if (source == null)
{
throw new ArgumentNullException("source");
}
this.download = download;
this.source = source;
string path = source.Path;
if (download.IsFileList)
{
string extension;
if (Supports(Features.XmlBZList))
{
path = "files.xml.bz2";
extension = ".xml.bz2";
}
else
{
path = "MyList.DcLst";
extension = ".DcLst";
}
// TODO: Change the name?
download.Name = connection.RemoteIdentity.Cid + extension;
}
long numBytes = download.Size - download.Position;
if (numBytes <= 0)
{
numBytes = -1;
}
// Send the command to get the file
if (Supports(Features.ADCGet))
{
if (download.TTH != null && Supports(Features.TTHF))
{
path = "TTH/" + download.TTH;
}
else if (!download.IsFileList)
{
// All files are relative to the root '/'
path = "/" + EscapeADCParameter(path);
}
ADCGET(path, download.Position, numBytes);
}
else if (Supports(Features.XmlBZList))
{
UGetBlock(path, download.Position, numBytes);
}
else
{
Get(path, download.Position);
}
}
protected void Get(string path, long startingByte)
{
// Most clients expects \ as the directory separator.
path = path.Replace(Path.DirectorySeparatorChar, '\\');
connection.Send("$Get " + path + "$" + (startingByte + 1) + "|");
}
protected void UGetBlock(string path, long startingByte, long numBytes)
{
// Most clients expects \ as the directory separator.
path = path.Replace(Path.DirectorySeparatorChar, '\\');
string command = String.Format("$UGetBlock {0} {1} {2}|",
startingByte, numBytes, path);
connection.Send(command, System.Text.Encoding.UTF8);
}
protected void ADCGET(string identifier, long startingByte, long numBytes)
{
// Here slash is used as the directory separator.
identifier = identifier.Replace(Path.DirectorySeparatorChar, '/');
string command = String.Format("$ADCGET file {0} {1} {2}|",
identifier, startingByte, numBytes);
connection.Send(command, System.Text.Encoding.UTF8);
}
#endregion
protected void MaxedOut()
{
connection.Send("$MaxedOut|");
}
protected string GetErrorDescription(Errors error)
{
switch(error)
{
case Errors.FileNotAvailable:
return "File Not Available";
case Errors.InvalidCommand:
return "Invalid Command";
case Errors.InvalidKey:
return "Invalid Key";
case Errors.NoSlots:
return "No slots available";
}
return null;
}
protected void Error(Errors error)
{
Error(GetErrorDescription(error));
}
protected void Error(string argument)
{
connection.Send("$Error " + argument + "|");
}
protected void Failed(Errors error)
{
Failed(GetErrorDescription(error));
}
protected void Failed(string argument)
{
connection.Send("$Failed " + argument);
}
protected static string EscapeADCParameter(string param)
{
param = param.Replace("\\", "\\\\");
param = param.Replace(" ", "\\s");
param = param.Replace("\n", "\\n");
// Used in the old spec
param = param.Replace(" ", "\\ ");
return param;
}
protected static string UnescapeADCParameter(string param)
{
param = param.Replace("\\\\", "\\");
param = param.Replace("\\s", " ");
param = param.Replace("\\n", "\n");
// Used in the old spec
param = param.Replace("\\ ", " ");
return param;
}
#endregion
}
}
|