Skip to content

Commit

Permalink
Merge pull request #11 from koralium/10_ssl_support
Browse files Browse the repository at this point in the history
Added support for ssl traffic
  • Loading branch information
Ulimo authored Feb 28, 2021
2 parents 4f2a031 + d14b2e4 commit b745304
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 21 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ The connection string takes the following parameters:
| ExtraCredentials | Extra credentials to send. | ExtraCredentials=key1:value1,key2:value2; |
| Trino | Use trino headers (required for trino) | Trino=true; |
| Password | Password for the user | Password=test; |
| Ssl | Https or http protocol | Ssl=true; |

# SSL Traffic

If the SSL connection string option is left out, the ADO.Net provider tries to figure out the protocol by itself.
It first tries https but if that fails it tests http. This is saved as long as the application is running.
But for better first time performance if one is not using https is to set ssl=false in the connection string.

# Nuget Package

Expand Down
125 changes: 108 additions & 17 deletions src/Data.Presto/Client/PrestoClient.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Data.Presto.Models;
using Data.Presto.Utils;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Net.Http;
using System.Net.Http.Headers;
Expand All @@ -15,10 +16,28 @@ class PrestoClient
private readonly HttpClient _httpClient;
private readonly PrestoConnectionStringBuilder _connectionString;

private bool? _useSsl = null;

//Dictionary used to store which http protocol to use for connections where it is not marked explicitly.
private static readonly ConcurrentDictionary<string, bool> _protocolLookup = new ConcurrentDictionary<string, bool>();

public PrestoClient(PrestoConnectionStringBuilder prestoConnectionString)
{
_connectionString = prestoConnectionString;
_httpClient = new HttpClient();

if (_connectionString.Ssl.HasValue)
{
_useSsl = _connectionString.Ssl.Value;
}
else
{
//Check if another presto client has already done connections against the host
if (_protocolLookup.TryGetValue(prestoConnectionString.Host, out var canUseSsl))
{
_useSsl = canUseSsl;
}
}
}

private void AddHeaders(HttpRequestMessage httpRequestMessage)
Expand Down Expand Up @@ -100,18 +119,98 @@ private async Task<DecodeResult> CheckResult(HttpResponseMessage httpResponseMes
return decodeResults;
}

public async Task<DecodeResult> Query(string statement, CancellationToken cancellationToken)
private string GetProtocol()
{
using var httpRequestMessage = new HttpRequestMessage()
if (!_useSsl.Value)
{
Method = HttpMethod.Post,
RequestUri = new Uri($"http://{_connectionString.Host}/v1/statement"),
Content = new StringContent(statement)
};
return "http://";
}
return "https://";
}

AddHeaders(httpRequestMessage);
private void SetHostSsl(in string host, bool canUseSsl)
{
if (_useSsl == null)
{
_useSsl = canUseSsl;
_protocolLookup.AddOrUpdate(host, canUseSsl, (key, old) => canUseSsl);
}
}

var result = await _httpClient.SendAsync(httpRequestMessage).ConfigureAwait(false);
private async Task<HttpResponseMessage> SendMessage(HttpMethod httpMethod, string path, CancellationToken cancellationToken, string content = null)
{
//Protocol has not yet been determined
if (_useSsl == null)
{
try
{
using var httpRequestMessage = new HttpRequestMessage()
{
Method = httpMethod,
RequestUri = new Uri($"https://{_connectionString.Host}{path}"),
};

if (content != null)
{
httpRequestMessage.Content = new StringContent(content);
}

AddHeaders(httpRequestMessage);

var response = await _httpClient.SendAsync(httpRequestMessage).ConfigureAwait(false);
SetHostSsl(_connectionString.Host, true);
return response;
}
catch (HttpRequestException requestException)
{
if (requestException?.InnerException?.Source == "System.Net.Security")
{
//Exception regarding security, test http:// instead
using var httpRequestMessage = new HttpRequestMessage()
{
Method = httpMethod,
RequestUri = new Uri($"http://{_connectionString.Host}{path}"),
};

if (content != null)
{
httpRequestMessage.Content = new StringContent(content);
}

AddHeaders(httpRequestMessage);

var response = await _httpClient.SendAsync(httpRequestMessage).ConfigureAwait(false);
SetHostSsl(_connectionString.Host, false);
return response;
}
else
{
throw;
}
}
}
else
{
using var httpRequestMessage = new HttpRequestMessage()
{
Method = httpMethod,
RequestUri = new Uri($"{GetProtocol()}{_connectionString.Host}{path}"),
};

if (content != null)
{
httpRequestMessage.Content = new StringContent(content);
}

AddHeaders(httpRequestMessage);

return await _httpClient.SendAsync(httpRequestMessage).ConfigureAwait(false);
}
}

