Skip to content
This repository was archived by the owner on Sep 30, 2024. It is now read-only.

Commit f129d66

Browse files
authored
feat/enterpriseportal: implement SubscriptionLicenseChecksService (#64400)
Implements the RPC defined in https://github.com/sourcegraph/sourcegraph/pull/64396. Follow-up PRs will implement migrations for in-instance checks and for through-dotcom checks. One major change is that we now bypass the check for subscriptions that are denoted as associated with `INTERNAL` instances. Most of the diff is generated mocks. Part of https://linear.app/sourcegraph/issue/CORE-227 ## Test plan - [x] Unit tests - [x] E2E tests (`sg test enterprise-portal-e2e`) ![image](https://github.com/user-attachments/assets/56fde7dd-95a0-4d98-bb4c-943b1f155e33)
1 parent e4ee9b9 commit f129d66

File tree

18 files changed

+2036
-388
lines changed

18 files changed

+2036
-388
lines changed

client/web/src/enterprise/site-admin/dotcom/productSubscriptions/enterpriseportalgen/subscriptions_pb.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,12 +582,20 @@ export enum EnterpriseSubscriptionLicenseCondition_Status {
582582
* @generated from enum value: STATUS_REVOKED = 2;
583583
*/
584584
REVOKED = 2,
585+
586+
/**
587+
* License usage from a Sourcegraph instance was detected.
588+
*
589+
* @generated from enum value: STATUS_INSTANCE_USAGE_DETECTED = 3;
590+
*/
591+
INSTANCE_USAGE_DETECTED = 3,
585592
}
586593
// Retrieve enum metadata with: proto3.getEnumType(EnterpriseSubscriptionLicenseCondition_Status)
587594
proto3.util.setEnumType(EnterpriseSubscriptionLicenseCondition_Status, "enterpriseportal.subscriptions.v1.EnterpriseSubscriptionLicenseCondition.Status", [
588595
{ no: 0, name: "STATUS_UNSPECIFIED" },
589596
{ no: 1, name: "STATUS_CREATED" },
590597
{ no: 2, name: "STATUS_REVOKED" },
598+
{ no: 3, name: "STATUS_INSTANCE_USAGE_DETECTED" },
591599
]);
592600

593601
/**

cmd/enterprise-portal/e2e/BUILD.bazel

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@ go_test(
66
deps = [
77
"//internal/grpc/defaults",
88
"//internal/grpc/grpcoauth",
9+
"//internal/license",
910
"//lib/enterpriseportal/codyaccess/v1:codyaccess",
11+
"//lib/enterpriseportal/subscriptionlicensechecks/v1:subscriptionlicensechecks",
1012
"//lib/enterpriseportal/subscriptions/v1:subscriptions",
13+
"@com_github_hexops_autogold_v2//:autogold",
1114
"@com_github_sourcegraph_log//logtest",
1215
"@com_github_sourcegraph_sourcegraph_accounts_sdk_go//:sourcegraph-accounts-sdk-go",
1316
"@com_github_sourcegraph_sourcegraph_accounts_sdk_go//scopes",

cmd/enterprise-portal/e2e/e2e_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"google.golang.org/protobuf/types/known/fieldmaskpb"
1515
"google.golang.org/protobuf/types/known/timestamppb"
1616

17+
"github.com/hexops/autogold/v2"
1718
"github.com/sourcegraph/log/logtest"
1819
"github.com/stretchr/testify/assert"
1920
"github.com/stretchr/testify/require"
@@ -22,14 +23,17 @@ import (
2223
"github.com/sourcegraph/sourcegraph-accounts-sdk-go/scopes"
2324
"github.com/sourcegraph/sourcegraph/internal/grpc/defaults"
2425
"github.com/sourcegraph/sourcegraph/internal/grpc/grpcoauth"
26+
"github.com/sourcegraph/sourcegraph/internal/license"
2527

2628
codyaccessv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/codyaccess/v1"
29+
subscriptionlicensechecks "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptionlicensechecks/v1"
2730
subscriptionsv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1"
2831
)
2932

3033
type Clients struct {
3134
Subscriptions subscriptionsv1.SubscriptionsServiceClient
3235
CodyAccess codyaccessv1.CodyAccessServiceClient
36+
LicenseChecks subscriptionlicensechecks.SubscriptionLicenseChecksServiceClient
3337
}
3438

3539
func newE2EClients(t *testing.T) *Clients {
@@ -80,9 +84,15 @@ func newE2EClients(t *testing.T) *Clients {
8084
require.NoError(t, err)
8185
t.Cleanup(func() { _ = client.Close() })
8286

87+
clientWithoutCreds, err := grpc.NewClient("dns:///"+addr.Host,
88+
defaults.DialOptions(logtest.Scoped(t).Scoped("grpc"))...)
89+
require.NoError(t, err)
90+
t.Cleanup(func() { _ = clientWithoutCreds.Close() })
91+
8392
return &Clients{
8493
Subscriptions: subscriptionsv1.NewSubscriptionsServiceClient(client),
8594
CodyAccess: codyaccessv1.NewCodyAccessServiceClient(client),
95+
LicenseChecks: subscriptionlicensechecks.NewSubscriptionLicenseChecksServiceClient(clientWithoutCreds),
8696
}
8797
}
8898

@@ -160,6 +170,7 @@ func runLifecycleTest(t *testing.T, ctx context.Context, clients *Clients, runID
160170
})
161171

162172
var createdLicenseID string
173+
var createdLicenseKey string
163174
t.Run("Create license", func(t *testing.T) {
164175
got, err := clients.Subscriptions.CreateEnterpriseSubscriptionLicense(ctx, &subscriptionsv1.CreateEnterpriseSubscriptionLicenseRequest{
165176
License: &subscriptionsv1.EnterpriseSubscriptionLicense{
@@ -178,6 +189,7 @@ func runLifecycleTest(t *testing.T, ctx context.Context, clients *Clients, runID
178189
})
179190
require.NoError(t, err)
180191
createdLicenseID = got.GetLicense().GetId()
192+
createdLicenseKey = got.GetLicense().GetKey().GetLicenseKey()
181193
prettyPrint(t, got)
182194
})
183195

@@ -224,6 +236,34 @@ func runLifecycleTest(t *testing.T, ctx context.Context, clients *Clients, runID
224236
prettyPrint(t, got)
225237
})
226238

239+
t.Run("Check license", func(t *testing.T) {
240+
got, err := clients.LicenseChecks.CheckLicenseKey(ctx, &subscriptionlicensechecks.CheckLicenseKeyRequest{
241+
InstanceId: "test-instance-id",
242+
LicenseKey: createdLicenseKey,
243+
})
244+
require.NoError(t, err)
245+
assert.True(t, got.GetValid())
246+
247+
t.Run("back-compat with license key token", func(t *testing.T) {
248+
got, err := clients.LicenseChecks.CheckLicenseKey(ctx, &subscriptionlicensechecks.CheckLicenseKeyRequest{
249+
InstanceId: "test-instance-id",
250+
LicenseKey: license.GenerateLicenseKeyBasedAccessToken(createdLicenseKey),
251+
})
252+
require.NoError(t, err)
253+
assert.True(t, got.GetValid())
254+
})
255+
256+
t.Run("with wrong site ID", func(t *testing.T) {
257+
got, err := clients.LicenseChecks.CheckLicenseKey(ctx, &subscriptionlicensechecks.CheckLicenseKeyRequest{
258+
InstanceId: "wrong-instance-id",
259+
LicenseKey: createdLicenseKey,
260+
})
261+
require.NoError(t, err)
262+
assert.False(t, got.GetValid())
263+
autogold.Expect("license has already been used by another instance").Equal(t, got.GetReason())
264+
})
265+
})
266+
227267
t.Run("Revoke license", func(t *testing.T) {
228268
got, err := clients.Subscriptions.RevokeEnterpriseSubscriptionLicense(ctx, &subscriptionsv1.RevokeEnterpriseSubscriptionLicenseRequest{
229269
LicenseId: createdLicenseID,
@@ -242,6 +282,16 @@ func runLifecycleTest(t *testing.T, ctx context.Context, clients *Clients, runID
242282
})
243283
})
244284

285+
t.Run("Check revoked license", func(t *testing.T) {
286+
got, err := clients.LicenseChecks.CheckLicenseKey(ctx, &subscriptionlicensechecks.CheckLicenseKeyRequest{
287+
InstanceId: "test-instance-id",
288+
LicenseKey: createdLicenseKey,
289+
})
290+
require.NoError(t, err)
291+
assert.False(t, got.GetValid())
292+
autogold.Expect("license has been revoked").Equal(t, got.GetReason())
293+
})
294+
245295
t.Run("Archive subscription", func(t *testing.T) {
246296
got, err := clients.Subscriptions.ArchiveEnterpriseSubscription(ctx, &subscriptionsv1.ArchiveEnterpriseSubscriptionRequest{
247297
SubscriptionId: createdSubscriptionID,

cmd/enterprise-portal/internal/database/subscriptions/licenses.go

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ type SubscriptionLicense struct {
7777
// Value shapes correspond to API types appropriate for each
7878
// 'EnterpriseSubscriptionLicenseType'.
7979
LicenseData json.RawMessage `gorm:"type:jsonb"`
80+
81+
// DetectedInstanceID is the identifier of the Sourcegraph instance that has
82+
// been automatically detected via onlince license checks (subscriptionlicensechecks).
83+
// It should only be used internally for reference.
84+
DetectedInstanceID *string
8085
}
8186

8287
// subscriptionLicenseWithConditionsColumns must match scanSubscriptionLicense()
@@ -93,6 +98,8 @@ func subscriptionLicenseWithConditionsColumns() []string {
9398
"license_type",
9499
"license_data",
95100

101+
"detected_instance_id",
102+
96103
subscriptionLicenseConditionJSONBAgg(),
97104
}
98105
}
@@ -113,6 +120,7 @@ func scanSubscriptionLicenseWithConditions(row pgx.Row) (*LicenseWithConditions,
113120
&l.ExpireAt,
114121
&l.LicenseType,
115122
&l.LicenseData,
123+
&l.DetectedInstanceID,
116124
&l.Conditions, // see subscriptionLicenseConditionJSONBAgg docstring
117125
)
118126
return &l, err
@@ -132,10 +140,19 @@ func NewLicensesStore(db *pgxpool.Pool) *LicensesStore {
132140
}
133141

134142
type ListLicensesOpts struct {
135-
SubscriptionID string
136-
LicenseType subscriptionsv1.EnterpriseSubscriptionLicenseType
137-
LicenseKeySubstring string
143+
SubscriptionID string
144+
LicenseType subscriptionsv1.EnterpriseSubscriptionLicenseType
145+
// LicenseKey is an exact match on the signed key.
146+
LicenseKey string
147+
// LicenseKeySubstring is a substring match on the signed key.
148+
LicenseKeySubstring string
149+
138150
SalesforceOpportunityID string
151+
152+
// LicenseKeyHash should be removed once subscriptionlicensechecks no longer
153+
// supports the old key hash format
154+
LicenseKeyHash []byte
155+
139156
// PageSize is the maximum number of licenses to return.
140157
PageSize int
141158
}
@@ -155,6 +172,11 @@ func (opts ListLicensesOpts) toQueryConditions() (where, limitClause string, _ p
155172

156173
switch opts.LicenseType {
157174
case subscriptionsv1.EnterpriseSubscriptionLicenseType_ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY:
175+
if opts.LicenseKey != "" {
176+
whereConds = append(whereConds,
177+
"license_data->>'SignedKey' = @licenseKey")
178+
namedArgs["licenseKey"] = opts.LicenseKey
179+
}
158180
if opts.LicenseKeySubstring != "" {
159181
whereConds = append(whereConds,
160182
"license_data->>'SignedKey' LIKE '%' || @licenseKeySubstring || '%'")
@@ -165,6 +187,11 @@ func (opts ListLicensesOpts) toQueryConditions() (where, limitClause string, _ p
165187
"license_data->'Info'->>'sf_opp_id' = @salesforceOpportunityID")
166188
namedArgs["salesforceOpportunityID"] = opts.SalesforceOpportunityID
167189
}
190+
if opts.LicenseKeyHash != nil {
191+
whereConds = append(whereConds,
192+
"DIGEST(license_data->>'SignedKey','sha256') = @licenseKeyHash")
193+
namedArgs["licenseKeyHash"] = opts.LicenseKeyHash
194+
}
168195
}
169196

170197
where = strings.Join(whereConds, " AND ")
@@ -409,7 +436,7 @@ func (s *LicensesStore) Revoke(ctx context.Context, licenseID string, opts Revok
409436
return nil, errors.Wrap(err, "begin transaction")
410437
}
411438
defer func() {
412-
if rollbackErr := tx.Rollback(context.Background()); rollbackErr != nil {
439+
if rollbackErr := tx.Rollback(context.WithoutCancel(ctx)); rollbackErr != nil {
413440
err = errors.Append(err, rollbackErr)
414441
}
415442
}()
@@ -442,3 +469,59 @@ WHERE id = @licenseID
442469

443470
return s.Get(ctx, licenseID)
444471
}
472+
473+
type SetDetectedInstanceOpts struct {
474+
// InstanceID is the ID of the instance that was detected to be using this
475+
// license.
476+
InstanceID string
477+
// Message to associate with the detection event.
478+
Message string
479+
// If nil, the detection time will be set to the current time.
480+
Time *utctime.Time
481+
}
482+
483+
// SetDetectedInstance sets the instance ID that was detected to be using this
484+
// license.
485+
func (s *LicensesStore) SetDetectedInstance(ctx context.Context, licenseID string, opts SetDetectedInstanceOpts) error {
486+
if opts.Time == nil {
487+
opts.Time = pointers.Ptr(utctime.Now())
488+
}
489+
490+
tx, err := s.db.Begin(ctx)
491+
if err != nil {
492+
return errors.Wrap(err, "begin transaction")
493+
}
494+
defer func() {
495+
if rollbackErr := tx.Rollback(context.WithoutCancel(ctx)); rollbackErr != nil {
496+
err = errors.Append(err, rollbackErr)
497+
}
498+
}()
499+
500+
if _, err := tx.Exec(ctx, `
501+
UPDATE enterprise_portal_subscription_licenses
502+
SET detected_instance_id = @instanceID
503+
WHERE id = @licenseID
504+
`, pgx.NamedArgs{
505+
"instanceID": opts.InstanceID,
506+
"licenseID": licenseID,
507+
}); err != nil {
508+
if errors.Is(err, pgx.ErrNoRows) {
509+
return ErrSubscriptionLicenseNotFound
510+
}
511+
return errors.Wrap(err, "update detected instance for license")
512+
}
513+
514+
if err := newLicenseConditionsStore(tx).createLicenseCondition(ctx, licenseID, createLicenseConditionOpts{
515+
Status: subscriptionsv1.EnterpriseSubscriptionLicenseCondition_STATUS_INSTANCE_USAGE_DETECTED,
516+
Message: opts.Message,
517+
TransitionTime: *opts.Time,
518+
}); err != nil {
519+
return errors.Wrap(err, "create license condition")
520+
}
521+
522+
if err := tx.Commit(ctx); err != nil {
523+
return errors.Wrap(err, "commit transaction")
524+
}
525+
526+
return nil
527+
}

cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,20 @@ func TestLicensesStore(t *testing.T) {
242242
})
243243
})
244244

245+
t.Run("list by license key hash token", func(t *testing.T) {
246+
hash, err := license.ExtractLicenseKeyBasedAccessTokenContents(
247+
license.GenerateLicenseKeyBasedAccessToken(signedKeyExample),
248+
)
249+
require.NoError(t, err)
250+
listedLicenses, err := licenses.List(ctx, subscriptions.ListLicensesOpts{
251+
LicenseType: subscriptionsv1.EnterpriseSubscriptionLicenseType_ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY,
252+
LicenseKeyHash: []byte(hash),
253+
})
254+
require.NoError(t, err)
255+
require.Len(t, listedLicenses, 1)
256+
assert.Equal(t, subscriptionID1, listedLicenses[0].SubscriptionID)
257+
})
258+
245259
t.Run("List by salesforce opportunity ID", func(t *testing.T) {
246260
listedLicenses, err := licenses.List(ctx, subscriptions.ListLicensesOpts{
247261
LicenseType: subscriptionsv1.EnterpriseSubscriptionLicenseType_ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY,
@@ -270,20 +284,41 @@ func TestLicensesStore(t *testing.T) {
270284
}
271285
})
272286

287+
t.Run("SetDetectedInstance", func(t *testing.T) {
288+
require.NoError(t, licenses.SetDetectedInstance(ctx, createdLicenses[0].ID, subscriptions.SetDetectedInstanceOpts{
289+
InstanceID: "instance-id",
290+
Message: t.Name(),
291+
}))
292+
got, err := licenses.Get(ctx, createdLicenses[0].ID)
293+
require.NoError(t, err)
294+
assert.Equal(t, "instance-id", *got.DetectedInstanceID)
295+
})
296+
273297
t.Run("Revoke", func(t *testing.T) {
274298
for idx, license := range createdLicenses {
275-
revokeTime := utctime.FromTime(time.Now().Add(-time.Second))
299+
revokeTime := utctime.FromTime(time.Now())
276300
got, err := licenses.Revoke(ctx, license.ID, subscriptions.RevokeLicenseOpts{
277301
Message: fmt.Sprintf("%s %d", t.Name(), idx),
278302
Time: pointers.Ptr(revokeTime),
279303
})
280304
require.NoError(t, err)
281305
assert.Equal(t, revokeTime.AsTime(), got.RevokedAt.AsTime())
282-
require.Len(t, got.Conditions, 2)
283-
// Most recent condition is sorted first, and should be the revocation
284-
assert.Equal(t, "STATUS_REVOKED", got.Conditions[0].Status)
285-
assert.Equal(t, revokeTime.AsTime(), got.Conditions[0].TransitionTime.AsTime())
286-
assert.Equal(t, "STATUS_CREATED", got.Conditions[1].Status)
306+
if idx > 0 {
307+
require.Len(t, got.Conditions, 2)
308+
// Most recent condition is sorted first, and should be the revocation
309+
assert.Equal(t, "STATUS_REVOKED", got.Conditions[0].Status)
310+
assert.Equal(t, revokeTime.AsTime(), got.Conditions[0].TransitionTime.AsTime())
311+
assert.Equal(t, "STATUS_CREATED", got.Conditions[1].Status)
312+
} else {
313+
require.Len(t, got.Conditions, 3)
314+
// Most recent condition is sorted first, and should be the revocation
315+
assert.Equal(t, "STATUS_REVOKED", got.Conditions[0].Status)
316+
assert.Equal(t, revokeTime.AsTime(), got.Conditions[0].TransitionTime.AsTime())
317+
// Then, the condition from SetDetectedInstance test
318+
assert.Equal(t, "STATUS_INSTANCE_USAGE_DETECTED", got.Conditions[1].Status)
319+
// Finally, the subscription creation event
320+
assert.Equal(t, "STATUS_CREATED", got.Conditions[2].Status)
321+
}
287322
}
288323
})
289324
}

0 commit comments

Comments
 (0)