Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 202 additions & 7 deletions src/Billing/Jobs/ReconcileAdditionalStorageJob.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,37 @@
using Bit.Billing.Services;
using Bit.Core;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Pricing;
using Bit.Core.Jobs;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Quartz;
using Stripe;
using OrganizationPlan = Bit.Core.Models.StaticStore.Plan;
using PremiumPlan = Bit.Core.Billing.Pricing.Premium.Plan;

namespace Bit.Billing.Jobs;

public class ReconcileAdditionalStorageJob(
IStripeFacade stripeFacade,
ILogger<ReconcileAdditionalStorageJob> logger,
IFeatureService featureService) : BaseJob(logger)
IFeatureService featureService,
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IPricingClient pricingClient) : BaseJob(logger)
{
private const string _storageGbMonthlyPriceId = "storage-gb-monthly";
private const string _storageGbAnnuallyPriceId = "storage-gb-annually";
private const string _personalStorageGbAnnuallyPriceId = "personal-storage-gb-annually";
private const int _storageGbToRemove = 4;
private const short _includedStorageGb = 5;

public enum SubscriptionPlanTier
{
Personal,
Organization,
Unknown
}

protected override async Task ExecuteJobAsync(IJobExecutionContext context)
{
Expand All @@ -34,10 +49,30 @@ protected override async Task ExecuteJobAsync(IJobExecutionContext context)
var subscriptionsFound = 0;
var subscriptionsUpdated = 0;
var subscriptionsWithErrors = 0;
var databaseUpdatesFailed = 0;
var failures = new List<string>();

logger.LogInformation("Starting ReconcileAdditionalStorageJob (live mode: {LiveMode})", liveMode);

// Load plans for subscription type determination
List<PremiumPlan> personalPremiumPlans;
List<OrganizationPlan> organizationPlans;
try
{
personalPremiumPlans = await pricingClient.ListPremiumPlans();
organizationPlans = await pricingClient.ListPlans();

logger.LogInformation(
"Loaded {PremiumCount} personal/premium plans and {OrgCount} organization plans from pricing client",
personalPremiumPlans.Count,
organizationPlans.Count);
}
catch (Exception ex)
{
logger.LogError(ex, "Failed to load pricing plans from pricing client. Cannot proceed with job execution.");
return;
}

var priceIds = new[] { _storageGbMonthlyPriceId, _storageGbAnnuallyPriceId, _personalStorageGbAnnuallyPriceId };
var stripeStatusesToProcess = new[] { StripeConstants.SubscriptionStatus.Active, StripeConstants.SubscriptionStatus.Trialing, StripeConstants.SubscriptionStatus.PastDue };

Expand All @@ -51,11 +86,13 @@ protected override async Task ExecuteJobAsync(IJobExecutionContext context)
{
logger.LogWarning(
"Job cancelled!! Exiting. Progress at time of cancellation: Subscriptions found: {SubscriptionsFound}, " +
"Updated: {SubscriptionsUpdated}, Errors: {SubscriptionsWithErrors}{Failures}",
"Stripe updates: {StripeUpdates}, Database updates: {DatabaseFailed} failed, " +
"Errors: {SubscriptionsWithErrors}{Failures}",
subscriptionsFound,
liveMode
? subscriptionsUpdated
: $"(In live mode, would have updated) {subscriptionsUpdated}",
databaseUpdatesFailed,
subscriptionsWithErrors,
failures.Count > 0
? $", Failures: {Environment.NewLine}{string.Join(Environment.NewLine, failures)}"
Expand Down Expand Up @@ -99,20 +136,48 @@ protected override async Task ExecuteJobAsync(IJobExecutionContext context)

subscriptionsUpdated++;

// Now, prepare the database update so we can log details out if not in live mode
var subscriptionPlanTier = DetermineSubscriptionPlanTier(subscription, personalPremiumPlans, organizationPlans);
// Calculate new MaxStorageGb
var currentStorageQuantity = GetCurrentStorageQuantityFromSubscription(subscription, priceId);
var newMaxStorageGb = CalculateNewMaxStorageGb(currentStorageQuantity, updateOptions);

if (!liveMode)
{
logger.LogInformation(
"Not live mode (dry-run): Would have updated subscription {SubscriptionId} with item changes: {NewLine}{UpdateOptions}",
"Not live mode (dry-run): Would have updated subscription {SubscriptionId} with item changes: {NewLine}{UpdateOptions}" +
"{NewLine2}And would have updated database record tier: {Tier} to new MaxStorageGb: {MaxStorageGb}",
subscription.Id,
Environment.NewLine,
JsonSerializer.Serialize(updateOptions));
JsonSerializer.Serialize(updateOptions),
Environment.NewLine,
subscriptionPlanTier,
newMaxStorageGb);
continue;
}

// Live mode enabled - continue with updates to stripe and database
try
{
await stripeFacade.UpdateSubscription(subscription.Id, updateOptions);
logger.LogInformation("Successfully updated subscription: {SubscriptionId}", subscription.Id);
logger.LogInformation("Successfully updated Stripe subscription: {SubscriptionId}", subscription.Id);

logger.LogInformation(
"Updating MaxStorageGb in database for subscription {SubscriptionId} ({Type}): New MaxStorageGb: {MaxStorage}",
subscription.Id,
subscriptionPlanTier,
newMaxStorageGb);

var dbUpdateSuccess = await UpdateDatabaseMaxStorageAsync(
subscriptionPlanTier,
subscription.Id,
newMaxStorageGb);

if (!dbUpdateSuccess)
{
databaseUpdatesFailed++;
failures.Add($"Subscription {subscription.Id}: Database update failed");
}
}
catch (Exception ex)
{
Expand All @@ -125,12 +190,14 @@ protected override async Task ExecuteJobAsync(IJobExecutionContext context)
}

logger.LogInformation(
"ReconcileAdditionalStorageJob completed. Subscriptions found: {SubscriptionsFound}, " +
"Updated: {SubscriptionsUpdated}, Errors: {SubscriptionsWithErrors}{Failures}",
"ReconcileAdditionalStorageJob FINISHED. Subscriptions found: {SubscriptionsFound}, " +
"Stripe updates: {StripeUpdates}, Database updates: {DatabaseFailed} failed, " +
"Errors: {SubscriptionsWithErrors}{Failures}",
subscriptionsFound,
liveMode
? subscriptionsUpdated
: $"(In live mode, would have updated) {subscriptionsUpdated}",
databaseUpdatesFailed,
subscriptionsWithErrors,
failures.Count > 0
? $", Failures: {Environment.NewLine}{string.Join(Environment.NewLine, failures)}"
Expand Down Expand Up @@ -182,6 +249,134 @@ protected override async Task ExecuteJobAsync(IJobExecutionContext context)
return hasUpdates ? updateOptions : null;
}

public SubscriptionPlanTier DetermineSubscriptionPlanTier(
Subscription subscription,
List<PremiumPlan> personalPremiumPlans,
List<OrganizationPlan> organizationPlans)
{
if (subscription.Items?.Data == null)
{
return SubscriptionPlanTier.Unknown;
}

foreach (var item in subscription.Items.Data)
{
if (item?.Price?.Id == null)
{
continue;
}

// eagerly match the first id found to determine if personal or org
if (personalPremiumPlans.Any(p => p.Seat.StripePriceId == item.Price.Id)) return SubscriptionPlanTier.Personal;

if (organizationPlans.Any(p =>
p.PasswordManager.StripeSeatPlanId == item.Price.Id ||
p.PasswordManager.StripePlanId == item.Price.Id))
return SubscriptionPlanTier.Organization;
}

return SubscriptionPlanTier.Unknown;
}

public long GetCurrentStorageQuantityFromSubscription(
Subscription subscription,
string storagePriceId)
{
return subscription.Items?.Data?.FirstOrDefault(item => item?.Price?.Id == storagePriceId)?.Quantity ?? 0;
}

public short CalculateNewMaxStorageGb(
long currentQuantity,
SubscriptionUpdateOptions? updateOptions)
{
if (updateOptions?.Items == null)
{
return (short)currentQuantity;
}

// If the update marks item as deleted, new quantity is whatever the base storage gb
if (updateOptions.Items.Any(i => i.Deleted == true))
{
return _includedStorageGb;
}

// If the update has a new quantity, use it to calculate the new max
var updatedItem = updateOptions.Items.FirstOrDefault(i => i.Quantity.HasValue);
if (updatedItem?.Quantity != null)
{
return (short)(_includedStorageGb + updatedItem.Quantity.Value);
}

// Otherwise, no change
return (short)currentQuantity;
}

public async Task<bool> UpdateDatabaseMaxStorageAsync(
SubscriptionPlanTier subscriptionPlanTier,
string subscriptionId,
short newMaxStorageGb)
{
try
{
switch (subscriptionPlanTier)
{
case SubscriptionPlanTier.Personal:
{
var user = await userRepository.GetByGatewaySubscriptionIdAsync(subscriptionId);
if (user == null)
{
logger.LogError(
"User not found for subscription {SubscriptionId}. Database not updated.",
subscriptionId);
return false;
}

user.MaxStorageGb = newMaxStorageGb;
await userRepository.ReplaceAsync(user);

logger.LogInformation(
"Successfully updated User {UserId} MaxStorageGb to {MaxStorageGb} for subscription {SubscriptionId}",
user.Id,
newMaxStorageGb,
subscriptionId);
return true;
}
case SubscriptionPlanTier.Organization:
{
var organization = await organizationRepository.GetByGatewaySubscriptionIdAsync(subscriptionId);
if (organization == null)
{
logger.LogWarning(
"Organization not found for subscription {SubscriptionId}. Database not updated.",
subscriptionId);
return false;
}

organization.MaxStorageGb = newMaxStorageGb;
await organizationRepository.ReplaceAsync(organization);

logger.LogInformation(
"Successfully updated Organization {OrganizationId} MaxStorageGb to {MaxStorageGb} for subscription {SubscriptionId}",
organization.Id,
newMaxStorageGb,
subscriptionId);
return true;
}
case SubscriptionPlanTier.Unknown:
default:
return false;
}
}
catch (Exception ex)
{
logger.LogError(ex,
"Failed to update database MaxStorageGb for subscription {SubscriptionId} (Plan Tier: {SubscriptionType})",
subscriptionId,
subscriptionPlanTier);
return false;
}
}

public static ITrigger GetTrigger()
{
return TriggerBuilder.Create()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ public interface IOrganizationRepository : IRepository<Organization, Guid>
Task<ICollection<Organization>> GetManyByUserIdAsync(Guid userId);
Task<ICollection<Organization>> SearchAsync(string name, string userEmail, bool? paid, int skip, int take);
Task UpdateStorageAsync(Guid id);
Task<Organization?> GetByGatewaySubscriptionIdAsync(string gatewaySubscriptionId);
Task<ICollection<OrganizationAbility>> GetManyAbilitiesAsync();
Task<Organization?> GetByLicenseKeyAsync(string licenseKey);
Task<SelfHostedOrganizationDetails?> GetSelfHostedOrganizationDetailsById(Guid id);
Expand Down
1 change: 1 addition & 0 deletions src/Core/Repositories/IUserRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public interface IUserRepository : IRepository<User, Guid>
Task<DateTime> GetAccountRevisionDateAsync(Guid id);
Task UpdateStorageAsync(Guid id);
Task UpdateRenewalReminderDateAsync(Guid id, DateTime renewalReminderDate);
Task<User?> GetByGatewaySubscriptionIdAsync(string gatewaySubscriptionId);
Task<IEnumerable<User>> GetManyAsync(IEnumerable<Guid> ids);
/// <summary>
/// Retrieves the data for the requested user IDs and includes an additional property indicating
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,19 @@ await connection.ExecuteAsync(
}
}

public async Task<Organization?> GetByGatewaySubscriptionIdAsync(string gatewaySubscriptionId)
{
using (var connection = new SqlConnection(ConnectionString))
{
var results = await connection.QueryAsync<Organization>(
"[dbo].[Organization_ReadByGatewaySubscriptionId]",
new { GatewaySubscriptionId = gatewaySubscriptionId },
commandType: CommandType.StoredProcedure);

return results.SingleOrDefault();
}
}

public async Task<ICollection<OrganizationAbility>> GetManyAbilitiesAsync()
{
using (var connection = new SqlConnection(ConnectionString))
Expand Down
14 changes: 14 additions & 0 deletions src/Infrastructure.Dapper/Repositories/UserRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,20 @@ public UserRepository(
}
}

public async Task<User?> GetByGatewaySubscriptionIdAsync(string gatewaySubscriptionId)
{
using (var connection = new SqlConnection(ConnectionString))
{
var results = await connection.QueryAsync<User>(
$"[{Schema}].[{Table}_ReadByGatewaySubscriptionId]",
new { GatewaySubscriptionId = gatewaySubscriptionId },
commandType: CommandType.StoredProcedure);

UnprotectData(results);
return results.SingleOrDefault();
}
}

public async Task<IEnumerable<User>> GetManyByEmailsAsync(IEnumerable<string> emails)
{
var emailTable = new DataTable();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ public OrganizationRepository(
}
}

public async Task<Core.AdminConsole.Entities.Organization> GetByGatewaySubscriptionIdAsync(string gatewaySubscriptionId)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var organization = await GetDbSet(dbContext).Where(e => e.GatewaySubscriptionId == gatewaySubscriptionId)
.FirstOrDefaultAsync();
return organization;
}
}

public async Task<ICollection<Core.AdminConsole.Entities.Organization>> GetManyByEnabledAsync()
{
using (var scope = ServiceScopeFactory.CreateScope())
Expand Down
10 changes: 10 additions & 0 deletions src/Infrastructure.EntityFramework/Repositories/UserRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ public UserRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper)
}
}

public async Task<Core.Entities.User?> GetByGatewaySubscriptionIdAsync(string gatewaySubscriptionId)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var entity = await GetDbSet(dbContext).FirstOrDefaultAsync(e => e.GatewaySubscriptionId == gatewaySubscriptionId);
return Mapper.Map<Core.Entities.User>(entity);
}
}

public async Task<IEnumerable<Core.Entities.User>> GetManyByEmailsAsync(IEnumerable<string> emails)
{
using (var scope = ServiceScopeFactory.CreateScope())
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
CREATE PROCEDURE [dbo].[Organization_ReadByGatewaySubscriptionId]
@GatewaySubscriptionId NVARCHAR(50)
AS
BEGIN
SET NOCOUNT ON;

SELECT
*
FROM
[dbo].[OrganizationView]
WHERE
[GatewaySubscriptionId] = @GatewaySubscriptionId
END
GO
Loading
Loading