Compare commits
100 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2e60bde9b9 | |||
| 3db5558603 | |||
| 61961db41d | |||
| d2d7c0ee40 | |||
| d25d95231f | |||
| 3a62a8e70e | |||
| 7fc84ecf0b | |||
| 0ebe8e57ad | |||
| 3894edbcc3 | |||
| d5296a4855 | |||
| 5073493850 | |||
| 32354261d3 | |||
| 6683d807ac | |||
| 7c2479ce92 | |||
| e1156b050f | |||
| 0712faef4f | |||
| 7d5cd06f83 | |||
| 8d6a202ee4 | |||
| ffa83a4ebc | |||
| b3a81be1aa | |||
| 0c5809726d | |||
| 000bc334c9 | |||
| 8dd7d8b882 | |||
| 74b6d12a8a | |||
| 64e7a77983 | |||
| 7d558e76e9 | |||
| 40adf91cb0 | |||
| 49a42eff5c | |||
| 61ae5b81ab | |||
| cc2efe9e1f | |||
| 2b448c7178 | |||
| 2730e29105 | |||
| 150763720d | |||
| 8b995e3e06 | |||
| 2c2c67665f | |||
| 4e8e158ee4 | |||
| 6ca70d3618 | |||
| a581431bc8 | |||
| d5100543ea | |||
| 091d31224d | |||
| 1bfd776cb4 | |||
| a09d85cc26 | |||
| 60b3fd0783 | |||
| d2044c2ee9 | |||
| 89f4d60e7b | |||
| 4bc49ed6eb | |||
| 1e8c292855 | |||
| 960c892413 | |||
| ba499d84af | |||
| b116d22c5f | |||
| 1081d42760 | |||
| 8ea9f587e8 | |||
| bddb808b25 | |||
| b20d1bf159 | |||
| 0f446f99dd | |||
| 49b34a716a | |||
| d1b0722034 | |||
| 1a9a1106ca | |||
| 17ba151ed2 | |||
| 646e9cc6a9 | |||
| c77c0fce52 | |||
| 9a0024c45f | |||
| 6bd2d1c85f | |||
| c3e3249a2a | |||
| fa561bcd0a | |||
| 989def7a94 | |||
| 467c8bbd6b | |||
| ef45ce4dfb | |||
| 6a40fb0e2c | |||
| 2a7a33bb46 | |||
| ed6d41a5ef | |||
| 41a966c284 | |||
| f792f0b162 | |||
| 4a97df3768 | |||
| 5691d38db7 | |||
| 172cd13b24 | |||
| e10fceb23c | |||
| 55cc6b807c | |||
| 13668d82d6 | |||
| 32f3481634 | |||
| f524e00df7 | |||
| 874f3994b5 | |||
| 07924037e7 | |||
| 21241abc4e | |||
| e5377fbd93 | |||
| 9ac865b72f | |||
| 39bf9ed18a | |||
| 733b6b7db9 | |||
| a173c38715 | |||
| b522c9471a | |||
| ed1b9a9897 | |||
| 3517772e92 | |||
| b97572285a | |||
| 5655760f1d | |||
| 61c379dba6 | |||
| 4f40b78185 | |||
| 2ab17b1634 | |||
| 7a259ffd39 | |||
| d9e155113b | |||
| 1d530a3ab2 |
@@ -7,8 +7,6 @@ runs:
|
||||
- name: go install tools
|
||||
shell: bash
|
||||
run: |
|
||||
go install tool
|
||||
# NOTE: protoc-gen-go cannot be installed with `go get`
|
||||
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.30
|
||||
go install storj.io/drpc/cmd/protoc-gen-go-drpc@v0.0.34
|
||||
go install golang.org/x/tools/cmd/goimports@v0.31.0
|
||||
go install github.com/mikefarah/yq/v4@v4.44.3
|
||||
go install go.uber.org/mock/mockgen@v0.5.0
|
||||
|
||||
@@ -4,7 +4,7 @@ description: |
|
||||
inputs:
|
||||
version:
|
||||
description: "The Go version to use."
|
||||
default: "1.24.10"
|
||||
default: "1.24.11"
|
||||
use-preinstalled-go:
|
||||
description: "Whether to use preinstalled Go."
|
||||
default: "false"
|
||||
|
||||
@@ -71,6 +71,7 @@ runs:
|
||||
|
||||
if [[ ${RACE_DETECTION} == true ]]; then
|
||||
gotestsum --junitfile="gotests.xml" --packages="${TEST_PACKAGES}" -- \
|
||||
-tags=testsmallbatch \
|
||||
-race \
|
||||
-parallel "${TEST_NUM_PARALLEL_TESTS}" \
|
||||
-p "${TEST_NUM_PARALLEL_PACKAGES}"
|
||||
|
||||
@@ -1373,7 +1373,7 @@ jobs:
|
||||
id: attest_main
|
||||
if: github.ref == 'refs/heads/main'
|
||||
continue-on-error: true
|
||||
uses: actions/attest@daf44fb950173508f38bd2406030372c1d1162b1 # v3.0.0
|
||||
uses: actions/attest@7667f588f2f73a90cea6c7ac70e78266c4f76616 # v3.1.0
|
||||
with:
|
||||
subject-name: "ghcr.io/coder/coder-preview:main"
|
||||
predicate-type: "https://slsa.dev/provenance/v1"
|
||||
@@ -1410,7 +1410,7 @@ jobs:
|
||||
id: attest_latest
|
||||
if: github.ref == 'refs/heads/main'
|
||||
continue-on-error: true
|
||||
uses: actions/attest@daf44fb950173508f38bd2406030372c1d1162b1 # v3.0.0
|
||||
uses: actions/attest@7667f588f2f73a90cea6c7ac70e78266c4f76616 # v3.1.0
|
||||
with:
|
||||
subject-name: "ghcr.io/coder/coder-preview:latest"
|
||||
predicate-type: "https://slsa.dev/provenance/v1"
|
||||
@@ -1447,7 +1447,7 @@ jobs:
|
||||
id: attest_version
|
||||
if: github.ref == 'refs/heads/main'
|
||||
continue-on-error: true
|
||||
uses: actions/attest@daf44fb950173508f38bd2406030372c1d1162b1 # v3.0.0
|
||||
uses: actions/attest@7667f588f2f73a90cea6c7ac70e78266c4f76616 # v3.1.0
|
||||
with:
|
||||
subject-name: "ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}"
|
||||
predicate-type: "https://slsa.dev/provenance/v1"
|
||||
|
||||
@@ -23,7 +23,7 @@ jobs:
|
||||
steps:
|
||||
- name: Dependabot metadata
|
||||
id: metadata
|
||||
uses: dependabot/fetch-metadata@08eff52bf64351f401fb50d4972fa95b9f2c2d1b # v2.4.0
|
||||
uses: dependabot/fetch-metadata@21025c705c08248db411dc16f3619e6b5f9ea21a # v2.5.0
|
||||
with:
|
||||
github-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ jobs:
|
||||
# on version 2.29 and above.
|
||||
nix_version: "2.28.5"
|
||||
|
||||
- uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3
|
||||
- uses: nix-community/cache-nix-action@b426b118b6dc86d6952988d396aa7c6b09776d08 # v7.0.0
|
||||
with:
|
||||
# restore and save a cache using this key
|
||||
primary-key: nix-${{ runner.os }}-${{ hashFiles('**/*.nix', '**/flake.lock') }}
|
||||
@@ -78,7 +78,7 @@ jobs:
|
||||
uses: depot/setup-action@b0b1ea4f69e92ebf5dea3f8713a1b0c37b2126a5 # v1.6.0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0
|
||||
|
||||
- name: Login to DockerHub
|
||||
if: github.ref == 'refs/heads/main'
|
||||
|
||||
@@ -20,4 +20,4 @@ jobs:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Assign author
|
||||
uses: toshimaru/auto-author-assign@16f0022cf3d7970c106d8d1105f75a1165edb516 # v2.1.1
|
||||
uses: toshimaru/auto-author-assign@4d585cc37690897bd9015942ed6e766aa7cdb97f # v3.0.1
|
||||
|
||||
@@ -454,7 +454,7 @@ jobs:
|
||||
id: attest_base
|
||||
if: ${{ !inputs.dry_run && steps.image-base-tag.outputs.tag != '' }}
|
||||
continue-on-error: true
|
||||
uses: actions/attest@daf44fb950173508f38bd2406030372c1d1162b1 # v3.0.0
|
||||
uses: actions/attest@7667f588f2f73a90cea6c7ac70e78266c4f76616 # v3.1.0
|
||||
with:
|
||||
subject-name: ${{ steps.image-base-tag.outputs.tag }}
|
||||
predicate-type: "https://slsa.dev/provenance/v1"
|
||||
@@ -570,7 +570,7 @@ jobs:
|
||||
id: attest_main
|
||||
if: ${{ !inputs.dry_run }}
|
||||
continue-on-error: true
|
||||
uses: actions/attest@daf44fb950173508f38bd2406030372c1d1162b1 # v3.0.0
|
||||
uses: actions/attest@7667f588f2f73a90cea6c7ac70e78266c4f76616 # v3.1.0
|
||||
with:
|
||||
subject-name: ${{ steps.build_docker.outputs.multiarch_image }}
|
||||
predicate-type: "https://slsa.dev/provenance/v1"
|
||||
@@ -614,7 +614,7 @@ jobs:
|
||||
id: attest_latest
|
||||
if: ${{ !inputs.dry_run && steps.build_docker.outputs.created_latest_tag == 'true' }}
|
||||
continue-on-error: true
|
||||
uses: actions/attest@daf44fb950173508f38bd2406030372c1d1162b1 # v3.0.0
|
||||
uses: actions/attest@7667f588f2f73a90cea6c7ac70e78266c4f76616 # v3.1.0
|
||||
with:
|
||||
subject-name: ${{ steps.latest_tag.outputs.tag }}
|
||||
predicate-type: "https://slsa.dev/provenance/v1"
|
||||
|
||||
@@ -47,6 +47,6 @@ jobs:
|
||||
|
||||
# Upload the results to GitHub's code scanning dashboard.
|
||||
- name: "Upload to code-scanning"
|
||||
uses: github/codeql-action/upload-sarif@fe4161a26a8629af62121b670040955b330f9af2 # v3.29.5
|
||||
uses: github/codeql-action/upload-sarif@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v3.29.5
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
|
||||
@@ -40,7 +40,7 @@ jobs:
|
||||
uses: ./.github/actions/setup-go
|
||||
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@fe4161a26a8629af62121b670040955b330f9af2 # v3.29.5
|
||||
uses: github/codeql-action/init@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v3.29.5
|
||||
with:
|
||||
languages: go, javascript
|
||||
|
||||
@@ -50,7 +50,7 @@ jobs:
|
||||
rm Makefile
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@fe4161a26a8629af62121b670040955b330f9af2 # v3.29.5
|
||||
uses: github/codeql-action/analyze@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v3.29.5
|
||||
|
||||
- name: Send Slack notification on failure
|
||||
if: ${{ failure() }}
|
||||
@@ -154,7 +154,7 @@ jobs:
|
||||
severity: "CRITICAL,HIGH"
|
||||
|
||||
- name: Upload Trivy scan results to GitHub Security tab
|
||||
uses: github/codeql-action/upload-sarif@fe4161a26a8629af62121b670040955b330f9af2 # v3.29.5
|
||||
uses: github/codeql-action/upload-sarif@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v3.29.5
|
||||
with:
|
||||
sarif_file: trivy-results.sarif
|
||||
category: "Trivy"
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
.eslintcache
|
||||
.gitpod.yml
|
||||
.idea
|
||||
.run
|
||||
**/*.swp
|
||||
gotests.coverage
|
||||
gotests.xml
|
||||
|
||||
@@ -69,6 +69,9 @@ MOST_GO_SRC_FILES := $(shell \
|
||||
# All the shell files in the repo, excluding ignored files.
|
||||
SHELL_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.sh')
|
||||
|
||||
MIGRATION_FILES := $(shell find ./coderd/database/migrations/ -maxdepth 1 $(FIND_EXCLUSIONS) -type f -name '*.sql')
|
||||
FIXTURE_FILES := $(shell find ./coderd/database/migrations/testdata/fixtures/ $(FIND_EXCLUSIONS) -type f -name '*.sql')
|
||||
|
||||
# Ensure we don't use the user's git configs which might cause side-effects
|
||||
GIT_FLAGS = GIT_CONFIG_GLOBAL=/dev/null GIT_CONFIG_SYSTEM=/dev/null
|
||||
|
||||
@@ -464,7 +467,7 @@ ifdef FILE
|
||||
# Format single file
|
||||
if [[ -f "$(FILE)" ]] && [[ "$(FILE)" == *.go ]] && ! grep -q "DO NOT EDIT" "$(FILE)"; then \
|
||||
echo "$(GREEN)==>$(RESET) $(BOLD)fmt/go$(RESET) $(FILE)"; \
|
||||
go run mvdan.cc/gofumpt@v0.8.0 -w -l "$(FILE)"; \
|
||||
./scripts/format_go_file.sh "$(FILE)"; \
|
||||
fi
|
||||
else
|
||||
go mod tidy
|
||||
@@ -473,7 +476,7 @@ else
|
||||
# https://github.com/mvdan/gofumpt#visual-studio-code
|
||||
find . $(FIND_EXCLUSIONS) -type f -name '*.go' -print0 | \
|
||||
xargs -0 grep -E --null -L '^// Code generated .* DO NOT EDIT\.$$' | \
|
||||
xargs -0 go run mvdan.cc/gofumpt@v0.8.0 -w -l
|
||||
xargs -0 ./scripts/format_go_file.sh
|
||||
endif
|
||||
.PHONY: fmt/go
|
||||
|
||||
@@ -561,7 +564,7 @@ endif
|
||||
|
||||
# Note: we don't run zizmor in the lint target because it takes a while. CI
|
||||
# runs it explicitly.
|
||||
lint: lint/shellcheck lint/go lint/ts lint/examples lint/helm lint/site-icons lint/markdown lint/actions/actionlint lint/check-scopes
|
||||
lint: lint/shellcheck lint/go lint/ts lint/examples lint/helm lint/site-icons lint/markdown lint/actions/actionlint lint/check-scopes lint/migrations
|
||||
.PHONY: lint
|
||||
|
||||
lint/site-icons:
|
||||
@@ -578,7 +581,7 @@ lint/go:
|
||||
./scripts/check_codersdk_imports.sh
|
||||
linter_ver=$(shell egrep -o 'GOLANGCI_LINT_VERSION=\S+' dogfood/coder/Dockerfile | cut -d '=' -f 2)
|
||||
go run github.com/golangci/golangci-lint/cmd/golangci-lint@v$$linter_ver run
|
||||
go run github.com/coder/paralleltestctx/cmd/paralleltestctx@v0.0.1 -custom-funcs="testutil.Context" ./...
|
||||
go tool github.com/coder/paralleltestctx/cmd/paralleltestctx -custom-funcs="testutil.Context" ./...
|
||||
.PHONY: lint/go
|
||||
|
||||
lint/examples:
|
||||
@@ -604,7 +607,7 @@ lint/actions: lint/actions/actionlint lint/actions/zizmor
|
||||
.PHONY: lint/actions
|
||||
|
||||
lint/actions/actionlint:
|
||||
go run github.com/rhysd/actionlint/cmd/actionlint@v1.7.7
|
||||
go tool github.com/rhysd/actionlint/cmd/actionlint
|
||||
.PHONY: lint/actions/actionlint
|
||||
|
||||
lint/actions/zizmor:
|
||||
@@ -619,6 +622,12 @@ lint/check-scopes: coderd/database/dump.sql
|
||||
go run ./scripts/check-scopes
|
||||
.PHONY: lint/check-scopes
|
||||
|
||||
# Verify migrations do not hardcode the public schema.
|
||||
lint/migrations:
|
||||
./scripts/check_pg_schema.sh "Migrations" $(MIGRATION_FILES)
|
||||
./scripts/check_pg_schema.sh "Fixtures" $(FIXTURE_FILES)
|
||||
.PHONY: lint/migrations
|
||||
|
||||
# All files generated by the database should be added here, and this can be used
|
||||
# as a target for jobs that need to run after the database is generated.
|
||||
DB_GEN_FILES := \
|
||||
@@ -1018,7 +1027,8 @@ endif
|
||||
|
||||
# default to 8x8 parallelism to avoid overwhelming our workspaces. Hopefully we can remove these defaults
|
||||
# when we get our test suite's resource utilization under control.
|
||||
GOTEST_FLAGS := -v -p $(or $(TEST_NUM_PARALLEL_PACKAGES),"8") -parallel=$(or $(TEST_NUM_PARALLEL_TESTS),"8")
|
||||
# Use testsmallbatch tag to reduce wireguard memory allocation in tests (from ~18GB to negligible).
|
||||
GOTEST_FLAGS := -tags=testsmallbatch -v -p $(or $(TEST_NUM_PARALLEL_PACKAGES),"8") -parallel=$(or $(TEST_NUM_PARALLEL_TESTS),"8")
|
||||
|
||||
# The most common use is to set TEST_COUNT=1 to avoid Go's test cache.
|
||||
ifdef TEST_COUNT
|
||||
@@ -1033,6 +1043,14 @@ ifdef RUN
|
||||
GOTEST_FLAGS += -run $(RUN)
|
||||
endif
|
||||
|
||||
ifdef TEST_CPUPROFILE
|
||||
GOTEST_FLAGS += -cpuprofile=$(TEST_CPUPROFILE)
|
||||
endif
|
||||
|
||||
ifdef TEST_MEMPROFILE
|
||||
GOTEST_FLAGS += -memprofile=$(TEST_MEMPROFILE)
|
||||
endif
|
||||
|
||||
TEST_PACKAGES ?= ./...
|
||||
|
||||
test:
|
||||
@@ -1081,6 +1099,7 @@ test-postgres: test-postgres-docker
|
||||
--jsonfile="gotests.json" \
|
||||
$(GOTESTSUM_RETRY_FLAGS) \
|
||||
--packages="./..." -- \
|
||||
-tags=testsmallbatch \
|
||||
-timeout=20m \
|
||||
-count=1
|
||||
.PHONY: test-postgres
|
||||
@@ -1153,7 +1172,7 @@ test-postgres-docker:
|
||||
|
||||
# Make sure to keep this in sync with test-go-race from .github/workflows/ci.yaml.
|
||||
test-race:
|
||||
$(GIT_FLAGS) gotestsum --junitfile="gotests.xml" -- -race -count=1 -parallel 4 -p 4 ./...
|
||||
$(GIT_FLAGS) gotestsum --junitfile="gotests.xml" -- -tags=testsmallbatch -race -count=1 -parallel 4 -p 4 ./...
|
||||
.PHONY: test-race
|
||||
|
||||
test-tailnet-integration:
|
||||
@@ -1163,6 +1182,7 @@ test-tailnet-integration:
|
||||
TS_DEBUG_NETCHECK=true \
|
||||
GOTRACEBACK=single \
|
||||
go test \
|
||||
-tags=testsmallbatch \
|
||||
-exec "sudo -E" \
|
||||
-timeout=5m \
|
||||
-count=1 \
|
||||
|
||||
+43
-5
@@ -36,13 +36,14 @@ import (
|
||||
"tailscale.com/types/netlogtype"
|
||||
"tailscale.com/util/clientmetric"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/clistat"
|
||||
"github.com/coder/coder/v2/agent/agentcontainers"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentscripts"
|
||||
"github.com/coder/coder/v2/agent/agentsocket"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/agent/boundarylogproxy"
|
||||
"github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/agent/proto/resourcesmonitor"
|
||||
"github.com/coder/coder/v2/agent/reconnectingpty"
|
||||
@@ -102,6 +103,7 @@ type Options struct {
|
||||
Clock quartz.Clock
|
||||
SocketServerEnabled bool
|
||||
SocketPath string // Path for the agent socket server socket
|
||||
BoundaryLogProxySocketPath string
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
@@ -205,10 +207,11 @@ func New(options Options) Agent {
|
||||
metrics: newAgentMetrics(prometheusRegistry),
|
||||
execer: options.Execer,
|
||||
|
||||
devcontainers: options.Devcontainers,
|
||||
containerAPIOptions: options.DevcontainerAPIOptions,
|
||||
socketPath: options.SocketPath,
|
||||
socketServerEnabled: options.SocketServerEnabled,
|
||||
devcontainers: options.Devcontainers,
|
||||
containerAPIOptions: options.DevcontainerAPIOptions,
|
||||
socketPath: options.SocketPath,
|
||||
socketServerEnabled: options.SocketServerEnabled,
|
||||
boundaryLogProxySocketPath: options.BoundaryLogProxySocketPath,
|
||||
}
|
||||
// Initially, we have a closed channel, reflecting the fact that we are not initially connected.
|
||||
// Each time we connect we replace the channel (while holding the closeMutex) with a new one
|
||||
@@ -277,6 +280,11 @@ type agent struct {
|
||||
|
||||
logSender *agentsdk.LogSender
|
||||
|
||||
// boundaryLogProxy is a socket server that forwards boundary audit logs to coderd.
|
||||
// It may be nil if there is a problem starting the server.
|
||||
boundaryLogProxy *boundarylogproxy.Server
|
||||
boundaryLogProxySocketPath string
|
||||
|
||||
prometheusRegistry *prometheus.Registry
|
||||
// metrics are prometheus registered metrics that will be collected and
|
||||
// labeled in Coder with the agent + workspace.
|
||||
@@ -371,6 +379,7 @@ func (a *agent) init() {
|
||||
)
|
||||
|
||||
a.initSocketServer()
|
||||
a.startBoundaryLogProxyServer()
|
||||
|
||||
go a.runLoop()
|
||||
}
|
||||
@@ -395,6 +404,19 @@ func (a *agent) initSocketServer() {
|
||||
a.logger.Debug(a.hardCtx, "socket server started", slog.F("path", a.socketPath))
|
||||
}
|
||||
|
||||
// startBoundaryLogProxyServer starts the boundary log proxy socket server.
|
||||
func (a *agent) startBoundaryLogProxyServer() {
|
||||
proxy := boundarylogproxy.NewServer(a.logger, a.boundaryLogProxySocketPath)
|
||||
if err := proxy.Start(); err != nil {
|
||||
a.logger.Warn(a.hardCtx, "failed to start boundary log proxy", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
a.boundaryLogProxy = proxy
|
||||
a.logger.Info(a.hardCtx, "boundary log proxy server started",
|
||||
slog.F("socket_path", a.boundaryLogProxySocketPath))
|
||||
}
|
||||
|
||||
// runLoop attempts to start the agent in a retry loop.
|
||||
// Coder may be offline temporarily, a connection issue
|
||||
// may be happening, but regardless after the intermittent
|
||||
@@ -1012,6 +1034,15 @@ func (a *agent) run() (retErr error) {
|
||||
return err
|
||||
})
|
||||
|
||||
// Forward boundary audit logs to coderd if boundary log forwarding is enabled.
|
||||
// These are audit logs so they should continue during graceful shutdown.
|
||||
if a.boundaryLogProxy != nil {
|
||||
proxyFunc := func(ctx context.Context, aAPI proto.DRPCAgentClient27) error {
|
||||
return a.boundaryLogProxy.RunForwarder(ctx, aAPI)
|
||||
}
|
||||
connMan.startAgentAPI("boundary log proxy", gracefulShutdownBehaviorRemain, proxyFunc)
|
||||
}
|
||||
|
||||
// part of graceful shut down is reporting the final lifecycle states, e.g "ShuttingDown" so the
|
||||
// lifecycle reporting has to be via gracefulShutdownBehaviorRemain
|
||||
connMan.startAgentAPI("report lifecycle", gracefulShutdownBehaviorRemain, a.reportLifecycle)
|
||||
@@ -1982,6 +2013,13 @@ func (a *agent) Close() error {
|
||||
a.logger.Error(a.hardCtx, "container API close", slog.Error(err))
|
||||
}
|
||||
|
||||
if a.boundaryLogProxy != nil {
|
||||
err = a.boundaryLogProxy.Close()
|
||||
if err != nil {
|
||||
a.logger.Warn(context.Background(), "close boundary log proxy", slog.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for the graceful shutdown to complete, but don't wait forever so
|
||||
// that we don't break user expectations.
|
||||
go func() {
|
||||
|
||||
@@ -6,9 +6,8 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
+5
-7
@@ -25,10 +25,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/goleak"
|
||||
"tailscale.com/net/speedtest"
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"github.com/bramvdbogaerde/go-scp"
|
||||
"github.com/google/uuid"
|
||||
"github.com/ory/dockertest/v3"
|
||||
@@ -40,12 +36,14 @@ import (
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/net/speedtest"
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent"
|
||||
"github.com/coder/coder/v2/agent/agentcontainers"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
|
||||
@@ -26,7 +26,7 @@ import (
|
||||
"github.com/spf13/afero"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentcontainers/ignore"
|
||||
"github.com/coder/coder/v2/agent/agentcontainers/watcher"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
|
||||
@@ -27,9 +27,9 @@ import (
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/sloghuman"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agentcontainers"
|
||||
"github.com/coder/coder/v2/agent/agentcontainers/acmock"
|
||||
"github.com/coder/coder/v2/agent/agentcontainers/watcher"
|
||||
|
||||
+1
-2
@@ -10,11 +10,10 @@ package dcspec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
)
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
func UnmarshalDevContainer(data []byte) (DevContainer, error) {
|
||||
var r DevContainer
|
||||
err := json.Unmarshal(data, &r)
|
||||
|
||||
@@ -61,7 +61,7 @@ fi
|
||||
exec 3>&-
|
||||
|
||||
# Format the generated code.
|
||||
go run mvdan.cc/gofumpt@v0.8.0 -w -l "${TMPDIR}/${DEST_FILENAME}"
|
||||
"${PROJECT_ROOT}/scripts/format_go_file.sh" "${TMPDIR}/${DEST_FILENAME}"
|
||||
|
||||
# Add a header so that Go recognizes this as a generated file.
|
||||
if grep -q -- "\[-i extension\]" < <(sed -h 2>&1); then
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
@@ -21,8 +21,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agentcontainers"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/usershell"
|
||||
"github.com/coder/coder/v2/pty"
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/spf13/afero"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -7,8 +7,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
@@ -20,8 +20,7 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
)
|
||||
|
||||
func cmdSysProcAttr() *syscall.SysProcAttr {
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
)
|
||||
|
||||
func cmdSysProcAttr() *syscall.SysProcAttr {
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"storj.io/drpc/drpcmux"
|
||||
"storj.io/drpc/drpcserver"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentsocket/proto"
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
"github.com/coder/coder/v2/codersdk/drpcsdk"
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent"
|
||||
"github.com/coder/coder/v2/agent/agentsocket"
|
||||
"github.com/coder/coder/v2/agent/agenttest"
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentsocket/proto"
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentsocket"
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
|
||||
@@ -27,8 +27,7 @@ import (
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentcontainers"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentrsa"
|
||||
|
||||
@@ -24,9 +24,8 @@ import (
|
||||
"go.uber.org/goleak"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
)
|
||||
|
||||
func cmdSysProcAttr() *syscall.SysProcAttr {
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
)
|
||||
|
||||
func cmdSysProcAttr() *syscall.SysProcAttr {
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
)
|
||||
|
||||
// streamLocalForwardPayload describes the extra data sent in a
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"go.uber.org/atomic"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
)
|
||||
|
||||
// localForwardChannelData is copied from the ssh package.
|
||||
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
"storj.io/drpc/drpcserver"
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
|
||||
+1
-1
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/quartz"
|
||||
|
||||
@@ -0,0 +1,164 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package agent_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/boundarylogproxy"
|
||||
"github.com/coder/coder/v2/agent/boundarylogproxy/codec"
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/agentapi"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// logSink captures structured log entries for testing.
|
||||
type logSink struct {
|
||||
mu sync.Mutex
|
||||
entries []slog.SinkEntry
|
||||
}
|
||||
|
||||
func (s *logSink) LogEntry(_ context.Context, e slog.SinkEntry) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.entries = append(s.entries, e)
|
||||
}
|
||||
|
||||
func (*logSink) Sync() {}
|
||||
|
||||
func (s *logSink) getEntries() []slog.SinkEntry {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return append([]slog.SinkEntry{}, s.entries...)
|
||||
}
|
||||
|
||||
// getField returns the value of a field by name from a slog.Map.
|
||||
func getField(fields slog.Map, name string) interface{} {
|
||||
for _, f := range fields {
|
||||
if f.Name == name {
|
||||
return f.Value
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sendBoundaryLogsRequest(t *testing.T, conn net.Conn, req *agentproto.ReportBoundaryLogsRequest) {
|
||||
t.Helper()
|
||||
|
||||
data, err := proto.Marshal(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = codec.WriteFrame(conn, codec.TagV1, data)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestBoundaryLogs_EndToEnd is an end-to-end test that sends a protobuf
|
||||
// message over the agent's unix socket (as boundary would) and verifies
|
||||
// it is ultimately logged by coderd with the correct structured fields.
|
||||
func TestBoundaryLogs_EndToEnd(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
|
||||
|
||||
err := srv.Start()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, srv.Close()) })
|
||||
|
||||
sink := &logSink{}
|
||||
logger := slog.Make(sink)
|
||||
workspaceID := uuid.New()
|
||||
reporter := &agentapi.BoundaryLogsAPI{
|
||||
Log: logger,
|
||||
WorkspaceID: workspaceID,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
forwarderDone := make(chan error, 1)
|
||||
go func() {
|
||||
forwarderDone <- srv.RunForwarder(ctx, reporter)
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Allowed HTTP request.
|
||||
req := &agentproto.ReportBoundaryLogsRequest{
|
||||
Logs: []*agentproto.BoundaryLog{
|
||||
{
|
||||
Allowed: true,
|
||||
Time: timestamppb.Now(),
|
||||
Resource: &agentproto.BoundaryLog_HttpRequest_{
|
||||
HttpRequest: &agentproto.BoundaryLog_HttpRequest{
|
||||
Method: "GET",
|
||||
Url: "https://example.com/allowed",
|
||||
MatchedRule: "*.example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
sendBoundaryLogsRequest(t, conn, req)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return len(sink.getEntries()) >= 1
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
entries := sink.getEntries()
|
||||
require.Len(t, entries, 1)
|
||||
entry := entries[0]
|
||||
require.Equal(t, slog.LevelInfo, entry.Level)
|
||||
require.Equal(t, "boundary_request", entry.Message)
|
||||
require.Equal(t, "allow", getField(entry.Fields, "decision"))
|
||||
require.Equal(t, workspaceID.String(), getField(entry.Fields, "workspace_id"))
|
||||
require.Equal(t, "GET", getField(entry.Fields, "http_method"))
|
||||
require.Equal(t, "https://example.com/allowed", getField(entry.Fields, "http_url"))
|
||||
require.Equal(t, "*.example.com", getField(entry.Fields, "matched_rule"))
|
||||
|
||||
// Denied HTTP request.
|
||||
req2 := &agentproto.ReportBoundaryLogsRequest{
|
||||
Logs: []*agentproto.BoundaryLog{
|
||||
{
|
||||
Allowed: false,
|
||||
Time: timestamppb.Now(),
|
||||
Resource: &agentproto.BoundaryLog_HttpRequest_{
|
||||
HttpRequest: &agentproto.BoundaryLog_HttpRequest{
|
||||
Method: "POST",
|
||||
Url: "https://blocked.com/denied",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
sendBoundaryLogsRequest(t, conn, req2)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return len(sink.getEntries()) >= 2
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
entries = sink.getEntries()
|
||||
entry = entries[1]
|
||||
require.Len(t, entries, 2)
|
||||
require.Equal(t, slog.LevelInfo, entry.Level)
|
||||
require.Equal(t, "boundary_request", entry.Message)
|
||||
require.Equal(t, "deny", getField(entry.Fields, "decision"))
|
||||
require.Equal(t, workspaceID.String(), getField(entry.Fields, "workspace_id"))
|
||||
require.Equal(t, "POST", getField(entry.Fields, "http_method"))
|
||||
require.Equal(t, "https://blocked.com/denied", getField(entry.Fields, "http_url"))
|
||||
require.Equal(t, nil, getField(entry.Fields, "matched_rule"))
|
||||
|
||||
cancel()
|
||||
<-forwarderDone
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
// Package codec implements the wire format for agent <-> boundary communication.
|
||||
//
|
||||
// Wire Format:
|
||||
// - 8 bits: big-endian tag
|
||||
// - 24 bits: big-endian length of the protobuf data (bit usage depends on tag)
|
||||
// - length bytes: encoded protobuf data
|
||||
//
|
||||
// Note that while there are 24 bits available for the length, the actual maximum
|
||||
// length depends on the tag. For TagV1, only 15 bits are used (MaxMessageSizeV1).
|
||||
package codec
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
type Tag uint8
|
||||
|
||||
const (
|
||||
// TagV1 identifies the first revision of the protocol. This version has a maximum
|
||||
// data length of MaxMessageSizeV1.
|
||||
TagV1 Tag = 1
|
||||
)
|
||||
|
||||
const (
|
||||
// DataLength is the number of bits used for the length of encoded protobuf data.
|
||||
DataLength = 24
|
||||
|
||||
// tagLength is the number of bits used for the tag.
|
||||
tagLength = 8
|
||||
|
||||
// MaxMessageSizeV1 is the maximum size of the encoded protobuf messages sent
|
||||
// over the wire for the TagV1 tag. While the wire format allows 24 bits for
|
||||
// length, TagV1 only uses 15 bits.
|
||||
MaxMessageSizeV1 uint32 = 1 << 15
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrMessageTooLarge is returned when the message exceeds the maximum size
|
||||
// allowed for the tag.
|
||||
ErrMessageTooLarge = xerrors.New("message too large")
|
||||
// ErrUnsupportedTag is returned when an unrecognized tag is encountered.
|
||||
ErrUnsupportedTag = xerrors.New("unsupported tag")
|
||||
)
|
||||
|
||||
// WriteFrame writes a framed message with the given tag and data. The data
|
||||
// must not exceed 2^DataLength in length.
|
||||
func WriteFrame(w io.Writer, tag Tag, data []byte) error {
|
||||
var maxSize uint32
|
||||
switch tag {
|
||||
case TagV1:
|
||||
maxSize = MaxMessageSizeV1
|
||||
default:
|
||||
return xerrors.Errorf("%w: %d", ErrUnsupportedTag, tag)
|
||||
}
|
||||
|
||||
if len(data) > int(maxSize) {
|
||||
return xerrors.Errorf("%w for tag %d: %d > %d", ErrMessageTooLarge, tag, len(data), maxSize)
|
||||
}
|
||||
|
||||
var header uint32
|
||||
//nolint:gosec // The length check above ensures there's no overflow.
|
||||
header |= uint32(len(data))
|
||||
header |= uint32(tag) << DataLength
|
||||
|
||||
if err := binary.Write(w, binary.BigEndian, header); err != nil {
|
||||
return xerrors.Errorf("write header error: %w", err)
|
||||
}
|
||||
if _, err := w.Write(data); err != nil {
|
||||
return xerrors.Errorf("write data error: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadFrame reads a framed message, returning the decoded tag and data. If the
|
||||
// message size exceeds MaxMessageSizeV1, ErrMessageTooLarge is returned. The
|
||||
// provided buf is used if it has sufficient capacity; otherwise a new buffer is
|
||||
// allocated. To reuse the buffer across calls, pass in the returned data slice:
|
||||
//
|
||||
// buf := make([]byte, initialSize)
|
||||
// for {
|
||||
// _, buf, _ = ReadFrame(r, buf)
|
||||
// }
|
||||
func ReadFrame(r io.Reader, buf []byte) (Tag, []byte, error) {
|
||||
var header uint32
|
||||
if err := binary.Read(r, binary.BigEndian, &header); err != nil {
|
||||
return 0, nil, xerrors.Errorf("read header error: %w", err)
|
||||
}
|
||||
|
||||
const lengthMask = (1 << DataLength) - 1
|
||||
length := header & lengthMask
|
||||
const tagMask = (1 << tagLength) - 1 // 0xFF
|
||||
shifted := (header >> DataLength) & tagMask
|
||||
if shifted > tagMask {
|
||||
// This is really only here to satisfy the gosec linter. We know from above that
|
||||
// shifted <= tagMask.
|
||||
return 0, nil, xerrors.Errorf("invalid tag: %d", shifted)
|
||||
}
|
||||
tag := Tag(shifted)
|
||||
|
||||
var maxSize uint32
|
||||
switch tag {
|
||||
case TagV1:
|
||||
maxSize = MaxMessageSizeV1
|
||||
default:
|
||||
return 0, nil, xerrors.Errorf("%w: %d", ErrUnsupportedTag, tag)
|
||||
}
|
||||
|
||||
if length > maxSize {
|
||||
return 0, nil, ErrMessageTooLarge
|
||||
}
|
||||
|
||||
if cap(buf) < int(length) {
|
||||
buf = make([]byte, length)
|
||||
} else {
|
||||
buf = buf[:length:cap(buf)]
|
||||
}
|
||||
|
||||
if _, err := io.ReadFull(r, buf[:length]); err != nil {
|
||||
return 0, nil, xerrors.Errorf("read full error: %w", err)
|
||||
}
|
||||
|
||||
return tag, buf[:length], nil
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
package codec_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/boundarylogproxy/codec"
|
||||
)
|
||||
|
||||
func TestRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tag codec.Tag
|
||||
data []byte
|
||||
}{
|
||||
{
|
||||
name: "empty data",
|
||||
tag: codec.TagV1,
|
||||
data: []byte{},
|
||||
},
|
||||
{
|
||||
name: "simple data",
|
||||
tag: codec.TagV1,
|
||||
data: []byte("hello world"),
|
||||
},
|
||||
{
|
||||
name: "binary data",
|
||||
tag: codec.TagV1,
|
||||
data: []byte{0x00, 0x01, 0x02, 0xff, 0xfe},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := codec.WriteFrame(&buf, tt.tag, tt.data)
|
||||
require.NoError(t, err)
|
||||
|
||||
readBuf := make([]byte, codec.MaxMessageSizeV1)
|
||||
tag, data, err := codec.ReadFrame(&buf, readBuf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.tag, tag)
|
||||
require.Equal(t, tt.data, data)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFrameTooLarge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Hand construct a header that indicates the message size exceeds the maximum
|
||||
// message size for codec.TagV1 by one. We just write the header to buf because
|
||||
// we expect codec.ReadFrame to bail out when reading the invalid length.
|
||||
header := uint32(codec.TagV1)<<codec.DataLength | (codec.MaxMessageSizeV1 + 1)
|
||||
data := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(data, header)
|
||||
|
||||
var buf bytes.Buffer
|
||||
_, err := buf.Write(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
readBuf := make([]byte, 1)
|
||||
_, _, err = codec.ReadFrame(&buf, readBuf)
|
||||
require.ErrorIs(t, err, codec.ErrMessageTooLarge)
|
||||
}
|
||||
|
||||
func TestReadFrameEmptyReader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var buf bytes.Buffer
|
||||
readBuf := make([]byte, codec.MaxMessageSizeV1)
|
||||
_, _, err := codec.ReadFrame(&buf, readBuf)
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
}
|
||||
|
||||
func TestReadFrameInvalidTag(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Hand construct a header that indicates a tag we don't know about. We just
|
||||
// write the header to buf because we expect codec.ReadFrame to bail out when
|
||||
// reading the invalid tag.
|
||||
const (
|
||||
dataLength uint32 = 10
|
||||
bogusTag uint32 = 2
|
||||
)
|
||||
header := bogusTag<<codec.DataLength | dataLength
|
||||
data := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(data, header)
|
||||
|
||||
var buf bytes.Buffer
|
||||
_, err := buf.Write(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
readBuf := make([]byte, 1)
|
||||
_, _, err = codec.ReadFrame(&buf, readBuf)
|
||||
require.ErrorIs(t, err, codec.ErrUnsupportedTag)
|
||||
}
|
||||
|
||||
func TestReadFrameAllocatesWhenNeeded(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var buf bytes.Buffer
|
||||
data := []byte("this message is longer than the buffer")
|
||||
err := codec.WriteFrame(&buf, codec.TagV1, data)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Buffer with insufficient capacity triggers allocation.
|
||||
readBuf := make([]byte, 4)
|
||||
tag, got, err := codec.ReadFrame(&buf, readBuf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codec.TagV1, tag)
|
||||
require.Equal(t, data, got)
|
||||
}
|
||||
|
||||
func TestWriteFrameDataSize(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var buf bytes.Buffer
|
||||
data := make([]byte, codec.MaxMessageSizeV1)
|
||||
err := codec.WriteFrame(&buf, codec.TagV1, data)
|
||||
require.NoError(t, err)
|
||||
|
||||
//nolint: makezero // This intentionally increases the slice length.
|
||||
data = append(data, 0) // One byte over the maximum
|
||||
err = codec.WriteFrame(&buf, codec.TagV1, data)
|
||||
require.ErrorIs(t, err, codec.ErrMessageTooLarge)
|
||||
}
|
||||
|
||||
func TestWriteFrameInvalidTag(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var buf bytes.Buffer
|
||||
data := make([]byte, 1)
|
||||
const bogusTag = 2
|
||||
err := codec.WriteFrame(&buf, codec.Tag(bogusTag), data)
|
||||
require.ErrorIs(t, err, codec.ErrUnsupportedTag)
|
||||
}
|
||||
@@ -2,6 +2,204 @@
|
||||
// audit logs and forwards them to coderd via the agent API.
|
||||
package boundarylogproxy
|
||||
|
||||
// Server a placeholder for the server that will listen on a Unix socket for
|
||||
// boundary logs to be forwarded.
|
||||
type Server struct{}
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/boundarylogproxy/codec"
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
// logBufferSize is the size of the channel buffer for incoming log requests
|
||||
// from workspaces. This buffer size is intended to handle short bursts of workspaces
|
||||
// forwarding batches of logs in parallel.
|
||||
logBufferSize = 100
|
||||
)
|
||||
|
||||
// DefaultSocketPath returns the default path for the boundary audit log socket.
|
||||
func DefaultSocketPath() string {
|
||||
return filepath.Join(os.TempDir(), "boundary-audit.sock")
|
||||
}
|
||||
|
||||
// Reporter reports boundary logs from workspaces.
|
||||
type Reporter interface {
|
||||
ReportBoundaryLogs(ctx context.Context, req *agentproto.ReportBoundaryLogsRequest) (*agentproto.ReportBoundaryLogsResponse, error)
|
||||
}
|
||||
|
||||
// Server listens on a Unix socket for boundary log messages and buffers them
|
||||
// for forwarding to coderd. The socket server and the forwarder are decoupled:
|
||||
// - Start() creates the socket and accepts a connection from boundary
|
||||
// - RunForwarder() drains the buffer and sends logs to coderd via AgentAPI
|
||||
type Server struct {
|
||||
logger slog.Logger
|
||||
socketPath string
|
||||
|
||||
listener net.Listener
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
|
||||
// logs buffers incoming log requests for the forwarder to drain.
|
||||
logs chan *agentproto.ReportBoundaryLogsRequest
|
||||
}
|
||||
|
||||
// NewServer creates a new boundary log proxy server.
|
||||
func NewServer(logger slog.Logger, socketPath string) *Server {
|
||||
return &Server{
|
||||
logger: logger.Named("boundary-log-proxy"),
|
||||
socketPath: socketPath,
|
||||
logs: make(chan *agentproto.ReportBoundaryLogsRequest, logBufferSize),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins listening for connections on the Unix socket, and handles new
|
||||
// connections in a separate goroutine. Incoming logs from connections are
|
||||
// buffered until RunForwarder drains them.
|
||||
func (s *Server) Start() error {
|
||||
if err := os.Remove(s.socketPath); err != nil && !os.IsNotExist(err) {
|
||||
return xerrors.Errorf("remove existing socket: %w", err)
|
||||
}
|
||||
|
||||
listener, err := net.Listen("unix", s.socketPath)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("listen on socket: %w", err)
|
||||
}
|
||||
|
||||
s.listener = listener
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s.cancel = cancel
|
||||
|
||||
s.wg.Add(1)
|
||||
go s.acceptLoop(ctx)
|
||||
|
||||
s.logger.Info(ctx, "boundary log proxy started", slog.F("socket_path", s.socketPath))
|
||||
return nil
|
||||
}
|
||||
|
||||
// RunForwarder drains the log buffer and forwards logs to coderd.
|
||||
// It blocks until ctx is canceled.
|
||||
func (s *Server) RunForwarder(ctx context.Context, sender Reporter) error {
|
||||
s.logger.Debug(ctx, "boundary log forwarder started")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case req := <-s.logs:
|
||||
_, err := sender.ReportBoundaryLogs(ctx, req)
|
||||
if err != nil {
|
||||
s.logger.Warn(ctx, "failed to forward boundary logs",
|
||||
slog.Error(err),
|
||||
slog.F("log_count", len(req.Logs)))
|
||||
// Continue forwarding other logs. The current batch is lost,
|
||||
// but the socket stays alive.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) acceptLoop(ctx context.Context) {
|
||||
defer s.wg.Done()
|
||||
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
s.logger.Warn(ctx, "accept loop terminated", slog.Error(ctx.Err()))
|
||||
return
|
||||
}
|
||||
s.logger.Warn(ctx, "socket accept error", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go s.handleConnection(ctx, conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleConnection(ctx context.Context, conn net.Conn) {
|
||||
defer s.wg.Done()
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
<-ctx.Done()
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
// This is intended to be a sane starting point for the read buffer size. It may be
|
||||
// grown by codec.ReadFrame if necessary.
|
||||
const initBufSize = 1 << 10
|
||||
buf := make([]byte, initBufSize)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
var (
|
||||
tag codec.Tag
|
||||
err error
|
||||
)
|
||||
tag, buf, err = codec.ReadFrame(conn, buf)
|
||||
switch {
|
||||
case errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed):
|
||||
return
|
||||
case err != nil:
|
||||
s.logger.Warn(ctx, "read frame error", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
if tag != codec.TagV1 {
|
||||
s.logger.Warn(ctx, "invalid tag value", slog.F("tag", tag))
|
||||
return
|
||||
}
|
||||
|
||||
var req agentproto.ReportBoundaryLogsRequest
|
||||
if err := proto.Unmarshal(buf, &req); err != nil {
|
||||
s.logger.Warn(ctx, "proto unmarshal error", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case s.logs <- &req:
|
||||
default:
|
||||
s.logger.Warn(ctx, "dropping boundary logs, buffer full",
|
||||
slog.F("log_count", len(req.Logs)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the server and blocks until resources have been cleaned up.
|
||||
// It must be called after Start.
|
||||
func (s *Server) Close() error {
|
||||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
if s.listener != nil {
|
||||
_ = s.listener.Close()
|
||||
}
|
||||
|
||||
s.wg.Wait()
|
||||
|
||||
err := os.Remove(s.socketPath)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,578 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package boundarylogproxy_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/coder/coder/v2/agent/boundarylogproxy"
|
||||
"github.com/coder/coder/v2/agent/boundarylogproxy/codec"
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// sendMessage writes a framed protobuf message to the connection.
|
||||
func sendMessage(t *testing.T, conn net.Conn, req *agentproto.ReportBoundaryLogsRequest) {
|
||||
t.Helper()
|
||||
|
||||
data, err := proto.Marshal(req)
|
||||
if err != nil {
|
||||
//nolint:gocritic // In tests we're not worried about conn being nil.
|
||||
t.Errorf("%s marshal req: %s", conn.LocalAddr().String(), err)
|
||||
}
|
||||
|
||||
err = codec.WriteFrame(conn, codec.TagV1, data)
|
||||
if err != nil {
|
||||
//nolint:gocritic // In tests we're not worried about conn being nil.
|
||||
t.Errorf("%s write frame: %s", conn.LocalAddr().String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
// fakeReporter implements boundarylogproxy.Reporter for testing.
|
||||
type fakeReporter struct {
|
||||
mu sync.Mutex
|
||||
logs []*agentproto.ReportBoundaryLogsRequest
|
||||
err error
|
||||
errOnce bool // only error once, then succeed
|
||||
|
||||
// reportCb is called when a ReportBoundaryLogsRequest is processed. It must not
|
||||
// block.
|
||||
reportCb func()
|
||||
}
|
||||
|
||||
func (f *fakeReporter) ReportBoundaryLogs(_ context.Context, req *agentproto.ReportBoundaryLogsRequest) (*agentproto.ReportBoundaryLogsResponse, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
if f.reportCb != nil {
|
||||
f.reportCb()
|
||||
}
|
||||
|
||||
if f.err != nil {
|
||||
if f.errOnce {
|
||||
err := f.err
|
||||
f.err = nil
|
||||
return nil, err
|
||||
}
|
||||
return nil, f.err
|
||||
}
|
||||
f.logs = append(f.logs, req)
|
||||
return &agentproto.ReportBoundaryLogsResponse{}, nil
|
||||
}
|
||||
|
||||
func (f *fakeReporter) getLogs() []*agentproto.ReportBoundaryLogsRequest {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return append([]*agentproto.ReportBoundaryLogsRequest{}, f.logs...)
|
||||
}
|
||||
|
||||
func TestServer_StartAndClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
|
||||
|
||||
err := srv.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify socket exists and is connectable.
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = srv.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestServer_ReceiveAndForwardLogs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
err := srv.Start()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, srv.Close()) })
|
||||
|
||||
reporter := &fakeReporter{}
|
||||
|
||||
// Start forwarder in background.
|
||||
forwarderDone := make(chan error, 1)
|
||||
go func() {
|
||||
forwarderDone <- srv.RunForwarder(ctx, reporter)
|
||||
}()
|
||||
|
||||
// Connect and send a log message.
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
req := &agentproto.ReportBoundaryLogsRequest{
|
||||
Logs: []*agentproto.BoundaryLog{
|
||||
{
|
||||
Allowed: true,
|
||||
Time: timestamppb.Now(),
|
||||
Resource: &agentproto.BoundaryLog_HttpRequest_{
|
||||
HttpRequest: &agentproto.BoundaryLog_HttpRequest{
|
||||
Method: "GET",
|
||||
Url: "https://example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
sendMessage(t, conn, req)
|
||||
|
||||
// Wait for the reporter to receive the log.
|
||||
require.Eventually(t, func() bool {
|
||||
logs := reporter.getLogs()
|
||||
return len(logs) == 1
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
logs := reporter.getLogs()
|
||||
require.Len(t, logs, 1)
|
||||
require.Len(t, logs[0].Logs, 1)
|
||||
require.True(t, logs[0].Logs[0].Allowed)
|
||||
require.Equal(t, "GET", logs[0].Logs[0].GetHttpRequest().Method)
|
||||
require.Equal(t, "https://example.com", logs[0].Logs[0].GetHttpRequest().Url)
|
||||
|
||||
cancel()
|
||||
<-forwarderDone
|
||||
}
|
||||
|
||||
func TestServer_MultipleMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
err := srv.Start()
|
||||
require.NoError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
reporter := &fakeReporter{}
|
||||
|
||||
forwarderDone := make(chan error, 1)
|
||||
go func() {
|
||||
forwarderDone <- srv.RunForwarder(ctx, reporter)
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Send multiple messages and verify they are all received.
|
||||
for range 5 {
|
||||
req := &agentproto.ReportBoundaryLogsRequest{
|
||||
Logs: []*agentproto.BoundaryLog{
|
||||
{
|
||||
Allowed: true,
|
||||
Time: timestamppb.Now(),
|
||||
Resource: &agentproto.BoundaryLog_HttpRequest_{
|
||||
HttpRequest: &agentproto.BoundaryLog_HttpRequest{
|
||||
Method: "POST",
|
||||
Url: "https://example.com/api",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
sendMessage(t, conn, req)
|
||||
}
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
logs := reporter.getLogs()
|
||||
return len(logs) == 5
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
cancel()
|
||||
<-forwarderDone
|
||||
}
|
||||
|
||||
func TestServer_MultipleConnections(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
err := srv.Start()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, srv.Close()) })
|
||||
|
||||
reporter := &fakeReporter{}
|
||||
|
||||
forwarderDone := make(chan error, 1)
|
||||
go func() {
|
||||
forwarderDone <- srv.RunForwarder(ctx, reporter)
|
||||
}()
|
||||
|
||||
// Create multiple connections and send from each.
|
||||
const numConns = 3
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numConns)
|
||||
for i := range numConns {
|
||||
go func(connID int) {
|
||||
defer wg.Done()
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
if err != nil {
|
||||
t.Errorf("conn %d dial: %s", connID, err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
req := &agentproto.ReportBoundaryLogsRequest{
|
||||
Logs: []*agentproto.BoundaryLog{
|
||||
{
|
||||
Allowed: true,
|
||||
Time: timestamppb.Now(),
|
||||
Resource: &agentproto.BoundaryLog_HttpRequest_{
|
||||
HttpRequest: &agentproto.BoundaryLog_HttpRequest{
|
||||
Method: "GET",
|
||||
Url: "https://example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
sendMessage(t, conn, req)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
logs := reporter.getLogs()
|
||||
return len(logs) == numConns
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
cancel()
|
||||
<-forwarderDone
|
||||
}
|
||||
|
||||
func TestServer_MessageTooLarge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
|
||||
|
||||
err := srv.Start()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, srv.Close()) })
|
||||
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Send a message claiming to be larger than the max message size.
|
||||
var length uint32 = codec.MaxMessageSizeV1 + 1
|
||||
err = binary.Write(conn, binary.BigEndian, length)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The server should close the connection after receiving an oversized
|
||||
// message length.
|
||||
buf := make([]byte, 1)
|
||||
err = conn.SetReadDeadline(time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
_, err = conn.Read(buf)
|
||||
require.Error(t, err) // Should get EOF or closed connection.
|
||||
}
|
||||
|
||||
func TestServer_ForwarderContinuesAfterError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
|
||||
|
||||
err := srv.Start()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, srv.Close()) })
|
||||
|
||||
reportNotify := make(chan struct{}, 1)
|
||||
reporter := &fakeReporter{
|
||||
// Simulate an error on the first call.
|
||||
err: context.DeadlineExceeded,
|
||||
errOnce: true,
|
||||
reportCb: func() {
|
||||
reportNotify <- struct{}{}
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
forwarderDone := make(chan error, 1)
|
||||
go func() {
|
||||
forwarderDone <- srv.RunForwarder(ctx, reporter)
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Send the first message to be processed and wait for failure.
|
||||
req1 := &agentproto.ReportBoundaryLogsRequest{
|
||||
Logs: []*agentproto.BoundaryLog{
|
||||
{
|
||||
Allowed: true,
|
||||
Time: timestamppb.Now(),
|
||||
Resource: &agentproto.BoundaryLog_HttpRequest_{
|
||||
HttpRequest: &agentproto.BoundaryLog_HttpRequest{
|
||||
Method: "GET",
|
||||
Url: "https://example.com/first",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
sendMessage(t, conn, req1)
|
||||
|
||||
select {
|
||||
case <-reportNotify:
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("timed out waiting for first message to be processed")
|
||||
}
|
||||
|
||||
// Send the second message, which should succeed.
|
||||
req2 := &agentproto.ReportBoundaryLogsRequest{
|
||||
Logs: []*agentproto.BoundaryLog{
|
||||
{
|
||||
Allowed: false,
|
||||
Time: timestamppb.Now(),
|
||||
Resource: &agentproto.BoundaryLog_HttpRequest_{
|
||||
HttpRequest: &agentproto.BoundaryLog_HttpRequest{
|
||||
Method: "POST",
|
||||
Url: "https://example.com/second",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
sendMessage(t, conn, req2)
|
||||
|
||||
// Only the second message should be recorded.
|
||||
require.Eventually(t, func() bool {
|
||||
logs := reporter.getLogs()
|
||||
return len(logs) == 1
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
logs := reporter.getLogs()
|
||||
require.Len(t, logs, 1)
|
||||
require.Equal(t, "https://example.com/second", logs[0].Logs[0].GetHttpRequest().Url)
|
||||
|
||||
cancel()
|
||||
<-forwarderDone
|
||||
}
|
||||
|
||||
func TestServer_CloseStopsForwarder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
|
||||
|
||||
err := srv.Start()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, srv.Close()) })
|
||||
|
||||
reporter := &fakeReporter{}
|
||||
|
||||
forwarderCtx, forwarderCancel := context.WithCancel(context.Background())
|
||||
forwarderDone := make(chan error, 1)
|
||||
go func() {
|
||||
forwarderDone <- srv.RunForwarder(forwarderCtx, reporter)
|
||||
}()
|
||||
|
||||
// Cancel the forwarder context and verify it stops.
|
||||
forwarderCancel()
|
||||
|
||||
select {
|
||||
case err := <-forwarderDone:
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("forwarder did not stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_InvalidProtobuf(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
|
||||
|
||||
err := srv.Start()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, srv.Close()) })
|
||||
|
||||
reporter := &fakeReporter{}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
forwarderDone := make(chan error, 1)
|
||||
go func() {
|
||||
forwarderDone <- srv.RunForwarder(ctx, reporter)
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Send a valid header with garbage protobuf data.
|
||||
// The server should log an unmarshal error but continue processing.
|
||||
invalidProto := []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF}
|
||||
//nolint: gosec // codec.DataLength is always less than the size of the header.
|
||||
header := (uint32(codec.TagV1) << codec.DataLength) | uint32(len(invalidProto))
|
||||
err = binary.Write(conn, binary.BigEndian, header)
|
||||
require.NoError(t, err)
|
||||
_, err = conn.Write(invalidProto)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now send a valid message. The server should continue processing.
|
||||
req := &agentproto.ReportBoundaryLogsRequest{
|
||||
Logs: []*agentproto.BoundaryLog{
|
||||
{
|
||||
Allowed: true,
|
||||
Time: timestamppb.Now(),
|
||||
Resource: &agentproto.BoundaryLog_HttpRequest_{
|
||||
HttpRequest: &agentproto.BoundaryLog_HttpRequest{
|
||||
Method: "GET",
|
||||
Url: "https://example.com/valid",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
sendMessage(t, conn, req)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
logs := reporter.getLogs()
|
||||
return len(logs) == 1
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
cancel()
|
||||
<-forwarderDone
|
||||
}
|
||||
|
||||
func TestServer_InvalidHeader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
|
||||
|
||||
err := srv.Start()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, srv.Close()) })
|
||||
|
||||
reporter := &fakeReporter{}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
forwarderDone := make(chan error, 1)
|
||||
go func() {
|
||||
forwarderDone <- srv.RunForwarder(ctx, reporter)
|
||||
}()
|
||||
|
||||
// sendInvalidHeader sends a header and verifies the server closes the
|
||||
// connection.
|
||||
sendInvalidHeader := func(t *testing.T, name string, header uint32) {
|
||||
t.Helper()
|
||||
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
err = binary.Write(conn, binary.BigEndian, header)
|
||||
require.NoError(t, err, name)
|
||||
|
||||
// The server closes the connection on invalid header, so the next
|
||||
// write should fail with a broken pipe error.
|
||||
require.Eventually(t, func() bool {
|
||||
_, err := conn.Write([]byte{0x00})
|
||||
return err != nil
|
||||
}, testutil.WaitShort, testutil.IntervalFast, name)
|
||||
}
|
||||
|
||||
// TagV1 with length exceeding MaxMessageSizeV1.
|
||||
sendInvalidHeader(t, "v1 too large", (uint32(codec.TagV1)<<codec.DataLength)|(codec.MaxMessageSizeV1+1))
|
||||
|
||||
// Unknown tag.
|
||||
const bogusTag = 0xFF
|
||||
sendInvalidHeader(t, "unknown tag too large", (bogusTag<<codec.DataLength)|(codec.MaxMessageSizeV1+1))
|
||||
|
||||
cancel()
|
||||
<-forwarderDone
|
||||
}
|
||||
|
||||
func TestServer_AllowRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
|
||||
|
||||
err := srv.Start()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, srv.Close()) })
|
||||
|
||||
reporter := &fakeReporter{}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
forwarderDone := make(chan error, 1)
|
||||
go func() {
|
||||
forwarderDone <- srv.RunForwarder(ctx, reporter)
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Send an allowed request with a matched rule.
|
||||
logTime := timestamppb.Now()
|
||||
req := &agentproto.ReportBoundaryLogsRequest{
|
||||
Logs: []*agentproto.BoundaryLog{
|
||||
{
|
||||
Allowed: true,
|
||||
Time: logTime,
|
||||
Resource: &agentproto.BoundaryLog_HttpRequest_{
|
||||
HttpRequest: &agentproto.BoundaryLog_HttpRequest{
|
||||
Method: "GET",
|
||||
Url: "https://malicious.com/attack",
|
||||
MatchedRule: "*.malicious.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
sendMessage(t, conn, req)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
logs := reporter.getLogs()
|
||||
return len(logs) == 1
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
logs := reporter.getLogs()
|
||||
require.Len(t, logs, 1)
|
||||
require.True(t, logs[0].Logs[0].Allowed)
|
||||
require.Equal(t, logTime.Seconds, logs[0].Logs[0].Time.Seconds)
|
||||
require.Equal(t, logTime.Nanos, logs[0].Logs[0].Time.Nanos)
|
||||
require.Equal(t, "*.malicious.com", logs[0].Logs[0].GetHttpRequest().MatchedRule)
|
||||
|
||||
cancel()
|
||||
<-forwarderDone
|
||||
}
|
||||
+1
-1
@@ -5,7 +5,7 @@ import (
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
)
|
||||
|
||||
// checkpoint allows a goroutine to communicate when it is OK to proceed beyond some async condition
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
|
||||
+1
-1
@@ -17,7 +17,7 @@ import (
|
||||
"golang.org/x/text/transform"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
|
||||
+1
-1
@@ -9,7 +9,7 @@ import (
|
||||
prompb "github.com/prometheus/client_model/go"
|
||||
"tailscale.com/util/clientmetric"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/proto"
|
||||
)
|
||||
|
||||
|
||||
@@ -4538,8 +4538,9 @@ type BoundaryLog_HttpRequest struct {
|
||||
|
||||
Method string `protobuf:"bytes,1,opt,name=method,proto3" json:"method,omitempty"`
|
||||
Url string `protobuf:"bytes,2,opt,name=url,proto3" json:"url,omitempty"`
|
||||
// The rule that resulted in this HTTP request not being allowed.
|
||||
// Only populated when allowed = false.
|
||||
// The rule that resulted in this HTTP request being allowed. Only populated
|
||||
// when allowed = true because boundary denies requests by default and
|
||||
// requires rule(s) that allow requests.
|
||||
MatchedRule string `protobuf:"bytes,3,opt,name=matched_rule,json=matchedRule,proto3" json:"matched_rule,omitempty"`
|
||||
}
|
||||
|
||||
|
||||
@@ -466,8 +466,9 @@ message BoundaryLog {
|
||||
message HttpRequest {
|
||||
string method = 1;
|
||||
string url = 2;
|
||||
// The rule that resulted in this HTTP request not being allowed.
|
||||
// Only populated when allowed = false.
|
||||
// The rule that resulted in this HTTP request being allowed. Only populated
|
||||
// when allowed = true because boundary denies requests by default and
|
||||
// requires rule(s) that allow requests.
|
||||
string matched_rule = 3;
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/agent/proto/resourcesmonitor"
|
||||
"github.com/coder/quartz"
|
||||
|
||||
@@ -12,8 +12,7 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/pty"
|
||||
)
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/pty"
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/pty"
|
||||
)
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentcontainers"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/agent/usershell"
|
||||
|
||||
+1
-1
@@ -9,7 +9,7 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/types/netlogtype"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/proto"
|
||||
)
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
"tailscale.com/types/ipproto"
|
||||
|
||||
"tailscale.com/types/netlogtype"
|
||||
|
||||
"github.com/coder/coder/v2/agent/proto"
|
||||
|
||||
+18
-10
@@ -16,26 +16,25 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/xerrors"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
"cdr.dev/slog/sloggers/slogjson"
|
||||
"cdr.dev/slog/sloggers/slogstackdriver"
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/sloghuman"
|
||||
"cdr.dev/slog/v3/sloggers/slogjson"
|
||||
"cdr.dev/slog/v3/sloggers/slogstackdriver"
|
||||
"github.com/coder/coder/v2/agent"
|
||||
"github.com/coder/coder/v2/agent/agentcontainers"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/agent/boundarylogproxy"
|
||||
"github.com/coder/coder/v2/agent/reaper"
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/cli/clilog"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func workspaceAgent() *serpent.Command {
|
||||
@@ -59,6 +58,7 @@ func workspaceAgent() *serpent.Command {
|
||||
devcontainerDiscoveryAutostart bool
|
||||
socketServerEnabled bool
|
||||
socketPath string
|
||||
boundaryLogProxySocketPath string
|
||||
)
|
||||
agentAuth := &AgentAuth{}
|
||||
cmd := &serpent.Command{
|
||||
@@ -319,8 +319,9 @@ func workspaceAgent() *serpent.Command {
|
||||
agentcontainers.WithProjectDiscovery(devcontainerProjectDiscovery),
|
||||
agentcontainers.WithDiscoveryAutostart(devcontainerDiscoveryAutostart),
|
||||
},
|
||||
SocketPath: socketPath,
|
||||
SocketServerEnabled: socketServerEnabled,
|
||||
SocketPath: socketPath,
|
||||
SocketServerEnabled: socketServerEnabled,
|
||||
BoundaryLogProxySocketPath: boundaryLogProxySocketPath,
|
||||
})
|
||||
|
||||
if debugAddress != "" {
|
||||
@@ -494,6 +495,13 @@ func workspaceAgent() *serpent.Command {
|
||||
Description: "Specify the path for the agent socket.",
|
||||
Value: serpent.StringOf(&socketPath),
|
||||
},
|
||||
{
|
||||
Flag: "boundary-log-proxy-socket-path",
|
||||
Default: boundarylogproxy.DefaultSocketPath(),
|
||||
Env: "CODER_AGENT_BOUNDARY_LOG_PROXY_SOCKET_PATH",
|
||||
Description: "The path for the boundary log proxy server Unix socket. Boundary should write audit logs to this socket.",
|
||||
Value: serpent.StringOf(&boundaryLogProxySocketPath),
|
||||
},
|
||||
}
|
||||
agentAuth.AttachOptions(cmd, false)
|
||||
return cmd
|
||||
|
||||
@@ -11,10 +11,10 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
"cdr.dev/slog/sloggers/slogjson"
|
||||
"cdr.dev/slog/sloggers/slogstackdriver"
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/sloghuman"
|
||||
"cdr.dev/slog/v3/sloggers/slogjson"
|
||||
"cdr.dev/slog/v3/sloggers/slogstackdriver"
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/serpent"
|
||||
|
||||
@@ -7,13 +7,13 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clilog"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuilder(t *testing.T) {
|
||||
|
||||
@@ -17,8 +17,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/cli"
|
||||
"github.com/coder/coder/v2/cli/config"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
|
||||
+11
-1
@@ -138,6 +138,17 @@ func normalizeGoldenFile(t *testing.T, byt []byte) []byte {
|
||||
|
||||
// The home directory changes depending on the test environment.
|
||||
byt = bytes.ReplaceAll(byt, []byte(homeDir), []byte("~"))
|
||||
|
||||
// Normalize the temp directory. os.TempDir() may include a trailing slash
|
||||
// (macOS) or not (Linux/Windows), and the temp directory may be followed by
|
||||
// more filepath elements with an OS-specific separator. We handle all cases
|
||||
// by replacing tempdir+separator first, then tempdir alone.
|
||||
tempDir := filepath.Clean(os.TempDir())
|
||||
byt = bytes.ReplaceAll(byt, []byte(tempDir+string(filepath.Separator)), []byte("/tmp/"))
|
||||
byt = bytes.ReplaceAll(byt, []byte(tempDir), []byte("/tmp"))
|
||||
// Clean up trailing slash when temp dir is used standalone (e.g., "/tmp/)" -> "/tmp)").
|
||||
byt = bytes.ReplaceAll(byt, []byte("/tmp/)"), []byte("/tmp)"))
|
||||
|
||||
for _, r := range []struct {
|
||||
old string
|
||||
new string
|
||||
@@ -145,7 +156,6 @@ func normalizeGoldenFile(t *testing.T, byt []byte) []byte {
|
||||
{"\r\n", "\n"},
|
||||
{`~\.cache\coder`, "~/.cache/coder"},
|
||||
{`C:\Users\RUNNER~1\AppData\Local\Temp`, "/tmp"},
|
||||
{os.TempDir(), "/tmp"},
|
||||
} {
|
||||
byt = bytes.ReplaceAll(byt, []byte(r.old), []byte(r.new))
|
||||
}
|
||||
|
||||
@@ -13,12 +13,11 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
|
||||
+1
-2
@@ -22,10 +22,9 @@ import (
|
||||
"golang.org/x/exp/constraints"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
+1
-2
@@ -1,9 +1,8 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) connectCmd() *serpent.Command {
|
||||
|
||||
+1
-2
@@ -9,11 +9,10 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/net/tsaddr"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/cli"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func TestConnectExists_Running(t *testing.T) {
|
||||
|
||||
+1
-2
@@ -12,13 +12,12 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/pretty"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/cli/cliutil"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/pretty"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
|
||||
+4
-6
@@ -10,23 +10,21 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/quartz"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestDelete(t *testing.T) {
|
||||
|
||||
+1
-2
@@ -13,9 +13,8 @@ import (
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/pretty"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/pretty"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/ory/dockertest/v3"
|
||||
"github.com/ory/dockertest/v3/docker"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent"
|
||||
"github.com/coder/coder/v2/agent/agentcontainers"
|
||||
@@ -15,9 +17,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExpRpty(t *testing.T) {
|
||||
|
||||
@@ -24,9 +24,8 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
@@ -69,6 +68,8 @@ func (r *RootCmd) scaletestCmd() *serpent.Command {
|
||||
r.scaletestTaskStatus(),
|
||||
r.scaletestSMTP(),
|
||||
r.scaletestPrebuilds(),
|
||||
r.scaletestBridge(),
|
||||
r.scaletestLLMMock(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,278 @@
|
||||
//go:build !slim
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/scaletest/bridge"
|
||||
"github.com/coder/coder/v2/scaletest/createusers"
|
||||
"github.com/coder/coder/v2/scaletest/harness"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) scaletestBridge() *serpent.Command {
|
||||
var (
|
||||
concurrentUsers int64
|
||||
noCleanup bool
|
||||
mode string
|
||||
upstreamURL string
|
||||
provider string
|
||||
requestsPerUser int64
|
||||
useStreamingAPI bool
|
||||
requestPayloadSize int64
|
||||
numMessages int64
|
||||
httpTimeout time.Duration
|
||||
|
||||
timeoutStrategy = &timeoutFlags{}
|
||||
cleanupStrategy = newScaletestCleanupStrategy()
|
||||
output = &scaletestOutputFlags{}
|
||||
prometheusFlags = &scaletestPrometheusFlags{}
|
||||
)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "bridge",
|
||||
Short: "Generate load on the AI Bridge service.",
|
||||
Long: `Generate load for AI Bridge testing. Supports two modes: 'bridge' mode routes requests through the Coder AI Bridge, 'direct' mode makes requests directly to an upstream URL (useful for baseline comparisons).
|
||||
|
||||
Examples:
|
||||
# Test OpenAI API through bridge
|
||||
coder scaletest bridge --mode bridge --provider openai --concurrent-users 10 --request-count 5 --num-messages 10
|
||||
|
||||
# Test Anthropic API through bridge
|
||||
coder scaletest bridge --mode bridge --provider anthropic --concurrent-users 10 --request-count 5 --num-messages 10
|
||||
|
||||
# Test directly against mock server
|
||||
coder scaletest bridge --mode direct --provider openai --upstream-url http://localhost:8080/v1/chat/completions
|
||||
`,
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
client.HTTPClient = &http.Client{
|
||||
Transport: &codersdk.HeaderTransport{
|
||||
Transport: http.DefaultTransport,
|
||||
Header: map[string][]string{
|
||||
codersdk.BypassRatelimitHeader: {"true"},
|
||||
},
|
||||
},
|
||||
}
|
||||
reg := prometheus.NewRegistry()
|
||||
metrics := bridge.NewMetrics(reg)
|
||||
|
||||
logger := inv.Logger
|
||||
prometheusSrvClose := ServeHandler(ctx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus")
|
||||
defer prometheusSrvClose()
|
||||
|
||||
defer func() {
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait)
|
||||
<-time.After(prometheusFlags.Wait)
|
||||
}()
|
||||
|
||||
notifyCtx, stop := signal.NotifyContext(ctx, StopSignals...)
|
||||
defer stop()
|
||||
ctx = notifyCtx
|
||||
|
||||
var userConfig createusers.Config
|
||||
if bridge.RequestMode(mode) == bridge.RequestModeBridge {
|
||||
me, err := requireAdmin(ctx, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(me.OrganizationIDs) == 0 {
|
||||
return xerrors.Errorf("admin user must have at least one organization")
|
||||
}
|
||||
userConfig = createusers.Config{
|
||||
OrganizationID: me.OrganizationIDs[0],
|
||||
}
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "Bridge mode: creating users and making requests through AI Bridge...")
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "Direct mode: making requests directly to %s\n", upstreamURL)
|
||||
}
|
||||
|
||||
outputs, err := output.parse()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("parse output flags: %w", err)
|
||||
}
|
||||
|
||||
config := bridge.Config{
|
||||
Mode: bridge.RequestMode(mode),
|
||||
Metrics: metrics,
|
||||
Provider: provider,
|
||||
RequestCount: int(requestsPerUser),
|
||||
Stream: useStreamingAPI,
|
||||
RequestPayloadSize: int(requestPayloadSize),
|
||||
NumMessages: int(numMessages),
|
||||
HTTPTimeout: httpTimeout,
|
||||
UpstreamURL: upstreamURL,
|
||||
User: userConfig,
|
||||
}
|
||||
if err := config.Validate(); err != nil {
|
||||
return xerrors.Errorf("validate config: %w", err)
|
||||
}
|
||||
if err := config.PrepareRequestBody(); err != nil {
|
||||
return xerrors.Errorf("prepare request body: %w", err)
|
||||
}
|
||||
|
||||
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
|
||||
|
||||
for i := range concurrentUsers {
|
||||
id := strconv.Itoa(int(i))
|
||||
name := fmt.Sprintf("bridge-%s", id)
|
||||
var runner harness.Runnable = bridge.NewRunner(client, config)
|
||||
th.AddRun(name, id, runner)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "Bridge scaletest configuration:")
|
||||
tw := tabwriter.NewWriter(inv.Stderr, 0, 0, 2, ' ', 0)
|
||||
for _, opt := range inv.Command.Options {
|
||||
if opt.Hidden || opt.ValueSource == serpent.ValueSourceNone {
|
||||
continue
|
||||
}
|
||||
_, _ = fmt.Fprintf(tw, " %s:\t%s", opt.Name, opt.Value.String())
|
||||
if opt.ValueSource != serpent.ValueSourceDefault {
|
||||
_, _ = fmt.Fprintf(tw, "\t(from %s)", opt.ValueSource)
|
||||
}
|
||||
_, _ = fmt.Fprintln(tw)
|
||||
}
|
||||
_ = tw.Flush()
|
||||
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "\nRunning bridge scaletest...")
|
||||
testCtx, testCancel := timeoutStrategy.toContext(ctx)
|
||||
defer testCancel()
|
||||
err = th.Run(testCtx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("run test harness (harness failure, not a test failure): %w", err)
|
||||
}
|
||||
|
||||
// If the command was interrupted, skip stats.
|
||||
if notifyCtx.Err() != nil {
|
||||
return notifyCtx.Err()
|
||||
}
|
||||
|
||||
res := th.Results()
|
||||
|
||||
for _, o := range outputs {
|
||||
err = o.write(res, inv.Stdout)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err)
|
||||
}
|
||||
}
|
||||
|
||||
if !noCleanup {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "\nCleaning up...")
|
||||
cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx)
|
||||
defer cleanupCancel()
|
||||
err = th.Cleanup(cleanupCtx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("cleanup tests: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if res.TotalFail > 0 {
|
||||
return xerrors.New("load test failed, see above for more details")
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Options = serpent.OptionSet{
|
||||
{
|
||||
Flag: "concurrent-users",
|
||||
FlagShorthand: "c",
|
||||
Env: "CODER_SCALETEST_BRIDGE_CONCURRENT_USERS",
|
||||
Description: "Required: Number of concurrent users.",
|
||||
Value: serpent.Validate(serpent.Int64Of(&concurrentUsers), func(value *serpent.Int64) error {
|
||||
if value == nil || value.Value() <= 0 {
|
||||
return xerrors.Errorf("--concurrent-users must be greater than 0")
|
||||
}
|
||||
return nil
|
||||
}),
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Flag: "mode",
|
||||
Env: "CODER_SCALETEST_BRIDGE_MODE",
|
||||
Default: "direct",
|
||||
Description: "Request mode: 'bridge' (create users and use AI Bridge) or 'direct' (make requests directly to upstream-url).",
|
||||
Value: serpent.EnumOf(&mode, string(bridge.RequestModeBridge), string(bridge.RequestModeDirect)),
|
||||
},
|
||||
{
|
||||
Flag: "upstream-url",
|
||||
Env: "CODER_SCALETEST_BRIDGE_UPSTREAM_URL",
|
||||
Description: "URL to make requests to directly (required in direct mode, e.g., http://localhost:8080/v1/chat/completions).",
|
||||
Value: serpent.StringOf(&upstreamURL),
|
||||
},
|
||||
{
|
||||
Flag: "provider",
|
||||
Env: "CODER_SCALETEST_BRIDGE_PROVIDER",
|
||||
Default: "openai",
|
||||
Description: "API provider to use.",
|
||||
Value: serpent.EnumOf(&provider, "openai", "anthropic"),
|
||||
},
|
||||
{
|
||||
Flag: "request-count",
|
||||
Env: "CODER_SCALETEST_BRIDGE_REQUEST_COUNT",
|
||||
Default: "1",
|
||||
Description: "Number of sequential requests to make per runner.",
|
||||
Value: serpent.Validate(serpent.Int64Of(&requestsPerUser), func(value *serpent.Int64) error {
|
||||
if value == nil || value.Value() <= 0 {
|
||||
return xerrors.Errorf("--request-count must be greater than 0")
|
||||
}
|
||||
return nil
|
||||
}),
|
||||
},
|
||||
{
|
||||
Flag: "stream",
|
||||
Env: "CODER_SCALETEST_BRIDGE_STREAM",
|
||||
Description: "Enable streaming requests.",
|
||||
Value: serpent.BoolOf(&useStreamingAPI),
|
||||
},
|
||||
{
|
||||
Flag: "request-payload-size",
|
||||
Env: "CODER_SCALETEST_BRIDGE_REQUEST_PAYLOAD_SIZE",
|
||||
Default: "1024",
|
||||
Description: "Size in bytes of the request payload (user message content). If 0, uses default message content.",
|
||||
Value: serpent.Int64Of(&requestPayloadSize),
|
||||
},
|
||||
{
|
||||
Flag: "num-messages",
|
||||
Env: "CODER_SCALETEST_BRIDGE_NUM_MESSAGES",
|
||||
Default: "1",
|
||||
Description: "Number of messages to include in the conversation.",
|
||||
Value: serpent.Int64Of(&numMessages),
|
||||
},
|
||||
{
|
||||
Flag: "no-cleanup",
|
||||
Env: "CODER_SCALETEST_NO_CLEANUP",
|
||||
Description: "Do not clean up resources after the test completes.",
|
||||
Value: serpent.BoolOf(&noCleanup),
|
||||
},
|
||||
{
|
||||
Flag: "http-timeout",
|
||||
Env: "CODER_SCALETEST_BRIDGE_HTTP_TIMEOUT",
|
||||
Default: "30s",
|
||||
Description: "Timeout for individual HTTP requests to the upstream provider.",
|
||||
Value: serpent.DurationOf(&httpTimeout),
|
||||
},
|
||||
}
|
||||
|
||||
timeoutStrategy.attach(&cmd.Options)
|
||||
cleanupStrategy.attach(&cmd.Options)
|
||||
output.attach(&cmd.Options)
|
||||
prometheusFlags.attach(&cmd.Options)
|
||||
return cmd
|
||||
}
|
||||
@@ -10,14 +10,12 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/scaletest/loadtestutil"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/scaletest/dynamicparameters"
|
||||
"github.com/coder/coder/v2/scaletest/harness"
|
||||
"github.com/coder/coder/v2/scaletest/loadtestutil"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -0,0 +1,118 @@
|
||||
//go:build !slim
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/signal"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/scaletest/llmmock"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (*RootCmd) scaletestLLMMock() *serpent.Command {
|
||||
var (
|
||||
address string
|
||||
artificialLatency time.Duration
|
||||
responsePayloadSize int64
|
||||
|
||||
pprofEnable bool
|
||||
pprofAddress string
|
||||
|
||||
traceEnable bool
|
||||
)
|
||||
cmd := &serpent.Command{
|
||||
Use: "llm-mock",
|
||||
Short: "Start a mock LLM API server for testing",
|
||||
Long: `Start a mock LLM API server that simulates OpenAI and Anthropic APIs`,
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx, stop := signal.NotifyContext(inv.Context(), StopSignals...)
|
||||
defer stop()
|
||||
|
||||
logger := slog.Make(sloghuman.Sink(inv.Stderr)).Leveled(slog.LevelInfo)
|
||||
|
||||
if pprofEnable {
|
||||
closePprof := ServeHandler(ctx, logger, nil, pprofAddress, "pprof")
|
||||
defer closePprof()
|
||||
logger.Info(ctx, "pprof server started", slog.F("address", pprofAddress))
|
||||
}
|
||||
|
||||
config := llmmock.Config{
|
||||
Address: address,
|
||||
Logger: logger,
|
||||
ArtificialLatency: artificialLatency,
|
||||
ResponsePayloadSize: int(responsePayloadSize),
|
||||
PprofEnable: pprofEnable,
|
||||
PprofAddress: pprofAddress,
|
||||
TraceEnable: traceEnable,
|
||||
}
|
||||
srv := new(llmmock.Server)
|
||||
|
||||
if err := srv.Start(ctx, config); err != nil {
|
||||
return xerrors.Errorf("start mock LLM server: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = srv.Stop()
|
||||
}()
|
||||
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "Mock LLM API server started on %s\n", srv.APIAddress())
|
||||
_, _ = fmt.Fprintf(inv.Stdout, " OpenAI endpoint: %s/v1/chat/completions\n", srv.APIAddress())
|
||||
_, _ = fmt.Fprintf(inv.Stdout, " Anthropic endpoint: %s/v1/messages\n", srv.APIAddress())
|
||||
|
||||
<-ctx.Done()
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Options = []serpent.Option{
|
||||
{
|
||||
Flag: "address",
|
||||
Env: "CODER_SCALETEST_LLM_MOCK_ADDRESS",
|
||||
Default: "localhost",
|
||||
Description: "Address to bind the mock LLM API server. Can include a port (e.g., 'localhost:8080' or ':8080'). Uses a random port if no port is specified.",
|
||||
Value: serpent.StringOf(&address),
|
||||
},
|
||||
{
|
||||
Flag: "artificial-latency",
|
||||
Env: "CODER_SCALETEST_LLM_MOCK_ARTIFICIAL_LATENCY",
|
||||
Default: "0s",
|
||||
Description: "Artificial latency to add to each response (e.g., 100ms, 1s). Simulates slow upstream processing.",
|
||||
Value: serpent.DurationOf(&artificialLatency),
|
||||
},
|
||||
{
|
||||
Flag: "response-payload-size",
|
||||
Env: "CODER_SCALETEST_LLM_MOCK_RESPONSE_PAYLOAD_SIZE",
|
||||
Default: "0",
|
||||
Description: "Size in bytes of the response payload. If 0, uses default context-aware responses.",
|
||||
Value: serpent.Int64Of(&responsePayloadSize),
|
||||
},
|
||||
{
|
||||
Flag: "pprof-enable",
|
||||
Env: "CODER_SCALETEST_LLM_MOCK_PPROF_ENABLE",
|
||||
Default: "false",
|
||||
Description: "Serve pprof metrics on the address defined by pprof-address.",
|
||||
Value: serpent.BoolOf(&pprofEnable),
|
||||
},
|
||||
{
|
||||
Flag: "pprof-address",
|
||||
Env: "CODER_SCALETEST_LLM_MOCK_PPROF_ADDRESS",
|
||||
Default: "127.0.0.1:6060",
|
||||
Description: "The bind address to serve pprof.",
|
||||
Value: serpent.StringOf(&pprofAddress),
|
||||
},
|
||||
{
|
||||
Flag: "trace-enable",
|
||||
Env: "CODER_SCALETEST_LLM_MOCK_TRACE_ENABLE",
|
||||
Default: "false",
|
||||
Description: "Whether application tracing data is collected. It exports to a backend configured by environment variables. See: https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/protocol/exporter.md.",
|
||||
Value: serpent.BoolOf(&traceEnable),
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
@@ -18,14 +18,12 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/scaletest/loadtestutil"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
notificationsLib "github.com/coder/coder/v2/coderd/notifications"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/scaletest/createusers"
|
||||
"github.com/coder/coder/v2/scaletest/harness"
|
||||
"github.com/coder/coder/v2/scaletest/loadtestutil"
|
||||
"github.com/coder/coder/v2/scaletest/notifications"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
@@ -13,10 +13,9 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/scaletest/loadtestutil"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/scaletest/harness"
|
||||
"github.com/coder/coder/v2/scaletest/loadtestutil"
|
||||
"github.com/coder/coder/v2/scaletest/prebuilds"
|
||||
"github.com/coder/quartz"
|
||||
"github.com/coder/serpent"
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/scaletest/smtpmock"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
@@ -14,15 +14,13 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/scaletest/loadtestutil"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/scaletest/harness"
|
||||
"github.com/coder/coder/v2/scaletest/loadtestutil"
|
||||
"github.com/coder/coder/v2/scaletest/taskstatus"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -7,8 +7,7 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
|
||||
@@ -4,12 +4,12 @@ import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFavoriteUnfavorite(t *testing.T) {
|
||||
|
||||
+1
-2
@@ -16,12 +16,11 @@ import (
|
||||
"github.com/pkg/browser"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/pretty"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/cli/sessionstore"
|
||||
"github.com/coder/coder/v2/coderd/userpassword"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/pretty"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
|
||||
+1
-2
@@ -11,14 +11,13 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/pretty"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/pretty"
|
||||
)
|
||||
|
||||
func TestLogin(t *testing.T) {
|
||||
|
||||
+270
@@ -0,0 +1,270 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) logs() *serpent.Command {
|
||||
var (
|
||||
buildNumberArg int64
|
||||
followArg bool
|
||||
)
|
||||
cmd := &serpent.Command{
|
||||
Use: "logs <workspace>",
|
||||
Short: "View logs for a workspace",
|
||||
Long: "View logs for a workspace",
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(1),
|
||||
),
|
||||
Options: serpent.OptionSet{
|
||||
{
|
||||
Name: "Build Number",
|
||||
Flag: "build-number",
|
||||
FlagShorthand: "n",
|
||||
Description: "Only show logs for a specific build number. Defaults to 0, which maps to the most recent build (build numbers start at 1). Negative values are treated as offsets—for example, -1 refers to the previous build.",
|
||||
Value: serpent.Int64Of(&buildNumberArg),
|
||||
Default: "0",
|
||||
},
|
||||
{
|
||||
Name: "Follow",
|
||||
Flag: "follow",
|
||||
FlagShorthand: "f",
|
||||
Description: "Follow logs as they are emitted.",
|
||||
Value: serpent.BoolOf(&followArg),
|
||||
Default: "false",
|
||||
},
|
||||
},
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ws, err := namedWorkspace(inv.Context(), client, inv.Args[0])
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to get workspace: %w", err)
|
||||
}
|
||||
bld := ws.LatestBuild
|
||||
buildNumber := buildNumberArg
|
||||
|
||||
// User supplied a negative build number, treat it as an offset from the latest build
|
||||
if buildNumber < 0 {
|
||||
buildNumber = int64(ws.LatestBuild.BuildNumber) + buildNumberArg
|
||||
if buildNumber < 1 {
|
||||
return xerrors.Errorf("invalid build number offset: %d latest build number: %d", buildNumberArg, ws.LatestBuild.BuildNumber)
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch specific build if requested
|
||||
if buildNumber > 0 {
|
||||
wb, err := client.WorkspaceBuildByUsernameAndWorkspaceNameAndBuildNumber(ctx, ws.OwnerName, ws.Name, strconv.FormatInt(buildNumber, 10))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to get build %d: %w", buildNumberArg, err)
|
||||
}
|
||||
bld = wb
|
||||
}
|
||||
cliui.Infof(inv.Stdout, "--- Logs for workspace build #%d (ID: %s Template Version: %s) ---", bld.BuildNumber, bld.ID, bld.TemplateVersionName)
|
||||
logs, logsCh, err := workspaceLogs(ctx, client, bld, followArg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, log := range logs {
|
||||
_, _ = fmt.Fprintln(inv.Stdout, log.String())
|
||||
}
|
||||
if followArg {
|
||||
_, _ = fmt.Fprintln(inv.Stdout, "--- Streaming logs ---")
|
||||
for log := range logsCh {
|
||||
_, _ = fmt.Fprintln(inv.Stdout, log.String())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
type logLine struct {
|
||||
ts time.Time
|
||||
Content string
|
||||
}
|
||||
|
||||
func (l *logLine) String() string {
|
||||
var sb strings.Builder
|
||||
_, _ = sb.WriteString(l.ts.Format(time.RFC3339))
|
||||
_, _ = sb.WriteString(l.Content)
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// workspaceLogs fetches logs for the given workspace build. If follow is true,
|
||||
// the returned channel will stream new logs as they are emitted. Otherwise,
|
||||
// the channel will be closed immediately.
|
||||
// nolint: revive // control flag is appropriate here
|
||||
func workspaceLogs(ctx context.Context, client *codersdk.Client, wb codersdk.WorkspaceBuild, follow bool) ([]logLine, <-chan logLine, error) {
|
||||
logs := make([]logLine, 0)
|
||||
logsCh := make(chan logLine)
|
||||
followCh := make(chan logLine)
|
||||
|
||||
var fetchGroup, followGroup errgroup.Group
|
||||
|
||||
buildLogsAfterCh := make(chan int64)
|
||||
fetchGroup.Go(func() error {
|
||||
var afterID int64
|
||||
defer func() {
|
||||
if !follow {
|
||||
return
|
||||
}
|
||||
buildLogsAfterCh <- afterID
|
||||
}()
|
||||
buildLogsC, closer, err := client.WorkspaceBuildLogsAfter(ctx, wb.ID, 0)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to get build logs: %w", err)
|
||||
}
|
||||
defer closer.Close()
|
||||
for log := range buildLogsC {
|
||||
afterID = log.ID
|
||||
logsCh <- logLine{
|
||||
ts: log.CreatedAt,
|
||||
Content: buildLogToString(log),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if follow {
|
||||
followGroup.Go(func() error {
|
||||
afterID := <-buildLogsAfterCh
|
||||
buildLogsC, closer, err := client.WorkspaceBuildLogsAfter(ctx, wb.ID, afterID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to follow build logs: %w", err)
|
||||
}
|
||||
defer closer.Close()
|
||||
for log := range buildLogsC {
|
||||
followCh <- logLine{
|
||||
ts: log.CreatedAt,
|
||||
Content: buildLogToString(log),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
for _, res := range wb.Resources {
|
||||
for _, agt := range res.Agents {
|
||||
logSrcNames := make(map[uuid.UUID]string)
|
||||
for _, src := range agt.LogSources {
|
||||
logSrcNames[src.ID] = src.DisplayName
|
||||
}
|
||||
agentLogsAfterCh := make(chan int64)
|
||||
var afterID int64
|
||||
fetchGroup.Go(func() error {
|
||||
defer func() {
|
||||
if !follow {
|
||||
return
|
||||
}
|
||||
agentLogsAfterCh <- afterID
|
||||
}()
|
||||
agentLogsCh, closer, err := client.WorkspaceAgentLogsAfter(ctx, agt.ID, 0, false)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to get agent logs: %w", err)
|
||||
}
|
||||
defer closer.Close()
|
||||
for logChunk := range agentLogsCh {
|
||||
for _, log := range logChunk {
|
||||
afterID = log.ID
|
||||
logsCh <- logLine{
|
||||
ts: log.CreatedAt,
|
||||
Content: workspaceAgentLogToString(log, agt.Name, logSrcNames[log.SourceID]),
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if follow {
|
||||
followGroup.Go(func() error {
|
||||
afterID := <-agentLogsAfterCh
|
||||
agentLogsCh, closer, err := client.WorkspaceAgentLogsAfter(ctx, agt.ID, afterID, true)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to follow agent logs: %w", err)
|
||||
}
|
||||
defer closer.Close()
|
||||
for logChunk := range agentLogsCh {
|
||||
for _, log := range logChunk {
|
||||
followCh <- logLine{
|
||||
ts: log.CreatedAt,
|
||||
Content: workspaceAgentLogToString(log, agt.Name, logSrcNames[log.SourceID]),
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logsDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(logsDone)
|
||||
for log := range logsCh {
|
||||
logs = append(logs, log)
|
||||
}
|
||||
}()
|
||||
|
||||
err := fetchGroup.Wait()
|
||||
close(logsCh)
|
||||
<-logsDone
|
||||
|
||||
slices.SortFunc(logs, func(a, b logLine) int {
|
||||
return a.ts.Compare(b.ts)
|
||||
})
|
||||
|
||||
if follow {
|
||||
go func() {
|
||||
_ = followGroup.Wait()
|
||||
close(followCh)
|
||||
}()
|
||||
} else {
|
||||
close(followCh)
|
||||
}
|
||||
|
||||
return logs, followCh, err
|
||||
}
|
||||
|
||||
func buildLogToString(log codersdk.ProvisionerJobLog) string {
|
||||
var sb strings.Builder
|
||||
_, _ = sb.WriteString(" [")
|
||||
_, _ = sb.WriteString(string(log.Level))
|
||||
_, _ = sb.WriteString("] [")
|
||||
_, _ = sb.WriteString("provisioner|")
|
||||
_, _ = sb.WriteString(log.Stage)
|
||||
_, _ = sb.WriteString("] ")
|
||||
_, _ = sb.WriteString(log.Output)
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func workspaceAgentLogToString(log codersdk.WorkspaceAgentLog, agtName, srcName string) string {
|
||||
var sb strings.Builder
|
||||
_, _ = sb.WriteString(" [")
|
||||
_, _ = sb.WriteString(string(log.Level))
|
||||
_, _ = sb.WriteString("] [")
|
||||
_, _ = sb.WriteString("agent.")
|
||||
_, _ = sb.WriteString(agtName)
|
||||
_, _ = sb.WriteString("|")
|
||||
_, _ = sb.WriteString(srcName)
|
||||
_, _ = sb.WriteString("] ")
|
||||
_, _ = sb.WriteString(log.Output)
|
||||
return sb.String()
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestLogsCmd(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
memberClient, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
|
||||
testWorkspace := func(t testing.TB, db database.Store, ownerID, orgID uuid.UUID) dbfake.WorkspaceResponse {
|
||||
wb := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: memberUser.ID,
|
||||
OrganizationID: owner.OrganizationID,
|
||||
}).WithAgent().Do()
|
||||
_ = dbgen.ProvisionerJobLog(t, db, database.ProvisionerJobLog{
|
||||
JobID: wb.Build.JobID,
|
||||
Output: "test provisioner log for build " + wb.Build.ID.String(),
|
||||
})
|
||||
for _, agt := range wb.Agents {
|
||||
_ = dbgen.WorkspaceAgentLog(t, db, database.WorkspaceAgentLog{
|
||||
AgentID: agt.ID,
|
||||
Output: "test agent log for agent " + agt.ID.String(),
|
||||
})
|
||||
}
|
||||
return wb
|
||||
}
|
||||
|
||||
assertLogOutput := func(t testing.TB, wb dbfake.WorkspaceResponse, output string) {
|
||||
t.Helper()
|
||||
require.Contains(t, output, "test provisioner log for build "+wb.Build.ID.String())
|
||||
for _, agt := range wb.Agents {
|
||||
require.Contains(t, output, "test agent log for agent "+agt.ID.String())
|
||||
}
|
||||
}
|
||||
|
||||
assertAntagonist := func(t testing.TB, wb dbfake.WorkspaceResponse, output string) {
|
||||
t.Helper()
|
||||
require.NotContains(t, output, "test provisioner log for build "+wb.Build.ID.String())
|
||||
for _, agt := range wb.Agents {
|
||||
require.NotContains(t, output, "test agent log for agent "+agt.ID.String())
|
||||
}
|
||||
}
|
||||
|
||||
wb1 := testWorkspace(t, db, memberUser.ID, owner.OrganizationID)
|
||||
wb2 := testWorkspace(t, db, owner.UserID, owner.OrganizationID)
|
||||
|
||||
t.Run("workspace not found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inv, root := clitest.New(t, "logs", "doesnotexist")
|
||||
clitest.SetupConfig(t, memberClient, root)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
var stdout strings.Builder
|
||||
inv.Stdout = &stdout
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.ErrorContains(t, err, "Resource not found or you do not have access to this resource")
|
||||
})
|
||||
|
||||
// Note: not testing with --follow as it is inherently racy.
|
||||
t.Run("current build", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inv, root := clitest.New(t, "logs", wb1.Workspace.Name)
|
||||
clitest.SetupConfig(t, memberClient, root)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
var stdout strings.Builder
|
||||
inv.Stdout = &stdout
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err, "failed to fetch logs for current build")
|
||||
assertLogOutput(t, wb1, stdout.String())
|
||||
assertAntagonist(t, wb2, stdout.String())
|
||||
})
|
||||
|
||||
t.Run("specific build", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inv, root := clitest.New(t, "logs", wb1.Workspace.Name, "-n", fmt.Sprintf("%d", wb1.Build.BuildNumber))
|
||||
clitest.SetupConfig(t, memberClient, root)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
var stdout strings.Builder
|
||||
inv.Stdout = &stdout
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err, "failed to fetch logs for specific build")
|
||||
assertLogOutput(t, wb1, stdout.String())
|
||||
assertAntagonist(t, wb2, stdout.String())
|
||||
})
|
||||
|
||||
t.Run("build out of range", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inv, root := clitest.New(t, "logs", wb1.Workspace.Name, "-n", "-9999")
|
||||
clitest.SetupConfig(t, memberClient, root)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
var stdout strings.Builder
|
||||
inv.Stdout = &stdout
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.ErrorContains(t, err, "invalid build number offset")
|
||||
})
|
||||
}
|
||||
@@ -5,9 +5,8 @@ import (
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) notifications() *serpent.Command {
|
||||
|
||||
+9
-3
@@ -169,8 +169,8 @@ func (r *RootCmd) openVSCode() *serpent.Command {
|
||||
// Note that this is irrelevant for devcontainer sub agents, as
|
||||
// they always have a directory set.
|
||||
if workspaceAgent.Directory != "" {
|
||||
workspace, workspaceAgent, err = waitForAgentCond(ctx, client, workspace, workspaceAgent, func(_ codersdk.WorkspaceAgent) bool {
|
||||
return workspaceAgent.LifecycleState != codersdk.WorkspaceAgentLifecycleCreated
|
||||
workspace, workspaceAgent, err = waitForAgentCond(ctx, client, workspace, workspaceAgent, func(wa codersdk.WorkspaceAgent) bool {
|
||||
return wa.LifecycleState != codersdk.WorkspaceAgentLifecycleCreated
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("wait for agent: %w", err)
|
||||
@@ -183,7 +183,13 @@ func (r *RootCmd) openVSCode() *serpent.Command {
|
||||
directory = inv.Args[1]
|
||||
}
|
||||
|
||||
directory, err = resolveAgentAbsPath(workspaceAgent.ExpandedDirectory, directory, workspaceAgent.OperatingSystem, insideThisWorkspace)
|
||||
// If we're opening into a dev container, we should use the directory of the dev container.
|
||||
workingDirectory := workspaceAgent.ExpandedDirectory
|
||||
if workingDirectory == "" && devcontainer.Agent != nil {
|
||||
workingDirectory = devcontainer.Agent.Directory
|
||||
}
|
||||
|
||||
directory, err = resolveAgentAbsPath(workingDirectory, directory, workspaceAgent.OperatingSystem, insideThisWorkspace)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve agent path: %w", err)
|
||||
}
|
||||
|
||||
@@ -65,6 +65,22 @@ func (r *RootCmd) organizationSettings(orgContext *OrganizationContext) *serpent
|
||||
return cli.OrganizationIDPSyncSettings(ctx)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "workspace-sharing",
|
||||
Aliases: []string{"workspacesharing"},
|
||||
Short: "Workspace sharing settings for the organization.",
|
||||
Patch: func(ctx context.Context, cli *codersdk.Client, org uuid.UUID, input json.RawMessage) (any, error) {
|
||||
var req codersdk.WorkspaceSharingSettings
|
||||
err := json.Unmarshal(input, &req)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("unmarshalling workspace sharing settings: %w", err)
|
||||
}
|
||||
return cli.PatchWorkspaceSharingSettings(ctx, org.String(), req)
|
||||
},
|
||||
Fetch: func(ctx context.Context, cli *codersdk.Client, org uuid.UUID) (any, error) {
|
||||
return cli.WorkspaceSharingSettings(ctx, org.String())
|
||||
},
|
||||
},
|
||||
}
|
||||
cmd := &serpent.Command{
|
||||
Use: "settings",
|
||||
|
||||
+5
-9
@@ -10,25 +10,21 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/briandowns/spinner"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/ipn/ipnstate"
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
|
||||
"github.com/briandowns/spinner"
|
||||
|
||||
"github.com/coder/pretty"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/cli/cliutil"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/healthsdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/pretty"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
type pingSummary struct {
|
||||
|
||||
+2
-3
@@ -15,9 +15,8 @@ import (
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
|
||||
+2
-3
@@ -5,11 +5,10 @@ import (
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/pretty"
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/pretty"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) publickey() *serpent.Command {
|
||||
|
||||
+2
-3
@@ -5,11 +5,10 @@ import (
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/pretty"
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/pretty"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) rename() *serpent.Command {
|
||||
|
||||
@@ -7,16 +7,15 @@ import (
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/awsiamrds"
|
||||
"github.com/coder/coder/v2/coderd/userpassword"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/pretty"
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/userpassword"
|
||||
)
|
||||
|
||||
func (*RootCmd) resetPassword() *serpent.Command {
|
||||
|
||||
+20
-4
@@ -29,10 +29,6 @@ import (
|
||||
"golang.org/x/mod/semver"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/pretty"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/cli/config"
|
||||
@@ -41,6 +37,8 @@ import (
|
||||
"github.com/coder/coder/v2/cli/telemetry"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/pretty"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -117,6 +115,7 @@ func (r *RootCmd) CoreSubcommands() []*serpent.Command {
|
||||
r.deleteWorkspace(),
|
||||
r.favorite(),
|
||||
r.list(),
|
||||
r.logs(),
|
||||
r.open(),
|
||||
r.ping(),
|
||||
r.rename(),
|
||||
@@ -685,6 +684,7 @@ func (r *RootCmd) HeaderTransport(ctx context.Context, serverURL *url.URL) (*cod
|
||||
func (r *RootCmd) createHTTPClient(ctx context.Context, serverURL *url.URL, inv *serpent.Invocation) (*http.Client, error) {
|
||||
transport := http.DefaultTransport
|
||||
transport = wrapTransportWithTelemetryHeader(transport, inv)
|
||||
transport = wrapTransportWithUserAgentHeader(transport, inv)
|
||||
if !r.noVersionCheck {
|
||||
transport = wrapTransportWithVersionMismatchCheck(transport, inv, buildinfo.Version(), func(ctx context.Context) (codersdk.BuildInfoResponse, error) {
|
||||
// Create a new client without any wrapped transport
|
||||
@@ -1498,6 +1498,22 @@ func wrapTransportWithTelemetryHeader(transport http.RoundTripper, inv *serpent.
|
||||
})
|
||||
}
|
||||
|
||||
// wrapTransportWithUserAgentHeader sets a User-Agent header for all CLI requests
|
||||
// that includes the CLI version, os/arch, and the specific command being run.
|
||||
func wrapTransportWithUserAgentHeader(transport http.RoundTripper, inv *serpent.Invocation) http.RoundTripper {
|
||||
var (
|
||||
userAgent string
|
||||
once sync.Once
|
||||
)
|
||||
return roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
once.Do(func() {
|
||||
userAgent = fmt.Sprintf("coder-cli/%s (%s/%s; %s)", buildinfo.Version(), runtime.GOOS, runtime.GOARCH, inv.Command.FullName())
|
||||
})
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
return transport.RoundTrip(req)
|
||||
})
|
||||
}
|
||||
|
||||
type roundTripper func(req *http.Request) (*http.Response, error)
|
||||
|
||||
func (r roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
|
||||
@@ -380,3 +380,59 @@ func agentClientCommand(clientRef **agentsdk.Client) *serpent.Command {
|
||||
agentAuth.AttachOptions(cmd, false)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func TestWrapTransportWithUserAgentHeader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
cmdArgs []string
|
||||
cmdEnv map[string]string
|
||||
expectedUserAgentHeader string
|
||||
}{
|
||||
{
|
||||
name: "top-level command",
|
||||
cmdArgs: []string{"login"},
|
||||
expectedUserAgentHeader: fmt.Sprintf("coder-cli/%s (%s/%s; coder login)", buildinfo.Version(), runtime.GOOS, runtime.GOARCH),
|
||||
},
|
||||
{
|
||||
name: "nested commands",
|
||||
cmdArgs: []string{"templates", "list"},
|
||||
expectedUserAgentHeader: fmt.Sprintf("coder-cli/%s (%s/%s; coder templates list)", buildinfo.Version(), runtime.GOOS, runtime.GOARCH),
|
||||
},
|
||||
{
|
||||
name: "does not include positional args, flags, or env",
|
||||
cmdArgs: []string{"templates", "push", "my-template", "-d", "/path/to/template", "--yes", "--var", "myvar=myvalue"},
|
||||
cmdEnv: map[string]string{"SECRET_KEY": "secret_value"},
|
||||
expectedUserAgentHeader: fmt.Sprintf("coder-cli/%s (%s/%s; coder templates push)", buildinfo.Version(), runtime.GOOS, runtime.GOARCH),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := make(chan string, 1)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
select {
|
||||
case ch <- r.Header.Get("User-Agent"):
|
||||
default: // already sent
|
||||
}
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
args := append([]string{}, tc.cmdArgs...)
|
||||
inv, _ := clitest.New(t, args...)
|
||||
inv.Environ.Set("CODER_URL", srv.URL)
|
||||
for k, v := range tc.cmdEnv {
|
||||
inv.Environ.Set(k, v)
|
||||
}
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
_ = inv.WithContext(ctx).Run() // Ignore error as we only care about headers.
|
||||
|
||||
actual := testutil.RequireReceive(ctx, t, ch)
|
||||
require.Equal(t, tc.expectedUserAgentHeader, actual, "User-Agent should match expected format exactly")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+58
-26
@@ -54,15 +54,8 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/coderd/pproflabel"
|
||||
"github.com/coder/pretty"
|
||||
"github.com/coder/quartz"
|
||||
"github.com/coder/retry"
|
||||
"github.com/coder/serpent"
|
||||
"github.com/coder/wgtunnel/tunnelsdk"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/cli/clilog"
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
@@ -86,6 +79,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/notifications"
|
||||
"github.com/coder/coder/v2/coderd/notifications/reports"
|
||||
"github.com/coder/coder/v2/coderd/oauthpki"
|
||||
"github.com/coder/coder/v2/coderd/pproflabel"
|
||||
"github.com/coder/coder/v2/coderd/prometheusmetrics"
|
||||
"github.com/coder/coder/v2/coderd/prometheusmetrics/insights"
|
||||
"github.com/coder/coder/v2/coderd/promoauth"
|
||||
@@ -111,6 +105,11 @@ import (
|
||||
"github.com/coder/coder/v2/provisionersdk"
|
||||
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/pretty"
|
||||
"github.com/coder/quartz"
|
||||
"github.com/coder/retry"
|
||||
"github.com/coder/serpent"
|
||||
"github.com/coder/wgtunnel/tunnelsdk"
|
||||
)
|
||||
|
||||
func createOIDCConfig(ctx context.Context, logger slog.Logger, vals *codersdk.DeploymentValues) (*coderd.OIDCConfig, error) {
|
||||
@@ -748,7 +747,16 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
// "bare" read on this channel.
|
||||
var pubsubWatchdogTimeout <-chan struct{}
|
||||
|
||||
sqlDB, dbURL, err := getAndMigratePostgresDB(ctx, logger, vals.PostgresURL.String(), codersdk.PostgresAuth(vals.PostgresAuth), sqlDriver)
|
||||
maxOpenConns := int(vals.PostgresConnMaxOpen.Value())
|
||||
maxIdleConns, err := codersdk.ComputeMaxIdleConns(maxOpenConns, vals.PostgresConnMaxIdle.Value())
|
||||
if err != nil {
|
||||
return xerrors.Errorf("compute max idle connections: %w", err)
|
||||
}
|
||||
logger.Debug(ctx, "creating database connection pool", slog.F("max_open_conns", maxOpenConns), slog.F("max_idle_conns", maxIdleConns))
|
||||
sqlDB, dbURL, err := getAndMigratePostgresDB(ctx, logger, vals.PostgresURL.String(), codersdk.PostgresAuth(vals.PostgresAuth), sqlDriver,
|
||||
WithMaxOpenConns(maxOpenConns),
|
||||
WithMaxIdleConns(maxIdleConns),
|
||||
)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("connect to postgres: %w", err)
|
||||
}
|
||||
@@ -2325,6 +2333,29 @@ func IsLocalhost(host string) bool {
|
||||
return host == "localhost" || host == "127.0.0.1" || host == "::1"
|
||||
}
|
||||
|
||||
// PostgresConnectOptions contains options for connecting to Postgres.
|
||||
type PostgresConnectOptions struct {
|
||||
MaxOpenConns int
|
||||
MaxIdleConns int
|
||||
}
|
||||
|
||||
// PostgresConnectOption is a functional option for ConnectToPostgres.
|
||||
type PostgresConnectOption func(*PostgresConnectOptions)
|
||||
|
||||
// WithMaxOpenConns sets the maximum number of open connections to the database.
|
||||
func WithMaxOpenConns(n int) PostgresConnectOption {
|
||||
return func(o *PostgresConnectOptions) {
|
||||
o.MaxOpenConns = n
|
||||
}
|
||||
}
|
||||
|
||||
// WithMaxIdleConns sets the maximum number of idle connections in the pool.
|
||||
func WithMaxIdleConns(n int) PostgresConnectOption {
|
||||
return func(o *PostgresConnectOptions) {
|
||||
o.MaxIdleConns = n
|
||||
}
|
||||
}
|
||||
|
||||
// ConnectToPostgres takes in the migration command to run on the database once
|
||||
// it connects. To avoid running migrations, pass in `nil` or a no-op function.
|
||||
// Regardless of the passed in migration function, if the database is not fully
|
||||
@@ -2332,7 +2363,15 @@ func IsLocalhost(host string) bool {
|
||||
// future or past migration version.
|
||||
//
|
||||
// If no error is returned, the database is fully migrated and up to date.
|
||||
func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, dbURL string, migrate func(db *sql.DB) error) (*sql.DB, error) {
|
||||
func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, dbURL string, migrate func(db *sql.DB) error, opts ...PostgresConnectOption) (*sql.DB, error) {
|
||||
// Apply defaults.
|
||||
options := PostgresConnectOptions{
|
||||
MaxOpenConns: 10,
|
||||
MaxIdleConns: 3,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
logger.Debug(ctx, "connecting to postgresql")
|
||||
|
||||
var err error
|
||||
@@ -2415,19 +2454,12 @@ func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, d
|
||||
// cannot accept new connections, so we try to limit that here.
|
||||
// Requests will wait for a new connection instead of a hard error
|
||||
// if a limit is set.
|
||||
sqlDB.SetMaxOpenConns(10)
|
||||
// Allow a max of 3 idle connections at a time. Lower values end up
|
||||
// creating a lot of connection churn. Since each connection uses about
|
||||
// 10MB of memory, we're allocating 30MB to Postgres connections per
|
||||
// replica, but is better than causing Postgres to spawn a thread 15-20
|
||||
// times/sec. PGBouncer's transaction pooling is not the greatest so
|
||||
// it's not optimal for us to deploy.
|
||||
//
|
||||
// This was set to 10 before we started doing HA deployments, but 3 was
|
||||
// later determined to be a better middle ground as to not use up all
|
||||
// of PGs default connection limit while simultaneously avoiding a lot
|
||||
// of connection churn.
|
||||
sqlDB.SetMaxIdleConns(3)
|
||||
sqlDB.SetMaxOpenConns(options.MaxOpenConns)
|
||||
// Limit idle connections to reduce connection churn while keeping some
|
||||
// connections ready for reuse. When a connection is returned to the pool
|
||||
// but the idle pool is full, it's closed immediately - which can cause
|
||||
// connection establishment overhead when load fluctuates.
|
||||
sqlDB.SetMaxIdleConns(options.MaxIdleConns)
|
||||
|
||||
dbNeedsClosing = false
|
||||
return sqlDB, nil
|
||||
@@ -2831,7 +2863,7 @@ func signalNotifyContext(ctx context.Context, inv *serpent.Invocation, sig ...os
|
||||
return inv.SignalNotifyContext(ctx, sig...)
|
||||
}
|
||||
|
||||
func getAndMigratePostgresDB(ctx context.Context, logger slog.Logger, postgresURL string, auth codersdk.PostgresAuth, sqlDriver string) (*sql.DB, string, error) {
|
||||
func getAndMigratePostgresDB(ctx context.Context, logger slog.Logger, postgresURL string, auth codersdk.PostgresAuth, sqlDriver string, opts ...PostgresConnectOption) (*sql.DB, string, error) {
|
||||
dbURL, err := escapePostgresURLUserInfo(postgresURL)
|
||||
if err != nil {
|
||||
return nil, "", xerrors.Errorf("escaping postgres URL: %w", err)
|
||||
@@ -2844,7 +2876,7 @@ func getAndMigratePostgresDB(ctx context.Context, logger slog.Logger, postgresUR
|
||||
}
|
||||
}
|
||||
|
||||
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, dbURL, migrations.Up)
|
||||
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, dbURL, migrations.Up, opts...)
|
||||
if err != nil {
|
||||
return nil, "", xerrors.Errorf("connect to postgres: %w", err)
|
||||
}
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/awsiamrds"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user