public async Task<DecodeResult> Query(string statement, CancellationToken cancellationToken)
{
var result = await SendMessage(HttpMethod.Post, "/v1/statement", cancellationToken, statement);

switch (result.StatusCode)
{
Expand All @@ -126,15 +225,7 @@ public async Task<DecodeResult> Query(string statement, CancellationToken cancel

public async Task KillQuery(string queryId, CancellationToken cancellationToken)
{
using var httpRequestMessage = new HttpRequestMessage()
{
Method = HttpMethod.Delete,
RequestUri = new Uri($"http://{_connectionString.Host}/v1/query/{queryId}")
};

AddHeaders(httpRequestMessage);

var result = await _httpClient.SendAsync(httpRequestMessage).ConfigureAwait(false);
await SendMessage(HttpMethod.Delete, $"/v1/query/{queryId}", cancellationToken).ConfigureAwait(false);
}
}
}
27 changes: 23 additions & 4 deletions src/Data.Presto/PrestoConnectionStringBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public class PrestoConnectionStringBuilder : DbConnectionStringBuilder
private const string StreamingKeyword = "Streaming";
private const string TrinoKeyword = "Trino";
private const string PasswordKeyword = "Password";
private const string SslKeyword = "Ssl";

private static readonly IReadOnlyList<string> _validKeywords;
private static readonly IReadOnlyDictionary<string, Keywords> _keywords;
Expand All @@ -35,6 +36,7 @@ public class PrestoConnectionStringBuilder : DbConnectionStringBuilder
private bool _streaming = true;
private bool _trino = false;
private string _password = string.Empty;
private bool? _ssl = null;
private ImmutableList<KeyValuePair<string, string>> _extraCredentials = ImmutableList.Create<KeyValuePair<string, string>>();

private enum Keywords
Expand All @@ -46,12 +48,13 @@ private enum Keywords
ExtraCredentials,
Streaming,
Trino,
Password
Password,
Ssl,
}

static PrestoConnectionStringBuilder()
{
var validKeywords = new string[8];
var validKeywords = new string[9];
validKeywords[(int)Keywords.DataSource] = DataSourceKeyword;
validKeywords[(int)Keywords.User] = UserKeyword;
validKeywords[(int)Keywords.Catalog] = CatalogKeyword;
Expand All @@ -60,9 +63,10 @@ static PrestoConnectionStringBuilder()
validKeywords[(int)Keywords.Streaming] = StreamingKeyword;
validKeywords[(int)Keywords.Trino] = TrinoKeyword;
validKeywords[(int)Keywords.Password] = PasswordKeyword;
validKeywords[(int)Keywords.Ssl] = SslKeyword;
_validKeywords = validKeywords;

_keywords = new Dictionary<string, Keywords>(9, StringComparer.OrdinalIgnoreCase)
_keywords = new Dictionary<string, Keywords>(10, StringComparer.OrdinalIgnoreCase)
{
[DataSourceKeyword] = Keywords.DataSource,
[DataSourceNoSpaceKeyword] = Keywords.DataSource,
Expand All @@ -72,7 +76,8 @@ static PrestoConnectionStringBuilder()
[ExtraCredentialKeyword] = Keywords.ExtraCredentials,
[StreamingKeyword] = Keywords.Streaming,
[TrinoKeyword] = Keywords.Trino,
[PasswordKeyword] = Keywords.Password
[PasswordKeyword] = Keywords.Password,
[SslKeyword] = Keywords.Ssl
};
}

Expand Down Expand Up @@ -129,6 +134,12 @@ public virtual bool Trino
set => base[TrinoKeyword] = _trino = value;
}

public virtual bool? Ssl
{
get => _ssl;
set => base[SslKeyword] = _ssl = value;
}

public virtual string Password
{
get => _password;
Expand Down Expand Up @@ -195,6 +206,9 @@ public override object this[string keyword]
case Keywords.Password:
Password = Convert.ToString(value, CultureInfo.InvariantCulture);
return;
case Keywords.Ssl:
Ssl = Convert.ToBoolean(value, CultureInfo.InvariantCulture);
return;
default:
Debug.Assert(false, "Unexpected keyword: " + keyword);
return;
Expand Down Expand Up @@ -290,6 +304,8 @@ private object GetAt(Keywords index)
return Trino;
case Keywords.Password:
return Password;
case Keywords.Ssl:
return Ssl;
default:
Debug.Assert(false, "Unexpected keyword: " + index);
return null;
Expand Down Expand Up @@ -329,6 +345,9 @@ private void Reset(Keywords index)
case Keywords.Password:
_password = string.Empty;
return;
case Keywords.Ssl:
_ssl = null;
return;
default:
Debug.Assert(false, "Unexpected keyword: " + index);
return;
Expand Down

0 comments on commit b745304

Please sign in to comment.