Compare commits

..

2 Commits

Author SHA1 Message Date
blink-so[bot] 3ca14f3f12 fix: update TokenSearch design to match comp
- Filter button now inside search box with border separator
- Token badges show key:value format (e.g. group:devops)
- Dropdown appears full-width below search box
- Cleaner styling matching dark theme comp
- Only show dropdown when typing filter values
2026-02-05 13:49:57 +00:00
blink-so[bot] fff27c28f5 feat: add TokenSearch component prototype
Adds a new token-based search component with:
- Dismissible filter badges
- Dropdown filter selector
- Typeahead autocomplete for filter values
- Full keyboard navigation support
- Storybook stories for Members and Workspaces patterns
2026-02-05 13:34:49 +00:00
306 changed files with 8997 additions and 26409 deletions
+4 -4
View File
@@ -1,13 +1,13 @@
apiVersion: cert-manager.io/v1
kind: Certificate
metadata:
name: ${DEPLOY_NAME}-tls
name: pr${PR_NUMBER}-tls
namespace: pr-deployment-certs
spec:
secretName: ${DEPLOY_NAME}-tls
secretName: pr${PR_NUMBER}-tls
issuerRef:
name: letsencrypt
kind: ClusterIssuer
dnsNames:
- "${DEPLOY_HOSTNAME}"
- "*.${DEPLOY_HOSTNAME}"
- "${PR_HOSTNAME}"
- "*.${PR_HOSTNAME}"
+9 -9
View File
@@ -1,15 +1,15 @@
apiVersion: v1
kind: ServiceAccount
metadata:
name: coder-workspace-${DEPLOY_NAME}
namespace: ${DEPLOY_NAME}
name: coder-workspace-pr${PR_NUMBER}
namespace: pr${PR_NUMBER}
---
apiVersion: rbac.authorization.k8s.io/v1
kind: Role
metadata:
name: coder-workspace-${DEPLOY_NAME}
namespace: ${DEPLOY_NAME}
name: coder-workspace-pr${PR_NUMBER}
namespace: pr${PR_NUMBER}
rules:
- apiGroups: ["*"]
resources: ["*"]
@@ -19,13 +19,13 @@ rules:
apiVersion: rbac.authorization.k8s.io/v1
kind: RoleBinding
metadata:
name: coder-workspace-${DEPLOY_NAME}
namespace: ${DEPLOY_NAME}
name: coder-workspace-pr${PR_NUMBER}
namespace: pr${PR_NUMBER}
subjects:
- kind: ServiceAccount
name: coder-workspace-${DEPLOY_NAME}
namespace: ${DEPLOY_NAME}
name: coder-workspace-pr${PR_NUMBER}
namespace: pr${PR_NUMBER}
roleRef:
apiGroup: rbac.authorization.k8s.io
kind: Role
name: coder-workspace-${DEPLOY_NAME}
name: coder-workspace-pr${PR_NUMBER}
+19 -52
View File
@@ -12,23 +12,9 @@ terraform {
provider "coder" {
}
variable "use_kubeconfig" {
type = bool
description = <<-EOF
Use host kubeconfig? (true/false)
Set this to false if the Coder host is itself running as a Pod on the same
Kubernetes cluster as you are deploying workspaces to.
Set this to true if the Coder host is running outside the Kubernetes cluster
for workspaces. A valid "~/.kube/config" must be present on the Coder host.
EOF
default = false
}
variable "namespace" {
type = string
description = "The Kubernetes namespace to create workspaces in (must exist prior to creating workspaces). If the Coder host is itself running as a Pod on the same Kubernetes cluster as you are deploying workspaces to, set this to the same namespace."
description = "The Kubernetes namespace to create workspaces in (must exist prior to creating workspaces)"
}
data "coder_parameter" "cpu" {
@@ -96,8 +82,7 @@ data "coder_parameter" "home_disk_size" {
}
provider "kubernetes" {
# Authenticate via ~/.kube/config or a Coder-specific ServiceAccount, depending on admin preferences
config_path = var.use_kubeconfig == true ? "~/.kube/config" : null
config_path = null
}
data "coder_workspace" "me" {}
@@ -109,12 +94,10 @@ resource "coder_agent" "main" {
startup_script = <<-EOT
set -e
# Install the latest code-server.
# Append "--version x.x.x" to install a specific version of code-server.
# install and start code-server
curl -fsSL https://code-server.dev/install.sh | sh -s -- --method=standalone --prefix=/tmp/code-server
# Start code-server in the background.
/tmp/code-server/bin/code-server --auth none --port 13337 >/tmp/code-server.log 2>&1 &
EOT
# The following metadata blocks are optional. They are used to display
@@ -191,13 +174,13 @@ resource "coder_app" "code-server" {
}
}
resource "kubernetes_persistent_volume_claim_v1" "home" {
resource "kubernetes_persistent_volume_claim" "home" {
metadata {
name = "coder-${data.coder_workspace.me.id}-home"
name = "coder-${lower(data.coder_workspace_owner.me.name)}-${lower(data.coder_workspace.me.name)}-home"
namespace = var.namespace
labels = {
"app.kubernetes.io/name" = "coder-pvc"
"app.kubernetes.io/instance" = "coder-pvc-${data.coder_workspace.me.id}"
"app.kubernetes.io/instance" = "coder-pvc-${lower(data.coder_workspace_owner.me.name)}-${lower(data.coder_workspace.me.name)}"
"app.kubernetes.io/part-of" = "coder"
//Coder-specific labels.
"com.coder.resource" = "true"
@@ -221,18 +204,18 @@ resource "kubernetes_persistent_volume_claim_v1" "home" {
}
}
resource "kubernetes_deployment_v1" "main" {
resource "kubernetes_deployment" "main" {
count = data.coder_workspace.me.start_count
depends_on = [
kubernetes_persistent_volume_claim_v1.home
kubernetes_persistent_volume_claim.home
]
wait_for_rollout = false
metadata {
name = "coder-${data.coder_workspace.me.id}"
name = "coder-${lower(data.coder_workspace_owner.me.name)}-${lower(data.coder_workspace.me.name)}"
namespace = var.namespace
labels = {
"app.kubernetes.io/name" = "coder-workspace"
"app.kubernetes.io/instance" = "coder-workspace-${data.coder_workspace.me.id}"
"app.kubernetes.io/instance" = "coder-workspace-${lower(data.coder_workspace_owner.me.name)}-${lower(data.coder_workspace.me.name)}"
"app.kubernetes.io/part-of" = "coder"
"com.coder.resource" = "true"
"com.coder.workspace.id" = data.coder_workspace.me.id
@@ -249,14 +232,7 @@ resource "kubernetes_deployment_v1" "main" {
replicas = 1
selector {
match_labels = {
"app.kubernetes.io/name" = "coder-workspace"
"app.kubernetes.io/instance" = "coder-workspace-${data.coder_workspace.me.id}"
"app.kubernetes.io/part-of" = "coder"
"com.coder.resource" = "true"
"com.coder.workspace.id" = data.coder_workspace.me.id
"com.coder.workspace.name" = data.coder_workspace.me.name
"com.coder.user.id" = data.coder_workspace_owner.me.id
"com.coder.user.username" = data.coder_workspace_owner.me.name
"app.kubernetes.io/name" = "coder-workspace"
}
}
strategy {
@@ -266,29 +242,20 @@ resource "kubernetes_deployment_v1" "main" {
template {
metadata {
labels = {
"app.kubernetes.io/name" = "coder-workspace"
"app.kubernetes.io/instance" = "coder-workspace-${data.coder_workspace.me.id}"
"app.kubernetes.io/part-of" = "coder"
"com.coder.resource" = "true"
"com.coder.workspace.id" = data.coder_workspace.me.id
"com.coder.workspace.name" = data.coder_workspace.me.name
"com.coder.user.id" = data.coder_workspace_owner.me.id
"com.coder.user.username" = data.coder_workspace_owner.me.name
"app.kubernetes.io/name" = "coder-workspace"
}
}
spec {
hostname = lower(data.coder_workspace.me.name)
security_context {
run_as_user = 1000
fs_group = 1000
run_as_non_root = true
run_as_user = 1000
fs_group = 1000
}
service_account_name = "coder-workspace-${var.namespace}"
container {
name = "dev"
image = "codercom/enterprise-base:ubuntu"
image_pull_policy = "IfNotPresent"
image = "bencdr/devops-tools"
image_pull_policy = "Always"
command = ["sh", "-c", coder_agent.main.init_script]
security_context {
run_as_user = "1000"
@@ -317,7 +284,7 @@ resource "kubernetes_deployment_v1" "main" {
volume {
name = "home"
persistent_volume_claim {
claim_name = kubernetes_persistent_volume_claim_v1.home.metadata.0.name
claim_name = kubernetes_persistent_volume_claim.home.metadata.0.name
read_only = false
}
}
+7 -9
View File
@@ -1,26 +1,24 @@
coder:
podAnnotations:
deploy-sha: "${GITHUB_SHA}"
image:
repo: "${REPO}"
tag: "${DEPLOY_NAME}"
tag: "pr${PR_NUMBER}"
pullPolicy: Always
service:
type: ClusterIP
ingress:
enable: true
className: traefik
host: "${DEPLOY_HOSTNAME}"
wildcardHost: "*.${DEPLOY_HOSTNAME}"
host: "${PR_HOSTNAME}"
wildcardHost: "*.${PR_HOSTNAME}"
tls:
enable: true
secretName: "${DEPLOY_NAME}-tls"
wildcardSecretName: "${DEPLOY_NAME}-tls"
secretName: "pr${PR_NUMBER}-tls"
wildcardSecretName: "pr${PR_NUMBER}-tls"
env:
- name: "CODER_ACCESS_URL"
value: "https://${DEPLOY_HOSTNAME}"
value: "https://${PR_HOSTNAME}"
- name: "CODER_WILDCARD_ACCESS_URL"
value: "*.${DEPLOY_HOSTNAME}"
value: "*.${PR_HOSTNAME}"
- name: "CODER_EXPERIMENTS"
value: "${EXPERIMENTS}"
- name: CODER_PG_CONNECTION_URL
-408
View File
@@ -1,408 +0,0 @@
name: Deploy Branch
on:
push:
workflow_dispatch:
permissions:
contents: read
concurrency:
group: deploy-${{ github.ref_name }}
cancel-in-progress: true
jobs:
build:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
permissions:
packages: write
env:
CODER_IMAGE_TAG: "ghcr.io/coder/coder-preview:${{ github.ref_name }}"
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
- name: Setup Node
uses: ./.github/actions/setup-node
- name: Setup Go
uses: ./.github/actions/setup-go
- name: Setup sqlc
uses: ./.github/actions/setup-sqlc
- name: GHCR Login
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build and push Docker image
run: |
set -euo pipefail
go mod download
make gen/mark-fresh
export DOCKER_IMAGE_NO_PREREQUISITES=true
version="$(./scripts/version.sh)"
CODER_IMAGE_BUILD_BASE_TAG="$(CODER_IMAGE_BASE=coder-base ./scripts/image_tag.sh --version "$version")"
export CODER_IMAGE_BUILD_BASE_TAG
make -j build/coder_linux_amd64
./scripts/build_docker.sh \
--arch amd64 \
--target "${CODER_IMAGE_TAG}" \
--version "$version" \
--push \
build/coder_linux_amd64
deploy:
needs: build
runs-on: ubuntu-latest
env:
BRANCH_NAME: ${{ github.ref_name }}
DEPLOY_NAME: "${{ github.ref_name }}"
TEST_DOMAIN_SUFFIX: "${{ startsWith(secrets.PR_DEPLOYMENTS_DOMAIN, 'test.') && secrets.PR_DEPLOYMENTS_DOMAIN || format('test.{0}', secrets.PR_DEPLOYMENTS_DOMAIN) }}"
BRANCH_HOSTNAME: "${{ github.ref_name }}.${{ startsWith(secrets.PR_DEPLOYMENTS_DOMAIN, 'test.') && secrets.PR_DEPLOYMENTS_DOMAIN || format('test.{0}', secrets.PR_DEPLOYMENTS_DOMAIN) }}"
CODER_IMAGE_TAG: "ghcr.io/coder/coder-preview:${{ github.ref_name }}"
REPO: ghcr.io/coder/coder-preview
EXPERIMENTS: "*,oauth2,mcp-server-http"
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Set up kubeconfig
run: |
set -euo pipefail
mkdir -p ~/.kube
echo "${{ secrets.PR_DEPLOYMENTS_KUBECONFIG_BASE64 }}" | base64 --decode > ~/.kube/config
chmod 600 ~/.kube/config
- name: Verify cluster authentication
run: |
set -euo pipefail
kubectl auth can-i get namespaces > /dev/null
- name: Check if deployment exists
id: check
run: |
set -euo pipefail
set +e
helm_status_output="$(helm status "${DEPLOY_NAME}" --namespace "${DEPLOY_NAME}" 2>&1)"
helm_status_code=$?
set -e
if [ "$helm_status_code" -eq 0 ]; then
echo "new=false" >> "$GITHUB_OUTPUT"
elif echo "$helm_status_output" | grep -qi "release: not found"; then
echo "new=true" >> "$GITHUB_OUTPUT"
else
echo "$helm_status_output"
exit "$helm_status_code"
fi
# ---- Every push: ensure routing + TLS ----
- name: Ensure DNS records
run: |
set -euo pipefail
api_base_url="https://api.cloudflare.com/client/v4/zones/${{ secrets.PR_DEPLOYMENTS_ZONE_ID }}/dns_records"
base_name="${BRANCH_HOSTNAME}"
base_target="${TEST_DOMAIN_SUFFIX}"
wildcard_name="*.${BRANCH_HOSTNAME}"
ensure_cname_record() {
local record_name="$1"
local record_content="$2"
echo "Ensuring CNAME ${record_name} -> ${record_content}."
set +e
lookup_raw_response="$(
curl -sS -G "${api_base_url}" \
-H "Authorization: Bearer ${{ secrets.PR_DEPLOYMENTS_CLOUDFLARE_API_TOKEN }}" \
-H "Content-Type:application/json" \
--data-urlencode "name=${record_name}" \
--data-urlencode "per_page=100" \
-w '\n%{http_code}'
)"
lookup_exit_code=$?
set -e
if [ "$lookup_exit_code" -eq 0 ]; then
lookup_response="${lookup_raw_response%$'\n'*}"
lookup_http_code="${lookup_raw_response##*$'\n'}"
if [ "$lookup_http_code" = "200" ] && echo "$lookup_response" | jq -e '.success == true' > /dev/null 2>&1; then
if echo "$lookup_response" | jq -e '.result[]? | select(.type != "CNAME")' > /dev/null 2>&1; then
echo "Conflicting non-CNAME DNS record exists for ${record_name}."
echo "$lookup_response"
return 1
fi
existing_cname_id="$(echo "$lookup_response" | jq -r '.result[]? | select(.type == "CNAME") | .id' | head -n1)"
if [ -n "$existing_cname_id" ]; then
existing_content="$(echo "$lookup_response" | jq -r --arg id "$existing_cname_id" '.result[] | select(.id == $id) | .content')"
if [ "$existing_content" = "$record_content" ]; then
echo "CNAME already set for ${record_name}."
return 0
fi
echo "Updating existing CNAME for ${record_name}."
update_response="$(
curl -sS -X PUT "${api_base_url}/${existing_cname_id}" \
-H "Authorization: Bearer ${{ secrets.PR_DEPLOYMENTS_CLOUDFLARE_API_TOKEN }}" \
-H "Content-Type:application/json" \
--data '{"type":"CNAME","name":"'"${record_name}"'","content":"'"${record_content}"'","ttl":1,"proxied":false}'
)"
if echo "$update_response" | jq -e '.success == true' > /dev/null 2>&1; then
echo "Updated CNAME for ${record_name}."
return 0
fi
echo "Cloudflare API error while updating ${record_name}:"
echo "$update_response"
return 1
fi
fi
else
echo "Could not query DNS record ${record_name}; attempting create."
fi
max_attempts=6
attempt=1
last_response=""
last_http_code=""
while [ "$attempt" -le "$max_attempts" ]; do
echo "Creating DNS record ${record_name} (attempt ${attempt}/${max_attempts})."
set +e
raw_response="$(
curl -sS -X POST "${api_base_url}" \
-H "Authorization: Bearer ${{ secrets.PR_DEPLOYMENTS_CLOUDFLARE_API_TOKEN }}" \
-H "Content-Type:application/json" \
--data '{"type":"CNAME","name":"'"${record_name}"'","content":"'"${record_content}"'","ttl":1,"proxied":false}' \
-w '\n%{http_code}'
)"
curl_exit_code=$?
set -e
curl_failed=false
if [ "$curl_exit_code" -eq 0 ]; then
response="${raw_response%$'\n'*}"
http_code="${raw_response##*$'\n'}"
else
response="curl exited with code ${curl_exit_code}."
http_code="000"
curl_failed=true
fi
last_response="$response"
last_http_code="$http_code"
if echo "$response" | jq -e '.success == true' > /dev/null 2>&1; then
echo "Created DNS record ${record_name}."
return 0
fi
# 81057: identical record exists. 81053: host record conflict.
if echo "$response" | jq -e '.errors[]? | select(.code == 81057 or .code == 81053)' > /dev/null 2>&1; then
echo "DNS record already exists for ${record_name}."
return 0
fi
transient_error=false
if [ "$curl_failed" = true ] || [ "$http_code" = "429" ]; then
transient_error=true
elif [[ "$http_code" =~ ^[0-9]{3}$ ]] && [ "$http_code" -ge 500 ] && [ "$http_code" -lt 600 ]; then
transient_error=true
fi
if echo "$response" | jq -e '.errors[]? | select(.code == 10000 or .code == 10001)' > /dev/null 2>&1; then
transient_error=true
fi
if [ "$transient_error" = true ] && [ "$attempt" -lt "$max_attempts" ]; then
sleep_seconds=$((attempt * 5))
echo "Transient Cloudflare API error (HTTP ${http_code}). Retrying in ${sleep_seconds}s."
sleep "$sleep_seconds"
attempt=$((attempt + 1))
continue
fi
break
done
echo "Cloudflare API error while creating DNS record ${record_name} after ${attempt} attempt(s):"
echo "HTTP status: ${last_http_code}"
echo "$last_response"
return 1
}
ensure_cname_record "${base_name}" "${base_target}"
ensure_cname_record "${wildcard_name}" "${base_name}"
# ---- First deploy only ----
- name: Create namespace
if: steps.check.outputs.new == 'true'
run: |
set -euo pipefail
kubectl delete namespace "${DEPLOY_NAME}" --wait=true || true
# Delete any orphaned PVs that were bound to PVCs in this
# namespace. Without this, the old PV (with stale Postgres
# data) gets reused on reinstall, causing auth failures.
kubectl get pv -o json | \
jq -r '.items[] | select(.spec.claimRef.namespace=='"${DEPLOY_NAME}"') | .metadata.name' | \
xargs -r kubectl delete pv || true
kubectl create namespace "${DEPLOY_NAME}"
# ---- Every push: ensure deployment certificate ----
- name: Ensure certificate
env:
DEPLOY_HOSTNAME: ${{ env.BRANCH_HOSTNAME }}
run: |
set -euo pipefail
cert_secret_name="${DEPLOY_NAME}-tls"
envsubst < ./.github/pr-deployments/certificate.yaml | kubectl apply -f -
if ! kubectl -n pr-deployment-certs wait --for=condition=Ready "certificate/${cert_secret_name}" --timeout=10m; then
echo "Timed out waiting for certificate ${cert_secret_name} to become Ready after 10 minutes."
kubectl -n pr-deployment-certs describe certificate "${cert_secret_name}" || true
kubectl -n pr-deployment-certs get certificaterequest,order,challenge -l "cert-manager.io/certificate-name=${cert_secret_name}" || true
exit 1
fi
kubectl get secret "${cert_secret_name}" -n pr-deployment-certs -o json |
jq 'del(.metadata.namespace,.metadata.creationTimestamp,.metadata.resourceVersion,.metadata.selfLink,.metadata.uid,.metadata.managedFields)' |
kubectl -n "${DEPLOY_NAME}" apply -f -
- name: Set up PostgreSQL
if: steps.check.outputs.new == 'true'
run: |
helm repo add bitnami https://charts.bitnami.com/bitnami
helm install coder-db bitnami/postgresql \
--namespace "${DEPLOY_NAME}" \
--set image.repository=bitnamilegacy/postgresql \
--set auth.username=coder \
--set auth.password=coder \
--set auth.database=coder \
--set persistence.size=10Gi
kubectl create secret generic coder-db-url -n "${DEPLOY_NAME}" \
--from-literal=url="postgres://coder:coder@coder-db-postgresql.${DEPLOY_NAME}.svc.cluster.local:5432/coder?sslmode=disable"
- name: Create RBAC
if: steps.check.outputs.new == 'true'
run: envsubst < ./.github/pr-deployments/rbac.yaml | kubectl apply -f -
# ---- Every push ----
- name: Create values.yaml
env:
DEPLOY_HOSTNAME: ${{ env.BRANCH_HOSTNAME }}
REPO: ${{ env.REPO }}
PR_DEPLOYMENTS_GITHUB_OAUTH_CLIENT_ID: ${{ secrets.PR_DEPLOYMENTS_GITHUB_OAUTH_CLIENT_ID }}
PR_DEPLOYMENTS_GITHUB_OAUTH_CLIENT_SECRET: ${{ secrets.PR_DEPLOYMENTS_GITHUB_OAUTH_CLIENT_SECRET }}
run: envsubst < ./.github/pr-deployments/values.yaml > ./deploy-values.yaml
- name: Install/Upgrade Helm chart
run: |
set -euo pipefail
helm dependency update --skip-refresh ./helm/coder
helm upgrade --install "${DEPLOY_NAME}" ./helm/coder \
--namespace "${DEPLOY_NAME}" \
--values ./deploy-values.yaml \
--force
- name: Install coder-logstream-kube
if: steps.check.outputs.new == 'true'
run: |
helm repo add coder-logstream-kube https://helm.coder.com/logstream-kube
helm upgrade --install coder-logstream-kube coder-logstream-kube/coder-logstream-kube \
--namespace "${DEPLOY_NAME}" \
--set url="https://${BRANCH_HOSTNAME}" \
--set "namespaces[0]=${DEPLOY_NAME}"
- name: Create first user and template
if: steps.check.outputs.new == 'true'
env:
PR_DEPLOYMENTS_ADMIN_PASSWORD: ${{ secrets.PR_DEPLOYMENTS_ADMIN_PASSWORD }}
run: |
set -euo pipefail
URL="https://${BRANCH_HOSTNAME}/bin/coder-linux-amd64"
COUNT=0
until curl --output /dev/null --silent --head --fail "$URL"; do
sleep 5
COUNT=$((COUNT+1))
if [ "$COUNT" -ge 60 ]; then echo "Timed out"; exit 1; fi
done
curl -fsSL "$URL" -o /tmp/coder && chmod +x /tmp/coder
password="${PR_DEPLOYMENTS_ADMIN_PASSWORD}"
if [ -z "$password" ]; then
echo "Missing PR_DEPLOYMENTS_ADMIN_PASSWORD repository secret."
exit 1
fi
echo "::add-mask::$password"
admin_username="${BRANCH_NAME}-admin"
admin_email="${BRANCH_NAME}@coder.com"
coder_url="https://${BRANCH_HOSTNAME}"
first_user_status="$(curl -sS -o /dev/null -w '%{http_code}' "${coder_url}/api/v2/users/first")"
if [ "$first_user_status" = "404" ]; then
/tmp/coder login \
--first-user-username "$admin_username" \
--first-user-email "$admin_email" \
--first-user-password "$password" \
--first-user-trial=false \
--use-token-as-session \
"$coder_url"
elif [ "$first_user_status" = "200" ]; then
login_payload="$(jq -n --arg email "$admin_email" --arg password "$password" '{email: $email, password: $password}')"
login_response="$(
curl -sS -X POST "${coder_url}/api/v2/users/login" \
-H "Content-Type: application/json" \
--data "$login_payload" \
-w '\n%{http_code}'
)"
login_body="${login_response%$'\n'*}"
login_status="${login_response##*$'\n'}"
if [ "$login_status" != "201" ]; then
echo "Password login failed for existing deployment (HTTP ${login_status})."
echo "$login_body"
exit 1
fi
session_token="$(echo "$login_body" | jq -r '.session_token // empty')"
if [ -z "$session_token" ]; then
echo "Password login response is missing session_token."
exit 1
fi
echo "::add-mask::$session_token"
/tmp/coder login \
--token "$session_token" \
--use-token-as-session \
"$coder_url"
else
echo "Unexpected status from /api/v2/users/first: ${first_user_status}."
exit 1
fi
cd .github/pr-deployments/template
/tmp/coder templates push -y --directory . --variable "namespace=${DEPLOY_NAME}" kubernetes
/tmp/coder create --template="kubernetes" kube \
--parameter cpu=2 --parameter memory=4 --parameter home_disk_size=2 -y
/tmp/coder stop kube -y
+14 -24
View File
@@ -160,41 +160,34 @@ jobs:
# Build context based on trigger type
case "${TRIGGER_TYPE}" in
new_pr)
CONTEXT="This is a NEW PR. Perform initial documentation review."
CONTEXT="This is a NEW PR. Perform a thorough documentation review."
;;
pr_updated)
CONTEXT="This PR was UPDATED with new commits. Check if previous feedback was addressed or if new doc needs arose."
CONTEXT="This PR was UPDATED with new commits. Only comment if the changes affect documentation needs or address previous feedback."
;;
label_requested)
CONTEXT="A documentation review was REQUESTED via label. Perform a thorough review."
CONTEXT="A documentation review was REQUESTED via label. Perform a thorough documentation review."
;;
ready_for_review)
CONTEXT="This PR was marked READY FOR REVIEW. Perform a thorough review."
CONTEXT="This PR was marked READY FOR REVIEW (converted from draft). Perform a thorough documentation review."
;;
manual)
CONTEXT="This is a MANUAL review request. Perform a thorough review."
CONTEXT="This is a MANUAL review request. Perform a thorough documentation review."
;;
*)
CONTEXT="Perform a documentation review."
CONTEXT="Perform a thorough documentation review."
;;
esac
# Build task prompt with sticky comment logic
# Build task prompt with PR-specific context
TASK_PROMPT="Use the doc-check skill to review PR #${PR_NUMBER} in coder/coder.
${CONTEXT}
Use \`gh\` to get PR details, diff, and all comments. Look for an existing doc-check comment containing \`<!-- doc-check-sticky -->\` - if one exists, you'll update it instead of creating a new one.
Use \`gh\` to get PR details, diff, and all comments. Check for previous doc-check comments (from coder-doc-check) and only post a new comment if it adds value.
**Do not comment if no documentation changes are needed.**
If a sticky comment already exists, compare your current findings against it:
- Check off \`[x]\` items that are now addressed
- Strikethrough items no longer needed (e.g., code was reverted)
- Add new unchecked \`[ ]\` items for newly discovered needs
- If an item is checked but you can't verify the docs were added, add a warning note below it
- If nothing meaningful changed, don't update the comment at all
## Comment format
Use this structure (only include relevant sections):
@@ -202,21 +195,18 @@ jobs:
\`\`\`
## Documentation Check
### Previous Feedback
[For re-reviews only: Addressed | Partially addressed | Not yet addressed]
### Updates Needed
- [ ] \`docs/path/file.md\` - What needs to change
- [x] \`docs/other/file.md\` - This was addressed
- ~~\`docs/removed.md\` - No longer needed~~ *(reverted in abc123)*
- [ ] \`docs/path/file.md\` - [what needs to change]
### New Documentation Needed
- [ ] \`docs/suggested/path.md\` - What should be documented
> ⚠️ *Checked but no corresponding documentation changes found in this PR*
- [ ] \`docs/suggested/path.md\` - [what should be documented]
---
*Automated review via [Coder Tasks](https://coder.com/docs/ai-coder/tasks)*
<!-- doc-check-sticky -->
\`\`\`
The \`<!-- doc-check-sticky -->\` marker must be at the end so future runs can find and update this comment."
\`\`\`"
# Output the prompt
{
+1 -3
View File
@@ -285,8 +285,6 @@ jobs:
PR_NUMBER: ${{ needs.get_info.outputs.PR_NUMBER }}
PR_TITLE: ${{ needs.get_info.outputs.PR_TITLE }}
PR_URL: ${{ needs.get_info.outputs.PR_URL }}
DEPLOY_NAME: "pr${{ needs.get_info.outputs.PR_NUMBER }}"
DEPLOY_HOSTNAME: "pr${{ needs.get_info.outputs.PR_NUMBER }}.${{ secrets.PR_DEPLOYMENTS_DOMAIN }}"
PR_HOSTNAME: "pr${{ needs.get_info.outputs.PR_NUMBER }}.${{ secrets.PR_DEPLOYMENTS_DOMAIN }}"
steps:
- name: Harden Runner
@@ -523,7 +521,7 @@ jobs:
run: |
set -euo pipefail
cd .github/pr-deployments/template
coder templates push -y --directory . --variable "namespace=pr${PR_NUMBER}" kubernetes
coder templates push -y --variable "namespace=pr${PR_NUMBER}" kubernetes
# Create workspace
coder create --template="kubernetes" kube --parameter cpu=2 --parameter memory=4 --parameter home_disk_size=2 -y
-1
View File
@@ -938,7 +938,6 @@ coderd/apidoc/.gen: \
coderd/rbac/object_gen.go \
.swaggo \
scripts/apidocgen/generate.sh \
scripts/apidocgen/swaginit/main.go \
$(wildcard scripts/apidocgen/postprocess/*) \
$(wildcard scripts/apidocgen/markdown-template/*)
./scripts/apidocgen/generate.sh
+19 -90
View File
@@ -12,7 +12,6 @@ import (
"net"
"net/http"
"net/netip"
"net/url"
"os"
"os/user"
"path/filepath"
@@ -882,7 +881,7 @@ const (
reportConnectionBufferLimit = 2048
)
func (a *agent) reportConnection(id uuid.UUID, connectionType proto.Connection_Type, ip string, options ...func(*proto.Connection)) (disconnected func(code int, reason string)) {
func (a *agent) reportConnection(id uuid.UUID, connectionType proto.Connection_Type, ip string) (disconnected func(code int, reason string)) {
// A blank IP can unfortunately happen if the connection is broken in a data race before we get to introspect it. We
// still report it, and the recipient can handle a blank IP.
if ip != "" {
@@ -913,20 +912,16 @@ func (a *agent) reportConnection(id uuid.UUID, connectionType proto.Connection_T
slog.F("ip", ip),
)
} else {
connectMsg := &proto.Connection{
Id: id[:],
Action: proto.Connection_CONNECT,
Type: connectionType,
Timestamp: timestamppb.New(time.Now()),
Ip: ip,
StatusCode: 0,
Reason: nil,
}
for _, opt := range options {
opt(connectMsg)
}
a.reportConnections = append(a.reportConnections, &proto.ReportConnectionRequest{
Connection: connectMsg,
Connection: &proto.Connection{
Id: id[:],
Action: proto.Connection_CONNECT,
Type: connectionType,
Timestamp: timestamppb.New(time.Now()),
Ip: ip,
StatusCode: 0,
Reason: nil,
},
})
select {
case a.reportConnectionsUpdate <- struct{}{}:
@@ -947,20 +942,16 @@ func (a *agent) reportConnection(id uuid.UUID, connectionType proto.Connection_T
return
}
disconnMsg := &proto.Connection{
Id: id[:],
Action: proto.Connection_DISCONNECT,
Type: connectionType,
Timestamp: timestamppb.New(time.Now()),
Ip: ip,
StatusCode: int32(code), //nolint:gosec
Reason: &reason,
}
for _, opt := range options {
opt(disconnMsg)
}
a.reportConnections = append(a.reportConnections, &proto.ReportConnectionRequest{
Connection: disconnMsg,
Connection: &proto.Connection{
Id: id[:],
Action: proto.Connection_DISCONNECT,
Type: connectionType,
Timestamp: timestamppb.New(time.Now()),
Ip: ip,
StatusCode: int32(code), //nolint:gosec
Reason: &reason,
},
})
select {
case a.reportConnectionsUpdate <- struct{}{}:
@@ -1386,8 +1377,6 @@ func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(co
manifest.DERPForceWebSockets,
manifest.DisableDirectConnections,
keySeed,
manifest.WorkspaceName,
manifest.Apps,
)
if err != nil {
return xerrors.Errorf("create tailnet: %w", err)
@@ -1536,39 +1525,12 @@ func (a *agent) trackGoroutine(fn func()) error {
return nil
}
// appPortFromURL extracts the port from a workspace app URL,
// defaulting to 80/443 by scheme.
func appPortFromURL(rawURL string) uint16 {
u, err := url.Parse(rawURL)
if err != nil {
return 0
}
p := u.Port()
if p == "" {
switch u.Scheme {
case "http":
return 80
case "https":
return 443
default:
return 0
}
}
port, err := strconv.ParseUint(p, 10, 16)
if err != nil {
return 0
}
return uint16(port)
}
func (a *agent) createTailnet(
ctx context.Context,
agentID uuid.UUID,
derpMap *tailcfg.DERPMap,
derpForceWebSockets, disableDirectConnections bool,
keySeed int64,
workspaceName string,
apps []codersdk.WorkspaceApp,
) (_ *tailnet.Conn, err error) {
// Inject `CODER_AGENT_HEADER` into the DERP header.
var header http.Header
@@ -1577,18 +1539,6 @@ func (a *agent) createTailnet(
header = headerTransport.Header
}
}
// Build port-to-app mapping for workspace app connection tracking
// via the tailnet callback.
portToApp := make(map[uint16]codersdk.WorkspaceApp)
for _, app := range apps {
port := appPortFromURL(app.URL)
if port == 0 || app.External {
continue
}
portToApp[port] = app
}
network, err := tailnet.NewConn(&tailnet.Options{
ID: agentID,
Addresses: a.wireguardAddresses(agentID),
@@ -1598,27 +1548,6 @@ func (a *agent) createTailnet(
Logger: a.logger.Named("net.tailnet"),
ListenPort: a.tailnetListenPort,
BlockEndpoints: disableDirectConnections,
ShortDescription: "Workspace Agent",
Hostname: workspaceName,
TCPConnCallback: func(src, dst netip.AddrPort) (disconnected func(int, string)) {
app, ok := portToApp[dst.Port()]
connType := proto.Connection_PORT_FORWARDING
slugOrPort := strconv.Itoa(int(dst.Port()))
if ok {
connType = proto.Connection_WORKSPACE_APP
if app.Slug != "" {
slugOrPort = app.Slug
}
}
return a.reportConnection(
uuid.New(),
connType,
src.String(),
func(c *proto.Connection) {
c.SlugOrPort = &slugOrPort
},
)
},
})
if err != nil {
return nil, xerrors.Errorf("create tailnet: %w", err)
-96
View File
@@ -2843,102 +2843,6 @@ func TestAgent_Dial(t *testing.T) {
}
}
// TestAgent_PortForwardConnectionType verifies connection
// type classification for forwarded TCP connections.
func TestAgent_PortForwardConnectionType(t *testing.T) {
t.Parallel()
// Start a TCP echo server for the "app" port.
appListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(func() { _ = appListener.Close() })
appPort := appListener.Addr().(*net.TCPAddr).Port
// Start a TCP echo server for a non-app port.
nonAppListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(func() { _ = nonAppListener.Close() })
nonAppPort := nonAppListener.Addr().(*net.TCPAddr).Port
echoOnce := func(l net.Listener) <-chan struct{} {
done := make(chan struct{})
go func() {
defer close(done)
c, err := l.Accept()
if err != nil {
return
}
defer c.Close()
_, _ = io.Copy(c, c)
}()
return done
}
ctx := testutil.Context(t, testutil.WaitLong)
//nolint:dogsled
agentConn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{
Apps: []codersdk.WorkspaceApp{
{
ID: uuid.New(),
Slug: "myapp",
URL: fmt.Sprintf("http://localhost:%d", appPort),
SharingLevel: codersdk.WorkspaceAppSharingLevelOwner,
Health: codersdk.WorkspaceAppHealthDisabled,
},
},
}, 0)
require.True(t, agentConn.AwaitReachable(ctx))
// Phase 1: Connect to the app port, expect WORKSPACE_APP.
appDone := echoOnce(appListener)
conn, err := agentConn.DialContext(ctx, "tcp", appListener.Addr().String())
require.NoError(t, err)
testDial(ctx, t, conn)
_ = conn.Close()
<-appDone
var reports []*proto.ReportConnectionRequest
require.Eventually(t, func() bool {
reports = agentClient.GetConnectionReports()
return len(reports) >= 2
}, testutil.WaitMedium, testutil.IntervalFast,
"waiting for 2 connection reports for workspace app",
)
require.Equal(t, proto.Connection_CONNECT, reports[0].GetConnection().GetAction())
require.Equal(t, proto.Connection_WORKSPACE_APP, reports[0].GetConnection().GetType())
require.Equal(t, "myapp", reports[0].GetConnection().GetSlugOrPort())
require.Equal(t, proto.Connection_DISCONNECT, reports[1].GetConnection().GetAction())
require.Equal(t, proto.Connection_WORKSPACE_APP, reports[1].GetConnection().GetType())
require.Equal(t, "myapp", reports[1].GetConnection().GetSlugOrPort())
// Phase 2: Connect to the non-app port, expect PORT_FORWARDING.
nonAppDone := echoOnce(nonAppListener)
conn, err = agentConn.DialContext(ctx, "tcp", nonAppListener.Addr().String())
require.NoError(t, err)
testDial(ctx, t, conn)
_ = conn.Close()
<-nonAppDone
nonAppPortStr := strconv.Itoa(nonAppPort)
require.Eventually(t, func() bool {
reports = agentClient.GetConnectionReports()
return len(reports) >= 4
}, testutil.WaitMedium, testutil.IntervalFast,
"waiting for 4 connection reports total",
)
require.Equal(t, proto.Connection_CONNECT, reports[2].GetConnection().GetAction())
require.Equal(t, proto.Connection_PORT_FORWARDING, reports[2].GetConnection().GetType())
require.Equal(t, nonAppPortStr, reports[2].GetConnection().GetSlugOrPort())
require.Equal(t, proto.Connection_DISCONNECT, reports[3].GetConnection().GetAction())
require.Equal(t, proto.Connection_PORT_FORWARDING, reports[3].GetConnection().GetType())
require.Equal(t, nonAppPortStr, reports[3].GetConnection().GetSlugOrPort())
}
// TestAgent_UpdatedDERP checks that agents can handle their DERP map being
// updated, and that clients can also handle it.
func TestAgent_UpdatedDERP(t *testing.T) {
+2 -71
View File
@@ -1,9 +1,9 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: .. (interfaces: ContainerCLI,DevcontainerCLI,SubAgentClient)
// Source: .. (interfaces: ContainerCLI,DevcontainerCLI)
//
// Generated by this command:
//
// mockgen -destination ./acmock.go -package acmock .. ContainerCLI,DevcontainerCLI,SubAgentClient
// mockgen -destination ./acmock.go -package acmock .. ContainerCLI,DevcontainerCLI
//
// Package acmock is a generated GoMock package.
@@ -15,7 +15,6 @@ import (
agentcontainers "github.com/coder/coder/v2/agent/agentcontainers"
codersdk "github.com/coder/coder/v2/codersdk"
uuid "github.com/google/uuid"
gomock "go.uber.org/mock/gomock"
)
@@ -217,71 +216,3 @@ func (mr *MockDevcontainerCLIMockRecorder) Up(ctx, workspaceFolder, configPath a
varargs := append([]any{ctx, workspaceFolder, configPath}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Up", reflect.TypeOf((*MockDevcontainerCLI)(nil).Up), varargs...)
}
// MockSubAgentClient is a mock of SubAgentClient interface.
type MockSubAgentClient struct {
ctrl *gomock.Controller
recorder *MockSubAgentClientMockRecorder
isgomock struct{}
}
// MockSubAgentClientMockRecorder is the mock recorder for MockSubAgentClient.
type MockSubAgentClientMockRecorder struct {
mock *MockSubAgentClient
}
// NewMockSubAgentClient creates a new mock instance.
func NewMockSubAgentClient(ctrl *gomock.Controller) *MockSubAgentClient {
mock := &MockSubAgentClient{ctrl: ctrl}
mock.recorder = &MockSubAgentClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSubAgentClient) EXPECT() *MockSubAgentClientMockRecorder {
return m.recorder
}
// Create mocks base method.
func (m *MockSubAgentClient) Create(ctx context.Context, agent agentcontainers.SubAgent) (agentcontainers.SubAgent, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Create", ctx, agent)
ret0, _ := ret[0].(agentcontainers.SubAgent)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Create indicates an expected call of Create.
func (mr *MockSubAgentClientMockRecorder) Create(ctx, agent any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockSubAgentClient)(nil).Create), ctx, agent)
}
// Delete mocks base method.
func (m *MockSubAgentClient) Delete(ctx context.Context, id uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete", ctx, id)
ret0, _ := ret[0].(error)
return ret0
}
// Delete indicates an expected call of Delete.
func (mr *MockSubAgentClientMockRecorder) Delete(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockSubAgentClient)(nil).Delete), ctx, id)
}
// List mocks base method.
func (m *MockSubAgentClient) List(ctx context.Context) ([]agentcontainers.SubAgent, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "List", ctx)
ret0, _ := ret[0].([]agentcontainers.SubAgent)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// List indicates an expected call of List.
func (mr *MockSubAgentClientMockRecorder) List(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockSubAgentClient)(nil).List), ctx)
}
+1 -1
View File
@@ -1,4 +1,4 @@
// Package acmock contains a mock implementation of agentcontainers.Lister for use in tests.
package acmock
//go:generate mockgen -destination ./acmock.go -package acmock .. ContainerCLI,DevcontainerCLI,SubAgentClient
//go:generate mockgen -destination ./acmock.go -package acmock .. ContainerCLI,DevcontainerCLI
+15 -47
View File
@@ -562,9 +562,12 @@ func (api *API) discoverDevcontainersInProject(projectPath string) error {
api.broadcastUpdatesLocked()
if dc.Status == codersdk.WorkspaceAgentDevcontainerStatusStarting {
api.asyncWg.Go(func() {
api.asyncWg.Add(1)
go func() {
defer api.asyncWg.Done()
_ = api.CreateDevcontainer(dc.WorkspaceFolder, dc.ConfigPath)
})
}()
}
}
api.mu.Unlock()
@@ -1624,25 +1627,16 @@ func (api *API) cleanupSubAgents(ctx context.Context) error {
api.mu.Lock()
defer api.mu.Unlock()
// Collect all subagent IDs that should be kept:
// 1. Subagents currently tracked by injectedSubAgentProcs
// 2. Subagents referenced by known devcontainers from the manifest
var keep []uuid.UUID
injected := make(map[uuid.UUID]bool, len(api.injectedSubAgentProcs))
for _, proc := range api.injectedSubAgentProcs {
keep = append(keep, proc.agent.ID)
}
for _, dc := range api.knownDevcontainers {
if dc.SubagentID.Valid {
keep = append(keep, dc.SubagentID.UUID)
}
injected[proc.agent.ID] = true
}
ctx, cancel := context.WithTimeout(ctx, defaultOperationTimeout)
defer cancel()
var errs []error
for _, agent := range agents {
if slices.Contains(keep, agent.ID) {
if injected[agent.ID] {
continue
}
client := *api.subAgentClient.Load()
@@ -1653,11 +1647,10 @@ func (api *API) cleanupSubAgents(ctx context.Context) error {
slog.F("agent_id", agent.ID),
slog.F("agent_name", agent.Name),
)
errs = append(errs, xerrors.Errorf("delete agent %s (%s): %w", agent.Name, agent.ID, err))
}
}
return errors.Join(errs...)
return nil
}
// maybeInjectSubAgentIntoContainerLocked injects a subagent into a dev
@@ -2008,20 +2001,7 @@ func (api *API) maybeInjectSubAgentIntoContainerLocked(ctx context.Context, dc c
// logger.Warn(ctx, "set CAP_NET_ADMIN on agent binary failed", slog.Error(err))
// }
// Only delete and recreate subagents that were dynamically created
// (ID == uuid.Nil). Terraform-defined subagents (subAgentConfig.ID !=
// uuid.Nil) must not be deleted because they have attached resources
// managed by terraform.
isTerraformManaged := subAgentConfig.ID != uuid.Nil
configHasChanged := !proc.agent.EqualConfig(subAgentConfig)
logger.Debug(ctx, "checking if sub agent should be deleted",
slog.F("is_terraform_managed", isTerraformManaged),
slog.F("maybe_recreate_sub_agent", maybeRecreateSubAgent),
slog.F("config_has_changed", configHasChanged),
)
deleteSubAgent := !isTerraformManaged && maybeRecreateSubAgent && configHasChanged
deleteSubAgent := proc.agent.ID != uuid.Nil && maybeRecreateSubAgent && !proc.agent.EqualConfig(subAgentConfig)
if deleteSubAgent {
logger.Debug(ctx, "deleting existing subagent for recreation", slog.F("agent_id", proc.agent.ID))
client := *api.subAgentClient.Load()
@@ -2032,23 +2012,11 @@ func (api *API) maybeInjectSubAgentIntoContainerLocked(ctx context.Context, dc c
proc.agent = SubAgent{} // Clear agent to signal that we need to create a new one.
}
// Re-create (upsert) terraform-managed subagents when the config
// changes so that display apps and other settings are updated
// without deleting the agent.
recreateTerraformSubAgent := isTerraformManaged && maybeRecreateSubAgent && configHasChanged
if proc.agent.ID == uuid.Nil || recreateTerraformSubAgent {
if recreateTerraformSubAgent {
logger.Debug(ctx, "updating existing subagent",
slog.F("directory", subAgentConfig.Directory),
slog.F("display_apps", subAgentConfig.DisplayApps),
)
} else {
logger.Debug(ctx, "creating new subagent",
slog.F("directory", subAgentConfig.Directory),
slog.F("display_apps", subAgentConfig.DisplayApps),
)
}
if proc.agent.ID == uuid.Nil {
logger.Debug(ctx, "creating new subagent",
slog.F("directory", subAgentConfig.Directory),
slog.F("display_apps", subAgentConfig.DisplayApps),
)
// Create new subagent record in the database to receive the auth token.
// If we get a unique constraint violation, try with expanded names that
+9 -369
View File
@@ -437,11 +437,7 @@ func (m *fakeSubAgentClient) Create(ctx context.Context, agent agentcontainers.S
}
}
// Only generate a new ID if one wasn't provided. Terraform-defined
// subagents have pre-existing IDs that should be preserved.
if agent.ID == uuid.Nil {
agent.ID = uuid.New()
}
agent.ID = uuid.New()
agent.AuthToken = uuid.New()
if m.agents == nil {
m.agents = make(map[uuid.UUID]agentcontainers.SubAgent)
@@ -1039,30 +1035,6 @@ func TestAPI(t *testing.T) {
wantStatus: []int{http.StatusAccepted, http.StatusConflict},
wantBody: []string{"Devcontainer recreation initiated", "is currently starting and cannot be restarted"},
},
{
name: "Terraform-defined devcontainer can be rebuilt",
devcontainerID: devcontainerID1.String(),
setupDevcontainers: []codersdk.WorkspaceAgentDevcontainer{
{
ID: devcontainerID1,
Name: "test-devcontainer-terraform",
WorkspaceFolder: workspaceFolder1,
ConfigPath: configPath1,
Status: codersdk.WorkspaceAgentDevcontainerStatusRunning,
Container: &devContainer1,
SubagentID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
},
},
lister: &fakeContainerCLI{
containers: codersdk.WorkspaceAgentListContainersResponse{
Containers: []codersdk.WorkspaceAgentContainer{devContainer1},
},
arch: "<none>",
},
devcontainerCLI: &fakeDevcontainerCLI{},
wantStatus: []int{http.StatusAccepted, http.StatusConflict},
wantBody: []string{"Devcontainer recreation initiated", "is currently starting and cannot be restarted"},
},
}
for _, tt := range tests {
@@ -1477,6 +1449,14 @@ func TestAPI(t *testing.T) {
)
}
api := agentcontainers.NewAPI(logger, apiOpts...)
api.Start()
defer api.Close()
r := chi.NewRouter()
r.Mount("/", api.Routes())
var (
agentRunningCh chan struct{}
stopAgentCh chan struct{}
@@ -1493,14 +1473,6 @@ func TestAPI(t *testing.T) {
}
}
api := agentcontainers.NewAPI(logger, apiOpts...)
api.Start()
defer api.Close()
r := chi.NewRouter()
r.Mount("/", api.Routes())
tickerTrap.MustWait(ctx).MustRelease(ctx)
tickerTrap.Close()
@@ -2518,338 +2490,6 @@ func TestAPI(t *testing.T) {
assert.Empty(t, fakeSAC.agents)
})
t.Run("SubAgentCleanupPreservesTerraformDefined", func(t *testing.T) {
t.Parallel()
var (
// Given: A terraform-defined agent and devcontainer that should be preserved
terraformAgentID = uuid.New()
terraformAgentToken = uuid.New()
terraformAgent = agentcontainers.SubAgent{
ID: terraformAgentID,
Name: "terraform-defined-agent",
Directory: "/workspace",
AuthToken: terraformAgentToken,
}
terraformDevcontainer = codersdk.WorkspaceAgentDevcontainer{
ID: uuid.New(),
Name: "terraform-devcontainer",
WorkspaceFolder: "/workspace/project",
SubagentID: uuid.NullUUID{UUID: terraformAgentID, Valid: true},
}
// Given: An orphaned agent that should be cleaned up
orphanedAgentID = uuid.New()
orphanedAgentToken = uuid.New()
orphanedAgent = agentcontainers.SubAgent{
ID: orphanedAgentID,
Name: "orphaned-agent",
Directory: "/tmp",
AuthToken: orphanedAgentToken,
}
ctx = testutil.Context(t, testutil.WaitMedium)
logger = slog.Make()
mClock = quartz.NewMock(t)
mCCLI = acmock.NewMockContainerCLI(gomock.NewController(t))
fakeSAC = &fakeSubAgentClient{
logger: logger.Named("fakeSubAgentClient"),
agents: map[uuid.UUID]agentcontainers.SubAgent{
terraformAgentID: terraformAgent,
orphanedAgentID: orphanedAgent,
},
}
)
mCCLI.EXPECT().List(gomock.Any()).Return(codersdk.WorkspaceAgentListContainersResponse{
Containers: []codersdk.WorkspaceAgentContainer{},
}, nil).AnyTimes()
mClock.Set(time.Now()).MustWait(ctx)
tickerTrap := mClock.Trap().TickerFunc("updaterLoop")
api := agentcontainers.NewAPI(logger,
agentcontainers.WithClock(mClock),
agentcontainers.WithContainerCLI(mCCLI),
agentcontainers.WithSubAgentClient(fakeSAC),
agentcontainers.WithDevcontainerCLI(&fakeDevcontainerCLI{}),
agentcontainers.WithDevcontainers([]codersdk.WorkspaceAgentDevcontainer{terraformDevcontainer}, nil),
)
api.Start()
defer api.Close()
tickerTrap.MustWait(ctx).MustRelease(ctx)
tickerTrap.Close()
// When: We advance the clock, allowing cleanup to occur
_, aw := mClock.AdvanceNext()
aw.MustWait(ctx)
// Then: The orphaned agent should be deleted
assert.Contains(t, fakeSAC.deleted, orphanedAgentID, "orphaned agent should be deleted")
// And: The terraform-defined agent should not be deleted
assert.NotContains(t, fakeSAC.deleted, terraformAgentID, "terraform-defined agent should be preserved")
assert.Len(t, fakeSAC.agents, 1, "only terraform agent should remain")
assert.Contains(t, fakeSAC.agents, terraformAgentID, "terraform agent should still exist")
})
t.Run("TerraformDefinedSubAgentNotRecreatedOnConfigChange", func(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("Dev Container tests are not supported on Windows (this test uses mocks but fails due to Windows paths)")
}
var (
logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
mCtrl = gomock.NewController(t)
// Given: A terraform-defined devcontainer with a pre-assigned subagent ID.
terraformAgentID = uuid.New()
terraformContainer = codersdk.WorkspaceAgentContainer{
ID: "test-container-id",
FriendlyName: "test-container",
Image: "test-image",
Running: true,
CreatedAt: time.Now(),
Labels: map[string]string{
agentcontainers.DevcontainerLocalFolderLabel: "/workspace/project",
agentcontainers.DevcontainerConfigFileLabel: "/workspace/project/.devcontainer/devcontainer.json",
},
}
terraformDevcontainer = codersdk.WorkspaceAgentDevcontainer{
ID: uuid.New(),
Name: "terraform-devcontainer",
WorkspaceFolder: "/workspace/project",
ConfigPath: "/workspace/project/.devcontainer/devcontainer.json",
SubagentID: uuid.NullUUID{UUID: terraformAgentID, Valid: true},
}
fCCLI = &fakeContainerCLI{
containers: codersdk.WorkspaceAgentListContainersResponse{
Containers: []codersdk.WorkspaceAgentContainer{terraformContainer},
},
arch: runtime.GOARCH,
}
fDCCLI = &fakeDevcontainerCLI{
upID: terraformContainer.ID,
readConfig: agentcontainers.DevcontainerConfig{
MergedConfiguration: agentcontainers.DevcontainerMergedConfiguration{
Customizations: agentcontainers.DevcontainerMergedCustomizations{
Coder: []agentcontainers.CoderCustomization{{
Apps: []agentcontainers.SubAgentApp{{Slug: "app1"}},
}},
},
},
},
}
mSAC = acmock.NewMockSubAgentClient(mCtrl)
closed bool
)
mSAC.EXPECT().List(gomock.Any()).Return([]agentcontainers.SubAgent{}, nil).AnyTimes()
// EXPECT: Create is called twice with the terraform-defined ID:
// once for the initial creation and once after the rebuild with
// config changes (upsert).
mSAC.EXPECT().Create(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, agent agentcontainers.SubAgent) (agentcontainers.SubAgent, error) {
assert.Equal(t, terraformAgentID, agent.ID, "agent should have terraform-defined ID")
agent.AuthToken = uuid.New()
return agent, nil
},
).Times(2)
// EXPECT: Delete may be called during Close, but not before.
mSAC.EXPECT().Delete(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ uuid.UUID) error {
assert.True(t, closed, "Delete should only be called after Close, not during recreation")
return nil
}).AnyTimes()
api := agentcontainers.NewAPI(logger,
agentcontainers.WithContainerCLI(fCCLI),
agentcontainers.WithDevcontainerCLI(fDCCLI),
agentcontainers.WithDevcontainers(
[]codersdk.WorkspaceAgentDevcontainer{terraformDevcontainer},
[]codersdk.WorkspaceAgentScript{{ID: terraformDevcontainer.ID, LogSourceID: uuid.New()}},
),
agentcontainers.WithSubAgentClient(mSAC),
agentcontainers.WithSubAgentURL("test-subagent-url"),
agentcontainers.WithWatcher(watcher.NewNoop()),
)
api.Start()
// Given: We create the devcontainer for the first time.
err := api.CreateDevcontainer(terraformDevcontainer.WorkspaceFolder, terraformDevcontainer.ConfigPath)
require.NoError(t, err)
// When: The container is recreated (new container ID) with config changes.
terraformContainer.ID = "new-container-id"
fCCLI.containers.Containers = []codersdk.WorkspaceAgentContainer{terraformContainer}
fDCCLI.upID = terraformContainer.ID
fDCCLI.readConfig.MergedConfiguration.Customizations.Coder = []agentcontainers.CoderCustomization{{
Apps: []agentcontainers.SubAgentApp{{Slug: "app2"}}, // Changed app triggers recreation logic.
}}
err = api.CreateDevcontainer(terraformDevcontainer.WorkspaceFolder, terraformDevcontainer.ConfigPath, agentcontainers.WithRemoveExistingContainer())
require.NoError(t, err)
// Then: Mock expectations verify that Create was called once and Delete was not called during recreation.
closed = true
api.Close()
})
// Verify that rebuilding a terraform-defined devcontainer via the
// HTTP API does not delete the sub agent. The sub agent should be
// preserved (Create called again with the same terraform ID) and
// display app changes should be picked up.
t.Run("TerraformDefinedSubAgentRebuildViaHTTP", func(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("Dev Container tests are not supported on Windows (this test uses mocks but fails due to Windows paths)")
}
var (
ctx = testutil.Context(t, testutil.WaitMedium)
logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
mCtrl = gomock.NewController(t)
terraformAgentID = uuid.New()
containerID = "test-container-id"
terraformContainer = codersdk.WorkspaceAgentContainer{
ID: containerID,
FriendlyName: "test-container",
Image: "test-image",
Running: true,
CreatedAt: time.Now(),
Labels: map[string]string{
agentcontainers.DevcontainerLocalFolderLabel: "/workspace/project",
agentcontainers.DevcontainerConfigFileLabel: "/workspace/project/.devcontainer/devcontainer.json",
},
}
terraformDevcontainer = codersdk.WorkspaceAgentDevcontainer{
ID: uuid.New(),
Name: "terraform-devcontainer",
WorkspaceFolder: "/workspace/project",
ConfigPath: "/workspace/project/.devcontainer/devcontainer.json",
SubagentID: uuid.NullUUID{UUID: terraformAgentID, Valid: true},
}
fCCLI = &fakeContainerCLI{
containers: codersdk.WorkspaceAgentListContainersResponse{
Containers: []codersdk.WorkspaceAgentContainer{terraformContainer},
},
arch: runtime.GOARCH,
}
fDCCLI = &fakeDevcontainerCLI{
upID: containerID,
readConfig: agentcontainers.DevcontainerConfig{
MergedConfiguration: agentcontainers.DevcontainerMergedConfiguration{
Customizations: agentcontainers.DevcontainerMergedCustomizations{
Coder: []agentcontainers.CoderCustomization{{
DisplayApps: map[codersdk.DisplayApp]bool{
codersdk.DisplayAppSSH: true,
codersdk.DisplayAppWebTerminal: true,
},
}},
},
},
},
}
mSAC = acmock.NewMockSubAgentClient(mCtrl)
closed bool
createCalled = make(chan agentcontainers.SubAgent, 2)
)
mSAC.EXPECT().List(gomock.Any()).Return([]agentcontainers.SubAgent{}, nil).AnyTimes()
// Create should be called twice: once for the initial injection
// and once after the rebuild picks up the new container.
mSAC.EXPECT().Create(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, agent agentcontainers.SubAgent) (agentcontainers.SubAgent, error) {
assert.Equal(t, terraformAgentID, agent.ID, "agent should always use terraform-defined ID")
agent.AuthToken = uuid.New()
createCalled <- agent
return agent, nil
},
).Times(2)
// Delete must only be called during Close, never during rebuild.
mSAC.EXPECT().Delete(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ uuid.UUID) error {
assert.True(t, closed, "Delete should only be called after Close, not during rebuild")
return nil
}).AnyTimes()
api := agentcontainers.NewAPI(logger,
agentcontainers.WithContainerCLI(fCCLI),
agentcontainers.WithDevcontainerCLI(fDCCLI),
agentcontainers.WithDevcontainers(
[]codersdk.WorkspaceAgentDevcontainer{terraformDevcontainer},
[]codersdk.WorkspaceAgentScript{{ID: terraformDevcontainer.ID, LogSourceID: uuid.New()}},
),
agentcontainers.WithSubAgentClient(mSAC),
agentcontainers.WithSubAgentURL("test-subagent-url"),
agentcontainers.WithWatcher(watcher.NewNoop()),
)
api.Start()
defer func() {
closed = true
api.Close()
}()
r := chi.NewRouter()
r.Mount("/", api.Routes())
// Perform the initial devcontainer creation directly to set up
// the subagent (mirrors the TerraformDefinedSubAgentNotRecreatedOnConfigChange
// test pattern).
err := api.CreateDevcontainer(terraformDevcontainer.WorkspaceFolder, terraformDevcontainer.ConfigPath)
require.NoError(t, err)
initialAgent := testutil.RequireReceive(ctx, t, createCalled)
assert.Equal(t, terraformAgentID, initialAgent.ID)
// Simulate container rebuild: new container ID, changed display apps.
newContainerID := "new-container-id"
terraformContainer.ID = newContainerID
fCCLI.containers.Containers = []codersdk.WorkspaceAgentContainer{terraformContainer}
fDCCLI.upID = newContainerID
fDCCLI.readConfig.MergedConfiguration.Customizations.Coder = []agentcontainers.CoderCustomization{{
DisplayApps: map[codersdk.DisplayApp]bool{
codersdk.DisplayAppSSH: true,
codersdk.DisplayAppWebTerminal: true,
codersdk.DisplayAppVSCodeDesktop: true,
codersdk.DisplayAppVSCodeInsiders: true,
},
}}
// Issue the rebuild request via the HTTP API.
req := httptest.NewRequest(http.MethodPost, "/devcontainers/"+terraformDevcontainer.ID.String()+"/recreate", nil).
WithContext(ctx)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusAccepted, rec.Code)
// Wait for the post-rebuild injection to complete.
rebuiltAgent := testutil.RequireReceive(ctx, t, createCalled)
assert.Equal(t, terraformAgentID, rebuiltAgent.ID, "rebuilt agent should preserve terraform ID")
// Verify that the display apps were updated.
assert.Contains(t, rebuiltAgent.DisplayApps, codersdk.DisplayAppVSCodeDesktop,
"rebuilt agent should include updated display apps")
assert.Contains(t, rebuiltAgent.DisplayApps, codersdk.DisplayAppVSCodeInsiders,
"rebuilt agent should include updated display apps")
})
t.Run("Error", func(t *testing.T) {
t.Parallel()
+2 -10
View File
@@ -24,12 +24,10 @@ type SubAgent struct {
DisplayApps []codersdk.DisplayApp
}
// CloneConfig makes a copy of SubAgent using configuration from the
// devcontainer. The ID is inherited from dc.SubagentID if present, and
// the name is inherited from the devcontainer. AuthToken is not copied.
// CloneConfig makes a copy of SubAgent without ID and AuthToken. The
// name is inherited from the devcontainer.
func (s SubAgent) CloneConfig(dc codersdk.WorkspaceAgentDevcontainer) SubAgent {
return SubAgent{
ID: dc.SubagentID.UUID,
Name: dc.Name,
Directory: s.Directory,
Architecture: s.Architecture,
@@ -192,11 +190,6 @@ func (a *subAgentAPIClient) List(ctx context.Context) ([]SubAgent, error) {
func (a *subAgentAPIClient) Create(ctx context.Context, agent SubAgent) (_ SubAgent, err error) {
a.logger.Debug(ctx, "creating sub agent", slog.F("name", agent.Name), slog.F("directory", agent.Directory))
var id []byte
if agent.ID != uuid.Nil {
id = agent.ID[:]
}
displayApps := make([]agentproto.CreateSubAgentRequest_DisplayApp, 0, len(agent.DisplayApps))
for _, displayApp := range agent.DisplayApps {
var app agentproto.CreateSubAgentRequest_DisplayApp
@@ -235,7 +228,6 @@ func (a *subAgentAPIClient) Create(ctx context.Context, agent SubAgent) (_ SubAg
OperatingSystem: agent.OperatingSystem,
DisplayApps: displayApps,
Apps: apps,
Id: id,
})
if err != nil {
return SubAgent{}, err
-125
View File
@@ -306,128 +306,3 @@ func TestSubAgentClient_CreateWithDisplayApps(t *testing.T) {
}
})
}
func TestSubAgent_CloneConfig(t *testing.T) {
t.Parallel()
t.Run("CopiesIDFromDevcontainer", func(t *testing.T) {
t.Parallel()
subAgent := agentcontainers.SubAgent{
ID: uuid.New(),
Name: "original-name",
Directory: "/workspace",
Architecture: "amd64",
OperatingSystem: "linux",
DisplayApps: []codersdk.DisplayApp{codersdk.DisplayAppVSCodeDesktop},
Apps: []agentcontainers.SubAgentApp{{Slug: "app1"}},
}
expectedID := uuid.MustParse("550e8400-e29b-41d4-a716-446655440000")
dc := codersdk.WorkspaceAgentDevcontainer{
Name: "devcontainer-name",
SubagentID: uuid.NullUUID{UUID: expectedID, Valid: true},
}
cloned := subAgent.CloneConfig(dc)
assert.Equal(t, expectedID, cloned.ID)
assert.Equal(t, dc.Name, cloned.Name)
assert.Equal(t, subAgent.Directory, cloned.Directory)
assert.Zero(t, cloned.AuthToken, "AuthToken should not be copied")
})
t.Run("HandlesNilSubagentID", func(t *testing.T) {
t.Parallel()
subAgent := agentcontainers.SubAgent{
ID: uuid.New(),
Name: "original-name",
Directory: "/workspace",
Architecture: "amd64",
OperatingSystem: "linux",
}
dc := codersdk.WorkspaceAgentDevcontainer{
Name: "devcontainer-name",
SubagentID: uuid.NullUUID{Valid: false},
}
cloned := subAgent.CloneConfig(dc)
assert.Equal(t, uuid.Nil, cloned.ID)
})
}
func TestSubAgent_EqualConfig(t *testing.T) {
t.Parallel()
base := agentcontainers.SubAgent{
ID: uuid.New(),
Name: "test-agent",
Directory: "/workspace",
Architecture: "amd64",
OperatingSystem: "linux",
DisplayApps: []codersdk.DisplayApp{codersdk.DisplayAppVSCodeDesktop},
Apps: []agentcontainers.SubAgentApp{
{Slug: "test-app", DisplayName: "Test App"},
},
}
tests := []struct {
name string
modify func(*agentcontainers.SubAgent)
wantEqual bool
}{
{
name: "identical",
modify: func(s *agentcontainers.SubAgent) {},
wantEqual: true,
},
{
name: "different ID",
modify: func(s *agentcontainers.SubAgent) { s.ID = uuid.New() },
wantEqual: true,
},
{
name: "different Name",
modify: func(s *agentcontainers.SubAgent) { s.Name = "different-name" },
wantEqual: false,
},
{
name: "different Directory",
modify: func(s *agentcontainers.SubAgent) { s.Directory = "/different/path" },
wantEqual: false,
},
{
name: "different Architecture",
modify: func(s *agentcontainers.SubAgent) { s.Architecture = "arm64" },
wantEqual: false,
},
{
name: "different OperatingSystem",
modify: func(s *agentcontainers.SubAgent) { s.OperatingSystem = "windows" },
wantEqual: false,
},
{
name: "different DisplayApps",
modify: func(s *agentcontainers.SubAgent) { s.DisplayApps = []codersdk.DisplayApp{codersdk.DisplayAppSSH} },
wantEqual: false,
},
{
name: "different Apps",
modify: func(s *agentcontainers.SubAgent) {
s.Apps = []agentcontainers.SubAgentApp{{Slug: "different-app", DisplayName: "Different App"}}
},
wantEqual: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
modified := base
tt.modify(&modified)
assert.Equal(t, tt.wantEqual, base.EqualConfig(modified))
})
}
}
+383 -142
View File
@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.9
// protoc v6.33.1
// protoc-gen-go v1.30.0
// protoc v4.23.4
// source: agent/agentsocket/proto/agentsocket.proto
package proto
@@ -11,7 +11,6 @@ import (
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
@@ -22,16 +21,18 @@ const (
)
type PingRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *PingRequest) Reset() {
*x = PingRequest{}
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *PingRequest) String() string {
@@ -42,7 +43,7 @@ func (*PingRequest) ProtoMessage() {}
func (x *PingRequest) ProtoReflect() protoreflect.Message {
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[0]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -58,16 +59,18 @@ func (*PingRequest) Descriptor() ([]byte, []int) {
}
type PingResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *PingResponse) Reset() {
*x = PingResponse{}
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *PingResponse) String() string {
@@ -78,7 +81,7 @@ func (*PingResponse) ProtoMessage() {}
func (x *PingResponse) ProtoReflect() protoreflect.Message {
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[1]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -94,17 +97,20 @@ func (*PingResponse) Descriptor() ([]byte, []int) {
}
type SyncStartRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
Unit string `protobuf:"bytes,1,opt,name=unit,proto3" json:"unit,omitempty"`
unknownFields protoimpl.UnknownFields
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Unit string `protobuf:"bytes,1,opt,name=unit,proto3" json:"unit,omitempty"`
}
func (x *SyncStartRequest) Reset() {
*x = SyncStartRequest{}
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SyncStartRequest) String() string {
@@ -115,7 +121,7 @@ func (*SyncStartRequest) ProtoMessage() {}
func (x *SyncStartRequest) ProtoReflect() protoreflect.Message {
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[2]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -138,16 +144,18 @@ func (x *SyncStartRequest) GetUnit() string {
}
type SyncStartResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *SyncStartResponse) Reset() {
*x = SyncStartResponse{}
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SyncStartResponse) String() string {
@@ -158,7 +166,7 @@ func (*SyncStartResponse) ProtoMessage() {}
func (x *SyncStartResponse) ProtoReflect() protoreflect.Message {
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[3]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -174,18 +182,21 @@ func (*SyncStartResponse) Descriptor() ([]byte, []int) {
}
type SyncWantRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
Unit string `protobuf:"bytes,1,opt,name=unit,proto3" json:"unit,omitempty"`
DependsOn string `protobuf:"bytes,2,opt,name=depends_on,json=dependsOn,proto3" json:"depends_on,omitempty"`
unknownFields protoimpl.UnknownFields
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Unit string `protobuf:"bytes,1,opt,name=unit,proto3" json:"unit,omitempty"`
DependsOn string `protobuf:"bytes,2,opt,name=depends_on,json=dependsOn,proto3" json:"depends_on,omitempty"`
}
func (x *SyncWantRequest) Reset() {
*x = SyncWantRequest{}
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SyncWantRequest) String() string {
@@ -196,7 +207,7 @@ func (*SyncWantRequest) ProtoMessage() {}
func (x *SyncWantRequest) ProtoReflect() protoreflect.Message {
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[4]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -226,16 +237,18 @@ func (x *SyncWantRequest) GetDependsOn() string {
}
type SyncWantResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *SyncWantResponse) Reset() {
*x = SyncWantResponse{}
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[5]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[5]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SyncWantResponse) String() string {
@@ -246,7 +259,7 @@ func (*SyncWantResponse) ProtoMessage() {}
func (x *SyncWantResponse) ProtoReflect() protoreflect.Message {
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[5]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -262,17 +275,20 @@ func (*SyncWantResponse) Descriptor() ([]byte, []int) {
}
type SyncCompleteRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
Unit string `protobuf:"bytes,1,opt,name=unit,proto3" json:"unit,omitempty"`
unknownFields protoimpl.UnknownFields
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Unit string `protobuf:"bytes,1,opt,name=unit,proto3" json:"unit,omitempty"`
}
func (x *SyncCompleteRequest) Reset() {
*x = SyncCompleteRequest{}
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[6]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[6]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SyncCompleteRequest) String() string {
@@ -283,7 +299,7 @@ func (*SyncCompleteRequest) ProtoMessage() {}
func (x *SyncCompleteRequest) ProtoReflect() protoreflect.Message {
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[6]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -306,16 +322,18 @@ func (x *SyncCompleteRequest) GetUnit() string {
}
type SyncCompleteResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *SyncCompleteResponse) Reset() {
*x = SyncCompleteResponse{}
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[7]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[7]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SyncCompleteResponse) String() string {
@@ -326,7 +344,7 @@ func (*SyncCompleteResponse) ProtoMessage() {}
func (x *SyncCompleteResponse) ProtoReflect() protoreflect.Message {
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[7]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -342,17 +360,20 @@ func (*SyncCompleteResponse) Descriptor() ([]byte, []int) {
}
type SyncReadyRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
Unit string `protobuf:"bytes,1,opt,name=unit,proto3" json:"unit,omitempty"`
unknownFields protoimpl.UnknownFields
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Unit string `protobuf:"bytes,1,opt,name=unit,proto3" json:"unit,omitempty"`
}
func (x *SyncReadyRequest) Reset() {
*x = SyncReadyRequest{}
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[8]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[8]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SyncReadyRequest) String() string {
@@ -363,7 +384,7 @@ func (*SyncReadyRequest) ProtoMessage() {}
func (x *SyncReadyRequest) ProtoReflect() protoreflect.Message {
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[8]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -386,17 +407,20 @@ func (x *SyncReadyRequest) GetUnit() string {
}
type SyncReadyResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
Ready bool `protobuf:"varint,1,opt,name=ready,proto3" json:"ready,omitempty"`
unknownFields protoimpl.UnknownFields
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Ready bool `protobuf:"varint,1,opt,name=ready,proto3" json:"ready,omitempty"`
}
func (x *SyncReadyResponse) Reset() {
*x = SyncReadyResponse{}
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[9]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[9]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SyncReadyResponse) String() string {
@@ -407,7 +431,7 @@ func (*SyncReadyResponse) ProtoMessage() {}
func (x *SyncReadyResponse) ProtoReflect() protoreflect.Message {
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[9]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -430,17 +454,20 @@ func (x *SyncReadyResponse) GetReady() bool {
}
type SyncStatusRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
Unit string `protobuf:"bytes,1,opt,name=unit,proto3" json:"unit,omitempty"`
unknownFields protoimpl.UnknownFields
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Unit string `protobuf:"bytes,1,opt,name=unit,proto3" json:"unit,omitempty"`
}
func (x *SyncStatusRequest) Reset() {
*x = SyncStatusRequest{}
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[10]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[10]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SyncStatusRequest) String() string {
@@ -451,7 +478,7 @@ func (*SyncStatusRequest) ProtoMessage() {}
func (x *SyncStatusRequest) ProtoReflect() protoreflect.Message {
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[10]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -474,21 +501,24 @@ func (x *SyncStatusRequest) GetUnit() string {
}
type DependencyInfo struct {
state protoimpl.MessageState `protogen:"open.v1"`
Unit string `protobuf:"bytes,1,opt,name=unit,proto3" json:"unit,omitempty"`
DependsOn string `protobuf:"bytes,2,opt,name=depends_on,json=dependsOn,proto3" json:"depends_on,omitempty"`
RequiredStatus string `protobuf:"bytes,3,opt,name=required_status,json=requiredStatus,proto3" json:"required_status,omitempty"`
CurrentStatus string `protobuf:"bytes,4,opt,name=current_status,json=currentStatus,proto3" json:"current_status,omitempty"`
IsSatisfied bool `protobuf:"varint,5,opt,name=is_satisfied,json=isSatisfied,proto3" json:"is_satisfied,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Unit string `protobuf:"bytes,1,opt,name=unit,proto3" json:"unit,omitempty"`
DependsOn string `protobuf:"bytes,2,opt,name=depends_on,json=dependsOn,proto3" json:"depends_on,omitempty"`
RequiredStatus string `protobuf:"bytes,3,opt,name=required_status,json=requiredStatus,proto3" json:"required_status,omitempty"`
CurrentStatus string `protobuf:"bytes,4,opt,name=current_status,json=currentStatus,proto3" json:"current_status,omitempty"`
IsSatisfied bool `protobuf:"varint,5,opt,name=is_satisfied,json=isSatisfied,proto3" json:"is_satisfied,omitempty"`
}
func (x *DependencyInfo) Reset() {
*x = DependencyInfo{}
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[11]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[11]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *DependencyInfo) String() string {
@@ -499,7 +529,7 @@ func (*DependencyInfo) ProtoMessage() {}
func (x *DependencyInfo) ProtoReflect() protoreflect.Message {
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[11]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -550,19 +580,22 @@ func (x *DependencyInfo) GetIsSatisfied() bool {
}
type SyncStatusResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
Status string `protobuf:"bytes,1,opt,name=status,proto3" json:"status,omitempty"`
IsReady bool `protobuf:"varint,2,opt,name=is_ready,json=isReady,proto3" json:"is_ready,omitempty"`
Dependencies []*DependencyInfo `protobuf:"bytes,3,rep,name=dependencies,proto3" json:"dependencies,omitempty"`
unknownFields protoimpl.UnknownFields
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Status string `protobuf:"bytes,1,opt,name=status,proto3" json:"status,omitempty"`
IsReady bool `protobuf:"varint,2,opt,name=is_ready,json=isReady,proto3" json:"is_ready,omitempty"`
Dependencies []*DependencyInfo `protobuf:"bytes,3,rep,name=dependencies,proto3" json:"dependencies,omitempty"`
}
func (x *SyncStatusResponse) Reset() {
*x = SyncStatusResponse{}
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[12]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[12]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SyncStatusResponse) String() string {
@@ -573,7 +606,7 @@ func (*SyncStatusResponse) ProtoMessage() {}
func (x *SyncStatusResponse) ProtoReflect() protoreflect.Message {
mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[12]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -611,62 +644,111 @@ func (x *SyncStatusResponse) GetDependencies() []*DependencyInfo {
var File_agent_agentsocket_proto_agentsocket_proto protoreflect.FileDescriptor
const file_agent_agentsocket_proto_agentsocket_proto_rawDesc = "" +
"\n" +
")agent/agentsocket/proto/agentsocket.proto\x12\x14coder.agentsocket.v1\"\r\n" +
"\vPingRequest\"\x0e\n" +
"\fPingResponse\"&\n" +
"\x10SyncStartRequest\x12\x12\n" +
"\x04unit\x18\x01 \x01(\tR\x04unit\"\x13\n" +
"\x11SyncStartResponse\"D\n" +
"\x0fSyncWantRequest\x12\x12\n" +
"\x04unit\x18\x01 \x01(\tR\x04unit\x12\x1d\n" +
"\n" +
"depends_on\x18\x02 \x01(\tR\tdependsOn\"\x12\n" +
"\x10SyncWantResponse\")\n" +
"\x13SyncCompleteRequest\x12\x12\n" +
"\x04unit\x18\x01 \x01(\tR\x04unit\"\x16\n" +
"\x14SyncCompleteResponse\"&\n" +
"\x10SyncReadyRequest\x12\x12\n" +
"\x04unit\x18\x01 \x01(\tR\x04unit\")\n" +
"\x11SyncReadyResponse\x12\x14\n" +
"\x05ready\x18\x01 \x01(\bR\x05ready\"'\n" +
"\x11SyncStatusRequest\x12\x12\n" +
"\x04unit\x18\x01 \x01(\tR\x04unit\"\xb6\x01\n" +
"\x0eDependencyInfo\x12\x12\n" +
"\x04unit\x18\x01 \x01(\tR\x04unit\x12\x1d\n" +
"\n" +
"depends_on\x18\x02 \x01(\tR\tdependsOn\x12'\n" +
"\x0frequired_status\x18\x03 \x01(\tR\x0erequiredStatus\x12%\n" +
"\x0ecurrent_status\x18\x04 \x01(\tR\rcurrentStatus\x12!\n" +
"\fis_satisfied\x18\x05 \x01(\bR\visSatisfied\"\x91\x01\n" +
"\x12SyncStatusResponse\x12\x16\n" +
"\x06status\x18\x01 \x01(\tR\x06status\x12\x19\n" +
"\bis_ready\x18\x02 \x01(\bR\aisReady\x12H\n" +
"\fdependencies\x18\x03 \x03(\v2$.coder.agentsocket.v1.DependencyInfoR\fdependencies2\xbb\x04\n" +
"\vAgentSocket\x12M\n" +
"\x04Ping\x12!.coder.agentsocket.v1.PingRequest\x1a\".coder.agentsocket.v1.PingResponse\x12\\\n" +
"\tSyncStart\x12&.coder.agentsocket.v1.SyncStartRequest\x1a'.coder.agentsocket.v1.SyncStartResponse\x12Y\n" +
"\bSyncWant\x12%.coder.agentsocket.v1.SyncWantRequest\x1a&.coder.agentsocket.v1.SyncWantResponse\x12e\n" +
"\fSyncComplete\x12).coder.agentsocket.v1.SyncCompleteRequest\x1a*.coder.agentsocket.v1.SyncCompleteResponse\x12\\\n" +
"\tSyncReady\x12&.coder.agentsocket.v1.SyncReadyRequest\x1a'.coder.agentsocket.v1.SyncReadyResponse\x12_\n" +
"\n" +
"SyncStatus\x12'.coder.agentsocket.v1.SyncStatusRequest\x1a(.coder.agentsocket.v1.SyncStatusResponseB3Z1github.com/coder/coder/v2/agent/agentsocket/protob\x06proto3"
var file_agent_agentsocket_proto_agentsocket_proto_rawDesc = []byte{
0x0a, 0x29, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2f, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63,
0x6b, 0x65, 0x74, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73,
0x6f, 0x63, 0x6b, 0x65, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x14, 0x63, 0x6f, 0x64,
0x65, 0x72, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x2e, 0x76,
0x31, 0x22, 0x0d, 0x0a, 0x0b, 0x50, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x22, 0x0e, 0x0a, 0x0c, 0x50, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
0x22, 0x26, 0x0a, 0x10, 0x53, 0x79, 0x6e, 0x63, 0x53, 0x74, 0x61, 0x72, 0x74, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x6e, 0x69, 0x74, 0x18, 0x01, 0x20, 0x01,
0x28, 0x09, 0x52, 0x04, 0x75, 0x6e, 0x69, 0x74, 0x22, 0x13, 0x0a, 0x11, 0x53, 0x79, 0x6e, 0x63,
0x53, 0x74, 0x61, 0x72, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x44, 0x0a,
0x0f, 0x53, 0x79, 0x6e, 0x63, 0x57, 0x61, 0x6e, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x12, 0x12, 0x0a, 0x04, 0x75, 0x6e, 0x69, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04,
0x75, 0x6e, 0x69, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x64, 0x65, 0x70, 0x65, 0x6e, 0x64, 0x73, 0x5f,
0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x64, 0x65, 0x70, 0x65, 0x6e, 0x64,
0x73, 0x4f, 0x6e, 0x22, 0x12, 0x0a, 0x10, 0x53, 0x79, 0x6e, 0x63, 0x57, 0x61, 0x6e, 0x74, 0x52,
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x29, 0x0a, 0x13, 0x53, 0x79, 0x6e, 0x63, 0x43,
0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12,
0x0a, 0x04, 0x75, 0x6e, 0x69, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x6e,
0x69, 0x74, 0x22, 0x16, 0x0a, 0x14, 0x53, 0x79, 0x6e, 0x63, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65,
0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x26, 0x0a, 0x10, 0x53, 0x79,
0x6e, 0x63, 0x52, 0x65, 0x61, 0x64, 0x79, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12,
0x0a, 0x04, 0x75, 0x6e, 0x69, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x6e,
0x69, 0x74, 0x22, 0x29, 0x0a, 0x11, 0x53, 0x79, 0x6e, 0x63, 0x52, 0x65, 0x61, 0x64, 0x79, 0x52,
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x72, 0x65, 0x61, 0x64, 0x79,
0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x72, 0x65, 0x61, 0x64, 0x79, 0x22, 0x27, 0x0a,
0x11, 0x53, 0x79, 0x6e, 0x63, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65,
0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x6e, 0x69, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09,
0x52, 0x04, 0x75, 0x6e, 0x69, 0x74, 0x22, 0xb6, 0x01, 0x0a, 0x0e, 0x44, 0x65, 0x70, 0x65, 0x6e,
0x64, 0x65, 0x6e, 0x63, 0x79, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x6e, 0x69,
0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x6e, 0x69, 0x74, 0x12, 0x1d, 0x0a,
0x0a, 0x64, 0x65, 0x70, 0x65, 0x6e, 0x64, 0x73, 0x5f, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28,
0x09, 0x52, 0x09, 0x64, 0x65, 0x70, 0x65, 0x6e, 0x64, 0x73, 0x4f, 0x6e, 0x12, 0x27, 0x0a, 0x0f,
0x72, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18,
0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x72, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x64, 0x53,
0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x63, 0x75, 0x72, 0x72, 0x65, 0x6e, 0x74,
0x5f, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x63,
0x75, 0x72, 0x72, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x21, 0x0a, 0x0c,
0x69, 0x73, 0x5f, 0x73, 0x61, 0x74, 0x69, 0x73, 0x66, 0x69, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01,
0x28, 0x08, 0x52, 0x0b, 0x69, 0x73, 0x53, 0x61, 0x74, 0x69, 0x73, 0x66, 0x69, 0x65, 0x64, 0x22,
0x91, 0x01, 0x0a, 0x12, 0x53, 0x79, 0x6e, 0x63, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65,
0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73,
0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x19,
0x0a, 0x08, 0x69, 0x73, 0x5f, 0x72, 0x65, 0x61, 0x64, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08,
0x52, 0x07, 0x69, 0x73, 0x52, 0x65, 0x61, 0x64, 0x79, 0x12, 0x48, 0x0a, 0x0c, 0x64, 0x65, 0x70,
0x65, 0x6e, 0x64, 0x65, 0x6e, 0x63, 0x69, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32,
0x24, 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63,
0x6b, 0x65, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x44, 0x65, 0x70, 0x65, 0x6e, 0x64, 0x65, 0x6e, 0x63,
0x79, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0c, 0x64, 0x65, 0x70, 0x65, 0x6e, 0x64, 0x65, 0x6e, 0x63,
0x69, 0x65, 0x73, 0x32, 0xbb, 0x04, 0x0a, 0x0b, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x53, 0x6f, 0x63,
0x6b, 0x65, 0x74, 0x12, 0x4d, 0x0a, 0x04, 0x50, 0x69, 0x6e, 0x67, 0x12, 0x21, 0x2e, 0x63, 0x6f,
0x64, 0x65, 0x72, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x2e,
0x76, 0x31, 0x2e, 0x50, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x22,
0x2e, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63, 0x6b,
0x65, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x50, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
0x73, 0x65, 0x12, 0x5c, 0x0a, 0x09, 0x53, 0x79, 0x6e, 0x63, 0x53, 0x74, 0x61, 0x72, 0x74, 0x12,
0x26, 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63,
0x6b, 0x65, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x53, 0x74, 0x61, 0x72, 0x74,
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x27, 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2e,
0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x53,
0x79, 0x6e, 0x63, 0x53, 0x74, 0x61, 0x72, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
0x12, 0x59, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x57, 0x61, 0x6e, 0x74, 0x12, 0x25, 0x2e, 0x63,
0x6f, 0x64, 0x65, 0x72, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74,
0x2e, 0x76, 0x31, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x57, 0x61, 0x6e, 0x74, 0x52, 0x65, 0x71, 0x75,
0x65, 0x73, 0x74, 0x1a, 0x26, 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2e, 0x61, 0x67, 0x65, 0x6e,
0x74, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x57,
0x61, 0x6e, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x65, 0x0a, 0x0c, 0x53,
0x79, 0x6e, 0x63, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x12, 0x29, 0x2e, 0x63, 0x6f,
0x64, 0x65, 0x72, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x2e,
0x76, 0x31, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x2a, 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2e, 0x61,
0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x79,
0x6e, 0x63, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
0x73, 0x65, 0x12, 0x5c, 0x0a, 0x09, 0x53, 0x79, 0x6e, 0x63, 0x52, 0x65, 0x61, 0x64, 0x79, 0x12,
0x26, 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63,
0x6b, 0x65, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x52, 0x65, 0x61, 0x64, 0x79,
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x27, 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2e,
0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x53,
0x79, 0x6e, 0x63, 0x52, 0x65, 0x61, 0x64, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
0x12, 0x5f, 0x0a, 0x0a, 0x53, 0x79, 0x6e, 0x63, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x27,
0x2e, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63, 0x6b,
0x65, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73,
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2e,
0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x53,
0x79, 0x6e, 0x63, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
0x65, 0x42, 0x33, 0x5a, 0x31, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f,
0x63, 0x6f, 0x64, 0x65, 0x72, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2f, 0x76, 0x32, 0x2f, 0x61,
0x67, 0x65, 0x6e, 0x74, 0x2f, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74,
0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_agent_agentsocket_proto_agentsocket_proto_rawDescOnce sync.Once
file_agent_agentsocket_proto_agentsocket_proto_rawDescData []byte
file_agent_agentsocket_proto_agentsocket_proto_rawDescData = file_agent_agentsocket_proto_agentsocket_proto_rawDesc
)
func file_agent_agentsocket_proto_agentsocket_proto_rawDescGZIP() []byte {
file_agent_agentsocket_proto_agentsocket_proto_rawDescOnce.Do(func() {
file_agent_agentsocket_proto_agentsocket_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_agent_agentsocket_proto_agentsocket_proto_rawDesc), len(file_agent_agentsocket_proto_agentsocket_proto_rawDesc)))
file_agent_agentsocket_proto_agentsocket_proto_rawDescData = protoimpl.X.CompressGZIP(file_agent_agentsocket_proto_agentsocket_proto_rawDescData)
})
return file_agent_agentsocket_proto_agentsocket_proto_rawDescData
}
var file_agent_agentsocket_proto_agentsocket_proto_msgTypes = make([]protoimpl.MessageInfo, 13)
var file_agent_agentsocket_proto_agentsocket_proto_goTypes = []any{
var file_agent_agentsocket_proto_agentsocket_proto_goTypes = []interface{}{
(*PingRequest)(nil), // 0: coder.agentsocket.v1.PingRequest
(*PingResponse)(nil), // 1: coder.agentsocket.v1.PingResponse
(*SyncStartRequest)(nil), // 2: coder.agentsocket.v1.SyncStartRequest
@@ -707,11 +789,169 @@ func file_agent_agentsocket_proto_agentsocket_proto_init() {
if File_agent_agentsocket_proto_agentsocket_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_agent_agentsocket_proto_agentsocket_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*PingRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_agent_agentsocket_proto_agentsocket_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*PingResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_agent_agentsocket_proto_agentsocket_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SyncStartRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_agent_agentsocket_proto_agentsocket_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SyncStartResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_agent_agentsocket_proto_agentsocket_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SyncWantRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_agent_agentsocket_proto_agentsocket_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SyncWantResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_agent_agentsocket_proto_agentsocket_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SyncCompleteRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_agent_agentsocket_proto_agentsocket_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SyncCompleteResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_agent_agentsocket_proto_agentsocket_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SyncReadyRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_agent_agentsocket_proto_agentsocket_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SyncReadyResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_agent_agentsocket_proto_agentsocket_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SyncStatusRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_agent_agentsocket_proto_agentsocket_proto_msgTypes[11].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*DependencyInfo); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_agent_agentsocket_proto_agentsocket_proto_msgTypes[12].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SyncStatusResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_agent_agentsocket_proto_agentsocket_proto_rawDesc), len(file_agent_agentsocket_proto_agentsocket_proto_rawDesc)),
RawDescriptor: file_agent_agentsocket_proto_agentsocket_proto_rawDesc,
NumEnums: 0,
NumMessages: 13,
NumExtensions: 0,
@@ -722,6 +962,7 @@ func file_agent_agentsocket_proto_agentsocket_proto_init() {
MessageInfos: file_agent_agentsocket_proto_agentsocket_proto_msgTypes,
}.Build()
File_agent_agentsocket_proto_agentsocket_proto = out.File
file_agent_agentsocket_proto_agentsocket_proto_rawDesc = nil
file_agent_agentsocket_proto_agentsocket_proto_goTypes = nil
file_agent_agentsocket_proto_agentsocket_proto_depIdxs = nil
}
+1 -13
View File
@@ -359,17 +359,6 @@ func (s *sessionCloseTracker) Close() error {
return s.Session.Close()
}
func fallbackDisconnectReason(code int, reason string) string {
if reason != "" || code == 0 {
return reason
}
return fmt.Sprintf(
"connection ended unexpectedly: session closed without explicit reason (exit code: %d)",
code,
)
}
func extractContainerInfo(env []string) (container, containerUser string, filteredEnv []string) {
for _, kv := range env {
if strings.HasPrefix(kv, ContainerEnvironmentVariable+"=") {
@@ -450,8 +439,7 @@ func (s *Server) sessionHandler(session ssh.Session) {
disconnected := s.config.ReportConnection(id, magicType, remoteAddrString)
defer func() {
code := scr.exitCode()
disconnected(code, fallbackDisconnectReason(code, reason))
disconnected(scr.exitCode(), reason)
}()
}
-27
View File
@@ -7,7 +7,6 @@ import (
"context"
"io"
"net"
"strings"
"testing"
gliderssh "github.com/gliderlabs/ssh"
@@ -103,32 +102,6 @@ func waitForChan(ctx context.Context, t *testing.T, c <-chan struct{}, msg strin
}
}
func TestFallbackDisconnectReason(t *testing.T) {
t.Parallel()
t.Run("KeepProvidedReason", func(t *testing.T) {
t.Parallel()
reason := fallbackDisconnectReason(255, "network path changed")
assert.Equal(t, "network path changed", reason)
})
t.Run("KeepEmptyReasonForCleanExit", func(t *testing.T) {
t.Parallel()
reason := fallbackDisconnectReason(0, "")
assert.Equal(t, "", reason)
})
t.Run("FallbackReasonForUnexpectedExit", func(t *testing.T) {
t.Parallel()
reason := fallbackDisconnectReason(1, "")
assert.True(t, strings.Contains(reason, "ended unexpectedly"))
assert.True(t, strings.Contains(reason, "exit code: 1"))
})
}
type testSession struct {
ctx testSSHContext
-1
View File
@@ -131,7 +131,6 @@ func TestServer_X11(t *testing.T) {
func TestServer_X11_EvictionLRU(t *testing.T) {
t.Parallel()
t.Skip("Flaky test, times out in CI")
if runtime.GOOS != "linux" {
t.Skip("X11 forwarding is only supported on Linux")
}
+13 -33
View File
@@ -576,8 +576,6 @@ const (
Connection_VSCODE Connection_Type = 2
Connection_JETBRAINS Connection_Type = 3
Connection_RECONNECTING_PTY Connection_Type = 4
Connection_WORKSPACE_APP Connection_Type = 5
Connection_PORT_FORWARDING Connection_Type = 6
)
// Enum value maps for Connection_Type.
@@ -588,8 +586,6 @@ var (
2: "VSCODE",
3: "JETBRAINS",
4: "RECONNECTING_PTY",
5: "WORKSPACE_APP",
6: "PORT_FORWARDING",
}
Connection_Type_value = map[string]int32{
"TYPE_UNSPECIFIED": 0,
@@ -597,8 +593,6 @@ var (
"VSCODE": 2,
"JETBRAINS": 3,
"RECONNECTING_PTY": 4,
"WORKSPACE_APP": 5,
"PORT_FORWARDING": 6,
}
)
@@ -2864,7 +2858,6 @@ type Connection struct {
Ip string `protobuf:"bytes,5,opt,name=ip,proto3" json:"ip,omitempty"`
StatusCode int32 `protobuf:"varint,6,opt,name=status_code,json=statusCode,proto3" json:"status_code,omitempty"`
Reason *string `protobuf:"bytes,7,opt,name=reason,proto3,oneof" json:"reason,omitempty"`
SlugOrPort *string `protobuf:"bytes,8,opt,name=slug_or_port,json=slugOrPort,proto3,oneof" json:"slug_or_port,omitempty"`
}
func (x *Connection) Reset() {
@@ -2948,13 +2941,6 @@ func (x *Connection) GetReason() string {
return ""
}
func (x *Connection) GetSlugOrPort() string {
if x != nil && x.SlugOrPort != nil {
return *x.SlugOrPort
}
return ""
}
type ReportConnectionRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
@@ -5119,7 +5105,7 @@ var file_agent_proto_agent_proto_rawDesc = []byte{
0x74, 0x6f, 0x74, 0x61, 0x6c, 0x42, 0x09, 0x0a, 0x07, 0x5f, 0x6d, 0x65, 0x6d, 0x6f, 0x72, 0x79,
0x22, 0x26, 0x0a, 0x24, 0x50, 0x75, 0x73, 0x68, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65,
0x73, 0x4d, 0x6f, 0x6e, 0x69, 0x74, 0x6f, 0x72, 0x69, 0x6e, 0x67, 0x55, 0x73, 0x61, 0x67, 0x65,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x96, 0x04, 0x0a, 0x0a, 0x43, 0x6f, 0x6e,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0xb6, 0x03, 0x0a, 0x0a, 0x43, 0x6f, 0x6e,
0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20,
0x01, 0x28, 0x0c, 0x52, 0x02, 0x69, 0x64, 0x12, 0x39, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f,
0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x21, 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2e,
@@ -5136,24 +5122,18 @@ var file_agent_proto_agent_proto_rawDesc = []byte{
0x70, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x5f, 0x63, 0x6f, 0x64, 0x65,
0x18, 0x06, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x43, 0x6f,
0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x72, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01,
0x28, 0x09, 0x48, 0x00, 0x52, 0x06, 0x72, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x88, 0x01, 0x01, 0x12,
0x25, 0x0a, 0x0c, 0x73, 0x6c, 0x75, 0x67, 0x5f, 0x6f, 0x72, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18,
0x08, 0x20, 0x01, 0x28, 0x09, 0x48, 0x01, 0x52, 0x0a, 0x73, 0x6c, 0x75, 0x67, 0x4f, 0x72, 0x50,
0x6f, 0x72, 0x74, 0x88, 0x01, 0x01, 0x22, 0x3d, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e,
0x12, 0x16, 0x0a, 0x12, 0x41, 0x43, 0x54, 0x49, 0x4f, 0x4e, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45,
0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x0b, 0x0a, 0x07, 0x43, 0x4f, 0x4e, 0x4e,
0x45, 0x43, 0x54, 0x10, 0x01, 0x12, 0x0e, 0x0a, 0x0a, 0x44, 0x49, 0x53, 0x43, 0x4f, 0x4e, 0x4e,
0x45, 0x43, 0x54, 0x10, 0x02, 0x22, 0x7e, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a,
0x10, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45,
0x44, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x53, 0x53, 0x48, 0x10, 0x01, 0x12, 0x0a, 0x0a, 0x06,
0x56, 0x53, 0x43, 0x4f, 0x44, 0x45, 0x10, 0x02, 0x12, 0x0d, 0x0a, 0x09, 0x4a, 0x45, 0x54, 0x42,
0x52, 0x41, 0x49, 0x4e, 0x53, 0x10, 0x03, 0x12, 0x14, 0x0a, 0x10, 0x52, 0x45, 0x43, 0x4f, 0x4e,
0x4e, 0x45, 0x43, 0x54, 0x49, 0x4e, 0x47, 0x5f, 0x50, 0x54, 0x59, 0x10, 0x04, 0x12, 0x11, 0x0a,
0x0d, 0x57, 0x4f, 0x52, 0x4b, 0x53, 0x50, 0x41, 0x43, 0x45, 0x5f, 0x41, 0x50, 0x50, 0x10, 0x05,
0x12, 0x13, 0x0a, 0x0f, 0x50, 0x4f, 0x52, 0x54, 0x5f, 0x46, 0x4f, 0x52, 0x57, 0x41, 0x52, 0x44,
0x49, 0x4e, 0x47, 0x10, 0x06, 0x42, 0x09, 0x0a, 0x07, 0x5f, 0x72, 0x65, 0x61, 0x73, 0x6f, 0x6e,
0x42, 0x0f, 0x0a, 0x0d, 0x5f, 0x73, 0x6c, 0x75, 0x67, 0x5f, 0x6f, 0x72, 0x5f, 0x70, 0x6f, 0x72,
0x74, 0x22, 0x55, 0x0a, 0x17, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x43, 0x6f, 0x6e, 0x6e, 0x65,
0x28, 0x09, 0x48, 0x00, 0x52, 0x06, 0x72, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x88, 0x01, 0x01, 0x22,
0x3d, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x12, 0x41, 0x43, 0x54,
0x49, 0x4f, 0x4e, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10,
0x00, 0x12, 0x0b, 0x0a, 0x07, 0x43, 0x4f, 0x4e, 0x4e, 0x45, 0x43, 0x54, 0x10, 0x01, 0x12, 0x0e,
0x0a, 0x0a, 0x44, 0x49, 0x53, 0x43, 0x4f, 0x4e, 0x4e, 0x45, 0x43, 0x54, 0x10, 0x02, 0x22, 0x56,
0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x10, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x55,
0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03,
0x53, 0x53, 0x48, 0x10, 0x01, 0x12, 0x0a, 0x0a, 0x06, 0x56, 0x53, 0x43, 0x4f, 0x44, 0x45, 0x10,
0x02, 0x12, 0x0d, 0x0a, 0x09, 0x4a, 0x45, 0x54, 0x42, 0x52, 0x41, 0x49, 0x4e, 0x53, 0x10, 0x03,
0x12, 0x14, 0x0a, 0x10, 0x52, 0x45, 0x43, 0x4f, 0x4e, 0x4e, 0x45, 0x43, 0x54, 0x49, 0x4e, 0x47,
0x5f, 0x50, 0x54, 0x59, 0x10, 0x04, 0x42, 0x09, 0x0a, 0x07, 0x5f, 0x72, 0x65, 0x61, 0x73, 0x6f,
0x6e, 0x22, 0x55, 0x0a, 0x17, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x43, 0x6f, 0x6e, 0x6e, 0x65,
0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x3a, 0x0a, 0x0a,
0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b,
0x32, 0x1a, 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x76,
-3
View File
@@ -364,8 +364,6 @@ message Connection {
VSCODE = 2;
JETBRAINS = 3;
RECONNECTING_PTY = 4;
WORKSPACE_APP = 5;
PORT_FORWARDING = 6;
}
bytes id = 1;
@@ -375,7 +373,6 @@ message Connection {
string ip = 5;
int32 status_code = 6;
optional string reason = 7;
optional string slug_or_port = 8;
}
message ReportConnectionRequest {
-9
View File
@@ -4,8 +4,6 @@ import (
"os"
"github.com/hashicorp/go-reap"
"cdr.dev/slog/v3"
)
type Option func(o *options)
@@ -36,15 +34,8 @@ func WithCatchSignals(sigs ...os.Signal) Option {
}
}
func WithLogger(logger slog.Logger) Option {
return func(o *options) {
o.Logger = logger
}
}
type options struct {
ExecArgs []string
PIDs reap.PidCh
CatchSignals []os.Signal
Logger slog.Logger
}
+2 -14
View File
@@ -3,15 +3,12 @@
package reaper
import (
"context"
"os"
"os/signal"
"syscall"
"github.com/hashicorp/go-reap"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
)
// IsInitProcess returns true if the current process's PID is 1.
@@ -19,7 +16,7 @@ func IsInitProcess() bool {
return os.Getpid() == 1
}
func catchSignals(logger slog.Logger, pid int, sigs []os.Signal) {
func catchSignals(pid int, sigs []os.Signal) {
if len(sigs) == 0 {
return
}
@@ -28,19 +25,10 @@ func catchSignals(logger slog.Logger, pid int, sigs []os.Signal) {
signal.Notify(sc, sigs...)
defer signal.Stop(sc)
logger.Info(context.Background(), "reaper catching signals",
slog.F("signals", sigs),
slog.F("child_pid", pid),
)
for {
s := <-sc
sig, ok := s.(syscall.Signal)
if ok {
logger.Info(context.Background(), "reaper caught signal, killing child process",
slog.F("signal", sig.String()),
slog.F("child_pid", pid),
)
_ = syscall.Kill(pid, sig)
}
}
@@ -90,7 +78,7 @@ func ForkReap(opt ...Option) (int, error) {
return 1, xerrors.Errorf("fork exec: %w", err)
}
go catchSignals(opts.Logger, pid, opts.CatchSignals)
go catchSignals(pid, opts.CatchSignals)
var wstatus syscall.WaitStatus
_, err = syscall.Wait4(pid, &wstatus, 0, nil)
+16 -44
View File
@@ -9,7 +9,6 @@ import (
"net/http/pprof"
"net/url"
"os"
"os/signal"
"path/filepath"
"runtime"
"slices"
@@ -131,7 +130,6 @@ func workspaceAgent() *serpent.Command {
sinks = append(sinks, sloghuman.Sink(logWriter))
logger := inv.Logger.AppendSinks(sinks...).Leveled(slog.LevelDebug)
logger = logger.Named("reaper")
logger.Info(ctx, "spawning reaper process")
// Do not start a reaper on the child process. It's important
@@ -141,19 +139,31 @@ func workspaceAgent() *serpent.Command {
exitCode, err := reaper.ForkReap(
reaper.WithExecArgs(args...),
reaper.WithCatchSignals(StopSignals...),
reaper.WithLogger(logger),
)
if err != nil {
logger.Error(ctx, "agent process reaper unable to fork", slog.Error(err))
return xerrors.Errorf("fork reap: %w", err)
}
logger.Info(ctx, "child process exited, propagating exit code",
slog.F("exit_code", exitCode),
)
logger.Info(ctx, "reaper child process exited", slog.F("exit_code", exitCode))
return ExitError(exitCode, nil)
}
// Handle interrupt signals to allow for graceful shutdown,
// note that calling stopNotify disables the signal handler
// and the next interrupt will terminate the program (you
// probably want cancel instead).
//
// Note that we don't want to handle these signals in the
// process that runs as PID 1, that's why we do this after
// the reaper forked.
ctx, stopNotify := inv.SignalNotifyContext(ctx, StopSignals...)
defer stopNotify()
// DumpHandler does signal handling, so we call it after the
// reaper.
go DumpHandler(ctx, "agent")
logWriter := &clilog.LumberjackWriteCloseFixer{Writer: &lumberjack.Logger{
Filename: filepath.Join(logDir, "coder-agent.log"),
MaxSize: 5, // MB
@@ -166,21 +176,6 @@ func workspaceAgent() *serpent.Command {
sinks = append(sinks, sloghuman.Sink(logWriter))
logger := inv.Logger.AppendSinks(sinks...).Leveled(slog.LevelDebug)
// Handle interrupt signals to allow for graceful shutdown,
// note that calling stopNotify disables the signal handler
// and the next interrupt will terminate the program (you
// probably want cancel instead).
//
// Note that we also handle these signals in the
// process that runs as PID 1, mainly to forward it to the agent child
// so that it can shutdown gracefully.
ctx, stopNotify := logSignalNotifyContext(ctx, logger, StopSignals...)
defer stopNotify()
// DumpHandler does signal handling, so we call it after the
// reaper.
go DumpHandler(ctx, "agent")
version := buildinfo.Version()
logger.Info(ctx, "agent is starting now",
slog.F("url", agentAuth.agentURL),
@@ -570,26 +565,3 @@ func urlPort(u string) (int, error) {
}
return -1, xerrors.Errorf("invalid port: %s", u)
}
// logSignalNotifyContext is like signal.NotifyContext but logs the received
// signal before canceling the context.
func logSignalNotifyContext(parent context.Context, logger slog.Logger, signals ...os.Signal) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancelCause(parent)
c := make(chan os.Signal, 1)
signal.Notify(c, signals...)
go func() {
select {
case sig := <-c:
logger.Info(ctx, "agent received signal", slog.F("signal", sig.String()))
cancel(xerrors.Errorf("signal: %s", sig.String()))
case <-ctx.Done():
logger.Info(ctx, "ctx canceled, stopping signal handler")
}
}()
return ctx, func() {
cancel(context.Canceled)
signal.Stop(c)
}
}
@@ -1,4 +1,4 @@
package hostname
package cliutil
import (
"os"
-1
View File
@@ -23,7 +23,6 @@ func (r *RootCmd) organizations() *serpent.Command {
},
Children: []*serpent.Command{
r.showOrganization(orgContext),
r.listOrganizations(),
r.createOrganization(),
r.deleteOrganization(orgContext),
r.organizationMembers(orgContext),
-43
View File
@@ -1,7 +1,6 @@
package cli_test
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
@@ -59,48 +58,6 @@ func TestCurrentOrganization(t *testing.T) {
})
}
func TestOrganizationList(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
orgID := uuid.New()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/v2/organizations":
_ = json.NewEncoder(w).Encode([]codersdk.Organization{
{
MinimalOrganization: codersdk.MinimalOrganization{
ID: orgID,
Name: "my-org",
DisplayName: "My Org",
},
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
},
})
default:
t.Errorf("unexpected request: %s %s", r.Method, r.URL.Path)
w.WriteHeader(http.StatusNotFound)
}
}))
defer server.Close()
client := codersdk.New(must(url.Parse(server.URL)))
inv, root := clitest.New(t, "organizations", "list")
clitest.SetupConfig(t, client, root)
buf := new(bytes.Buffer)
inv.Stdout = buf
require.NoError(t, inv.Run())
require.Contains(t, buf.String(), "my-org")
require.Contains(t, buf.String(), "My Org")
require.Contains(t, buf.String(), orgID.String())
})
}
func TestOrganizationDelete(t *testing.T) {
t.Parallel()
-53
View File
@@ -1,53 +0,0 @@
package cli
import (
"fmt"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/serpent"
)
func (r *RootCmd) listOrganizations() *serpent.Command {
formatter := cliui.NewOutputFormatter(
cliui.TableFormat([]codersdk.Organization{}, []string{"name", "display name", "id", "default"}),
cliui.JSONFormat(),
)
cmd := &serpent.Command{
Use: "list",
Short: "List all organizations",
Long: "List all organizations. Requires a role which grants ResourceOrganization: read.",
Aliases: []string{"ls"},
Middleware: serpent.Chain(
serpent.RequireNArgs(0),
),
Handler: func(inv *serpent.Invocation) error {
client, err := r.InitClient(inv)
if err != nil {
return err
}
organizations, err := client.Organizations(inv.Context())
if err != nil {
return err
}
out, err := formatter.Format(inv.Context(), organizations)
if err != nil {
return err
}
if out == "" {
cliui.Infof(inv.Stderr, "No organizations found.")
return nil
}
_, err = fmt.Fprintln(inv.Stdout, out)
return err
},
}
formatter.AttachOptions(&cmd.Options)
return cmd
}
+1 -3
View File
@@ -123,9 +123,7 @@ func (r *RootCmd) ping() *serpent.Command {
spin.Start()
}
opts := &workspacesdk.DialAgentOptions{
ShortDescription: "CLI ping",
}
opts := &workspacesdk.DialAgentOptions{}
if r.verbose {
opts.Logger = inv.Logger.AppendSinks(sloghuman.Sink(inv.Stdout)).Leveled(slog.LevelDebug)
+1 -3
View File
@@ -107,9 +107,7 @@ func (r *RootCmd) portForward() *serpent.Command {
return xerrors.Errorf("await agent: %w", err)
}
opts := &workspacesdk.DialAgentOptions{
ShortDescription: "CLI port-forward",
}
opts := &workspacesdk.DialAgentOptions{}
logger := inv.Logger
if r.verbose {
+3 -3
View File
@@ -59,7 +59,7 @@ import (
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/cli/clilog"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/cli/cliutil/hostname"
"github.com/coder/coder/v2/cli/cliutil"
"github.com/coder/coder/v2/cli/config"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/autobuild"
@@ -1029,7 +1029,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
suffix := fmt.Sprintf("%d", i)
// The suffix is added to the hostname, so we may need to trim to fit into
// the 64 character limit.
hostname := stringutil.Truncate(hostname.Hostname(), 63-len(suffix))
hostname := stringutil.Truncate(cliutil.Hostname(), 63-len(suffix))
name := fmt.Sprintf("%s-%s", hostname, suffix)
daemonCacheDir := filepath.Join(cacheDir, fmt.Sprintf("provisioner-%d", i))
daemon, err := newProvisionerDaemon(
@@ -2174,7 +2174,7 @@ func startBuiltinPostgres(ctx context.Context, cfg config.Root, logger slog.Logg
// existing database
retryPortDiscovery := errors.Is(err, os.ErrNotExist) && testing.Testing()
if retryPortDiscovery {
maxAttempts = 10
maxAttempts = 3
}
var startErr error
+1 -3
View File
@@ -97,9 +97,7 @@ func (r *RootCmd) speedtest() *serpent.Command {
return xerrors.Errorf("await agent: %w", err)
}
opts := &workspacesdk.DialAgentOptions{
ShortDescription: "CLI speedtest",
}
opts := &workspacesdk.DialAgentOptions{}
if r.verbose {
opts.Logger = inv.Logger.AppendSinks(sloghuman.Sink(inv.Stderr)).Leveled(slog.LevelDebug)
}
+3 -8
View File
@@ -365,10 +365,6 @@ func (r *RootCmd) ssh() *serpent.Command {
}
return err
}
shortDescription := "CLI ssh"
if stdio {
shortDescription = "CLI ssh (stdio)"
}
// If we're in stdio mode, check to see if we can use Coder Connect.
// We don't support Coder Connect over non-stdio coder ssh yet.
@@ -409,10 +405,9 @@ func (r *RootCmd) ssh() *serpent.Command {
}
conn, err := wsClient.
DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{
Logger: logger,
BlockEndpoints: r.disableDirect,
EnableTelemetry: !r.disableNetworkTelemetry,
ShortDescription: shortDescription,
Logger: logger,
BlockEndpoints: r.disableDirect,
EnableTelemetry: !r.disableNetworkTelemetry,
})
if err != nil {
return xerrors.Errorf("dial agent: %w", err)
-1
View File
@@ -418,7 +418,6 @@ func writeBundle(src *support.Bundle, dest *zip.Writer) error {
"workspace/template_version.json": src.Workspace.TemplateVersion,
"workspace/parameters.json": src.Workspace.Parameters,
"workspace/workspace.json": src.Workspace.Workspace,
"workspace/workspace_sessions.json": src.Workspace.WorkspaceSessions,
} {
f, err := dest.Create(k)
if err != nil {
-1
View File
@@ -10,7 +10,6 @@ USAGE:
SUBCOMMANDS:
create Create a new organization.
delete Delete an organization
list List all organizations
members Manage organization members
roles Manage organization roles.
settings Manage organization settings.
-21
View File
@@ -1,21 +0,0 @@
coder v0.0.0-devel
USAGE:
coder organizations list [flags]
List all organizations
Aliases: ls
List all organizations. Requires a role which grants ResourceOrganization:
read.
OPTIONS:
-c, --column [id|name|display name|icon|description|created at|updated at|default] (default: name,display name,id,default)
Columns to display in table output.
-o, --output table|json (default: table)
Output format.
———
Run `coder --help` for a list of global options.
+2 -3
View File
@@ -166,9 +166,8 @@ func (r *RootCmd) vscodeSSH() *serpent.Command {
}
agentConn, err := workspacesdk.New(client).
DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{
Logger: logger,
BlockEndpoints: r.disableDirect,
ShortDescription: "VSCode SSH",
Logger: logger,
BlockEndpoints: r.disableDirect,
})
if err != nil {
return xerrors.Errorf("dial workspace agent: %w", err)
+5 -9
View File
@@ -89,7 +89,6 @@ type Options struct {
PublishWorkspaceAgentLogsUpdateFn func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage)
NetworkTelemetryHandler func(batch []*tailnetproto.TelemetryEvent)
BoundaryUsageTracker *boundaryusage.Tracker
LifecycleMetrics *LifecycleMetrics
AccessURL *url.URL
AppHostname string
@@ -171,7 +170,6 @@ func New(opts Options, workspace database.Workspace) *API {
Database: opts.Database,
Log: opts.Log,
PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate,
Metrics: opts.LifecycleMetrics,
}
api.AppsAPI = &AppsAPI{
@@ -202,13 +200,11 @@ func New(opts Options, workspace database.Workspace) *API {
}
api.ConnLogAPI = &ConnLogAPI{
AgentFn: api.agent,
ConnectionLogger: opts.ConnectionLogger,
TailnetCoordinator: opts.TailnetCoordinator,
Database: opts.Database,
Workspace: api.cachedWorkspaceFields,
Log: opts.Log,
PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate,
AgentFn: api.agent,
ConnectionLogger: opts.ConnectionLogger,
Database: opts.Database,
Workspace: api.cachedWorkspaceFields,
Log: opts.Log,
}
api.DRPCService = &tailnet.DRPCService{
+8 -132
View File
@@ -3,8 +3,6 @@ package agentapi
import (
"context"
"database/sql"
"fmt"
"net/netip"
"sync/atomic"
"github.com/google/uuid"
@@ -17,18 +15,14 @@ import (
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/wspubsub"
"github.com/coder/coder/v2/tailnet"
)
type ConnLogAPI struct {
AgentFn func(context.Context) (database.WorkspaceAgent, error)
ConnectionLogger *atomic.Pointer[connectionlog.ConnectionLogger]
TailnetCoordinator *atomic.Pointer[tailnet.Coordinator]
Workspace *CachedWorkspaceFields
Database database.Store
Log slog.Logger
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent, wspubsub.WorkspaceEventKind) error
AgentFn func(context.Context) (database.WorkspaceAgent, error)
ConnectionLogger *atomic.Pointer[connectionlog.ConnectionLogger]
Workspace *CachedWorkspaceFields
Database database.Store
Log slog.Logger
}
func (a *ConnLogAPI) ReportConnection(ctx context.Context, req *agentproto.ReportConnectionRequest) (*emptypb.Empty, error) {
@@ -94,35 +88,6 @@ func (a *ConnLogAPI) ReportConnection(ctx context.Context, req *agentproto.Repor
}
logIP := database.ParseIP(logIPRaw) // will return null if invalid
// At connect time, look up the tailnet peer to capture the
// client hostname and description for session grouping later.
var clientHostname, shortDescription sql.NullString
if action == database.ConnectionStatusConnected && a.TailnetCoordinator != nil {
if coord := a.TailnetCoordinator.Load(); coord != nil {
for _, peer := range (*coord).TunnelPeers(workspaceAgent.ID) {
if peer.Node != nil {
// Match peer by checking if any of its addresses
// match the connection IP.
for _, addr := range peer.Node.Addresses {
prefix, err := netip.ParsePrefix(addr)
if err != nil {
continue
}
if logIP.Valid && prefix.Addr().String() == logIP.IPNet.IP.String() {
if peer.Node.Hostname != "" {
clientHostname = sql.NullString{String: peer.Node.Hostname, Valid: true}
}
if peer.Node.ShortDescription != "" {
shortDescription = sql.NullString{String: peer.Node.ShortDescription, Valid: true}
}
break
}
}
}
}
}
}
reason := req.GetConnection().GetReason()
connLogger := *a.ConnectionLogger.Load()
err = connLogger.Upsert(ctx, database.UpsertConnectionLogParams{
@@ -133,7 +98,6 @@ func (a *ConnLogAPI) ReportConnection(ctx context.Context, req *agentproto.Repor
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: workspaceAgent.Name,
AgentID: uuid.NullUUID{UUID: workspaceAgent.ID, Valid: true},
Type: connectionType,
Code: code,
Ip: logIP,
@@ -145,7 +109,6 @@ func (a *ConnLogAPI) ReportConnection(ctx context.Context, req *agentproto.Repor
String: reason,
Valid: reason != "",
},
SessionID: uuid.NullUUID{},
// We supply the action:
// - So the DB can handle duplicate connections or disconnections properly.
// - To make it clear whether this is a connection or disconnection
@@ -158,100 +121,13 @@ func (a *ConnLogAPI) ReportConnection(ctx context.Context, req *agentproto.Repor
Valid: false,
},
// N/A
UserAgent: sql.NullString{},
ClientHostname: clientHostname,
ShortDescription: shortDescription,
SlugOrPort: sql.NullString{
String: req.GetConnection().GetSlugOrPort(),
Valid: req.GetConnection().GetSlugOrPort() != "",
},
UserAgent: sql.NullString{},
// N/A
SlugOrPort: sql.NullString{},
})
if err != nil {
return nil, xerrors.Errorf("export connection log: %w", err)
}
// At disconnect time, find or create a session for this connection.
// This groups related connection logs into workspace sessions.
if action == database.ConnectionStatusDisconnected {
a.assignSessionForDisconnect(ctx, connectionID, ws, workspaceAgent, req)
}
if a.PublishWorkspaceUpdateFn != nil {
if err := a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindConnectionLogUpdate); err != nil {
a.Log.Warn(ctx, "failed to publish connection log update", slog.Error(err))
}
}
return &emptypb.Empty{}, nil
}
// assignSessionForDisconnect looks up the existing connection log for this
// connection ID and finds or creates a session to group it with.
func (a *ConnLogAPI) assignSessionForDisconnect(
ctx context.Context,
connectionID uuid.UUID,
ws database.WorkspaceIdentity,
workspaceAgent database.WorkspaceAgent,
req *agentproto.ReportConnectionRequest,
) {
//nolint:gocritic // The agent context doesn't have connection_log
// permissions. Session creation is authorized by the workspace
// access already validated in ReportConnection.
ctx = dbauthz.AsConnectionLogger(ctx)
existingLog, err := a.Database.GetConnectionLogByConnectionID(ctx, database.GetConnectionLogByConnectionIDParams{
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
WorkspaceID: ws.ID,
AgentName: workspaceAgent.Name,
})
if err != nil {
a.Log.Warn(ctx, "failed to look up connection log for session assignment",
slog.Error(err),
slog.F("connection_id", connectionID),
)
return
}
sessionIDRaw, err := a.Database.FindOrCreateSessionForDisconnect(ctx, database.FindOrCreateSessionForDisconnectParams{
WorkspaceID: ws.ID.String(),
Ip: existingLog.Ip,
ClientHostname: existingLog.ClientHostname,
ShortDescription: existingLog.ShortDescription,
ConnectTime: existingLog.ConnectTime,
DisconnectTime: req.GetConnection().GetTimestamp().AsTime(),
AgentID: uuid.NullUUID{UUID: workspaceAgent.ID, Valid: true},
})
if err != nil {
a.Log.Warn(ctx, "failed to find or create session for disconnect",
slog.Error(err),
slog.F("connection_id", connectionID),
)
return
}
// The query uses COALESCE which returns a generic type. The
// database/sql driver may return the UUID as a string, []byte,
// or [16]byte rather than uuid.UUID, so we parse it.
sessionID, parseErr := uuid.Parse(fmt.Sprintf("%s", sessionIDRaw))
if parseErr != nil {
a.Log.Warn(ctx, "failed to parse session ID from FindOrCreateSessionForDisconnect",
slog.Error(parseErr),
slog.F("connection_id", connectionID),
slog.F("session_id_raw", sessionIDRaw),
slog.F("session_id_type", fmt.Sprintf("%T", sessionIDRaw)),
)
return
}
// Link the connection log to its session so that
// CloseConnectionLogsAndCreateSessions skips it.
if err := a.Database.UpdateConnectionLogSessionID(ctx, database.UpdateConnectionLogSessionIDParams{
ID: existingLog.ID,
SessionID: uuid.NullUUID{UUID: sessionID, Valid: true},
}); err != nil {
a.Log.Warn(ctx, "failed to update connection log session ID",
slog.Error(err),
slog.F("connection_id", connectionID),
)
}
}
+8 -105
View File
@@ -19,7 +19,6 @@ import (
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/wspubsub"
)
func TestConnectionLog(t *testing.T) {
@@ -42,15 +41,14 @@ func TestConnectionLog(t *testing.T) {
)
tests := []struct {
name string
id uuid.UUID
action *agentproto.Connection_Action
typ *agentproto.Connection_Type
time time.Time
ip string
status int32
reason string
slugOrPort string
name string
id uuid.UUID
action *agentproto.Connection_Action
typ *agentproto.Connection_Type
time time.Time
ip string
status int32
reason string
}{
{
name: "SSH Connect",
@@ -86,34 +84,6 @@ func TestConnectionLog(t *testing.T) {
typ: agentproto.Connection_RECONNECTING_PTY.Enum(),
time: dbtime.Now(),
},
{
name: "Port Forwarding Connect",
id: uuid.New(),
action: agentproto.Connection_CONNECT.Enum(),
typ: agentproto.Connection_PORT_FORWARDING.Enum(),
time: dbtime.Now(),
ip: "192.168.1.1",
slugOrPort: "8080",
},
{
name: "Port Forwarding Disconnect",
id: uuid.New(),
action: agentproto.Connection_DISCONNECT.Enum(),
typ: agentproto.Connection_PORT_FORWARDING.Enum(),
time: dbtime.Now(),
ip: "192.168.1.1",
status: 200,
slugOrPort: "8080",
},
{
name: "Workspace App Connect",
id: uuid.New(),
action: agentproto.Connection_CONNECT.Enum(),
typ: agentproto.Connection_WORKSPACE_APP.Enum(),
time: dbtime.Now(),
ip: "10.0.0.1",
slugOrPort: "my-app",
},
{
name: "SSH Disconnect",
id: uuid.New(),
@@ -140,10 +110,6 @@ func TestConnectionLog(t *testing.T) {
mDB := dbmock.NewMockStore(gomock.NewController(t))
mDB.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Return(workspace, nil)
// Disconnect actions trigger session assignment which calls
// GetConnectionLogByConnectionID and FindOrCreateSessionForDisconnect.
mDB.EXPECT().GetConnectionLogByConnectionID(gomock.Any(), gomock.Any()).Return(database.ConnectionLog{}, nil).AnyTimes()
mDB.EXPECT().FindOrCreateSessionForDisconnect(gomock.Any(), gomock.Any()).Return(database.WorkspaceSession{}, nil).AnyTimes()
api := &agentapi.ConnLogAPI{
ConnectionLogger: asAtomicPointer[connectionlog.ConnectionLogger](connLogger),
@@ -162,7 +128,6 @@ func TestConnectionLog(t *testing.T) {
Ip: tt.ip,
StatusCode: tt.status,
Reason: &tt.reason,
SlugOrPort: &tt.slugOrPort,
},
})
@@ -179,7 +144,6 @@ func TestConnectionLog(t *testing.T) {
WorkspaceID: workspace.ID,
WorkspaceName: workspace.Name,
AgentName: agent.Name,
AgentID: uuid.NullUUID{UUID: agent.ID, Valid: true},
UserID: uuid.NullUUID{
UUID: uuid.Nil,
Valid: false,
@@ -200,72 +164,11 @@ func TestConnectionLog(t *testing.T) {
UUID: tt.id,
Valid: tt.id != uuid.Nil,
},
SlugOrPort: sql.NullString{
String: tt.slugOrPort,
Valid: tt.slugOrPort != "",
},
}))
})
}
}
func TestConnectionLogPublishesWorkspaceUpdate(t *testing.T) {
t.Parallel()
var (
owner = database.User{ID: uuid.New(), Username: "cool-user"}
workspace = database.Workspace{
ID: uuid.New(),
OrganizationID: uuid.New(),
OwnerID: owner.ID,
Name: "cool-workspace",
}
agent = database.WorkspaceAgent{ID: uuid.New()}
)
connLogger := connectionlog.NewFake()
mDB := dbmock.NewMockStore(gomock.NewController(t))
mDB.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Return(workspace, nil)
var (
called int
gotKind wspubsub.WorkspaceEventKind
gotAgent uuid.UUID
)
api := &agentapi.ConnLogAPI{
ConnectionLogger: asAtomicPointer[connectionlog.ConnectionLogger](connLogger),
Database: mDB,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Workspace: &agentapi.CachedWorkspaceFields{},
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
called++
gotKind = kind
gotAgent = agent.ID
return nil
},
}
id := uuid.New()
_, err := api.ReportConnection(context.Background(), &agentproto.ReportConnectionRequest{
Connection: &agentproto.Connection{
Id: id[:],
Action: agentproto.Connection_CONNECT,
Type: agentproto.Connection_SSH,
Timestamp: timestamppb.New(dbtime.Now()),
Ip: "127.0.0.1",
},
})
require.NoError(t, err)
require.Equal(t, 1, called)
require.Equal(t, wspubsub.WorkspaceEventKindConnectionLogUpdate, gotKind)
require.Equal(t, agent.ID, gotAgent)
}
func agentProtoConnectionTypeToConnectionLog(t *testing.T, typ agentproto.Connection_Type) database.ConnectionType {
a, err := db2sdk.ConnectionLogConnectionTypeFromAgentProtoConnectionType(typ)
require.NoError(t, err)
+1 -15
View File
@@ -4,7 +4,6 @@ import (
"context"
"database/sql"
"slices"
"sync"
"time"
"github.com/google/uuid"
@@ -32,9 +31,7 @@ type LifecycleAPI struct {
Log slog.Logger
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent, wspubsub.WorkspaceEventKind) error
TimeNowFn func() time.Time // defaults to dbtime.Now()
Metrics *LifecycleMetrics
emitMetricsOnce sync.Once
TimeNowFn func() time.Time // defaults to dbtime.Now()
}
func (a *LifecycleAPI) now() time.Time {
@@ -128,17 +125,6 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda
}
}
// Emit build duration metric when agent transitions to a terminal startup state.
// We only emit once per agent connection to avoid duplicate metrics.
switch lifecycleState {
case database.WorkspaceAgentLifecycleStateReady,
database.WorkspaceAgentLifecycleStateStartTimeout,
database.WorkspaceAgentLifecycleStateStartError:
a.emitMetricsOnce.Do(func() {
a.emitBuildDurationMetric(ctx, workspaceAgent.ResourceID)
})
}
return req.Lifecycle, nil
}
-260
View File
@@ -9,14 +9,12 @@ import (
"time"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"google.golang.org/protobuf/types/known/timestamppb"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/agentapi"
"github.com/coder/coder/v2/coderd/coderdtest/promhelp"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtime"
@@ -24,10 +22,6 @@ import (
"github.com/coder/coder/v2/testutil"
)
// fullMetricName is the fully-qualified Prometheus metric name
// (namespace + name) used for gathering in tests.
const fullMetricName = "coderd_" + agentapi.BuildDurationMetricName
func TestUpdateLifecycle(t *testing.T) {
t.Parallel()
@@ -36,12 +30,6 @@ func TestUpdateLifecycle(t *testing.T) {
someTime = dbtime.Time(someTime)
now := dbtime.Now()
// Fixed times for build duration metric assertions.
// The expected duration is exactly 90 seconds.
buildCreatedAt := dbtime.Time(time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC))
agentReadyAt := dbtime.Time(time.Date(2025, 1, 1, 0, 1, 30, 0, time.UTC))
expectedDuration := agentReadyAt.Sub(buildCreatedAt).Seconds() // 90.0
var (
workspaceID = uuid.New()
agentCreated = database.WorkspaceAgent{
@@ -117,19 +105,6 @@ func TestUpdateLifecycle(t *testing.T) {
Valid: true,
},
}).Return(nil)
dbM.EXPECT().GetWorkspaceBuildMetricsByResourceID(gomock.Any(), agentStarting.ResourceID).Return(database.GetWorkspaceBuildMetricsByResourceIDRow{
CreatedAt: buildCreatedAt,
Transition: database.WorkspaceTransitionStart,
TemplateName: "test-template",
OrganizationName: "test-org",
IsPrebuild: false,
AllAgentsReady: true,
LastAgentReadyAt: agentReadyAt,
WorstStatus: "success",
}, nil)
reg := prometheus.NewRegistry()
metrics := agentapi.NewLifecycleMetrics(reg)
api := &agentapi.LifecycleAPI{
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
@@ -138,7 +113,6 @@ func TestUpdateLifecycle(t *testing.T) {
WorkspaceID: workspaceID,
Database: dbM,
Log: testutil.Logger(t),
Metrics: metrics,
// Test that nil publish fn works.
PublishWorkspaceUpdateFn: nil,
}
@@ -148,16 +122,6 @@ func TestUpdateLifecycle(t *testing.T) {
})
require.NoError(t, err)
require.Equal(t, lifecycle, resp)
got := promhelp.HistogramValue(t, reg, fullMetricName, prometheus.Labels{
"template_name": "test-template",
"organization_name": "test-org",
"transition": "start",
"status": "success",
"is_prebuild": "false",
})
require.Equal(t, uint64(1), got.GetSampleCount())
require.Equal(t, expectedDuration, got.GetSampleSum())
})
// This test jumps from CREATING to READY, skipping STARTED. Both the
@@ -183,21 +147,8 @@ func TestUpdateLifecycle(t *testing.T) {
Valid: true,
},
}).Return(nil)
dbM.EXPECT().GetWorkspaceBuildMetricsByResourceID(gomock.Any(), agentCreated.ResourceID).Return(database.GetWorkspaceBuildMetricsByResourceIDRow{
CreatedAt: buildCreatedAt,
Transition: database.WorkspaceTransitionStart,
TemplateName: "test-template",
OrganizationName: "test-org",
IsPrebuild: false,
AllAgentsReady: true,
LastAgentReadyAt: agentReadyAt,
WorstStatus: "success",
}, nil)
publishCalled := false
reg := prometheus.NewRegistry()
metrics := agentapi.NewLifecycleMetrics(reg)
api := &agentapi.LifecycleAPI{
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return agentCreated, nil
@@ -205,7 +156,6 @@ func TestUpdateLifecycle(t *testing.T) {
WorkspaceID: workspaceID,
Database: dbM,
Log: testutil.Logger(t),
Metrics: metrics,
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
publishCalled = true
return nil
@@ -218,16 +168,6 @@ func TestUpdateLifecycle(t *testing.T) {
require.NoError(t, err)
require.Equal(t, lifecycle, resp)
require.True(t, publishCalled)
got := promhelp.HistogramValue(t, reg, fullMetricName, prometheus.Labels{
"template_name": "test-template",
"organization_name": "test-org",
"transition": "start",
"status": "success",
"is_prebuild": "false",
})
require.Equal(t, uint64(1), got.GetSampleCount())
require.Equal(t, expectedDuration, got.GetSampleSum())
})
t.Run("NoTimeSpecified", func(t *testing.T) {
@@ -254,19 +194,6 @@ func TestUpdateLifecycle(t *testing.T) {
Valid: true,
},
})
dbM.EXPECT().GetWorkspaceBuildMetricsByResourceID(gomock.Any(), agentCreated.ResourceID).Return(database.GetWorkspaceBuildMetricsByResourceIDRow{
CreatedAt: buildCreatedAt,
Transition: database.WorkspaceTransitionStart,
TemplateName: "test-template",
OrganizationName: "test-org",
IsPrebuild: false,
AllAgentsReady: true,
LastAgentReadyAt: agentReadyAt,
WorstStatus: "success",
}, nil)
reg := prometheus.NewRegistry()
metrics := agentapi.NewLifecycleMetrics(reg)
api := &agentapi.LifecycleAPI{
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
@@ -275,7 +202,6 @@ func TestUpdateLifecycle(t *testing.T) {
WorkspaceID: workspaceID,
Database: dbM,
Log: testutil.Logger(t),
Metrics: metrics,
PublishWorkspaceUpdateFn: nil,
TimeNowFn: func() time.Time {
return now
@@ -287,16 +213,6 @@ func TestUpdateLifecycle(t *testing.T) {
})
require.NoError(t, err)
require.Equal(t, lifecycle, resp)
got := promhelp.HistogramValue(t, reg, fullMetricName, prometheus.Labels{
"template_name": "test-template",
"organization_name": "test-org",
"transition": "start",
"status": "success",
"is_prebuild": "false",
})
require.Equal(t, uint64(1), got.GetSampleCount())
require.Equal(t, expectedDuration, got.GetSampleSum())
})
t.Run("AllStates", func(t *testing.T) {
@@ -312,9 +228,6 @@ func TestUpdateLifecycle(t *testing.T) {
dbM := dbmock.NewMockStore(gomock.NewController(t))
var publishCalled int64
reg := prometheus.NewRegistry()
metrics := agentapi.NewLifecycleMetrics(reg)
api := &agentapi.LifecycleAPI{
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return agent, nil
@@ -322,7 +235,6 @@ func TestUpdateLifecycle(t *testing.T) {
WorkspaceID: workspaceID,
Database: dbM,
Log: testutil.Logger(t),
Metrics: metrics,
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
atomic.AddInt64(&publishCalled, 1)
return nil
@@ -365,20 +277,6 @@ func TestUpdateLifecycle(t *testing.T) {
ReadyAt: expectedReadyAt,
}).Times(1).Return(nil)
// The first ready state triggers the build duration metric query.
if state == agentproto.Lifecycle_READY || state == agentproto.Lifecycle_START_TIMEOUT || state == agentproto.Lifecycle_START_ERROR {
dbM.EXPECT().GetWorkspaceBuildMetricsByResourceID(gomock.Any(), agent.ResourceID).Return(database.GetWorkspaceBuildMetricsByResourceIDRow{
CreatedAt: someTime,
Transition: database.WorkspaceTransitionStart,
TemplateName: "test-template",
OrganizationName: "test-org",
IsPrebuild: false,
AllAgentsReady: true,
LastAgentReadyAt: stateNow,
WorstStatus: "success",
}, nil).MaxTimes(1)
}
resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{
Lifecycle: lifecycle,
})
@@ -424,164 +322,6 @@ func TestUpdateLifecycle(t *testing.T) {
require.Nil(t, resp)
require.False(t, publishCalled)
})
// Test that metric is NOT emitted when not all agents are ready (multi-agent case).
t.Run("MetricNotEmittedWhenNotAllAgentsReady", func(t *testing.T) {
t.Parallel()
lifecycle := &agentproto.Lifecycle{
State: agentproto.Lifecycle_READY,
ChangedAt: timestamppb.New(now),
}
dbM := dbmock.NewMockStore(gomock.NewController(t))
dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), gomock.Any()).Return(nil)
// Return AllAgentsReady = false to simulate multi-agent case where not all are ready.
dbM.EXPECT().GetWorkspaceBuildMetricsByResourceID(gomock.Any(), agentStarting.ResourceID).Return(database.GetWorkspaceBuildMetricsByResourceIDRow{
CreatedAt: someTime,
Transition: database.WorkspaceTransitionStart,
TemplateName: "test-template",
OrganizationName: "test-org",
IsPrebuild: false,
AllAgentsReady: false, // Not all agents ready yet
LastAgentReadyAt: time.Time{}, // No ready time yet
WorstStatus: "success",
}, nil)
reg := prometheus.NewRegistry()
metrics := agentapi.NewLifecycleMetrics(reg)
api := &agentapi.LifecycleAPI{
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return agentStarting, nil
},
WorkspaceID: workspaceID,
Database: dbM,
Log: testutil.Logger(t),
Metrics: metrics,
PublishWorkspaceUpdateFn: nil,
}
resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{
Lifecycle: lifecycle,
})
require.NoError(t, err)
require.Equal(t, lifecycle, resp)
require.Nil(t, promhelp.MetricValue(t, reg, fullMetricName, prometheus.Labels{
"template_name": "test-template",
"organization_name": "test-org",
"transition": "start",
"status": "success",
"is_prebuild": "false",
}), "metric should not be emitted when not all agents are ready")
})
// Test that prebuild label is "true" when owner is prebuild system user.
t.Run("PrebuildLabelTrue", func(t *testing.T) {
t.Parallel()
lifecycle := &agentproto.Lifecycle{
State: agentproto.Lifecycle_READY,
ChangedAt: timestamppb.New(now),
}
dbM := dbmock.NewMockStore(gomock.NewController(t))
dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), gomock.Any()).Return(nil)
dbM.EXPECT().GetWorkspaceBuildMetricsByResourceID(gomock.Any(), agentStarting.ResourceID).Return(database.GetWorkspaceBuildMetricsByResourceIDRow{
CreatedAt: buildCreatedAt,
Transition: database.WorkspaceTransitionStart,
TemplateName: "test-template",
OrganizationName: "test-org",
IsPrebuild: true, // Prebuild workspace
AllAgentsReady: true,
LastAgentReadyAt: agentReadyAt,
WorstStatus: "success",
}, nil)
reg := prometheus.NewRegistry()
metrics := agentapi.NewLifecycleMetrics(reg)
api := &agentapi.LifecycleAPI{
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return agentStarting, nil
},
WorkspaceID: workspaceID,
Database: dbM,
Log: testutil.Logger(t),
Metrics: metrics,
PublishWorkspaceUpdateFn: nil,
}
resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{
Lifecycle: lifecycle,
})
require.NoError(t, err)
require.Equal(t, lifecycle, resp)
got := promhelp.HistogramValue(t, reg, fullMetricName, prometheus.Labels{
"template_name": "test-template",
"organization_name": "test-org",
"transition": "start",
"status": "success",
"is_prebuild": "true",
})
require.Equal(t, uint64(1), got.GetSampleCount())
require.Equal(t, expectedDuration, got.GetSampleSum())
})
// Test worst status is used when one agent has an error.
t.Run("WorstStatusError", func(t *testing.T) {
t.Parallel()
lifecycle := &agentproto.Lifecycle{
State: agentproto.Lifecycle_READY,
ChangedAt: timestamppb.New(now),
}
dbM := dbmock.NewMockStore(gomock.NewController(t))
dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), gomock.Any()).Return(nil)
dbM.EXPECT().GetWorkspaceBuildMetricsByResourceID(gomock.Any(), agentStarting.ResourceID).Return(database.GetWorkspaceBuildMetricsByResourceIDRow{
CreatedAt: buildCreatedAt,
Transition: database.WorkspaceTransitionStart,
TemplateName: "test-template",
OrganizationName: "test-org",
IsPrebuild: false,
AllAgentsReady: true,
LastAgentReadyAt: agentReadyAt,
WorstStatus: "error", // One agent had an error
}, nil)
reg := prometheus.NewRegistry()
metrics := agentapi.NewLifecycleMetrics(reg)
api := &agentapi.LifecycleAPI{
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return agentStarting, nil
},
WorkspaceID: workspaceID,
Database: dbM,
Log: testutil.Logger(t),
Metrics: metrics,
PublishWorkspaceUpdateFn: nil,
}
resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{
Lifecycle: lifecycle,
})
require.NoError(t, err)
require.Equal(t, lifecycle, resp)
got := promhelp.HistogramValue(t, reg, fullMetricName, prometheus.Labels{
"template_name": "test-template",
"organization_name": "test-org",
"transition": "start",
"status": "error",
"is_prebuild": "false",
})
require.Equal(t, uint64(1), got.GetSampleCount())
require.Equal(t, expectedDuration, got.GetSampleSum())
})
}
func TestUpdateStartup(t *testing.T) {
-97
View File
@@ -1,97 +0,0 @@
package agentapi
import (
"context"
"strconv"
"time"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"cdr.dev/slog/v3"
)
// BuildDurationMetricName is the short name for the end-to-end
// workspace build duration histogram. The full metric name is
// prefixed with the namespace "coderd_".
const BuildDurationMetricName = "template_workspace_build_duration_seconds"
// LifecycleMetrics contains Prometheus metrics for the lifecycle API.
type LifecycleMetrics struct {
BuildDuration *prometheus.HistogramVec
}
// NewLifecycleMetrics creates and registers all lifecycle-related
// Prometheus metrics.
//
// The build duration histogram tracks the end-to-end duration from
// workspace build creation to agent ready, by template. It is
// recorded by the coderd replica handling the agent's connection
// when the last agent reports ready. In multi-replica deployments,
// each replica only has observations for agents it handles.
//
// The "is_prebuild" label distinguishes prebuild creation (background,
// no user waiting) from user-initiated builds (regular workspace
// creation or prebuild claims).
func NewLifecycleMetrics(reg prometheus.Registerer) *LifecycleMetrics {
m := &LifecycleMetrics{
BuildDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: "coderd",
Name: BuildDurationMetricName,
Help: "Duration from workspace build creation to agent ready, by template.",
Buckets: []float64{
1, // 1s
10,
30,
60, // 1min
60 * 5,
60 * 10,
60 * 30, // 30min
60 * 60, // 1hr
},
NativeHistogramBucketFactor: 1.1,
NativeHistogramMaxBucketNumber: 100,
NativeHistogramMinResetDuration: time.Hour,
}, []string{"template_name", "organization_name", "transition", "status", "is_prebuild"}),
}
reg.MustRegister(m.BuildDuration)
return m
}
// emitBuildDurationMetric records the end-to-end workspace build
// duration from build creation to when all agents are ready.
func (a *LifecycleAPI) emitBuildDurationMetric(ctx context.Context, resourceID uuid.UUID) {
if a.Metrics == nil {
return
}
buildInfo, err := a.Database.GetWorkspaceBuildMetricsByResourceID(ctx, resourceID)
if err != nil {
a.Log.Warn(ctx, "failed to get build info for metrics", slog.Error(err))
return
}
// Wait until all agents have reached a terminal startup state.
if !buildInfo.AllAgentsReady {
return
}
// LastAgentReadyAt is the MAX(ready_at) across all agents. Since
// we only get here when AllAgentsReady is true, this should always
// be valid.
if buildInfo.LastAgentReadyAt.IsZero() {
a.Log.Warn(ctx, "last_agent_ready_at is unexpectedly zero",
slog.F("last_agent_ready_at", buildInfo.LastAgentReadyAt))
return
}
duration := buildInfo.LastAgentReadyAt.Sub(buildInfo.CreatedAt).Seconds()
a.Metrics.BuildDuration.WithLabelValues(
buildInfo.TemplateName,
buildInfo.OrganizationName,
string(buildInfo.Transition),
buildInfo.WorstStatus,
strconv.FormatBool(buildInfo.IsPrebuild),
).Observe(duration)
}
+4 -21
View File
@@ -977,27 +977,10 @@ func (api *API) authAndDoWithTaskAppClient(
ctx := r.Context()
if task.Status != database.TaskStatusActive {
// Return 409 Conflict for valid requests blocked by current state
// (pending/initializing are transitional, paused requires resume).
// Return 400 Bad Request for error/unknown states.
switch task.Status {
case database.TaskStatusPending, database.TaskStatusInitializing:
return httperror.NewResponseError(http.StatusConflict, codersdk.Response{
Message: fmt.Sprintf("Task is %s.", task.Status),
Detail: "The task is resuming. Wait for the task to become active before sending messages.",
})
case database.TaskStatusPaused:
return httperror.NewResponseError(http.StatusConflict, codersdk.Response{
Message: "Task is paused.",
Detail: "Resume the task to send messages.",
})
default:
// Default handler for error and unknown status.
return httperror.NewResponseError(http.StatusBadRequest, codersdk.Response{
Message: "Task must be active.",
Detail: fmt.Sprintf("Task status is %q, it must be %q to interact with the task.", task.Status, codersdk.TaskStatusActive),
})
}
return httperror.NewResponseError(http.StatusBadRequest, codersdk.Response{
Message: "Task status must be active.",
Detail: fmt.Sprintf("Task status is %q, it must be %q to interact with the task.", task.Status, codersdk.TaskStatusActive),
})
}
if !task.WorkspaceID.Valid {
return httperror.NewResponseError(http.StatusBadRequest, codersdk.Response{
+62 -295
View File
@@ -30,7 +30,6 @@ import (
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/notifications"
"github.com/coder/coder/v2/coderd/notifications/notificationstest"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
@@ -40,66 +39,6 @@ import (
"github.com/coder/quartz"
)
// createTaskInState is a helper to create a task in the desired state.
// It returns a function that takes context, test, and status, and returns the task ID.
// The caller is responsible for setting up the database, owner, and user.
func createTaskInState(db database.Store, ownerSubject rbac.Subject, ownerOrgID, userID uuid.UUID) func(context.Context, *testing.T, database.TaskStatus) uuid.UUID {
return func(ctx context.Context, t *testing.T, status database.TaskStatus) uuid.UUID {
ctx = dbauthz.As(ctx, ownerSubject)
builder := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: ownerOrgID,
OwnerID: userID,
}).
WithTask(database.TaskTable{
OrganizationID: ownerOrgID,
OwnerID: userID,
}, nil)
switch status {
case database.TaskStatusPending:
builder = builder.Pending()
case database.TaskStatusInitializing:
builder = builder.Starting()
case database.TaskStatusPaused:
builder = builder.Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStop,
})
case database.TaskStatusError:
// For error state, create a completed build then manipulate app health.
default:
require.Fail(t, "unsupported task status in test helper", "status: %s", status)
}
resp := builder.Do()
taskID := resp.Task.ID
// Post-process by manipulating agent and app state.
if status == database.TaskStatusError {
// First, set agent to ready state so agent_status returns 'active'.
// This ensures the cascade reaches app_status.
err := db.UpdateWorkspaceAgentLifecycleStateByID(ctx, database.UpdateWorkspaceAgentLifecycleStateByIDParams{
ID: resp.Agents[0].ID,
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
})
require.NoError(t, err)
// Then set workspace app health to unhealthy to trigger error state.
apps, err := db.GetWorkspaceAppsByAgentID(ctx, resp.Agents[0].ID)
require.NoError(t, err)
require.Len(t, apps, 1, "expected exactly one app for task")
err = db.UpdateWorkspaceAppHealthByID(ctx, database.UpdateWorkspaceAppHealthByIDParams{
ID: apps[0].ID,
Health: database.WorkspaceAppHealthUnhealthy,
})
require.NoError(t, err)
}
return taskID
}
}
func TestTasks(t *testing.T) {
t.Parallel()
@@ -459,144 +398,6 @@ func TestTasks(t *testing.T) {
require.NoError(t, err, "should be possible to delete a task with no workspace")
})
t.Run("SnapshotCleanupOnDeletion", func(t *testing.T) {
t.Parallel()
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
user := coderdtest.CreateFirstUser(t, client)
template := createAITemplate(t, client, user)
ctx := testutil.Context(t, testutil.WaitLong)
userObj, err := client.User(ctx, user.UserID.String())
require.NoError(t, err)
userSubject := coderdtest.AuthzUserSubject(userObj)
task, err := client.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
TemplateVersionID: template.ActiveVersionID,
Input: "delete me with snapshot",
})
require.NoError(t, err)
ws, err := client.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
// Create a snapshot for the task.
snapshotJSON := `{"format":"agentapi","data":{"messages":[{"role":"user","content":"test"}]}}`
err = db.UpsertTaskSnapshot(dbauthz.As(ctx, userSubject), database.UpsertTaskSnapshotParams{
TaskID: task.ID,
LogSnapshot: json.RawMessage(snapshotJSON),
LogSnapshotCreatedAt: dbtime.Now(),
})
require.NoError(t, err)
// Verify snapshot exists.
_, err = db.GetTaskSnapshot(dbauthz.As(ctx, userSubject), task.ID)
require.NoError(t, err)
// Delete the task.
err = client.DeleteTask(ctx, "me", task.ID)
require.NoError(t, err, "delete task request should be accepted")
// Verify snapshot no longer exists.
_, err = db.GetTaskSnapshot(dbauthz.As(ctx, userSubject), task.ID)
require.ErrorIs(t, err, sql.ErrNoRows, "snapshot should be deleted with task")
})
t.Run("DeletionWithoutSnapshot", func(t *testing.T) {
t.Parallel()
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
user := coderdtest.CreateFirstUser(t, client)
template := createAITemplate(t, client, user)
ctx := testutil.Context(t, testutil.WaitLong)
userObj, err := client.User(ctx, user.UserID.String())
require.NoError(t, err)
userSubject := coderdtest.AuthzUserSubject(userObj)
task, err := client.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
TemplateVersionID: template.ActiveVersionID,
Input: "delete me without snapshot",
})
require.NoError(t, err)
ws, err := client.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
// Verify no snapshot exists.
_, err = db.GetTaskSnapshot(dbauthz.As(ctx, userSubject), task.ID)
require.ErrorIs(t, err, sql.ErrNoRows, "snapshot should not exist initially")
// Delete the task (should succeed even without snapshot).
err = client.DeleteTask(ctx, "me", task.ID)
require.NoError(t, err, "delete task should succeed even without snapshot")
})
t.Run("PreservesOtherTaskSnapshots", func(t *testing.T) {
t.Parallel()
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
user := coderdtest.CreateFirstUser(t, client)
template := createAITemplate(t, client, user)
ctx := testutil.Context(t, testutil.WaitLong)
userObj, err := client.User(ctx, user.UserID.String())
require.NoError(t, err)
userSubject := coderdtest.AuthzUserSubject(userObj)
// Create task A.
taskA, err := client.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
TemplateVersionID: template.ActiveVersionID,
Input: "task A",
})
require.NoError(t, err)
wsA, err := client.Workspace(ctx, taskA.WorkspaceID.UUID)
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, wsA.LatestBuild.ID)
// Create task B.
taskB, err := client.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
TemplateVersionID: template.ActiveVersionID,
Input: "task B",
})
require.NoError(t, err)
wsB, err := client.Workspace(ctx, taskB.WorkspaceID.UUID)
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, wsB.LatestBuild.ID)
// Create snapshots for both tasks.
snapshotJSONA := `{"format":"agentapi","data":{"messages":[{"role":"user","content":"task A"}]}}`
err = db.UpsertTaskSnapshot(dbauthz.As(ctx, userSubject), database.UpsertTaskSnapshotParams{
TaskID: taskA.ID,
LogSnapshot: json.RawMessage(snapshotJSONA),
LogSnapshotCreatedAt: dbtime.Now(),
})
require.NoError(t, err)
snapshotJSONB := `{"format":"agentapi","data":{"messages":[{"role":"user","content":"task B"}]}}`
err = db.UpsertTaskSnapshot(dbauthz.As(ctx, userSubject), database.UpsertTaskSnapshotParams{
TaskID: taskB.ID,
LogSnapshot: json.RawMessage(snapshotJSONB),
LogSnapshotCreatedAt: dbtime.Now(),
})
require.NoError(t, err)
// Delete task A.
err = client.DeleteTask(ctx, "me", taskA.ID)
require.NoError(t, err, "delete task A should succeed")
// Verify task A's snapshot is removed.
_, err = db.GetTaskSnapshot(dbauthz.As(ctx, userSubject), taskA.ID)
require.ErrorIs(t, err, sql.ErrNoRows, "task A snapshot should be deleted")
// Verify task B's snapshot still exists.
_, err = db.GetTaskSnapshot(dbauthz.As(ctx, userSubject), taskB.ID)
require.NoError(t, err, "task B snapshot should still exist")
})
t.Run("DeletingTaskWorkspaceDeletesTask", func(t *testing.T) {
t.Parallel()
@@ -790,94 +591,6 @@ func TestTasks(t *testing.T) {
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
})
t.Run("SendToNonActiveStates", func(t *testing.T) {
t.Parallel()
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitMedium)
ownerUser, err := client.User(ctx, owner.UserID.String())
require.NoError(t, err)
ownerSubject := coderdtest.AuthzUserSubject(ownerUser)
// Create a regular user for task ownership.
_, user := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
createTask := createTaskInState(db, ownerSubject, owner.OrganizationID, user.ID)
t.Run("Paused", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
taskID := createTask(ctx, t, database.TaskStatusPaused)
err := client.TaskSend(ctx, "me", taskID, codersdk.TaskSendRequest{
Input: "Hello",
})
var sdkErr *codersdk.Error
require.Error(t, err)
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusConflict, sdkErr.StatusCode())
require.Contains(t, sdkErr.Message, "paused")
require.Contains(t, sdkErr.Detail, "Resume")
})
t.Run("Initializing", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
taskID := createTask(ctx, t, database.TaskStatusInitializing)
err := client.TaskSend(ctx, "me", taskID, codersdk.TaskSendRequest{
Input: "Hello",
})
var sdkErr *codersdk.Error
require.Error(t, err)
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusConflict, sdkErr.StatusCode())
require.Contains(t, sdkErr.Message, "initializing")
require.Contains(t, sdkErr.Detail, "resuming")
})
t.Run("Pending", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
taskID := createTask(ctx, t, database.TaskStatusPending)
err := client.TaskSend(ctx, "me", taskID, codersdk.TaskSendRequest{
Input: "Hello",
})
var sdkErr *codersdk.Error
require.Error(t, err)
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusConflict, sdkErr.StatusCode())
require.Contains(t, sdkErr.Message, "pending")
require.Contains(t, sdkErr.Detail, "resuming")
})
t.Run("Error", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
taskID := createTask(ctx, t, database.TaskStatusError)
err := client.TaskSend(ctx, "me", taskID, codersdk.TaskSendRequest{
Input: "Hello",
})
var sdkErr *codersdk.Error
require.Error(t, err)
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
require.Contains(t, sdkErr.Message, "must be active")
})
})
})
t.Run("Logs", func(t *testing.T) {
@@ -1024,7 +737,61 @@ func TestTasks(t *testing.T) {
// Create a regular user to test snapshot access.
client, user := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID)
createTask := createTaskInState(db, ownerSubject, owner.OrganizationID, user.ID)
// Helper to create a task in the desired state.
createTaskInState := func(ctx context.Context, t *testing.T, status database.TaskStatus) uuid.UUID {
ctx = dbauthz.As(ctx, ownerSubject)
builder := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: owner.OrganizationID,
OwnerID: user.ID,
}).
WithTask(database.TaskTable{
OrganizationID: owner.OrganizationID,
OwnerID: user.ID,
}, nil)
switch status {
case database.TaskStatusPending:
builder = builder.Pending()
case database.TaskStatusInitializing:
builder = builder.Starting()
case database.TaskStatusPaused:
builder = builder.Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStop,
})
case database.TaskStatusError:
// For error state, create a completed build then manipulate app health.
default:
require.Fail(t, "unsupported task status in test helper", "status: %s", status)
}
resp := builder.Do()
taskID := resp.Task.ID
// Post-process by manipulating agent and app state.
if status == database.TaskStatusError {
// First, set agent to ready state so agent_status returns 'active'.
// This ensures the cascade reaches app_status.
err := db.UpdateWorkspaceAgentLifecycleStateByID(ctx, database.UpdateWorkspaceAgentLifecycleStateByIDParams{
ID: resp.Agents[0].ID,
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
})
require.NoError(t, err)
// Then set workspace app health to unhealthy to trigger error state.
apps, err := db.GetWorkspaceAppsByAgentID(ctx, resp.Agents[0].ID)
require.NoError(t, err)
require.Len(t, apps, 1, "expected exactly one app for task")
err = db.UpdateWorkspaceAppHealthByID(ctx, database.UpdateWorkspaceAppHealthByIDParams{
ID: apps[0].ID,
Health: database.WorkspaceAppHealthUnhealthy,
})
require.NoError(t, err)
}
return taskID
}
// Prepare snapshot data used across tests.
snapshotMessages := []agentapisdk.Message{
@@ -1086,7 +853,7 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
taskID := createTask(ctx, t, database.TaskStatusPending)
taskID := createTaskInState(ctx, t, database.TaskStatusPending)
err := db.UpsertTaskSnapshot(dbauthz.As(ctx, ownerSubject), database.UpsertTaskSnapshotParams{
TaskID: taskID,
@@ -1104,7 +871,7 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
taskID := createTask(ctx, t, database.TaskStatusInitializing)
taskID := createTaskInState(ctx, t, database.TaskStatusInitializing)
err := db.UpsertTaskSnapshot(dbauthz.As(ctx, ownerSubject), database.UpsertTaskSnapshotParams{
TaskID: taskID,
@@ -1122,7 +889,7 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
taskID := createTask(ctx, t, database.TaskStatusPaused)
taskID := createTaskInState(ctx, t, database.TaskStatusPaused)
err := db.UpsertTaskSnapshot(dbauthz.As(ctx, ownerSubject), database.UpsertTaskSnapshotParams{
TaskID: taskID,
@@ -1140,7 +907,7 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
taskID := createTask(ctx, t, database.TaskStatusPending)
taskID := createTaskInState(ctx, t, database.TaskStatusPending)
logsResp, err := client.TaskLogs(ctx, "me", taskID)
require.NoError(t, err)
@@ -1154,7 +921,7 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
taskID := createTask(ctx, t, database.TaskStatusPending)
taskID := createTaskInState(ctx, t, database.TaskStatusPending)
invalidEnvelope := coderd.TaskLogSnapshotEnvelope{
Format: "unknown-format",
@@ -1183,7 +950,7 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
taskID := createTask(ctx, t, database.TaskStatusPending)
taskID := createTaskInState(ctx, t, database.TaskStatusPending)
err := db.UpsertTaskSnapshot(dbauthz.As(ctx, ownerSubject), database.UpsertTaskSnapshotParams{
TaskID: taskID,
@@ -1204,7 +971,7 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
taskID := createTask(ctx, t, database.TaskStatusError)
taskID := createTaskInState(ctx, t, database.TaskStatusError)
_, err := client.TaskLogs(ctx, "me", taskID)
require.Error(t, err)
+10 -931
View File
File diff suppressed because it is too large Load Diff
+10 -909
View File
File diff suppressed because it is too large Load Diff
+9 -20
View File
@@ -95,26 +95,15 @@ func (t *Tracker) FlushToDB(ctx context.Context, db database.Store, replicaID uu
t.mu.Unlock()
//nolint:gocritic // This is the actual package doing boundary usage tracking.
authCtx := dbauthz.AsBoundaryUsageTracker(ctx)
err := db.InTx(func(tx database.Store) error {
// The advisory lock ensures a clean period cutover by preventing
// this upsert from racing with the aggregate+delete in
// GetAndResetBoundaryUsageSummary. Without it, upserted data
// could be lost or miscounted across periods.
if err := tx.AcquireLock(authCtx, database.LockIDBoundaryUsageStats); err != nil {
return err
}
_, err := tx.UpsertBoundaryUsageStats(authCtx, database.UpsertBoundaryUsageStatsParams{
ReplicaID: replicaID,
UniqueWorkspacesCount: workspaceCount, // cumulative, for UPDATE
UniqueUsersCount: userCount, // cumulative, for UPDATE
UniqueWorkspacesDelta: workspaceDelta, // delta, for INSERT
UniqueUsersDelta: userDelta, // delta, for INSERT
AllowedRequests: allowed,
DeniedRequests: denied,
})
return err
}, nil)
_, err := db.UpsertBoundaryUsageStats(dbauthz.AsBoundaryUsageTracker(ctx), database.UpsertBoundaryUsageStatsParams{
ReplicaID: replicaID,
UniqueWorkspacesCount: workspaceCount, // cumulative, for UPDATE
UniqueUsersCount: userCount, // cumulative, for UPDATE
UniqueWorkspacesDelta: workspaceDelta, // delta, for INSERT
UniqueUsersDelta: userDelta, // delta, for INSERT
AllowedRequests: allowed,
DeniedRequests: denied,
})
// Always reset cumulative counts to prevent unbounded memory growth (e.g.
// if the DB is unreachable). Copy delta maps to preserve any Track() calls
+87 -42
View File
@@ -45,7 +45,7 @@ func TestTracker_Track_Single(t *testing.T) {
// Verify the data was written correctly.
boundaryCtx := dbauthz.AsBoundaryUsageTracker(ctx)
summary, err := db.GetAndResetBoundaryUsageSummary(boundaryCtx, 60000)
summary, err := db.GetBoundaryUsageSummary(boundaryCtx, 60000)
require.NoError(t, err)
require.Equal(t, int64(1), summary.UniqueWorkspaces)
require.Equal(t, int64(1), summary.UniqueUsers)
@@ -73,7 +73,7 @@ func TestTracker_Track_DuplicateWorkspaceUser(t *testing.T) {
require.NoError(t, err)
boundaryCtx := dbauthz.AsBoundaryUsageTracker(ctx)
summary, err := db.GetAndResetBoundaryUsageSummary(boundaryCtx, 60000)
summary, err := db.GetBoundaryUsageSummary(boundaryCtx, 60000)
require.NoError(t, err)
require.Equal(t, int64(1), summary.UniqueWorkspaces, "should be 1 unique workspace")
require.Equal(t, int64(1), summary.UniqueUsers, "should be 1 unique user")
@@ -102,7 +102,7 @@ func TestTracker_Track_MultipleWorkspacesUsers(t *testing.T) {
require.NoError(t, err)
boundaryCtx := dbauthz.AsBoundaryUsageTracker(ctx)
summary, err := db.GetAndResetBoundaryUsageSummary(boundaryCtx, 60000)
summary, err := db.GetBoundaryUsageSummary(boundaryCtx, 60000)
require.NoError(t, err)
require.Equal(t, int64(3), summary.UniqueWorkspaces)
require.Equal(t, int64(2), summary.UniqueUsers)
@@ -140,7 +140,7 @@ func TestTracker_Track_Concurrent(t *testing.T) {
require.NoError(t, err)
boundaryCtx := dbauthz.AsBoundaryUsageTracker(ctx)
summary, err := db.GetAndResetBoundaryUsageSummary(boundaryCtx, 60000)
summary, err := db.GetBoundaryUsageSummary(boundaryCtx, 60000)
require.NoError(t, err)
require.Equal(t, int64(numGoroutines), summary.UniqueWorkspaces)
require.Equal(t, int64(numGoroutines), summary.UniqueUsers)
@@ -175,7 +175,7 @@ func TestTracker_FlushToDB_Accumulates(t *testing.T) {
require.NoError(t, err)
boundaryCtx := dbauthz.AsBoundaryUsageTracker(ctx)
summary, err := db.GetAndResetBoundaryUsageSummary(boundaryCtx, 60000)
summary, err := db.GetBoundaryUsageSummary(boundaryCtx, 60000)
require.NoError(t, err)
require.Equal(t, int64(1), summary.UniqueWorkspaces)
require.Equal(t, int64(1), summary.UniqueUsers)
@@ -202,7 +202,7 @@ func TestTracker_FlushToDB_NewPeriod(t *testing.T) {
require.NoError(t, err)
// Simulate telemetry reset (new period).
_, err = db.GetAndResetBoundaryUsageSummary(boundaryCtx, 60000)
err = db.ResetBoundaryUsageStats(boundaryCtx)
require.NoError(t, err)
// Track new data.
@@ -215,7 +215,7 @@ func TestTracker_FlushToDB_NewPeriod(t *testing.T) {
require.NoError(t, err)
// The summary should only contain the new data after reset.
summary, err := db.GetAndResetBoundaryUsageSummary(boundaryCtx, 60000)
summary, err := db.GetBoundaryUsageSummary(boundaryCtx, 60000)
require.NoError(t, err)
require.Equal(t, int64(1), summary.UniqueWorkspaces, "should only count new workspace")
require.Equal(t, int64(1), summary.UniqueUsers, "should only count new user")
@@ -237,7 +237,7 @@ func TestTracker_FlushToDB_NoActivity(t *testing.T) {
// Verify nothing was written to DB.
boundaryCtx := dbauthz.AsBoundaryUsageTracker(ctx)
summary, err := db.GetAndResetBoundaryUsageSummary(boundaryCtx, 60000)
summary, err := db.GetBoundaryUsageSummary(boundaryCtx, 60000)
require.NoError(t, err)
require.Equal(t, int64(0), summary.UniqueWorkspaces)
require.Equal(t, int64(0), summary.AllowedRequests)
@@ -265,7 +265,7 @@ func TestUpsertBoundaryUsageStats_Insert(t *testing.T) {
require.True(t, newPeriod, "should return true for insert")
// Verify INSERT used the delta values, not cumulative.
summary, err := db.GetAndResetBoundaryUsageSummary(ctx, 60000)
summary, err := db.GetBoundaryUsageSummary(ctx, 60000)
require.NoError(t, err)
require.Equal(t, int64(5), summary.UniqueWorkspaces)
require.Equal(t, int64(3), summary.UniqueUsers)
@@ -301,7 +301,7 @@ func TestUpsertBoundaryUsageStats_Update(t *testing.T) {
require.False(t, newPeriod, "should return false for update")
// Verify UPDATE used cumulative values.
summary, err := db.GetAndResetBoundaryUsageSummary(ctx, 60000)
summary, err := db.GetBoundaryUsageSummary(ctx, 60000)
require.NoError(t, err)
require.Equal(t, int64(8), summary.UniqueWorkspaces)
require.Equal(t, int64(5), summary.UniqueUsers)
@@ -309,7 +309,7 @@ func TestUpsertBoundaryUsageStats_Update(t *testing.T) {
require.Equal(t, int64(10+20), summary.DeniedRequests)
}
func TestGetAndResetBoundaryUsageSummary_MultipleReplicas(t *testing.T) {
func TestGetBoundaryUsageSummary_MultipleReplicas(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
@@ -347,7 +347,7 @@ func TestGetAndResetBoundaryUsageSummary_MultipleReplicas(t *testing.T) {
})
require.NoError(t, err)
summary, err := db.GetAndResetBoundaryUsageSummary(ctx, 60000)
summary, err := db.GetBoundaryUsageSummary(ctx, 60000)
require.NoError(t, err)
// Verify aggregation (SUM of all replicas).
@@ -357,13 +357,13 @@ func TestGetAndResetBoundaryUsageSummary_MultipleReplicas(t *testing.T) {
require.Equal(t, int64(45), summary.DeniedRequests) // 10 + 15 + 20
}
func TestGetAndResetBoundaryUsageSummary_Empty(t *testing.T) {
func TestGetBoundaryUsageSummary_Empty(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := dbauthz.AsBoundaryUsageTracker(context.Background())
summary, err := db.GetAndResetBoundaryUsageSummary(ctx, 60000)
summary, err := db.GetBoundaryUsageSummary(ctx, 60000)
require.NoError(t, err)
// COALESCE should return 0 for all columns.
@@ -373,7 +373,7 @@ func TestGetAndResetBoundaryUsageSummary_Empty(t *testing.T) {
require.Equal(t, int64(0), summary.DeniedRequests)
}
func TestGetAndResetBoundaryUsageSummary_DeletesData(t *testing.T) {
func TestResetBoundaryUsageStats(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
@@ -391,19 +391,61 @@ func TestGetAndResetBoundaryUsageSummary_DeletesData(t *testing.T) {
require.NoError(t, err)
}
// Should return the summary AND delete all data.
summary, err := db.GetAndResetBoundaryUsageSummary(ctx, 60000)
// Verify data exists.
summary, err := db.GetBoundaryUsageSummary(ctx, 60000)
require.NoError(t, err)
require.Greater(t, summary.AllowedRequests, int64(0))
// Reset.
err = db.ResetBoundaryUsageStats(ctx)
require.NoError(t, err)
require.Equal(t, int64(1+2+3+4+5), summary.UniqueWorkspaces)
require.Equal(t, int64(10+20+30+40+50), summary.AllowedRequests)
// Verify all data is gone.
summary, err = db.GetAndResetBoundaryUsageSummary(ctx, 60000)
summary, err = db.GetBoundaryUsageSummary(ctx, 60000)
require.NoError(t, err)
require.Equal(t, int64(0), summary.UniqueWorkspaces)
require.Equal(t, int64(0), summary.AllowedRequests)
}
func TestDeleteBoundaryUsageStatsByReplicaID(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := dbauthz.AsBoundaryUsageTracker(context.Background())
replica1 := uuid.New()
replica2 := uuid.New()
// Insert stats for 2 replicas. Delta fields are used for INSERT.
_, err := db.UpsertBoundaryUsageStats(ctx, database.UpsertBoundaryUsageStatsParams{
ReplicaID: replica1,
UniqueWorkspacesDelta: 10,
UniqueUsersDelta: 5,
AllowedRequests: 100,
DeniedRequests: 10,
})
require.NoError(t, err)
_, err = db.UpsertBoundaryUsageStats(ctx, database.UpsertBoundaryUsageStatsParams{
ReplicaID: replica2,
UniqueWorkspacesDelta: 20,
UniqueUsersDelta: 10,
AllowedRequests: 200,
DeniedRequests: 20,
})
require.NoError(t, err)
// Delete replica1's stats.
err = db.DeleteBoundaryUsageStatsByReplicaID(ctx, replica1)
require.NoError(t, err)
// Verify only replica2's stats remain.
summary, err := db.GetBoundaryUsageSummary(ctx, 60000)
require.NoError(t, err)
require.Equal(t, int64(20), summary.UniqueWorkspaces)
require.Equal(t, int64(200), summary.AllowedRequests)
}
func TestTracker_TelemetryCycle(t *testing.T) {
t.Parallel()
@@ -435,8 +477,8 @@ func TestTracker_TelemetryCycle(t *testing.T) {
require.NoError(t, tracker2.FlushToDB(ctx, db, replica2))
require.NoError(t, tracker3.FlushToDB(ctx, db, replica3))
// Telemetry aggregates and resets (simulating telemetry report sent).
summary, err := db.GetAndResetBoundaryUsageSummary(boundaryCtx, 60000)
// Telemetry aggregates.
summary, err := db.GetBoundaryUsageSummary(boundaryCtx, 60000)
require.NoError(t, err)
// Verify aggregation.
@@ -445,12 +487,15 @@ func TestTracker_TelemetryCycle(t *testing.T) {
require.Equal(t, int64(105), summary.AllowedRequests) // 25 + 75 + 5
require.Equal(t, int64(15), summary.DeniedRequests) // 3 + 12 + 0
// Telemetry resets stats (simulating telemetry report sent).
require.NoError(t, db.ResetBoundaryUsageStats(boundaryCtx))
// Next flush from trackers should detect new period.
tracker1.Track(uuid.New(), uuid.New(), 1, 0)
require.NoError(t, tracker1.FlushToDB(ctx, db, replica1))
// Verify trackers reset their in-memory state.
summary, err = db.GetAndResetBoundaryUsageSummary(boundaryCtx, 60000)
summary, err = db.GetBoundaryUsageSummary(boundaryCtx, 60000)
require.NoError(t, err)
require.Equal(t, int64(1), summary.UniqueWorkspaces)
require.Equal(t, int64(1), summary.AllowedRequests)
@@ -468,24 +513,30 @@ func TestTracker_FlushToDB_NoStaleDataAfterReset(t *testing.T) {
workspaceID := uuid.New()
ownerID := uuid.New()
// Track some data and flush.
// Track some data, flush, and verify.
tracker.Track(workspaceID, ownerID, 10, 5)
err := tracker.FlushToDB(ctx, db, replicaID)
require.NoError(t, err)
// Simulate telemetry reset (new period) - this also verifies the data.
summary, err := db.GetAndResetBoundaryUsageSummary(boundaryCtx, 60000)
summary, err := db.GetBoundaryUsageSummary(boundaryCtx, 60000)
require.NoError(t, err)
require.Equal(t, int64(1), summary.UniqueWorkspaces)
require.Equal(t, int64(10), summary.AllowedRequests)
// Simulate telemetry reset (new period).
err = db.ResetBoundaryUsageStats(boundaryCtx)
require.NoError(t, err)
summary, err = db.GetBoundaryUsageSummary(boundaryCtx, 60000)
require.NoError(t, err)
require.Equal(t, int64(0), summary.AllowedRequests)
// Flush again without any new Track() calls. This should not write stale
// data back to the DB.
err = tracker.FlushToDB(ctx, db, replicaID)
require.NoError(t, err)
// Summary should be empty (no stale data written).
summary, err = db.GetAndResetBoundaryUsageSummary(boundaryCtx, 60000)
summary, err = db.GetBoundaryUsageSummary(boundaryCtx, 60000)
require.NoError(t, err)
require.Equal(t, int64(0), summary.UniqueWorkspaces)
require.Equal(t, int64(0), summary.UniqueUsers)
@@ -531,7 +582,7 @@ func TestTracker_ConcurrentFlushAndTrack(t *testing.T) {
// Verify stats are non-negative.
boundaryCtx := dbauthz.AsBoundaryUsageTracker(ctx)
summary, err := db.GetAndResetBoundaryUsageSummary(boundaryCtx, 60000)
summary, err := db.GetBoundaryUsageSummary(boundaryCtx, 60000)
require.NoError(t, err)
require.GreaterOrEqual(t, summary.AllowedRequests, int64(0))
require.GreaterOrEqual(t, summary.DeniedRequests, int64(0))
@@ -546,17 +597,6 @@ type trackDuringUpsertDB struct {
userID uuid.UUID
}
func (s *trackDuringUpsertDB) InTx(fn func(database.Store) error, opts *database.TxOptions) error {
return s.Store.InTx(func(tx database.Store) error {
return fn(&trackDuringUpsertDB{
Store: tx,
tracker: s.tracker,
workspaceID: s.workspaceID,
userID: s.userID,
})
}, opts)
}
func (s *trackDuringUpsertDB) UpsertBoundaryUsageStats(ctx context.Context, arg database.UpsertBoundaryUsageStatsParams) (bool, error) {
s.tracker.Track(s.workspaceID, s.userID, 20, 10)
return s.Store.UpsertBoundaryUsageStats(ctx, arg)
@@ -586,12 +626,17 @@ func TestTracker_TrackDuringFlush(t *testing.T) {
err := tracker.FlushToDB(ctx, trackingDB, replicaID)
require.NoError(t, err)
// Second flush captures the Track() that happened during the first flush.
// Verify first flush only wrote the initial data.
summary, err := db.GetBoundaryUsageSummary(boundaryCtx, 60000)
require.NoError(t, err)
require.Equal(t, int64(10), summary.AllowedRequests)
// The second flush should include the Track() call that happened during the
// first flush's DB operation.
err = tracker.FlushToDB(ctx, db, replicaID)
require.NoError(t, err)
// Verify both flushes are in the summary.
summary, err := db.GetAndResetBoundaryUsageSummary(boundaryCtx, 60000)
summary, err = db.GetBoundaryUsageSummary(boundaryCtx, 60000)
require.NoError(t, err)
require.Equal(t, int64(10+20), summary.AllowedRequests)
require.Equal(t, int64(5+10), summary.DeniedRequests)
-20
View File
@@ -1,20 +0,0 @@
Copyright (c) 2015-present Peter Kieltyka (https://github.com/pkieltyka), Google Inc.
MIT License
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
-440
View File
@@ -1,440 +0,0 @@
// Package cachecompress creates a compressed cache of static files based on an http.FS. It is modified from
// https://github.com/go-chi/chi Compressor middleware. See the LICENSE file in this directory for copyright
// information.
package cachecompress
import (
"compress/flate"
"compress/gzip"
"context"
"encoding/base64"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
)
type cacheKey struct {
encoding string
urlPath string
}
func (c cacheKey) filePath(cacheDir string) string {
// URLs can have slashes or other characters we don't want the file system interpreting. So we just encode the path
// to a flat base64 filename.
filename := base64.URLEncoding.EncodeToString([]byte(c.urlPath))
return filepath.Join(cacheDir, c.encoding, filename)
}
func getCacheKey(encoding string, r *http.Request) cacheKey {
return cacheKey{
encoding: encoding,
urlPath: r.URL.Path,
}
}
type ref struct {
key cacheKey
done chan struct{}
err chan error
}
// Compressor represents a set of encoding configurations.
type Compressor struct {
logger slog.Logger
// The mapping of encoder names to encoder functions.
encoders map[string]EncoderFunc
// The mapping of pooled encoders to pools.
pooledEncoders map[string]*sync.Pool
// The list of encoders in order of decreasing precedence.
encodingPrecedence []string
level int // The compression level.
cacheDir string
orig http.FileSystem
mu sync.Mutex
cache map[cacheKey]ref
}
// NewCompressor creates a new Compressor that will handle encoding responses.
//
// The level should be one of the ones defined in the flate package.
// The types are the content types that are allowed to be compressed.
func NewCompressor(logger slog.Logger, level int, cacheDir string, orig http.FileSystem) *Compressor {
c := &Compressor{
logger: logger.Named("cachecompress"),
level: level,
encoders: make(map[string]EncoderFunc),
pooledEncoders: make(map[string]*sync.Pool),
cacheDir: cacheDir,
orig: orig,
cache: make(map[cacheKey]ref),
}
// Set the default encoders. The precedence order uses the reverse
// ordering that the encoders were added. This means adding new encoders
// will move them to the front of the order.
//
// TODO:
// lzma: Opera.
// sdch: Chrome, Android. Gzip output + dictionary header.
// br: Brotli, see https://github.com/go-chi/chi/pull/326
// HTTP 1.1 "deflate" (RFC 2616) stands for DEFLATE data (RFC 1951)
// wrapped with zlib (RFC 1950). The zlib wrapper uses Adler-32
// checksum compared to CRC-32 used in "gzip" and thus is faster.
//
// But.. some old browsers (MSIE, Safari 5.1) incorrectly expect
// raw DEFLATE data only, without the mentioned zlib wrapper.
// Because of this major confusion, most modern browsers try it
// both ways, first looking for zlib headers.
// Quote by Mark Adler: http://stackoverflow.com/a/9186091/385548
//
// The list of browsers having problems is quite big, see:
// http://zoompf.com/blog/2012/02/lose-the-wait-http-compression
// https://web.archive.org/web/20120321182910/http://www.vervestudios.co/projects/compression-tests/results
//
// That's why we prefer gzip over deflate. It's just more reliable
// and not significantly slower than deflate.
c.SetEncoder("deflate", encoderDeflate)
// TODO: Exception for old MSIE browsers that can't handle non-HTML?
// https://zoompf.com/blog/2012/02/lose-the-wait-http-compression
c.SetEncoder("gzip", encoderGzip)
// NOTE: Not implemented, intentionally:
// case "compress": // LZW. Deprecated.
// case "bzip2": // Too slow on-the-fly.
// case "zopfli": // Too slow on-the-fly.
// case "xz": // Too slow on-the-fly.
return c
}
// SetEncoder can be used to set the implementation of a compression algorithm.
//
// The encoding should be a standardized identifier. See:
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding
//
// For example, add the Brotli algorithm:
//
// import brotli_enc "gopkg.in/kothar/brotli-go.v0/enc"
//
// compressor := middleware.NewCompressor(5, "text/html")
// compressor.SetEncoder("br", func(w io.Writer, level int) io.Writer {
// params := brotli_enc.NewBrotliParams()
// params.SetQuality(level)
// return brotli_enc.NewBrotliWriter(params, w)
// })
func (c *Compressor) SetEncoder(encoding string, fn EncoderFunc) {
encoding = strings.ToLower(encoding)
if encoding == "" {
panic("the encoding can not be empty")
}
if fn == nil {
panic("attempted to set a nil encoder function")
}
// If we are adding a new encoder that is already registered, we have to
// clear that one out first.
delete(c.pooledEncoders, encoding)
delete(c.encoders, encoding)
// If the encoder supports Resetting (IoReseterWriter), then it can be pooled.
encoder := fn(io.Discard, c.level)
if _, ok := encoder.(ioResetterWriter); ok {
pool := &sync.Pool{
New: func() interface{} {
return fn(io.Discard, c.level)
},
}
c.pooledEncoders[encoding] = pool
}
// If the encoder is not in the pooledEncoders, add it to the normal encoders.
if _, ok := c.pooledEncoders[encoding]; !ok {
c.encoders[encoding] = fn
}
for i, v := range c.encodingPrecedence {
if v == encoding {
c.encodingPrecedence = append(c.encodingPrecedence[:i], c.encodingPrecedence[i+1:]...)
}
}
c.encodingPrecedence = append([]string{encoding}, c.encodingPrecedence...)
}
// ServeHTTP returns the response from the orig file system, compressed if possible.
func (c *Compressor) ServeHTTP(w http.ResponseWriter, r *http.Request) {
encoding := c.selectEncoder(r.Header)
// we can only serve a cached response if all the following:
// 1. they requested an encoding we support
// 2. they are requesting the whole file, not a range
// 3. the method is GET
if encoding == "" || r.Header.Get("Range") != "" || r.Method != "GET" {
http.FileServer(c.orig).ServeHTTP(w, r)
return
}
// Whether we should serve a cached response also depends in a fairly complex way on the path and request
// headers. In particular, we don't need a cached response for non-existing files/directories, and should not serve
// a cached response if the correct Etag for the file is provided. This logic is all handled by the http.FileServer,
// and we don't want to reimplement it here. So, what we'll do is send a HEAD request to the http.FileServer to see
// what it would do.
headReq := r.Clone(r.Context())
headReq.Method = http.MethodHead
headRW := &compressResponseWriter{
w: io.Discard,
headers: make(http.Header),
}
// deep-copy the headers already set on the response. This includes things like ETags.
for key, values := range w.Header() {
for _, value := range values {
headRW.headers.Add(key, value)
}
}
http.FileServer(c.orig).ServeHTTP(headRW, headReq)
if headRW.code != http.StatusOK {
// again, fall back to the file server. This is often a 404 Not Found, or a 304 Not Modified if they provided
// the correct ETag.
http.FileServer(c.orig).ServeHTTP(w, r)
return
}
cref := c.getRef(encoding, r)
c.serveRef(w, r, headRW.headers, cref)
}
func (c *Compressor) serveRef(w http.ResponseWriter, r *http.Request, headers http.Header, cref ref) {
select {
case <-r.Context().Done():
w.WriteHeader(http.StatusServiceUnavailable)
return
case <-cref.done:
cachePath := cref.key.filePath(c.cacheDir)
cacheFile, err := os.Open(cachePath)
if err != nil {
c.logger.Error(context.Background(), "failed to open compressed cache file",
slog.F("cache_path", cachePath), slog.F("url_path", cref.key.urlPath), slog.Error(err))
// fall back to uncompressed
http.FileServer(c.orig).ServeHTTP(w, r)
}
defer cacheFile.Close()
// we need to remove or modify the Content-Length, if any, set by the FileServer because it will be for
// uncompressed data and wrong.
info, err := cacheFile.Stat()
if err != nil {
c.logger.Error(context.Background(), "failed to stat compressed cache file",
slog.F("cache_path", cachePath), slog.F("url_path", cref.key.urlPath), slog.Error(err))
headers.Del("Content-Length")
} else {
headers.Set("Content-Length", fmt.Sprintf("%d", info.Size()))
}
for key, values := range headers {
for _, value := range values {
w.Header().Add(key, value)
}
}
w.Header().Set("Content-Encoding", cref.key.encoding)
w.Header().Add("Vary", "Accept-Encoding")
w.WriteHeader(http.StatusOK)
_, err = io.Copy(w, cacheFile)
if err != nil {
// most commonly, the writer will hang up before we are done.
c.logger.Debug(context.Background(), "failed to write compressed cache file", slog.Error(err))
}
return
case <-cref.err:
// fall back to uncompressed
http.FileServer(c.orig).ServeHTTP(w, r)
return
}
}
func (c *Compressor) getRef(encoding string, r *http.Request) ref {
ck := getCacheKey(encoding, r)
c.mu.Lock()
defer c.mu.Unlock()
cref, ok := c.cache[ck]
if ok {
return cref
}
// we are the first to encode
cref = ref{
key: ck,
done: make(chan struct{}),
err: make(chan error),
}
c.cache[ck] = cref
go c.compress(context.Background(), encoding, cref, r)
return cref
}
func (c *Compressor) compress(ctx context.Context, encoding string, cref ref, r *http.Request) {
cachePath := cref.key.filePath(c.cacheDir)
var err error
// we want to handle closing either cref.done or cref.err in a defer at the bottom of the stack so that the encoder
// and cache file are both closed first (higher in the defer stack). This prevents data races where waiting HTTP
// handlers start reading the file before all the data has been flushed.
defer func() {
if err != nil {
if rErr := os.Remove(cachePath); rErr != nil {
// nolint: gocritic // best effort, just debug log any errors
c.logger.Debug(ctx, "failed to remove cache file",
slog.F("main_err", err), slog.F("remove_err", rErr), slog.F("cache_path", cachePath))
}
c.mu.Lock()
delete(c.cache, cref.key)
c.mu.Unlock()
close(cref.err)
return
}
close(cref.done)
}()
cacheDir := filepath.Dir(cachePath)
err = os.MkdirAll(cacheDir, 0o700)
if err != nil {
c.logger.Error(ctx, "failed to create cache directory", slog.F("cache_dir", cacheDir))
return
}
// We will truncate and overwrite any existing files. This is important in the case that we get restarted
// with the same cache dir, possibly with different source files.
cacheFile, err := os.OpenFile(cachePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600)
if err != nil {
c.logger.Error(ctx, "failed to open compression cache file",
slog.F("path", cachePath), slog.Error(err))
return
}
defer cacheFile.Close()
encoder, cleanup := c.getEncoder(encoding, cacheFile)
if encoder == nil {
// can only hit this if there is a programming error
c.logger.Critical(ctx, "got nil encoder", slog.F("encoding", encoding))
err = xerrors.New("nil encoder")
return
}
defer cleanup()
defer encoder.Close() // ensures we flush, needs to be called before cleanup(), so we defer after it.
cw := &compressResponseWriter{
w: encoder,
headers: make(http.Header), // ignored
}
http.FileServer(c.orig).ServeHTTP(cw, r)
if cw.code != http.StatusOK {
// log at debug because this is likely just a 404
c.logger.Debug(ctx, "file server failed to serve",
slog.F("encoding", encoding), slog.F("url_path", cref.key.urlPath), slog.F("http_code", cw.code))
// mark the error so that we clean up correctly
err = xerrors.New("file server failed to serve")
return
}
// success!
}
// selectEncoder returns the name of the encoder
func (c *Compressor) selectEncoder(h http.Header) string {
header := h.Get("Accept-Encoding")
// Parse the names of all accepted algorithms from the header.
accepted := strings.Split(strings.ToLower(header), ",")
// Find supported encoder by accepted list by precedence
for _, name := range c.encodingPrecedence {
if matchAcceptEncoding(accepted, name) {
return name
}
}
// No encoder found to match the accepted encoding
return ""
}
// getEncoder returns a writer that encodes and writes to the provided writer, and a cleanup func.
func (c *Compressor) getEncoder(name string, w io.Writer) (io.WriteCloser, func()) {
if pool, ok := c.pooledEncoders[name]; ok {
encoder, typeOK := pool.Get().(ioResetterWriter)
if !typeOK {
return nil, nil
}
cleanup := func() {
pool.Put(encoder)
}
encoder.Reset(w)
return encoder, cleanup
}
if fn, ok := c.encoders[name]; ok {
return fn(w, c.level), func() {}
}
return nil, nil
}
func matchAcceptEncoding(accepted []string, encoding string) bool {
for _, v := range accepted {
if strings.Contains(v, encoding) {
return true
}
}
return false
}
// An EncoderFunc is a function that wraps the provided io.Writer with a
// streaming compression algorithm and returns it.
//
// In case of failure, the function should return nil.
type EncoderFunc func(w io.Writer, level int) io.WriteCloser
// Interface for types that allow resetting io.Writers.
type ioResetterWriter interface {
io.WriteCloser
Reset(w io.Writer)
}
func encoderGzip(w io.Writer, level int) io.WriteCloser {
gw, err := gzip.NewWriterLevel(w, level)
if err != nil {
return nil
}
return gw
}
func encoderDeflate(w io.Writer, level int) io.WriteCloser {
dw, err := flate.NewWriter(w, level)
if err != nil {
return nil
}
return dw
}
type compressResponseWriter struct {
w io.Writer
headers http.Header
code int
}
func (cw *compressResponseWriter) Header() http.Header {
return cw.headers
}
func (cw *compressResponseWriter) WriteHeader(code int) {
cw.code = code
}
func (cw *compressResponseWriter) Write(p []byte) (int, error) {
if cw.code == 0 {
cw.code = http.StatusOK
}
return cw.w.Write(p)
}
@@ -1,227 +0,0 @@
package cachecompress
import (
"bytes"
"compress/flate"
"compress/gzip"
"context"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/testutil"
)
func TestCompressorEncodings(t *testing.T) {
t.Parallel()
tests := []struct {
name string
path string
expectedEncoding string
acceptedEncodings []string
}{
{
name: "no expected encodings due to no accepted encodings",
path: "/file.html",
acceptedEncodings: nil,
expectedEncoding: "",
},
{
name: "gzip is only encoding",
path: "/file.html",
acceptedEncodings: []string{"gzip"},
expectedEncoding: "gzip",
},
{
name: "gzip is preferred over deflate",
path: "/file.html",
acceptedEncodings: []string{"gzip", "deflate"},
expectedEncoding: "gzip",
},
{
name: "deflate is used",
path: "/file.html",
acceptedEncodings: []string{"deflate"},
expectedEncoding: "deflate",
},
{
name: "nop is preferred",
path: "/file.html",
acceptedEncodings: []string{"nop, gzip, deflate"},
expectedEncoding: "nop",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
logger := testutil.Logger(t)
tempDir := t.TempDir()
cacheDir := filepath.Join(tempDir, "cache")
err := os.MkdirAll(cacheDir, 0o700)
require.NoError(t, err)
srcDir := filepath.Join(tempDir, "src")
err = os.MkdirAll(srcDir, 0o700)
require.NoError(t, err)
err = os.WriteFile(filepath.Join(srcDir, "file.html"), []byte("textstring"), 0o600)
require.NoError(t, err)
compressor := NewCompressor(logger, 5, cacheDir, http.FS(os.DirFS(srcDir)))
if len(compressor.encoders) != 0 || len(compressor.pooledEncoders) != 2 {
t.Errorf("gzip and deflate should be pooled")
}
logger.Debug(context.Background(), "started compressor")
compressor.SetEncoder("nop", func(w io.Writer, _ int) io.WriteCloser {
return nopEncoder{w}
})
if len(compressor.encoders) != 1 {
t.Errorf("nop encoder should be stored in the encoders map")
}
ts := httptest.NewServer(compressor)
defer ts.Close()
// ctx := testutil.Context(t, testutil.WaitShort)
ctx := context.Background()
header, respString := testRequestWithAcceptedEncodings(ctx, t, ts, "GET", tc.path, tc.acceptedEncodings...)
if respString != "textstring" {
t.Errorf("response text doesn't match; expected:%q, got:%q", "textstring", respString)
}
if got := header.Get("Content-Encoding"); got != tc.expectedEncoding {
t.Errorf("expected encoding %q but got %q", tc.expectedEncoding, got)
}
})
}
}
func testRequestWithAcceptedEncodings(ctx context.Context, t *testing.T, ts *httptest.Server, method, path string, encodings ...string) (http.Header, string) {
req, err := http.NewRequestWithContext(ctx, method, ts.URL+path, nil)
if err != nil {
t.Fatal(err)
return nil, ""
}
if len(encodings) > 0 {
encodingsString := strings.Join(encodings, ",")
req.Header.Set("Accept-Encoding", encodingsString)
}
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.DisableCompression = true // prevent automatically setting gzip
resp, err := (&http.Client{Transport: transport}).Do(req)
require.NoError(t, err)
respBody := decodeResponseBody(t, resp)
defer resp.Body.Close()
return resp.Header, respBody
}
func decodeResponseBody(t *testing.T, resp *http.Response) string {
var reader io.ReadCloser
t.Logf("encoding: '%s'", resp.Header.Get("Content-Encoding"))
rawBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)
t.Logf("raw body: %x", rawBody)
switch resp.Header.Get("Content-Encoding") {
case "gzip":
var err error
reader, err = gzip.NewReader(bytes.NewReader(rawBody))
require.NoError(t, err)
case "deflate":
reader = flate.NewReader(bytes.NewReader(rawBody))
default:
return string(rawBody)
}
respBody, err := io.ReadAll(reader)
require.NoError(t, err, "failed to read response body: %T %+v", err, err)
err = reader.Close()
require.NoError(t, err)
return string(respBody)
}
type nopEncoder struct {
io.Writer
}
func (nopEncoder) Close() error { return nil }
// nolint: tparallel // we want to assert the state of the cache, so run synchronously
func TestCompressorHeadings(t *testing.T) {
t.Parallel()
logger := testutil.Logger(t)
tempDir := t.TempDir()
cacheDir := filepath.Join(tempDir, "cache")
err := os.MkdirAll(cacheDir, 0o700)
require.NoError(t, err)
srcDir := filepath.Join(tempDir, "src")
err = os.MkdirAll(srcDir, 0o700)
require.NoError(t, err)
err = os.WriteFile(filepath.Join(srcDir, "file.html"), []byte("textstring"), 0o600)
require.NoError(t, err)
compressor := NewCompressor(logger, 5, cacheDir, http.FS(os.DirFS(srcDir)))
ts := httptest.NewServer(compressor)
defer ts.Close()
tests := []struct {
name string
path string
}{
{
name: "exists",
path: "/file.html",
},
{
name: "not found",
path: "/missing.html",
},
{
name: "not found directory",
path: "/a_directory/",
},
}
// nolint: paralleltest // we want to assert the state of the cache, so run synchronously
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitShort)
req := httptest.NewRequestWithContext(ctx, "GET", tc.path, nil)
// request directly from http.FileServer as our baseline response
respROrig := httptest.NewRecorder()
http.FileServer(http.Dir(srcDir)).ServeHTTP(respROrig, req)
respOrig := respROrig.Result()
req.Header.Add("Accept-Encoding", "gzip")
// serve twice so that we go thru cache hit and cache miss code
for range 2 {
respRec := httptest.NewRecorder()
compressor.ServeHTTP(respRec, req)
respComp := respRec.Result()
require.Equal(t, respOrig.StatusCode, respComp.StatusCode)
for key, values := range respOrig.Header {
if key == "Content-Length" {
continue // we don't get length on compressed responses
}
require.Equal(t, values, respComp.Header[key])
}
}
})
}
// only the cache hit should leave a file around
files, err := os.ReadDir(srcDir)
require.NoError(t, err)
require.Len(t, files, 1)
}
+30 -67
View File
@@ -21,9 +21,11 @@ import (
"sync/atomic"
"time"
"github.com/andybalholm/brotli"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/google/uuid"
"github.com/klauspost/compress/zstd"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"github.com/prometheus/client_golang/prometheus/promhttp"
@@ -42,7 +44,6 @@ import (
"cdr.dev/slog/v3"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/coderd/agentapi"
"github.com/coder/coder/v2/coderd/agentapi/metadatabatcher"
_ "github.com/coder/coder/v2/coderd/apidoc" // Used for swagger docs.
"github.com/coder/coder/v2/coderd/appearance"
@@ -90,7 +91,6 @@ import (
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
"github.com/coder/coder/v2/coderd/workspacestats"
"github.com/coder/coder/v2/coderd/wsbuilder"
"github.com/coder/coder/v2/coderd/wspubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/drpcsdk"
"github.com/coder/coder/v2/codersdk/healthsdk"
@@ -99,8 +99,6 @@ import (
"github.com/coder/coder/v2/provisionersdk"
"github.com/coder/coder/v2/site"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/eventsink"
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/quartz"
"github.com/coder/serpent"
)
@@ -417,8 +415,7 @@ func New(options *Options) *API {
options.NetworkTelemetryBatchMaxSize = 1_000
}
if options.TailnetCoordinator == nil {
eventSink := eventsink.NewEventSink(context.Background(), options.Database, options.Logger)
options.TailnetCoordinator = tailnet.NewCoordinator(options.Logger, eventSink)
options.TailnetCoordinator = tailnet.NewCoordinator(options.Logger)
}
if options.Auditor == nil {
options.Auditor = audit.NewNop()
@@ -465,6 +462,10 @@ func New(options *Options) *API {
if siteCacheDir != "" {
siteCacheDir = filepath.Join(siteCacheDir, "site")
}
binFS, binHashes, err := site.ExtractOrReadBinFS(siteCacheDir, site.FS())
if err != nil {
panic(xerrors.Errorf("read site bin failed: %w", err))
}
metricsCache := metricscache.New(
options.Database,
@@ -657,8 +658,9 @@ func New(options *Options) *API {
WebPushPublicKey: api.WebpushDispatcher.PublicKey(),
Telemetry: api.Telemetry.Enabled(),
}
api.SiteHandler, err = site.New(&site.Options{
CacheDir: siteCacheDir,
api.SiteHandler = site.New(&site.Options{
BinFS: binFS,
BinHashes: binHashes,
Database: options.Database,
SiteFS: site.FS(),
OAuth2Configs: oauthConfigs,
@@ -670,9 +672,6 @@ func New(options *Options) *API {
Logger: options.Logger.Named("site"),
HideAITasks: options.DeploymentValues.HideAITasks.Value(),
})
if err != nil {
options.Logger.Fatal(ctx, "failed to initialize site handler", slog.Error(err))
}
api.SiteHandler.Experiments.Store(&experiments)
if options.UpdateCheckOptions != nil {
@@ -738,23 +737,20 @@ func New(options *Options) *API {
api.Auditor.Store(&options.Auditor)
api.ConnectionLogger.Store(&options.ConnectionLogger)
api.TailnetCoordinator.Store(&options.TailnetCoordinator)
serverTailnetID := uuid.New()
dialer := &InmemTailnetDialer{
CoordPtr: &api.TailnetCoordinator,
DERPFn: api.DERPMap,
Logger: options.Logger,
ClientID: serverTailnetID,
ClientID: uuid.New(),
DatabaseHealthCheck: api.Database,
}
stn, err := NewServerTailnet(api.ctx,
options.Logger,
options.DERPServer,
serverTailnetID,
dialer,
options.DeploymentValues.DERP.Config.ForceWebSockets.Value(),
options.DeploymentValues.DERP.Config.BlockDirect.Value(),
api.TracerProvider,
"Coder Server",
)
if err != nil {
panic("failed to setup server tailnet: " + err.Error())
@@ -762,7 +758,6 @@ func New(options *Options) *API {
api.agentProvider = stn
if options.DeploymentValues.Prometheus.Enable {
options.PrometheusRegistry.MustRegister(stn)
api.lifecycleMetrics = agentapi.NewLifecycleMetrics(options.PrometheusRegistry)
}
api.NetworkTelemetryBatcher = tailnet.NewNetworkTelemetryBatcher(
quartz.NewReal(),
@@ -770,19 +765,17 @@ func New(options *Options) *API {
api.Options.NetworkTelemetryBatchMaxSize,
api.handleNetworkTelemetry,
)
api.PeerNetworkTelemetryStore = NewPeerNetworkTelemetryStore()
if options.CoordinatorResumeTokenProvider == nil {
panic("CoordinatorResumeTokenProvider is nil")
}
api.TailnetClientService, err = tailnet.NewClientService(tailnet.ClientServiceOptions{
Logger: api.Logger.Named("tailnetclient"),
CoordPtr: &api.TailnetCoordinator,
DERPMapUpdateFrequency: api.Options.DERPMapUpdateFrequency,
DERPMapFn: api.DERPMap,
NetworkTelemetryHandler: api.NetworkTelemetryBatcher.Handler,
IdentifiedTelemetryHandler: api.handleIdentifiedTelemetry,
ResumeTokenProvider: api.Options.CoordinatorResumeTokenProvider,
WorkspaceUpdatesProvider: api.UpdatesProvider,
Logger: api.Logger.Named("tailnetclient"),
CoordPtr: &api.TailnetCoordinator,
DERPMapUpdateFrequency: api.Options.DERPMapUpdateFrequency,
DERPMapFn: api.DERPMap,
NetworkTelemetryHandler: api.NetworkTelemetryBatcher.Handler,
ResumeTokenProvider: api.Options.CoordinatorResumeTokenProvider,
WorkspaceUpdatesProvider: api.UpdatesProvider,
})
if err != nil {
api.Logger.Fatal(context.Background(), "failed to initialize tailnet client service", slog.Error(err))
@@ -1526,7 +1519,6 @@ func New(options *Options) *API {
r.Delete("/", api.deleteWorkspaceAgentPortShare)
})
r.Get("/timings", api.workspaceTimings)
r.Get("/sessions", api.workspaceSessions)
r.Route("/acl", func(r chi.Router) {
r.Use(
httpmw.RequireExperiment(api.Experiments, codersdk.ExperimentWorkspaceSharing),
@@ -1838,7 +1830,6 @@ type API struct {
WorkspaceClientCoordinateOverride atomic.Pointer[func(rw http.ResponseWriter) bool]
TailnetCoordinator atomic.Pointer[tailnet.Coordinator]
NetworkTelemetryBatcher *tailnet.NetworkTelemetryBatcher
PeerNetworkTelemetryStore *PeerNetworkTelemetryStore
TailnetClientService *tailnet.ClientService
// WebpushDispatcher is a way to send notifications to users via Web Push.
WebpushDispatcher webpush.Dispatcher
@@ -1901,9 +1892,8 @@ type API struct {
healthCheckCache atomic.Pointer[healthsdk.HealthcheckReport]
healthCheckProgress healthcheck.Progress
statsReporter *workspacestats.Reporter
metadataBatcher *metadatabatcher.Batcher
lifecycleMetrics *agentapi.LifecycleMetrics
statsReporter *workspacestats.Reporter
metadataBatcher *metadatabatcher.Batcher
Acquirer *provisionerdserver.Acquirer
// dbRolluper rolls up template usage stats from raw agent and app
@@ -1973,36 +1963,6 @@ func (api *API) Close() error {
return nil
}
// handleIdentifiedTelemetry stores peer telemetry events and publishes a
// workspace update so watch subscribers see fresh data.
func (api *API) handleIdentifiedTelemetry(agentID, peerID uuid.UUID, events []*tailnetproto.TelemetryEvent) {
if len(events) == 0 {
return
}
for _, event := range events {
api.PeerNetworkTelemetryStore.Update(agentID, peerID, event)
}
// Telemetry callback runs outside any user request, so we use a system
// context to look up the workspace for the pubsub notification.
ctx := dbauthz.AsSystemRestricted(context.Background()) //nolint:gocritic // Telemetry callback has no user context.
workspace, err := api.Database.GetWorkspaceByAgentID(ctx, agentID)
if err != nil {
api.Logger.Warn(ctx, "failed to resolve workspace for telemetry update",
slog.F("agent_id", agentID),
slog.Error(err),
)
return
}
api.publishWorkspaceUpdate(ctx, workspace.OwnerID, wspubsub.WorkspaceEvent{
Kind: wspubsub.WorkspaceEventKindConnectionLogUpdate,
WorkspaceID: workspace.ID,
AgentID: &agentID,
})
}
func compressHandler(h http.Handler) http.Handler {
level := 5
if flag.Lookup("test.v") != nil {
@@ -2014,13 +1974,16 @@ func compressHandler(h http.Handler) http.Handler {
"application/*",
"image/*",
)
for encoding := range site.StandardEncoders {
writeCloserFn := site.StandardEncoders[encoding]
cmp.SetEncoder(encoding, func(w io.Writer, level int) io.Writer {
writeCloser := writeCloserFn(w, level)
return writeCloser
})
}
cmp.SetEncoder("br", func(w io.Writer, level int) io.Writer {
return brotli.NewWriterLevel(w, level)
})
cmp.SetEncoder("zstd", func(w io.Writer, level int) io.Writer {
zw, err := zstd.NewWriter(w, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(level)))
if err != nil {
panic("invalid zstd compressor: " + err.Error())
}
return zw
})
return cmp.Handler(h)
}
-4
View File
@@ -82,10 +82,6 @@ func (m *FakeConnectionLogger) Contains(t testing.TB, expected database.UpsertCo
t.Logf("connection log %d: expected AgentName %s, got %s", idx+1, expected.AgentName, cl.AgentName)
continue
}
if expected.AgentID.Valid && cl.AgentID.UUID != expected.AgentID.UUID {
t.Logf("connection log %d: expected AgentID %s, got %s", idx+1, expected.AgentID.UUID, cl.AgentID.UUID)
continue
}
if expected.Type != "" && cl.Type != expected.Type {
t.Logf("connection log %d: expected Type %s, got %s", idx+1, expected.Type, cl.Type)
continue
@@ -1,938 +0,0 @@
package database_test
import (
"context"
"database/sql"
"fmt"
"net"
"testing"
"time"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime"
)
func TestCloseOpenAgentConnectionLogsForWorkspace(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
u := dbgen.User(t, db, database.User{})
o := dbgen.Organization(t, db, database.Organization{})
tpl := dbgen.Template(t, db, database.Template{
OrganizationID: o.ID,
CreatedBy: u.ID,
})
ws1 := dbgen.Workspace(t, db, database.WorkspaceTable{
ID: uuid.New(),
OwnerID: u.ID,
OrganizationID: o.ID,
AutomaticUpdates: database.AutomaticUpdatesNever,
TemplateID: tpl.ID,
})
ws2 := dbgen.Workspace(t, db, database.WorkspaceTable{
ID: uuid.New(),
OwnerID: u.ID,
OrganizationID: o.ID,
AutomaticUpdates: database.AutomaticUpdatesNever,
TemplateID: tpl.ID,
})
ip := pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
}
// Simulate agent clock skew by using a connect time in the future.
connectTime := dbtime.Now().Add(time.Hour)
sshLog1, err := db.UpsertConnectionLog(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: connectTime,
OrganizationID: ws1.OrganizationID,
WorkspaceOwnerID: ws1.OwnerID,
WorkspaceID: ws1.ID,
WorkspaceName: ws1.Name,
AgentName: "agent",
Type: database.ConnectionTypeSsh,
Ip: ip,
ConnectionID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
})
require.NoError(t, err)
appLog, err := db.UpsertConnectionLog(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: dbtime.Now(),
OrganizationID: ws1.OrganizationID,
WorkspaceOwnerID: ws1.OwnerID,
WorkspaceID: ws1.ID,
WorkspaceName: ws1.Name,
AgentName: "agent",
Type: database.ConnectionTypeWorkspaceApp,
Ip: ip,
UserAgent: sql.NullString{String: "test", Valid: true},
UserID: uuid.NullUUID{UUID: ws1.OwnerID, Valid: true},
SlugOrPort: sql.NullString{String: "app", Valid: true},
Code: sql.NullInt32{Int32: 200, Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
})
require.NoError(t, err)
sshLog2, err := db.UpsertConnectionLog(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: dbtime.Now(),
OrganizationID: ws2.OrganizationID,
WorkspaceOwnerID: ws2.OwnerID,
WorkspaceID: ws2.ID,
WorkspaceName: ws2.Name,
AgentName: "agent",
Type: database.ConnectionTypeSsh,
Ip: ip,
ConnectionID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
})
require.NoError(t, err)
rowsClosed, err := db.CloseOpenAgentConnectionLogsForWorkspace(ctx, database.CloseOpenAgentConnectionLogsForWorkspaceParams{
WorkspaceID: ws1.ID,
ClosedAt: dbtime.Now(),
Reason: "workspace stopped",
Types: []database.ConnectionType{
database.ConnectionTypeSsh,
database.ConnectionTypeVscode,
database.ConnectionTypeJetbrains,
database.ConnectionTypeReconnectingPty,
},
})
require.NoError(t, err)
require.EqualValues(t, 1, rowsClosed)
ws1Rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{WorkspaceID: ws1.ID})
require.NoError(t, err)
require.Len(t, ws1Rows, 2)
for _, row := range ws1Rows {
switch row.ConnectionLog.ID {
case sshLog1.ID:
updated := row.ConnectionLog
require.True(t, updated.DisconnectTime.Valid)
require.True(t, updated.DisconnectReason.Valid)
require.Equal(t, "workspace stopped", updated.DisconnectReason.String)
require.False(t, updated.DisconnectTime.Time.Before(updated.ConnectTime), "disconnect_time should never be before connect_time")
case appLog.ID:
notClosed := row.ConnectionLog
require.False(t, notClosed.DisconnectTime.Valid)
require.False(t, notClosed.DisconnectReason.Valid)
default:
t.Fatalf("unexpected connection log id: %s", row.ConnectionLog.ID)
}
}
ws2Rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{WorkspaceID: ws2.ID})
require.NoError(t, err)
require.Len(t, ws2Rows, 1)
require.Equal(t, sshLog2.ID, ws2Rows[0].ConnectionLog.ID)
require.False(t, ws2Rows[0].ConnectionLog.DisconnectTime.Valid)
}
// Regression test: CloseConnectionLogsAndCreateSessions must not fail
// when connection_logs have NULL IPs (e.g., disconnect-only tunnel
// events). NULL-IP logs should be closed but no session created for
// them.
func TestCloseConnectionLogsAndCreateSessions_NullIP(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
u := dbgen.User(t, db, database.User{})
o := dbgen.Organization(t, db, database.Organization{})
tpl := dbgen.Template(t, db, database.Template{
OrganizationID: o.ID,
CreatedBy: u.ID,
})
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
OwnerID: u.ID,
OrganizationID: o.ID,
AutomaticUpdates: database.AutomaticUpdatesNever,
TemplateID: tpl.ID,
})
validIP := pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(10, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
}
now := dbtime.Now()
// Connection with a valid IP.
sshLog, err := db.UpsertConnectionLog(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: now.Add(-30 * time.Minute),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: "agent",
Type: database.ConnectionTypeSsh,
Ip: validIP,
ConnectionID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
})
require.NoError(t, err)
// Connection with a NULL IP — simulates a disconnect-only tunnel
// event where the source node info is unavailable.
nullIPLog, err := db.UpsertConnectionLog(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: now.Add(-25 * time.Minute),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: "agent",
Type: database.ConnectionTypeSystem,
Ip: pqtype.Inet{Valid: false},
ConnectionID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
})
require.NoError(t, err)
// This previously failed with: "pq: null value in column ip of
// relation workspace_sessions violates not-null constraint".
closedAt := now.Add(-5 * time.Minute)
_, err = db.CloseConnectionLogsAndCreateSessions(ctx, database.CloseConnectionLogsAndCreateSessionsParams{
ClosedAt: sql.NullTime{Time: closedAt, Valid: true},
Reason: sql.NullString{String: "workspace stopped", Valid: true},
WorkspaceID: ws.ID,
Types: []database.ConnectionType{
database.ConnectionTypeSsh,
database.ConnectionTypeSystem,
},
})
require.NoError(t, err)
// Verify both logs were closed.
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{
WorkspaceID: ws.ID,
})
require.NoError(t, err)
require.Len(t, rows, 2)
for _, row := range rows {
cl := row.ConnectionLog
require.True(t, cl.DisconnectTime.Valid,
"connection log %s (type=%s) should be closed", cl.ID, cl.Type)
switch cl.ID {
case sshLog.ID:
// Valid-IP log should have a session.
require.True(t, cl.SessionID.Valid,
"valid-IP log should be linked to a session")
case nullIPLog.ID:
// NULL-IP system connection overlaps with the SSH
// session, so it gets attached to that session.
require.True(t, cl.SessionID.Valid,
"NULL-IP system log overlapping with SSH session should be linked to a session")
default:
t.Fatalf("unexpected connection log id: %s", cl.ID)
}
}
}
// Regression test: CloseConnectionLogsAndCreateSessions must handle
// connections that are already disconnected but have no session_id
// (e.g., system/tunnel connections disconnected by dbsink). It must
// also avoid creating duplicate sessions when assignSessionForDisconnect
// has already created one for the same IP/time range.
func TestCloseConnectionLogsAndCreateSessions_AlreadyDisconnectedGetsSession(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
u := dbgen.User(t, db, database.User{})
o := dbgen.Organization(t, db, database.Organization{})
tpl := dbgen.Template(t, db, database.Template{
OrganizationID: o.ID,
CreatedBy: u.ID,
})
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
OwnerID: u.ID,
OrganizationID: o.ID,
AutomaticUpdates: database.AutomaticUpdatesNever,
TemplateID: tpl.ID,
})
ip := pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
}
now := dbtime.Now()
// A system connection that was already disconnected (by dbsink)
// but has no session_id — dbsink doesn't assign sessions.
sysConnID := uuid.New()
_, err := db.UpsertConnectionLog(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: now.Add(-10 * time.Minute),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: "agent",
Type: database.ConnectionTypeSystem,
Ip: ip,
ConnectionID: uuid.NullUUID{UUID: sysConnID, Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
})
require.NoError(t, err)
_, err = db.UpsertConnectionLog(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: now.Add(-5 * time.Minute),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: "agent",
Type: database.ConnectionTypeSystem,
Ip: ip,
ConnectionID: uuid.NullUUID{UUID: sysConnID, Valid: true},
ConnectionStatus: database.ConnectionStatusDisconnected,
})
require.NoError(t, err)
// Run CloseConnectionLogsAndCreateSessions (workspace stop).
closedAt := now
_, err = db.CloseConnectionLogsAndCreateSessions(ctx, database.CloseConnectionLogsAndCreateSessionsParams{
ClosedAt: sql.NullTime{Time: closedAt, Valid: true},
Reason: sql.NullString{String: "workspace stopped", Valid: true},
WorkspaceID: ws.ID,
Types: []database.ConnectionType{
database.ConnectionTypeSsh,
database.ConnectionTypeSystem,
},
})
require.NoError(t, err)
// The system connection should now have a session_id.
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{
WorkspaceID: ws.ID,
})
require.NoError(t, err)
require.Len(t, rows, 1)
require.True(t, rows[0].ConnectionLog.SessionID.Valid,
"already-disconnected system connection should be assigned to a session")
}
// Regression test: when assignSessionForDisconnect has already
// created a session for an SSH connection,
// CloseConnectionLogsAndCreateSessions must reuse that session
// instead of creating a duplicate.
func TestCloseConnectionLogsAndCreateSessions_ReusesExistingSession(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
u := dbgen.User(t, db, database.User{})
o := dbgen.Organization(t, db, database.Organization{})
tpl := dbgen.Template(t, db, database.Template{
OrganizationID: o.ID,
CreatedBy: u.ID,
})
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
OwnerID: u.ID,
OrganizationID: o.ID,
AutomaticUpdates: database.AutomaticUpdatesNever,
TemplateID: tpl.ID,
})
ip := pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
}
now := dbtime.Now()
// Simulate an SSH connection where assignSessionForDisconnect
// already created a session but the connection log's session_id
// was set (the normal successful path).
sshConnID := uuid.New()
_, err := db.UpsertConnectionLog(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: now.Add(-10 * time.Minute),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: "agent",
Type: database.ConnectionTypeSsh,
Ip: ip,
ConnectionID: uuid.NullUUID{UUID: sshConnID, Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
})
require.NoError(t, err)
sshLog, err := db.UpsertConnectionLog(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: now.Add(-5 * time.Minute),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: "agent",
Type: database.ConnectionTypeSsh,
Ip: ip,
ConnectionID: uuid.NullUUID{UUID: sshConnID, Valid: true},
ConnectionStatus: database.ConnectionStatusDisconnected,
})
require.NoError(t, err)
// Create the session that assignSessionForDisconnect would have
// created, and link the connection log to it.
existingSessionIDRaw, err := db.FindOrCreateSessionForDisconnect(ctx, database.FindOrCreateSessionForDisconnectParams{
WorkspaceID: ws.ID.String(),
Ip: ip,
ConnectTime: sshLog.ConnectTime,
DisconnectTime: sshLog.DisconnectTime.Time,
})
require.NoError(t, err)
existingSessionID, err := uuid.Parse(fmt.Sprintf("%s", existingSessionIDRaw))
require.NoError(t, err)
err = db.UpdateConnectionLogSessionID(ctx, database.UpdateConnectionLogSessionIDParams{
ID: sshLog.ID,
SessionID: uuid.NullUUID{UUID: existingSessionID, Valid: true},
})
require.NoError(t, err)
// Also add a system connection (no session, already disconnected).
sysConnID := uuid.New()
_, err = db.UpsertConnectionLog(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: now.Add(-10 * time.Minute),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: "agent",
Type: database.ConnectionTypeSystem,
Ip: ip,
ConnectionID: uuid.NullUUID{UUID: sysConnID, Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
})
require.NoError(t, err)
_, err = db.UpsertConnectionLog(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: now.Add(-5 * time.Minute),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: "agent",
Type: database.ConnectionTypeSystem,
Ip: ip,
ConnectionID: uuid.NullUUID{UUID: sysConnID, Valid: true},
ConnectionStatus: database.ConnectionStatusDisconnected,
})
require.NoError(t, err)
// Run CloseConnectionLogsAndCreateSessions.
closedAt := now
_, err = db.CloseConnectionLogsAndCreateSessions(ctx, database.CloseConnectionLogsAndCreateSessionsParams{
ClosedAt: sql.NullTime{Time: closedAt, Valid: true},
Reason: sql.NullString{String: "workspace stopped", Valid: true},
WorkspaceID: ws.ID,
Types: []database.ConnectionType{
database.ConnectionTypeSsh,
database.ConnectionTypeSystem,
},
})
require.NoError(t, err)
// Verify: the system connection should be assigned to the
// EXISTING session (reused), not a new one.
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{
WorkspaceID: ws.ID,
})
require.NoError(t, err)
require.Len(t, rows, 2)
for _, row := range rows {
cl := row.ConnectionLog
require.True(t, cl.SessionID.Valid,
"connection log %s (type=%s) should have a session", cl.ID, cl.Type)
require.Equal(t, existingSessionID, cl.SessionID.UUID,
"connection log %s should reuse the existing session, not create a new one", cl.ID)
}
}
// Test: connections with different IPs but same hostname get grouped
// into one session.
func TestCloseConnectionLogsAndCreateSessions_GroupsByHostname(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
u := dbgen.User(t, db, database.User{})
o := dbgen.Organization(t, db, database.Organization{})
tpl := dbgen.Template(t, db, database.Template{
OrganizationID: o.ID,
CreatedBy: u.ID,
})
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
OwnerID: u.ID,
OrganizationID: o.ID,
AutomaticUpdates: database.AutomaticUpdatesNever,
TemplateID: tpl.ID,
})
now := dbtime.Now()
hostname := sql.NullString{String: "my-laptop", Valid: true}
// Create 3 SSH connections with different IPs but same hostname,
// overlapping in time.
for i := 0; i < 3; i++ {
ip := pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(10, 0, 0, byte(i+1)),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
}
_, err := db.UpsertConnectionLog(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: now.Add(time.Duration(-30+i*5) * time.Minute),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: "agent",
Type: database.ConnectionTypeSsh,
Ip: ip,
ClientHostname: hostname,
ConnectionID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
})
require.NoError(t, err)
}
closedAt := now
_, err := db.CloseConnectionLogsAndCreateSessions(ctx, database.CloseConnectionLogsAndCreateSessionsParams{
ClosedAt: sql.NullTime{Time: closedAt, Valid: true},
Reason: sql.NullString{String: "workspace stopped", Valid: true},
WorkspaceID: ws.ID,
Types: []database.ConnectionType{
database.ConnectionTypeSsh,
},
})
require.NoError(t, err)
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{
WorkspaceID: ws.ID,
})
require.NoError(t, err)
require.Len(t, rows, 3)
// All 3 connections should have the same session_id.
var sessionID uuid.UUID
for i, row := range rows {
cl := row.ConnectionLog
require.True(t, cl.SessionID.Valid,
"connection %d should have a session", i)
if i == 0 {
sessionID = cl.SessionID.UUID
} else {
require.Equal(t, sessionID, cl.SessionID.UUID,
"all connections with same hostname should share one session")
}
}
}
// Test: a long-running system connection gets attached to the first
// overlapping primary session, not the second.
func TestCloseConnectionLogsAndCreateSessions_SystemAttachesToFirstSession(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
u := dbgen.User(t, db, database.User{})
o := dbgen.Organization(t, db, database.Organization{})
tpl := dbgen.Template(t, db, database.Template{
OrganizationID: o.ID,
CreatedBy: u.ID,
})
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
OwnerID: u.ID,
OrganizationID: o.ID,
AutomaticUpdates: database.AutomaticUpdatesNever,
TemplateID: tpl.ID,
})
ip := pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(10, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
}
now := dbtime.Now()
// System connection spanning the full workspace lifetime.
sysLog, err := db.UpsertConnectionLog(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: now.Add(-3 * time.Hour),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: "agent",
Type: database.ConnectionTypeSystem,
Ip: ip,
ConnectionID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
})
require.NoError(t, err)
// SSH session 1: -3h to -2h.
ssh1ConnID := uuid.New()
_, err = db.UpsertConnectionLog(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: now.Add(-3 * time.Hour),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: "agent",
Type: database.ConnectionTypeSsh,
Ip: ip,
ConnectionID: uuid.NullUUID{UUID: ssh1ConnID, Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
})
require.NoError(t, err)
ssh1Disc, err := db.UpsertConnectionLog(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: now.Add(-2 * time.Hour),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: "agent",
Type: database.ConnectionTypeSsh,
Ip: ip,
ConnectionID: uuid.NullUUID{UUID: ssh1ConnID, Valid: true},
ConnectionStatus: database.ConnectionStatusDisconnected,
})
require.NoError(t, err)
_ = ssh1Disc
// SSH session 2: -30min to now (>30min gap from session 1).
_, err = db.UpsertConnectionLog(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: now.Add(-30 * time.Minute),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: "agent",
Type: database.ConnectionTypeSsh,
Ip: ip,
ConnectionID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
})
require.NoError(t, err)
closedAt := now
_, err = db.CloseConnectionLogsAndCreateSessions(ctx, database.CloseConnectionLogsAndCreateSessionsParams{
ClosedAt: sql.NullTime{Time: closedAt, Valid: true},
Reason: sql.NullString{String: "workspace stopped", Valid: true},
WorkspaceID: ws.ID,
Types: []database.ConnectionType{
database.ConnectionTypeSsh,
database.ConnectionTypeSystem,
},
})
require.NoError(t, err)
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{
WorkspaceID: ws.ID,
})
require.NoError(t, err)
// Find the system connection and its assigned session.
var sysSessionID uuid.UUID
// Collect all session IDs from SSH connections to verify 2
// distinct sessions were created.
sshSessionIDs := make(map[uuid.UUID]bool)
for _, row := range rows {
cl := row.ConnectionLog
if cl.ID == sysLog.ID {
require.True(t, cl.SessionID.Valid,
"system connection should have a session")
sysSessionID = cl.SessionID.UUID
}
if cl.Type == database.ConnectionTypeSsh && cl.SessionID.Valid {
sshSessionIDs[cl.SessionID.UUID] = true
}
}
// Two distinct SSH sessions should exist (>30min gap).
require.Len(t, sshSessionIDs, 2, "should have 2 distinct SSH sessions")
// System connection should be attached to the first (earliest)
// session.
require.True(t, sshSessionIDs[sysSessionID],
"system connection should be attached to one of the SSH sessions")
}
// Test: an orphaned system connection (no overlapping primary sessions)
// with an IP gets its own session.
func TestCloseConnectionLogsAndCreateSessions_OrphanSystemGetsOwnSession(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
u := dbgen.User(t, db, database.User{})
o := dbgen.Organization(t, db, database.Organization{})
tpl := dbgen.Template(t, db, database.Template{
OrganizationID: o.ID,
CreatedBy: u.ID,
})
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
OwnerID: u.ID,
OrganizationID: o.ID,
AutomaticUpdates: database.AutomaticUpdatesNever,
TemplateID: tpl.ID,
})
ip := pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(10, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
}
now := dbtime.Now()
// System connection with an IP but no overlapping primary
// connections.
_, err := db.UpsertConnectionLog(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: now.Add(-10 * time.Minute),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: "agent",
Type: database.ConnectionTypeSystem,
Ip: ip,
ConnectionID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
})
require.NoError(t, err)
closedAt := now
_, err = db.CloseConnectionLogsAndCreateSessions(ctx, database.CloseConnectionLogsAndCreateSessionsParams{
ClosedAt: sql.NullTime{Time: closedAt, Valid: true},
Reason: sql.NullString{String: "workspace stopped", Valid: true},
WorkspaceID: ws.ID,
Types: []database.ConnectionType{
database.ConnectionTypeSystem,
},
})
require.NoError(t, err)
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{
WorkspaceID: ws.ID,
})
require.NoError(t, err)
require.Len(t, rows, 1)
require.True(t, rows[0].ConnectionLog.SessionID.Valid,
"orphaned system connection with IP should get its own session")
}
// Test: a system connection with NULL IP and no overlapping primary
// sessions gets no session (can't create a useful session without IP).
func TestCloseConnectionLogsAndCreateSessions_SystemNoIPNoSession(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
u := dbgen.User(t, db, database.User{})
o := dbgen.Organization(t, db, database.Organization{})
tpl := dbgen.Template(t, db, database.Template{
OrganizationID: o.ID,
CreatedBy: u.ID,
})
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
OwnerID: u.ID,
OrganizationID: o.ID,
AutomaticUpdates: database.AutomaticUpdatesNever,
TemplateID: tpl.ID,
})
now := dbtime.Now()
// System connection with NULL IP and no overlapping primary.
_, err := db.UpsertConnectionLog(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: now.Add(-10 * time.Minute),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: "agent",
Type: database.ConnectionTypeSystem,
Ip: pqtype.Inet{Valid: false},
ConnectionID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
})
require.NoError(t, err)
closedAt := now
_, err = db.CloseConnectionLogsAndCreateSessions(ctx, database.CloseConnectionLogsAndCreateSessionsParams{
ClosedAt: sql.NullTime{Time: closedAt, Valid: true},
Reason: sql.NullString{String: "workspace stopped", Valid: true},
WorkspaceID: ws.ID,
Types: []database.ConnectionType{
database.ConnectionTypeSystem,
},
})
require.NoError(t, err)
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{
WorkspaceID: ws.ID,
})
require.NoError(t, err)
require.Len(t, rows, 1)
require.True(t, rows[0].ConnectionLog.DisconnectTime.Valid,
"system connection should be closed")
require.False(t, rows[0].ConnectionLog.SessionID.Valid,
"NULL-IP system connection with no primary overlap should not get a session")
}
// Test: connections from the same hostname with a >30-minute gap
// create separate sessions.
func TestCloseConnectionLogsAndCreateSessions_SeparateSessionsForLargeGap(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
u := dbgen.User(t, db, database.User{})
o := dbgen.Organization(t, db, database.Organization{})
tpl := dbgen.Template(t, db, database.Template{
OrganizationID: o.ID,
CreatedBy: u.ID,
})
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
OwnerID: u.ID,
OrganizationID: o.ID,
AutomaticUpdates: database.AutomaticUpdatesNever,
TemplateID: tpl.ID,
})
ip := pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(10, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
}
now := dbtime.Now()
// SSH connection 1: -3h to -2h.
conn1ID := uuid.New()
_, err := db.UpsertConnectionLog(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: now.Add(-3 * time.Hour),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: "agent",
Type: database.ConnectionTypeSsh,
Ip: ip,
ConnectionID: uuid.NullUUID{UUID: conn1ID, Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
})
require.NoError(t, err)
_, err = db.UpsertConnectionLog(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: now.Add(-2 * time.Hour),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: "agent",
Type: database.ConnectionTypeSsh,
Ip: ip,
ConnectionID: uuid.NullUUID{UUID: conn1ID, Valid: true},
ConnectionStatus: database.ConnectionStatusDisconnected,
})
require.NoError(t, err)
// SSH connection 2: -30min to now (>30min gap from connection 1).
_, err = db.UpsertConnectionLog(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: now.Add(-30 * time.Minute),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: "agent",
Type: database.ConnectionTypeSsh,
Ip: ip,
ConnectionID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
})
require.NoError(t, err)
closedAt := now
_, err = db.CloseConnectionLogsAndCreateSessions(ctx, database.CloseConnectionLogsAndCreateSessionsParams{
ClosedAt: sql.NullTime{Time: closedAt, Valid: true},
Reason: sql.NullString{String: "workspace stopped", Valid: true},
WorkspaceID: ws.ID,
Types: []database.ConnectionType{
database.ConnectionTypeSsh,
},
})
require.NoError(t, err)
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{
WorkspaceID: ws.ID,
})
require.NoError(t, err)
sessionIDs := make(map[uuid.UUID]bool)
for _, row := range rows {
cl := row.ConnectionLog
if cl.SessionID.Valid {
sessionIDs[cl.SessionID.UUID] = true
}
}
require.Len(t, sessionIDs, 2,
"connections with >30min gap should create 2 separate sessions")
}
@@ -1,239 +0,0 @@
package database_test
import (
"context"
"database/sql"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime"
)
func TestGetOngoingAgentConnectionsLast24h(t *testing.T) {
t.Parallel()
ctx := context.Background()
db, _ := dbtestutil.NewDB(t)
org := dbfake.Organization(t, db).Do()
user := dbgen.User(t, db, database.User{})
tpl := dbgen.Template(t, db, database.Template{OrganizationID: org.Org.ID, CreatedBy: user.ID})
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
OrganizationID: org.Org.ID,
OwnerID: user.ID,
TemplateID: tpl.ID,
Name: "ws",
})
now := dbtime.Now()
since := now.Add(-24 * time.Hour)
const (
agent1 = "agent1"
agent2 = "agent2"
)
// Insert a disconnected log that should be excluded.
disconnectedConnID := uuid.New()
disconnected := dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
Time: now.Add(-30 * time.Minute),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: agent1,
Type: database.ConnectionTypeSsh,
ConnectionStatus: database.ConnectionStatusConnected,
ConnectionID: uuid.NullUUID{UUID: disconnectedConnID, Valid: true},
})
_ = dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
Time: now.Add(-20 * time.Minute),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
AgentName: disconnected.AgentName,
ConnectionStatus: database.ConnectionStatusDisconnected,
ConnectionID: disconnected.ConnectionID,
DisconnectReason: sql.NullString{String: "closed", Valid: true},
})
// Insert an old log that should be excluded by the 24h window.
_ = dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
Time: now.Add(-25 * time.Hour),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: agent1,
Type: database.ConnectionTypeSsh,
ConnectionStatus: database.ConnectionStatusConnected,
ConnectionID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
})
// Insert a web log that should be excluded by the types filter.
_ = dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
Time: now.Add(-10 * time.Minute),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: agent1,
Type: database.ConnectionTypeWorkspaceApp,
ConnectionStatus: database.ConnectionStatusConnected,
ConnectionID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
})
// Insert 55 active logs for agent1 (should be capped to 50).
for i := 0; i < 55; i++ {
_ = dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
Time: now.Add(-time.Duration(i) * time.Minute),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: agent1,
Type: database.ConnectionTypeVscode,
ConnectionStatus: database.ConnectionStatusConnected,
ConnectionID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
})
}
// Insert one active log for agent2.
agent2Log := dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
Time: now.Add(-5 * time.Minute),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: agent2,
Type: database.ConnectionTypeJetbrains,
ConnectionStatus: database.ConnectionStatusConnected,
ConnectionID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
})
logs, err := db.GetOngoingAgentConnectionsLast24h(ctx, database.GetOngoingAgentConnectionsLast24hParams{
WorkspaceIds: []uuid.UUID{ws.ID},
AgentNames: []string{agent1, agent2},
Types: []database.ConnectionType{database.ConnectionTypeSsh, database.ConnectionTypeVscode, database.ConnectionTypeJetbrains, database.ConnectionTypeReconnectingPty},
Since: since,
PerAgentLimit: 50,
})
require.NoError(t, err)
byAgent := map[string][]database.GetOngoingAgentConnectionsLast24hRow{}
for _, l := range logs {
byAgent[l.AgentName] = append(byAgent[l.AgentName], l)
}
// Agent1 should be capped at 50 and contain only active logs within the window.
require.Len(t, byAgent[agent1], 50)
for i, l := range byAgent[agent1] {
require.False(t, l.DisconnectTime.Valid, "expected log to be ongoing")
require.True(t, l.ConnectTime.After(since) || l.ConnectTime.Equal(since), "expected log to be within window")
if i > 0 {
require.True(t, byAgent[agent1][i-1].ConnectTime.After(l.ConnectTime) || byAgent[agent1][i-1].ConnectTime.Equal(l.ConnectTime), "expected logs to be ordered by connect_time desc")
}
}
// Agent2 should include its single active log.
require.Equal(t, []uuid.UUID{agent2Log.ID}, []uuid.UUID{byAgent[agent2][0].ID})
}
func TestGetOngoingAgentConnectionsLast24h_PortForwarding(t *testing.T) {
t.Parallel()
ctx := context.Background()
db, _ := dbtestutil.NewDB(t)
org := dbfake.Organization(t, db).Do()
user := dbgen.User(t, db, database.User{})
tpl := dbgen.Template(t, db, database.Template{OrganizationID: org.Org.ID, CreatedBy: user.ID})
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
OrganizationID: org.Org.ID,
OwnerID: user.ID,
TemplateID: tpl.ID,
Name: "ws-pf",
})
now := dbtime.Now()
since := now.Add(-24 * time.Hour)
const agentName = "agent-pf"
// Agent-reported: NULL user_agent, included unconditionally.
agentReported := dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
Time: now.Add(-10 * time.Minute),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: agentName,
Type: database.ConnectionTypePortForwarding,
ConnectionStatus: database.ConnectionStatusConnected,
ConnectionID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
SlugOrPort: sql.NullString{String: "8080", Valid: true},
Ip: database.ParseIP("fd7a:115c:a1e0:4353:89d9:4ca8:9c42:8d2d"),
})
// Stale proxy-reported: non-NULL user_agent, bumped but older than AppActiveSince.
// Use a non-localhost IP to verify the fix works even behind a reverse proxy.
staleConnID := uuid.New()
staleConnectTime := now.Add(-15 * time.Minute)
_ = dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
Time: staleConnectTime,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: agentName,
Type: database.ConnectionTypePortForwarding,
ConnectionStatus: database.ConnectionStatusConnected,
ConnectionID: uuid.NullUUID{UUID: staleConnID, Valid: true},
SlugOrPort: sql.NullString{String: "3000", Valid: true},
Ip: database.ParseIP("203.0.113.45"),
UserAgent: sql.NullString{String: "Mozilla/5.0", Valid: true},
})
// Bump updated_at to simulate a proxy refresh.
staleBumpTime := now.Add(-8 * time.Minute)
_, err := db.UpsertConnectionLog(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: staleBumpTime,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: agentName,
Type: database.ConnectionTypePortForwarding,
ConnectionStatus: database.ConnectionStatusConnected,
ConnectionID: uuid.NullUUID{UUID: staleConnID, Valid: true},
SlugOrPort: sql.NullString{String: "3000", Valid: true},
})
require.NoError(t, err)
appActiveSince := now.Add(-5 * time.Minute)
logs, err := db.GetOngoingAgentConnectionsLast24h(ctx, database.GetOngoingAgentConnectionsLast24hParams{
WorkspaceIds: []uuid.UUID{ws.ID},
AgentNames: []string{agentName},
Types: []database.ConnectionType{database.ConnectionTypePortForwarding},
Since: since,
PerAgentLimit: 50,
AppActiveSince: appActiveSince,
})
require.NoError(t, err)
// Only the agent-reported connection should appear.
require.Len(t, logs, 1)
require.Equal(t, agentReported.ID, logs[0].ID)
require.Equal(t, database.ConnectionTypePortForwarding, logs[0].Type)
require.True(t, logs[0].SlugOrPort.Valid)
require.Equal(t, "8080", logs[0].SlugOrPort.String)
}
-9
View File
@@ -3,12 +3,3 @@ package database
import "github.com/google/uuid"
var PrebuildsSystemUserID = uuid.MustParse("c42fdf75-3097-471c-8c33-fb52454d81c0")
const (
TailnetPeeringEventTypeAddedTunnel = "added_tunnel"
TailnetPeeringEventTypeRemovedTunnel = "removed_tunnel"
TailnetPeeringEventTypePeerUpdateNode = "peer_update_node"
TailnetPeeringEventTypePeerUpdateDisconnected = "peer_update_disconnected"
TailnetPeeringEventTypePeerUpdateLost = "peer_update_lost"
TailnetPeeringEventTypePeerUpdateReadyForHandshake = "peer_update_ready_for_handshake"
)
-4
View File
@@ -849,10 +849,6 @@ func ConnectionLogConnectionTypeFromAgentProtoConnectionType(typ agentproto.Conn
return database.ConnectionTypeVscode, nil
case agentproto.Connection_RECONNECTING_PTY:
return database.ConnectionTypeReconnectingPty, nil
case agentproto.Connection_WORKSPACE_APP:
return database.ConnectionTypeWorkspaceApp, nil
case agentproto.Connection_PORT_FORWARDING:
return database.ConnectionTypePortForwarding, nil
default:
// Also Connection_TYPE_UNSPECIFIED, no mapping.
return "", xerrors.Errorf("unknown agent connection type %q", typ)
+23 -141
View File
@@ -461,24 +461,6 @@ var (
Scope: rbac.ScopeAll,
}.WithCachedASTValue()
subjectTailnetCoordinator = rbac.Subject{
Type: rbac.SubjectTypeTailnetCoordinator,
FriendlyName: "Tailnet Coordinator",
ID: uuid.Nil.String(),
Roles: rbac.Roles([]rbac.Role{
{
Identifier: rbac.RoleIdentifier{Name: "tailnetcoordinator"},
DisplayName: "Tailnet Coordinator",
Site: rbac.Permissions(map[string][]policy.Action{
rbac.ResourceTailnetCoordinator.Type: {policy.WildcardSymbol},
}),
User: []rbac.Permission{},
ByOrgID: map[string]rbac.OrgPermissions{},
},
}),
Scope: rbac.ScopeAll,
}.WithCachedASTValue()
subjectSystemOAuth2 = rbac.Subject{
Type: rbac.SubjectTypeSystemOAuth,
FriendlyName: "System OAuth2",
@@ -744,12 +726,6 @@ func AsSystemRestricted(ctx context.Context) context.Context {
return As(ctx, subjectSystemRestricted)
}
// AsTailnetCoordinator returns a context with an actor that has permissions
// required for tailnet coordinator operations.
func AsTailnetCoordinator(ctx context.Context) context.Context {
return As(ctx, subjectTailnetCoordinator)
}
// AsSystemOAuth2 returns a context with an actor that has permissions
// required for OAuth2 provider operations (token revocation, device codes, registration).
func AsSystemOAuth2(ctx context.Context) context.Context {
@@ -1612,20 +1588,6 @@ func (q *querier) CleanTailnetTunnels(ctx context.Context) error {
return q.db.CleanTailnetTunnels(ctx)
}
func (q *querier) CloseConnectionLogsAndCreateSessions(ctx context.Context, arg database.CloseConnectionLogsAndCreateSessionsParams) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
return 0, err
}
return q.db.CloseConnectionLogsAndCreateSessions(ctx, arg)
}
func (q *querier) CloseOpenAgentConnectionLogsForWorkspace(ctx context.Context, arg database.CloseOpenAgentConnectionLogsForWorkspaceParams) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
return 0, err
}
return q.db.CloseOpenAgentConnectionLogsForWorkspace(ctx, arg)
}
func (q *querier) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
if err != nil {
@@ -1661,13 +1623,6 @@ func (q *querier) CountConnectionLogs(ctx context.Context, arg database.CountCon
return q.db.CountAuthorizedConnectionLogs(ctx, arg, prep)
}
func (q *querier) CountGlobalWorkspaceSessions(ctx context.Context, arg database.CountGlobalWorkspaceSessionsParams) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog); err != nil {
return 0, err
}
return q.db.CountGlobalWorkspaceSessions(ctx, arg)
}
func (q *querier) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceWorkspace.All()); err != nil {
return nil, err
@@ -1689,13 +1644,6 @@ func (q *querier) CountUnreadInboxNotificationsByUserID(ctx context.Context, use
return q.db.CountUnreadInboxNotificationsByUserID(ctx, userID)
}
func (q *querier) CountWorkspaceSessions(ctx context.Context, arg database.CountWorkspaceSessionsParams) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog); err != nil {
return 0, err
}
return q.db.CountWorkspaceSessions(ctx, arg)
}
func (q *querier) CreateUserSecret(ctx context.Context, arg database.CreateUserSecretParams) (database.UserSecret, error) {
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
if err := q.authorizeContext(ctx, policy.ActionCreate, obj); err != nil {
@@ -1755,6 +1703,13 @@ func (q *querier) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, u
return q.db.DeleteApplicationConnectAPIKeysByUserID(ctx, userID)
}
func (q *querier) DeleteBoundaryUsageStatsByReplicaID(ctx context.Context, replicaID uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceBoundaryUsage); err != nil {
return err
}
return q.db.DeleteBoundaryUsageStatsByReplicaID(ctx, replicaID)
}
func (q *querier) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceCryptoKey); err != nil {
return database.CryptoKey{}, err
@@ -1977,14 +1932,14 @@ func (q *querier) DeleteTailnetTunnel(ctx context.Context, arg database.DeleteTa
return q.db.DeleteTailnetTunnel(ctx, arg)
}
func (q *querier) DeleteTask(ctx context.Context, arg database.DeleteTaskParams) (uuid.UUID, error) {
func (q *querier) DeleteTask(ctx context.Context, arg database.DeleteTaskParams) (database.TaskTable, error) {
task, err := q.db.GetTaskByID(ctx, arg.ID)
if err != nil {
return uuid.UUID{}, err
return database.TaskTable{}, err
}
if err := q.authorizeContext(ctx, policy.ActionDelete, task.RBACObject()); err != nil {
return uuid.UUID{}, err
return database.TaskTable{}, err
}
return q.db.DeleteTask(ctx, arg)
@@ -2170,13 +2125,6 @@ func (q *querier) FindMatchingPresetID(ctx context.Context, arg database.FindMat
return q.db.FindMatchingPresetID(ctx, arg)
}
func (q *querier) FindOrCreateSessionForDisconnect(ctx context.Context, arg database.FindOrCreateSessionForDisconnectParams) (interface{}, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
return nil, err
}
return q.db.FindOrCreateSessionForDisconnect(ctx, arg)
}
func (q *querier) GetAIBridgeInterceptionByID(ctx context.Context, id uuid.UUID) (database.AIBridgeInterception, error) {
return fetch(q.log, q.auth, q.db.GetAIBridgeInterceptionByID)(ctx, id)
}
@@ -2261,13 +2209,6 @@ func (q *querier) GetAllTailnetCoordinators(ctx context.Context) ([]database.Tai
return q.db.GetAllTailnetCoordinators(ctx)
}
func (q *querier) GetAllTailnetPeeringEventsByPeerID(ctx context.Context, srcPeerID uuid.NullUUID) ([]database.TailnetPeeringEvent, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
return nil, err
}
return q.db.GetAllTailnetPeeringEventsByPeerID(ctx, srcPeerID)
}
func (q *querier) GetAllTailnetPeers(ctx context.Context) ([]database.TailnetPeer, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
return nil, err
@@ -2282,13 +2223,6 @@ func (q *querier) GetAllTailnetTunnels(ctx context.Context) ([]database.TailnetT
return q.db.GetAllTailnetTunnels(ctx)
}
func (q *querier) GetAndResetBoundaryUsageSummary(ctx context.Context, maxStalenessMs int64) (database.GetAndResetBoundaryUsageSummaryRow, error) {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceBoundaryUsage); err != nil {
return database.GetAndResetBoundaryUsageSummaryRow{}, err
}
return q.db.GetAndResetBoundaryUsageSummary(ctx, maxStalenessMs)
}
func (q *querier) GetAnnouncementBanners(ctx context.Context) (string, error) {
// No authz checks
return q.db.GetAnnouncementBanners(ctx)
@@ -2337,18 +2271,11 @@ func (q *querier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUI
return q.db.GetAuthorizationUserRoles(ctx, userID)
}
func (q *querier) GetConnectionLogByConnectionID(ctx context.Context, arg database.GetConnectionLogByConnectionIDParams) (database.ConnectionLog, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog); err != nil {
return database.ConnectionLog{}, err
func (q *querier) GetBoundaryUsageSummary(ctx context.Context, maxStalenessMs int64) (database.GetBoundaryUsageSummaryRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceBoundaryUsage); err != nil {
return database.GetBoundaryUsageSummaryRow{}, err
}
return q.db.GetConnectionLogByConnectionID(ctx, arg)
}
func (q *querier) GetConnectionLogsBySessionIDs(ctx context.Context, sessionIDs []uuid.UUID) ([]database.ConnectionLog, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog); err != nil {
return nil, err
}
return q.db.GetConnectionLogsBySessionIDs(ctx, sessionIDs)
return q.db.GetBoundaryUsageSummary(ctx, maxStalenessMs)
}
func (q *querier) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
@@ -2526,13 +2453,6 @@ func (q *querier) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.
return fetchWithAction(q.log, q.auth, policy.ActionReadPersonal, q.db.GetGitSSHKey)(ctx, userID)
}
func (q *querier) GetGlobalWorkspaceSessionsOffset(ctx context.Context, arg database.GetGlobalWorkspaceSessionsOffsetParams) ([]database.GetGlobalWorkspaceSessionsOffsetRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog); err != nil {
return nil, err
}
return q.db.GetGlobalWorkspaceSessionsOffset(ctx, arg)
}
func (q *querier) GetGroupByID(ctx context.Context, id uuid.UUID) (database.Group, error) {
return fetch(q.log, q.auth, q.db.GetGroupByID)(ctx, id)
}
@@ -2799,15 +2719,6 @@ func (q *querier) GetOAuthSigningKey(ctx context.Context) (string, error) {
return q.db.GetOAuthSigningKey(ctx)
}
func (q *querier) GetOngoingAgentConnectionsLast24h(ctx context.Context, arg database.GetOngoingAgentConnectionsLast24hParams) ([]database.GetOngoingAgentConnectionsLast24hRow, error) {
// This is a system-level read; authorization comes from the
// caller using dbauthz.AsSystemRestricted(ctx).
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.GetOngoingAgentConnectionsLast24h(ctx, arg)
}
func (q *querier) GetOrganizationByID(ctx context.Context, id uuid.UUID) (database.Organization, error) {
return fetch(q.log, q.auth, q.db.GetOrganizationByID)(ctx, id)
}
@@ -3177,13 +3088,6 @@ func (q *querier) GetTailnetTunnelPeerBindings(ctx context.Context, srcID uuid.U
return q.db.GetTailnetTunnelPeerBindings(ctx, srcID)
}
func (q *querier) GetTailnetTunnelPeerBindingsByDstID(ctx context.Context, dstID uuid.UUID) ([]database.GetTailnetTunnelPeerBindingsByDstIDRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
return nil, err
}
return q.db.GetTailnetTunnelPeerBindingsByDstID(ctx, dstID)
}
func (q *querier) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) ([]database.GetTailnetTunnelPeerIDsRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
return nil, err
@@ -3989,14 +3893,6 @@ func (q *querier) GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx context.Conte
return q.db.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, arg)
}
func (q *querier) GetWorkspaceBuildMetricsByResourceID(ctx context.Context, id uuid.UUID) (database.GetWorkspaceBuildMetricsByResourceIDRow, error) {
// Verify access to the resource first.
if _, err := q.GetWorkspaceResourceByID(ctx, id); err != nil {
return database.GetWorkspaceBuildMetricsByResourceIDRow{}, err
}
return q.db.GetWorkspaceBuildMetricsByResourceID(ctx, id)
}
func (q *querier) GetWorkspaceBuildParameters(ctx context.Context, workspaceBuildID uuid.UUID) ([]database.WorkspaceBuildParameter, error) {
// Authorized call to get the workspace build. If we can read the build,
// we can read the params.
@@ -4189,13 +4085,6 @@ func (q *querier) GetWorkspaceResourcesCreatedAfter(ctx context.Context, created
return q.db.GetWorkspaceResourcesCreatedAfter(ctx, createdAt)
}
func (q *querier) GetWorkspaceSessionsOffset(ctx context.Context, arg database.GetWorkspaceSessionsOffsetParams) ([]database.GetWorkspaceSessionsOffsetRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog); err != nil {
return nil, err
}
return q.db.GetWorkspaceSessionsOffset(ctx, arg)
}
func (q *querier) GetWorkspaceUniqueOwnerCountByTemplateIDs(ctx context.Context, templateIDs []uuid.UUID) ([]database.GetWorkspaceUniqueOwnerCountByTemplateIDsRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
@@ -4510,13 +4399,6 @@ func (q *querier) InsertReplica(ctx context.Context, arg database.InsertReplicaP
return q.db.InsertReplica(ctx, arg)
}
func (q *querier) InsertTailnetPeeringEvent(ctx context.Context, arg database.InsertTailnetPeeringEventParams) error {
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceTailnetCoordinator); err != nil {
return err
}
return q.db.InsertTailnetPeeringEvent(ctx, arg)
}
func (q *querier) InsertTask(ctx context.Context, arg database.InsertTaskParams) (database.TaskTable, error) {
// Ensure the actor can access the specified template version (and thus its template).
if _, err := q.GetTemplateVersionByID(ctx, arg.TemplateVersionID); err != nil {
@@ -5009,6 +4891,13 @@ func (q *querier) RemoveUserFromGroups(ctx context.Context, arg database.RemoveU
return q.db.RemoveUserFromGroups(ctx, arg)
}
func (q *querier) ResetBoundaryUsageStats(ctx context.Context) error {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceBoundaryUsage); err != nil {
return err
}
return q.db.ResetBoundaryUsageStats(ctx)
}
func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
@@ -5065,13 +4954,6 @@ func (q *querier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKe
return update(q.log, q.auth, fetch, q.db.UpdateAPIKeyByID)(ctx, arg)
}
func (q *querier) UpdateConnectionLogSessionID(ctx context.Context, arg database.UpdateConnectionLogSessionIDParams) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
return err
}
return q.db.UpdateConnectionLogSessionID(ctx, arg)
}
func (q *querier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceCryptoKey); err != nil {
return database.CryptoKey{}, err
@@ -6326,9 +6208,9 @@ func (q *querier) UpsertWorkspaceApp(ctx context.Context, arg database.UpsertWor
return q.db.UpsertWorkspaceApp(ctx, arg)
}
func (q *querier) UpsertWorkspaceAppAuditSession(ctx context.Context, arg database.UpsertWorkspaceAppAuditSessionParams) (database.UpsertWorkspaceAppAuditSessionRow, error) {
func (q *querier) UpsertWorkspaceAppAuditSession(ctx context.Context, arg database.UpsertWorkspaceAppAuditSessionParams) (bool, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return database.UpsertWorkspaceAppAuditSessionRow{}, err
return false, err
}
return q.db.UpsertWorkspaceAppAuditSession(ctx, arg)
}
+15 -27
View File
@@ -277,6 +277,11 @@ func (s *MethodTestSuite) TestAPIKey() {
dbm.EXPECT().DeleteApplicationConnectAPIKeysByUserID(gomock.Any(), a.UserID).Return(nil).AnyTimes()
check.Args(a.UserID).Asserts(rbac.ResourceApiKey.WithOwner(a.UserID.String()), policy.ActionDelete).Returns()
}))
s.Run("DeleteBoundaryUsageStatsByReplicaID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
replicaID := uuid.New()
dbm.EXPECT().DeleteBoundaryUsageStatsByReplicaID(gomock.Any(), replicaID).Return(nil).AnyTimes()
check.Args(replicaID).Asserts(rbac.ResourceBoundaryUsage, policy.ActionDelete)
}))
s.Run("DeleteExternalAuthLink", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
a := testutil.Fake(s.T(), faker, database.ExternalAuthLink{})
dbm.EXPECT().GetExternalAuthLink(gomock.Any(), database.GetExternalAuthLinkParams{ProviderID: a.ProviderID, UserID: a.UserID}).Return(a, nil).AnyTimes()
@@ -362,11 +367,6 @@ func (s *MethodTestSuite) TestConnectionLogs() {
dbm.EXPECT().DeleteOldConnectionLogs(gomock.Any(), database.DeleteOldConnectionLogsParams{}).Return(int64(0), nil).AnyTimes()
check.Args(database.DeleteOldConnectionLogsParams{}).Asserts(rbac.ResourceSystem, policy.ActionDelete)
}))
s.Run("CloseOpenAgentConnectionLogsForWorkspace", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.CloseOpenAgentConnectionLogsForWorkspaceParams{}
dbm.EXPECT().CloseOpenAgentConnectionLogsForWorkspace(gomock.Any(), arg).Return(int64(0), nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceConnectionLog, policy.ActionUpdate)
}))
}
func (s *MethodTestSuite) TestFile() {
@@ -532,9 +532,9 @@ func (s *MethodTestSuite) TestGroup() {
dbm.EXPECT().RemoveUserFromGroups(gomock.Any(), arg).Return(slice.New(g1.ID, g2.ID), nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns(slice.New(g1.ID, g2.ID))
}))
s.Run("GetAndResetBoundaryUsageSummary", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetAndResetBoundaryUsageSummary(gomock.Any(), int64(1000)).Return(database.GetAndResetBoundaryUsageSummaryRow{}, nil).AnyTimes()
check.Args(int64(1000)).Asserts(rbac.ResourceBoundaryUsage, policy.ActionDelete)
s.Run("ResetBoundaryUsageStats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().ResetBoundaryUsageStats(gomock.Any()).Return(nil).AnyTimes()
check.Args().Asserts(rbac.ResourceBoundaryUsage, policy.ActionDelete)
}))
s.Run("UpdateGroupByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
@@ -2041,18 +2041,6 @@ func (s *MethodTestSuite) TestWorkspace() {
dbm.EXPECT().GetWorkspaceByID(gomock.Any(), build.WorkspaceID).Return(ws, nil).AnyTimes()
check.Args(res.ID).Asserts(ws, policy.ActionRead).Returns(res)
}))
s.Run("GetWorkspaceBuildMetricsByResourceID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
ws := testutil.Fake(s.T(), faker, database.Workspace{})
build := testutil.Fake(s.T(), faker, database.WorkspaceBuild{WorkspaceID: ws.ID})
job := testutil.Fake(s.T(), faker, database.ProvisionerJob{ID: build.JobID, Type: database.ProvisionerJobTypeWorkspaceBuild})
res := testutil.Fake(s.T(), faker, database.WorkspaceResource{JobID: build.JobID})
dbm.EXPECT().GetWorkspaceResourceByID(gomock.Any(), res.ID).Return(res, nil).AnyTimes()
dbm.EXPECT().GetProvisionerJobByID(gomock.Any(), res.JobID).Return(job, nil).AnyTimes()
dbm.EXPECT().GetWorkspaceBuildByJobID(gomock.Any(), res.JobID).Return(build, nil).AnyTimes()
dbm.EXPECT().GetWorkspaceByID(gomock.Any(), build.WorkspaceID).Return(ws, nil).AnyTimes()
dbm.EXPECT().GetWorkspaceBuildMetricsByResourceID(gomock.Any(), res.ID).Return(database.GetWorkspaceBuildMetricsByResourceIDRow{}, nil).AnyTimes()
check.Args(res.ID).Asserts(ws, policy.ActionRead).Returns(database.GetWorkspaceBuildMetricsByResourceIDRow{})
}))
s.Run("Build/GetWorkspaceResourcesByJobID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
ws := testutil.Fake(s.T(), faker, database.Workspace{})
build := testutil.Fake(s.T(), faker, database.WorkspaceBuild{WorkspaceID: ws.ID})
@@ -2529,8 +2517,8 @@ func (s *MethodTestSuite) TestTasks() {
DeletedAt: dbtime.Now(),
}
dbm.EXPECT().GetTaskByID(gomock.Any(), task.ID).Return(task, nil).AnyTimes()
dbm.EXPECT().DeleteTask(gomock.Any(), arg).Return(task.ID, nil).AnyTimes()
check.Args(arg).Asserts(task, policy.ActionDelete).Returns(task.ID)
dbm.EXPECT().DeleteTask(gomock.Any(), arg).Return(database.TaskTable{}, nil).AnyTimes()
check.Args(arg).Asserts(task, policy.ActionDelete).Returns(database.TaskTable{})
}))
s.Run("InsertTask", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
tpl := testutil.Fake(s.T(), faker, database.Template{})
@@ -2846,10 +2834,6 @@ func (s *MethodTestSuite) TestTailnetFunctions() {
check.Args(uuid.New()).
Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead)
}))
s.Run("GetTailnetTunnelPeerBindingsByDstID", s.Subtest(func(_ database.Store, check *expects) {
check.Args(uuid.New()).
Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead)
}))
s.Run("GetTailnetTunnelPeerIDs", s.Subtest(func(_ database.Store, check *expects) {
check.Args(uuid.New()).
Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead)
@@ -3007,6 +2991,10 @@ func (s *MethodTestSuite) TestSystemFunctions() {
dbm.EXPECT().GetAuthorizationUserRoles(gomock.Any(), u.ID).Return(database.GetAuthorizationUserRolesRow{}, nil).AnyTimes()
check.Args(u.ID).Asserts(rbac.ResourceSystem, policy.ActionRead)
}))
s.Run("GetBoundaryUsageSummary", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetBoundaryUsageSummary(gomock.Any(), int64(1000)).Return(database.GetBoundaryUsageSummaryRow{}, nil).AnyTimes()
check.Args(int64(1000)).Asserts(rbac.ResourceBoundaryUsage, policy.ActionRead)
}))
s.Run("GetDERPMeshKey", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetDERPMeshKey(gomock.Any()).Return("testing", nil).AnyTimes()
check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead)
@@ -3318,7 +3306,7 @@ func (s *MethodTestSuite) TestSystemFunctions() {
agent := testutil.Fake(s.T(), faker, database.WorkspaceAgent{})
app := testutil.Fake(s.T(), faker, database.WorkspaceApp{})
arg := database.UpsertWorkspaceAppAuditSessionParams{AgentID: agent.ID, AppID: app.ID, UserID: u.ID, Ip: "127.0.0.1"}
dbm.EXPECT().UpsertWorkspaceAppAuditSession(gomock.Any(), arg).Return(database.UpsertWorkspaceAppAuditSessionRow{NewOrStale: true}, nil).AnyTimes()
dbm.EXPECT().UpsertWorkspaceAppAuditSession(gomock.Any(), arg).Return(true, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionUpdate)
}))
s.Run("InsertWorkspaceAgentScriptTimings", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
+6 -19
View File
@@ -35,25 +35,12 @@ import (
var errMatchAny = xerrors.New("match any error")
var skipMethods = map[string]string{
"InTx": "Not relevant",
"Ping": "Not relevant",
"PGLocks": "Not relevant",
"Wrappers": "Not relevant",
"AcquireLock": "Not relevant",
"TryAcquireLock": "Not relevant",
"GetOngoingAgentConnectionsLast24h": "Hackathon",
"InsertTailnetPeeringEvent": "Hackathon",
"CloseConnectionLogsAndCreateSessions": "Hackathon",
"CountGlobalWorkspaceSessions": "Hackathon",
"CountWorkspaceSessions": "Hackathon",
"FindOrCreateSessionForDisconnect": "Hackathon",
"GetConnectionLogByConnectionID": "Hackathon",
"GetConnectionLogsBySessionIDs": "Hackathon",
"GetGlobalWorkspaceSessionsOffset": "Hackathon",
"GetWorkspaceSessionsOffset": "Hackathon",
"UpdateConnectionLogSessionID": "Hackathon",
"GetAllTailnetPeeringEventsByPeerID": "Hackathon",
"InTx": "Not relevant",
"Ping": "Not relevant",
"PGLocks": "Not relevant",
"Wrappers": "Not relevant",
"AcquireLock": "Not relevant",
"TryAcquireLock": "Not relevant",
}
// TestMethodTestSuite runs MethodTestSuite.
+11 -140
View File
@@ -58,61 +58,6 @@ type WorkspaceBuildBuilder struct {
jobStatus database.ProvisionerJobStatus
taskAppID uuid.UUID
taskSeed database.TaskTable
// Individual timestamp fields for job customization.
jobCreatedAt time.Time
jobStartedAt time.Time
jobUpdatedAt time.Time
jobCompletedAt time.Time
jobError string // Error message for failed jobs
jobErrorCode string // Error code for failed jobs
}
// BuilderOption is a functional option for customizing job timestamps
// on status methods.
type BuilderOption func(*WorkspaceBuildBuilder)
// WithJobCreatedAt sets the CreatedAt timestamp for the provisioner job.
func WithJobCreatedAt(t time.Time) BuilderOption {
return func(b *WorkspaceBuildBuilder) {
b.jobCreatedAt = t
}
}
// WithJobStartedAt sets the StartedAt timestamp for the provisioner job.
func WithJobStartedAt(t time.Time) BuilderOption {
return func(b *WorkspaceBuildBuilder) {
b.jobStartedAt = t
}
}
// WithJobUpdatedAt sets the UpdatedAt timestamp for the provisioner job.
func WithJobUpdatedAt(t time.Time) BuilderOption {
return func(b *WorkspaceBuildBuilder) {
b.jobUpdatedAt = t
}
}
// WithJobCompletedAt sets the CompletedAt timestamp for the provisioner job.
func WithJobCompletedAt(t time.Time) BuilderOption {
return func(b *WorkspaceBuildBuilder) {
b.jobCompletedAt = t
}
}
// WithJobError sets the error message for the provisioner job.
func WithJobError(msg string) BuilderOption {
return func(b *WorkspaceBuildBuilder) {
b.jobError = msg
}
}
// WithJobErrorCode sets the error code for the provisioner job.
func WithJobErrorCode(code string) BuilderOption {
return func(b *WorkspaceBuildBuilder) {
b.jobErrorCode = code
}
}
// WorkspaceBuild generates a workspace build for the provided workspace.
@@ -196,59 +141,18 @@ func (b WorkspaceBuildBuilder) WithTask(taskSeed database.TaskTable, appSeed *sd
})
}
// Starting sets the job to running status.
func (b WorkspaceBuildBuilder) Starting(opts ...BuilderOption) WorkspaceBuildBuilder {
//nolint: revive // returns modified struct
func (b WorkspaceBuildBuilder) Starting() WorkspaceBuildBuilder {
b.jobStatus = database.ProvisionerJobStatusRunning
for _, opt := range opts {
opt(&b)
}
return b
}
// Pending sets the job to pending status.
func (b WorkspaceBuildBuilder) Pending(opts ...BuilderOption) WorkspaceBuildBuilder {
//nolint: revive // returns modified struct
func (b WorkspaceBuildBuilder) Pending() WorkspaceBuildBuilder {
b.jobStatus = database.ProvisionerJobStatusPending
for _, opt := range opts {
opt(&b)
}
return b
}
// Canceled sets the job to canceled status.
func (b WorkspaceBuildBuilder) Canceled(opts ...BuilderOption) WorkspaceBuildBuilder {
//nolint: revive // returns modified struct
func (b WorkspaceBuildBuilder) Canceled() WorkspaceBuildBuilder {
b.jobStatus = database.ProvisionerJobStatusCanceled
for _, opt := range opts {
opt(&b)
}
return b
}
// Succeeded sets the job to succeeded status.
// This is the default status.
func (b WorkspaceBuildBuilder) Succeeded(opts ...BuilderOption) WorkspaceBuildBuilder {
//nolint: revive // returns modified struct
b.jobStatus = database.ProvisionerJobStatusSucceeded
for _, opt := range opts {
opt(&b)
}
return b
}
// Failed sets the provisioner job to a failed state. Use WithJobError and
// WithJobErrorCode options to set the error message and code. If no error
// message is provided, "failed" is used as the default.
func (b WorkspaceBuildBuilder) Failed(opts ...BuilderOption) WorkspaceBuildBuilder {
//nolint: revive // returns modified struct
b.jobStatus = database.ProvisionerJobStatusFailed
for _, opt := range opts {
opt(&b)
}
if b.jobError == "" {
b.jobError = "failed"
}
return b
}
@@ -363,8 +267,8 @@ func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse {
job, err := b.db.InsertProvisionerJob(ownerCtx, database.InsertProvisionerJobParams{
ID: jobID,
CreatedAt: takeFirstTime(b.jobCreatedAt, b.ws.CreatedAt, dbtime.Now()),
UpdatedAt: takeFirstTime(b.jobCreatedAt, b.ws.CreatedAt, dbtime.Now()),
CreatedAt: dbtime.Now(),
UpdatedAt: dbtime.Now(),
OrganizationID: b.ws.OrganizationID,
InitiatorID: b.ws.OwnerID,
Provisioner: database.ProvisionerTypeEcho,
@@ -387,12 +291,11 @@ func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse {
// might need to do this multiple times if we got a template version
// import job as well
b.logger.Debug(context.Background(), "looping to acquire provisioner job")
startedAt := takeFirstTime(b.jobStartedAt, dbtime.Now())
for {
j, err := b.db.AcquireProvisionerJob(ownerCtx, database.AcquireProvisionerJobParams{
OrganizationID: job.OrganizationID,
StartedAt: sql.NullTime{
Time: startedAt,
Time: dbtime.Now(),
Valid: true,
},
WorkerID: uuid.NullUUID{
@@ -408,54 +311,32 @@ func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse {
break
}
}
if !b.jobUpdatedAt.IsZero() {
err = b.db.UpdateProvisionerJobByID(ownerCtx, database.UpdateProvisionerJobByIDParams{
ID: job.ID,
UpdatedAt: b.jobUpdatedAt,
})
require.NoError(b.t, err, "update job updated_at")
}
case database.ProvisionerJobStatusCanceled:
// Set provisioner job status to 'canceled'
b.logger.Debug(context.Background(), "canceling the provisioner job")
completedAt := takeFirstTime(b.jobCompletedAt, dbtime.Now())
now := dbtime.Now()
err = b.db.UpdateProvisionerJobWithCancelByID(ownerCtx, database.UpdateProvisionerJobWithCancelByIDParams{
ID: jobID,
CanceledAt: sql.NullTime{
Time: completedAt,
Time: now,
Valid: true,
},
CompletedAt: sql.NullTime{
Time: completedAt,
Time: now,
Valid: true,
},
})
require.NoError(b.t, err, "cancel job")
case database.ProvisionerJobStatusFailed:
b.logger.Debug(context.Background(), "failing the provisioner job")
completedAt := takeFirstTime(b.jobCompletedAt, dbtime.Now())
err = b.db.UpdateProvisionerJobWithCompleteByID(ownerCtx, database.UpdateProvisionerJobWithCompleteByIDParams{
ID: job.ID,
UpdatedAt: completedAt,
Error: sql.NullString{String: b.jobError, Valid: b.jobError != ""},
ErrorCode: sql.NullString{String: b.jobErrorCode, Valid: b.jobErrorCode != ""},
CompletedAt: sql.NullTime{
Time: completedAt,
Valid: true,
},
})
require.NoError(b.t, err, "fail job")
default:
// By default, consider jobs in 'succeeded' status
b.logger.Debug(context.Background(), "completing the provisioner job")
completedAt := takeFirstTime(b.jobCompletedAt, dbtime.Now())
err = b.db.UpdateProvisionerJobWithCompleteByID(ownerCtx, database.UpdateProvisionerJobWithCompleteByIDParams{
ID: job.ID,
UpdatedAt: completedAt,
UpdatedAt: dbtime.Now(),
Error: sql.NullString{},
ErrorCode: sql.NullString{},
CompletedAt: sql.NullTime{
Time: completedAt,
Time: dbtime.Now(),
Valid: true,
},
})
@@ -870,16 +751,6 @@ func takeFirst[Value comparable](values ...Value) Value {
})
}
// takeFirstTime returns the first non-zero time.Time.
func takeFirstTime(values ...time.Time) time.Time {
for _, v := range values {
if !v.IsZero() {
return v
}
}
return time.Time{}
}
// mustWorkspaceAppByWorkspaceAndBuildAndAppID finds a workspace app by
// workspace ID, build number, and app ID. It returns the workspace app
// if found, otherwise fails the test.
+8 -29
View File
@@ -86,27 +86,18 @@ func ConnectionLog(t testing.TB, db database.Store, seed database.UpsertConnecti
WorkspaceID: takeFirst(seed.WorkspaceID, uuid.New()),
WorkspaceName: takeFirst(seed.WorkspaceName, testutil.GetRandomName(t)),
AgentName: takeFirst(seed.AgentName, testutil.GetRandomName(t)),
AgentID: uuid.NullUUID{
UUID: takeFirst(seed.AgentID.UUID, uuid.Nil),
Valid: takeFirst(seed.AgentID.Valid, false),
},
Type: takeFirst(seed.Type, database.ConnectionTypeSsh),
Type: takeFirst(seed.Type, database.ConnectionTypeSsh),
Code: sql.NullInt32{
Int32: takeFirst(seed.Code.Int32, 0),
Valid: takeFirst(seed.Code.Valid, false),
},
Ip: func() pqtype.Inet {
if seed.Ip.Valid {
return seed.Ip
}
return pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
}
}(),
Ip: pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
},
UserAgent: sql.NullString{
String: takeFirst(seed.UserAgent.String, ""),
Valid: takeFirst(seed.UserAgent.Valid, false),
@@ -127,18 +118,6 @@ func ConnectionLog(t testing.TB, db database.Store, seed database.UpsertConnecti
String: takeFirst(seed.DisconnectReason.String, ""),
Valid: takeFirst(seed.DisconnectReason.Valid, false),
},
SessionID: uuid.NullUUID{
UUID: takeFirst(seed.SessionID.UUID, uuid.Nil),
Valid: takeFirst(seed.SessionID.Valid, false),
},
ClientHostname: sql.NullString{
String: takeFirst(seed.ClientHostname.String, ""),
Valid: takeFirst(seed.ClientHostname.Valid, false),
},
ShortDescription: sql.NullString{
String: takeFirst(seed.ShortDescription.String, ""),
Valid: takeFirst(seed.ShortDescription.Valid, false),
},
ConnectionStatus: takeFirst(seed.ConnectionStatus, database.ConnectionStatusConnected),
})
require.NoError(t, err, "insert connection log")
+22 -126
View File
@@ -231,22 +231,6 @@ func (m queryMetricsStore) CleanTailnetTunnels(ctx context.Context) error {
return r0
}
func (m queryMetricsStore) CloseConnectionLogsAndCreateSessions(ctx context.Context, arg database.CloseConnectionLogsAndCreateSessionsParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.CloseConnectionLogsAndCreateSessions(ctx, arg)
m.queryLatencies.WithLabelValues("CloseConnectionLogsAndCreateSessions").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CloseConnectionLogsAndCreateSessions").Inc()
return r0, r1
}
func (m queryMetricsStore) CloseOpenAgentConnectionLogsForWorkspace(ctx context.Context, arg database.CloseOpenAgentConnectionLogsForWorkspaceParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.CloseOpenAgentConnectionLogsForWorkspace(ctx, arg)
m.queryLatencies.WithLabelValues("CloseOpenAgentConnectionLogsForWorkspace").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CloseOpenAgentConnectionLogsForWorkspace").Inc()
return r0, r1
}
func (m queryMetricsStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.CountAIBridgeInterceptions(ctx, arg)
@@ -271,14 +255,6 @@ func (m queryMetricsStore) CountConnectionLogs(ctx context.Context, arg database
return r0, r1
}
func (m queryMetricsStore) CountGlobalWorkspaceSessions(ctx context.Context, arg database.CountGlobalWorkspaceSessionsParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.CountGlobalWorkspaceSessions(ctx, arg)
m.queryLatencies.WithLabelValues("CountGlobalWorkspaceSessions").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CountGlobalWorkspaceSessions").Inc()
return r0, r1
}
func (m queryMetricsStore) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
start := time.Now()
r0, r1 := m.s.CountInProgressPrebuilds(ctx)
@@ -303,14 +279,6 @@ func (m queryMetricsStore) CountUnreadInboxNotificationsByUserID(ctx context.Con
return r0, r1
}
func (m queryMetricsStore) CountWorkspaceSessions(ctx context.Context, arg database.CountWorkspaceSessionsParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.CountWorkspaceSessions(ctx, arg)
m.queryLatencies.WithLabelValues("CountWorkspaceSessions").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CountWorkspaceSessions").Inc()
return r0, r1
}
func (m queryMetricsStore) CreateUserSecret(ctx context.Context, arg database.CreateUserSecretParams) (database.UserSecret, error) {
start := time.Now()
r0, r1 := m.s.CreateUserSecret(ctx, arg)
@@ -367,6 +335,14 @@ func (m queryMetricsStore) DeleteApplicationConnectAPIKeysByUserID(ctx context.C
return r0
}
func (m queryMetricsStore) DeleteBoundaryUsageStatsByReplicaID(ctx context.Context, replicaID uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteBoundaryUsageStatsByReplicaID(ctx, replicaID)
m.queryLatencies.WithLabelValues("DeleteBoundaryUsageStatsByReplicaID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteBoundaryUsageStatsByReplicaID").Inc()
return r0
}
func (m queryMetricsStore) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
start := time.Now()
r0, r1 := m.s.DeleteCryptoKey(ctx, arg)
@@ -599,7 +575,7 @@ func (m queryMetricsStore) DeleteTailnetTunnel(ctx context.Context, arg database
return r0, r1
}
func (m queryMetricsStore) DeleteTask(ctx context.Context, arg database.DeleteTaskParams) (uuid.UUID, error) {
func (m queryMetricsStore) DeleteTask(ctx context.Context, arg database.DeleteTaskParams) (database.TaskTable, error) {
start := time.Now()
r0, r1 := m.s.DeleteTask(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteTask").Observe(time.Since(start).Seconds())
@@ -750,14 +726,6 @@ func (m queryMetricsStore) FindMatchingPresetID(ctx context.Context, arg databas
return r0, r1
}
func (m queryMetricsStore) FindOrCreateSessionForDisconnect(ctx context.Context, arg database.FindOrCreateSessionForDisconnectParams) (interface{}, error) {
start := time.Now()
r0, r1 := m.s.FindOrCreateSessionForDisconnect(ctx, arg)
m.queryLatencies.WithLabelValues("FindOrCreateSessionForDisconnect").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "FindOrCreateSessionForDisconnect").Inc()
return r0, r1
}
func (m queryMetricsStore) GetAIBridgeInterceptionByID(ctx context.Context, id uuid.UUID) (database.AIBridgeInterception, error) {
start := time.Now()
r0, r1 := m.s.GetAIBridgeInterceptionByID(ctx, id)
@@ -870,14 +838,6 @@ func (m queryMetricsStore) GetAllTailnetCoordinators(ctx context.Context) ([]dat
return r0, r1
}
func (m queryMetricsStore) GetAllTailnetPeeringEventsByPeerID(ctx context.Context, srcPeerID uuid.NullUUID) ([]database.TailnetPeeringEvent, error) {
start := time.Now()
r0, r1 := m.s.GetAllTailnetPeeringEventsByPeerID(ctx, srcPeerID)
m.queryLatencies.WithLabelValues("GetAllTailnetPeeringEventsByPeerID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetAllTailnetPeeringEventsByPeerID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetAllTailnetPeers(ctx context.Context) ([]database.TailnetPeer, error) {
start := time.Now()
r0, r1 := m.s.GetAllTailnetPeers(ctx)
@@ -894,14 +854,6 @@ func (m queryMetricsStore) GetAllTailnetTunnels(ctx context.Context) ([]database
return r0, r1
}
func (m queryMetricsStore) GetAndResetBoundaryUsageSummary(ctx context.Context, maxStalenessMs int64) (database.GetAndResetBoundaryUsageSummaryRow, error) {
start := time.Now()
r0, r1 := m.s.GetAndResetBoundaryUsageSummary(ctx, maxStalenessMs)
m.queryLatencies.WithLabelValues("GetAndResetBoundaryUsageSummary").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetAndResetBoundaryUsageSummary").Inc()
return r0, r1
}
func (m queryMetricsStore) GetAnnouncementBanners(ctx context.Context) (string, error) {
start := time.Now()
r0, r1 := m.s.GetAnnouncementBanners(ctx)
@@ -950,19 +902,11 @@ func (m queryMetricsStore) GetAuthorizationUserRoles(ctx context.Context, userID
return r0, r1
}
func (m queryMetricsStore) GetConnectionLogByConnectionID(ctx context.Context, arg database.GetConnectionLogByConnectionIDParams) (database.ConnectionLog, error) {
func (m queryMetricsStore) GetBoundaryUsageSummary(ctx context.Context, maxStalenessMs int64) (database.GetBoundaryUsageSummaryRow, error) {
start := time.Now()
r0, r1 := m.s.GetConnectionLogByConnectionID(ctx, arg)
m.queryLatencies.WithLabelValues("GetConnectionLogByConnectionID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetConnectionLogByConnectionID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetConnectionLogsBySessionIDs(ctx context.Context, sessionIds []uuid.UUID) ([]database.ConnectionLog, error) {
start := time.Now()
r0, r1 := m.s.GetConnectionLogsBySessionIDs(ctx, sessionIds)
m.queryLatencies.WithLabelValues("GetConnectionLogsBySessionIDs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetConnectionLogsBySessionIDs").Inc()
r0, r1 := m.s.GetBoundaryUsageSummary(ctx, maxStalenessMs)
m.queryLatencies.WithLabelValues("GetBoundaryUsageSummary").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetBoundaryUsageSummary").Inc()
return r0, r1
}
@@ -1158,14 +1102,6 @@ func (m queryMetricsStore) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (
return r0, r1
}
func (m queryMetricsStore) GetGlobalWorkspaceSessionsOffset(ctx context.Context, arg database.GetGlobalWorkspaceSessionsOffsetParams) ([]database.GetGlobalWorkspaceSessionsOffsetRow, error) {
start := time.Now()
r0, r1 := m.s.GetGlobalWorkspaceSessionsOffset(ctx, arg)
m.queryLatencies.WithLabelValues("GetGlobalWorkspaceSessionsOffset").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetGlobalWorkspaceSessionsOffset").Inc()
return r0, r1
}
func (m queryMetricsStore) GetGroupByID(ctx context.Context, id uuid.UUID) (database.Group, error) {
start := time.Now()
r0, r1 := m.s.GetGroupByID(ctx, id)
@@ -1462,14 +1398,6 @@ func (m queryMetricsStore) GetOAuthSigningKey(ctx context.Context) (string, erro
return r0, r1
}
func (m queryMetricsStore) GetOngoingAgentConnectionsLast24h(ctx context.Context, arg database.GetOngoingAgentConnectionsLast24hParams) ([]database.GetOngoingAgentConnectionsLast24hRow, error) {
start := time.Now()
r0, r1 := m.s.GetOngoingAgentConnectionsLast24h(ctx, arg)
m.queryLatencies.WithLabelValues("GetOngoingAgentConnectionsLast24h").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetOngoingAgentConnectionsLast24h").Inc()
return r0, r1
}
func (m queryMetricsStore) GetOrganizationByID(ctx context.Context, id uuid.UUID) (database.Organization, error) {
start := time.Now()
r0, r1 := m.s.GetOrganizationByID(ctx, id)
@@ -1814,14 +1742,6 @@ func (m queryMetricsStore) GetTailnetTunnelPeerBindings(ctx context.Context, src
return r0, r1
}
func (m queryMetricsStore) GetTailnetTunnelPeerBindingsByDstID(ctx context.Context, dstID uuid.UUID) ([]database.GetTailnetTunnelPeerBindingsByDstIDRow, error) {
start := time.Now()
r0, r1 := m.s.GetTailnetTunnelPeerBindingsByDstID(ctx, dstID)
m.queryLatencies.WithLabelValues("GetTailnetTunnelPeerBindingsByDstID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetTailnetTunnelPeerBindingsByDstID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) ([]database.GetTailnetTunnelPeerIDsRow, error) {
start := time.Now()
r0, r1 := m.s.GetTailnetTunnelPeerIDs(ctx, srcID)
@@ -2494,14 +2414,6 @@ func (m queryMetricsStore) GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx cont
return r0, r1
}
func (m queryMetricsStore) GetWorkspaceBuildMetricsByResourceID(ctx context.Context, id uuid.UUID) (database.GetWorkspaceBuildMetricsByResourceIDRow, error) {
start := time.Now()
r0, r1 := m.s.GetWorkspaceBuildMetricsByResourceID(ctx, id)
m.queryLatencies.WithLabelValues("GetWorkspaceBuildMetricsByResourceID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetWorkspaceBuildMetricsByResourceID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetWorkspaceBuildParameters(ctx context.Context, workspaceBuildID uuid.UUID) ([]database.WorkspaceBuildParameter, error) {
start := time.Now()
r0, r1 := m.s.GetWorkspaceBuildParameters(ctx, workspaceBuildID)
@@ -2678,14 +2590,6 @@ func (m queryMetricsStore) GetWorkspaceResourcesCreatedAfter(ctx context.Context
return r0, r1
}
func (m queryMetricsStore) GetWorkspaceSessionsOffset(ctx context.Context, arg database.GetWorkspaceSessionsOffsetParams) ([]database.GetWorkspaceSessionsOffsetRow, error) {
start := time.Now()
r0, r1 := m.s.GetWorkspaceSessionsOffset(ctx, arg)
m.queryLatencies.WithLabelValues("GetWorkspaceSessionsOffset").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetWorkspaceSessionsOffset").Inc()
return r0, r1
}
func (m queryMetricsStore) GetWorkspaceUniqueOwnerCountByTemplateIDs(ctx context.Context, templateIds []uuid.UUID) ([]database.GetWorkspaceUniqueOwnerCountByTemplateIDsRow, error) {
start := time.Now()
r0, r1 := m.s.GetWorkspaceUniqueOwnerCountByTemplateIDs(ctx, templateIds)
@@ -3014,14 +2918,6 @@ func (m queryMetricsStore) InsertReplica(ctx context.Context, arg database.Inser
return r0, r1
}
func (m queryMetricsStore) InsertTailnetPeeringEvent(ctx context.Context, arg database.InsertTailnetPeeringEventParams) error {
start := time.Now()
r0 := m.s.InsertTailnetPeeringEvent(ctx, arg)
m.queryLatencies.WithLabelValues("InsertTailnetPeeringEvent").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertTailnetPeeringEvent").Inc()
return r0
}
func (m queryMetricsStore) InsertTask(ctx context.Context, arg database.InsertTaskParams) (database.TaskTable, error) {
start := time.Now()
r0, r1 := m.s.InsertTask(ctx, arg)
@@ -3438,6 +3334,14 @@ func (m queryMetricsStore) RemoveUserFromGroups(ctx context.Context, arg databas
return r0, r1
}
func (m queryMetricsStore) ResetBoundaryUsageStats(ctx context.Context) error {
start := time.Now()
r0 := m.s.ResetBoundaryUsageStats(ctx)
m.queryLatencies.WithLabelValues("ResetBoundaryUsageStats").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ResetBoundaryUsageStats").Inc()
return r0
}
func (m queryMetricsStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
start := time.Now()
r0 := m.s.RevokeDBCryptKey(ctx, activeKeyDigest)
@@ -3494,14 +3398,6 @@ func (m queryMetricsStore) UpdateAPIKeyByID(ctx context.Context, arg database.Up
return r0
}
func (m queryMetricsStore) UpdateConnectionLogSessionID(ctx context.Context, arg database.UpdateConnectionLogSessionIDParams) error {
start := time.Now()
r0 := m.s.UpdateConnectionLogSessionID(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateConnectionLogSessionID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateConnectionLogSessionID").Inc()
return r0
}
func (m queryMetricsStore) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {
start := time.Now()
r0, r1 := m.s.UpdateCryptoKeyDeletesAt(ctx, arg)
@@ -4397,7 +4293,7 @@ func (m queryMetricsStore) UpsertWorkspaceApp(ctx context.Context, arg database.
return r0, r1
}
func (m queryMetricsStore) UpsertWorkspaceAppAuditSession(ctx context.Context, arg database.UpsertWorkspaceAppAuditSessionParams) (database.UpsertWorkspaceAppAuditSessionRow, error) {
func (m queryMetricsStore) UpsertWorkspaceAppAuditSession(ctx context.Context, arg database.UpsertWorkspaceAppAuditSessionParams) (bool, error) {
start := time.Now()
r0, r1 := m.s.UpsertWorkspaceAppAuditSession(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertWorkspaceAppAuditSession").Observe(time.Since(start).Seconds())
+39 -234
View File
@@ -276,36 +276,6 @@ func (mr *MockStoreMockRecorder) CleanTailnetTunnels(ctx any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanTailnetTunnels", reflect.TypeOf((*MockStore)(nil).CleanTailnetTunnels), ctx)
}
// CloseConnectionLogsAndCreateSessions mocks base method.
func (m *MockStore) CloseConnectionLogsAndCreateSessions(ctx context.Context, arg database.CloseConnectionLogsAndCreateSessionsParams) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CloseConnectionLogsAndCreateSessions", ctx, arg)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CloseConnectionLogsAndCreateSessions indicates an expected call of CloseConnectionLogsAndCreateSessions.
func (mr *MockStoreMockRecorder) CloseConnectionLogsAndCreateSessions(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseConnectionLogsAndCreateSessions", reflect.TypeOf((*MockStore)(nil).CloseConnectionLogsAndCreateSessions), ctx, arg)
}
// CloseOpenAgentConnectionLogsForWorkspace mocks base method.
func (m *MockStore) CloseOpenAgentConnectionLogsForWorkspace(ctx context.Context, arg database.CloseOpenAgentConnectionLogsForWorkspaceParams) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CloseOpenAgentConnectionLogsForWorkspace", ctx, arg)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CloseOpenAgentConnectionLogsForWorkspace indicates an expected call of CloseOpenAgentConnectionLogsForWorkspace.
func (mr *MockStoreMockRecorder) CloseOpenAgentConnectionLogsForWorkspace(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseOpenAgentConnectionLogsForWorkspace", reflect.TypeOf((*MockStore)(nil).CloseOpenAgentConnectionLogsForWorkspace), ctx, arg)
}
// CountAIBridgeInterceptions mocks base method.
func (m *MockStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
m.ctrl.T.Helper()
@@ -396,21 +366,6 @@ func (mr *MockStoreMockRecorder) CountConnectionLogs(ctx, arg any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountConnectionLogs", reflect.TypeOf((*MockStore)(nil).CountConnectionLogs), ctx, arg)
}
// CountGlobalWorkspaceSessions mocks base method.
func (m *MockStore) CountGlobalWorkspaceSessions(ctx context.Context, arg database.CountGlobalWorkspaceSessionsParams) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CountGlobalWorkspaceSessions", ctx, arg)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CountGlobalWorkspaceSessions indicates an expected call of CountGlobalWorkspaceSessions.
func (mr *MockStoreMockRecorder) CountGlobalWorkspaceSessions(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountGlobalWorkspaceSessions", reflect.TypeOf((*MockStore)(nil).CountGlobalWorkspaceSessions), ctx, arg)
}
// CountInProgressPrebuilds mocks base method.
func (m *MockStore) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
m.ctrl.T.Helper()
@@ -456,21 +411,6 @@ func (mr *MockStoreMockRecorder) CountUnreadInboxNotificationsByUserID(ctx, user
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountUnreadInboxNotificationsByUserID", reflect.TypeOf((*MockStore)(nil).CountUnreadInboxNotificationsByUserID), ctx, userID)
}
// CountWorkspaceSessions mocks base method.
func (m *MockStore) CountWorkspaceSessions(ctx context.Context, arg database.CountWorkspaceSessionsParams) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CountWorkspaceSessions", ctx, arg)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CountWorkspaceSessions indicates an expected call of CountWorkspaceSessions.
func (mr *MockStoreMockRecorder) CountWorkspaceSessions(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountWorkspaceSessions", reflect.TypeOf((*MockStore)(nil).CountWorkspaceSessions), ctx, arg)
}
// CreateUserSecret mocks base method.
func (m *MockStore) CreateUserSecret(ctx context.Context, arg database.CreateUserSecretParams) (database.UserSecret, error) {
m.ctrl.T.Helper()
@@ -571,6 +511,20 @@ func (mr *MockStoreMockRecorder) DeleteApplicationConnectAPIKeysByUserID(ctx, us
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteApplicationConnectAPIKeysByUserID", reflect.TypeOf((*MockStore)(nil).DeleteApplicationConnectAPIKeysByUserID), ctx, userID)
}
// DeleteBoundaryUsageStatsByReplicaID mocks base method.
func (m *MockStore) DeleteBoundaryUsageStatsByReplicaID(ctx context.Context, replicaID uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteBoundaryUsageStatsByReplicaID", ctx, replicaID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteBoundaryUsageStatsByReplicaID indicates an expected call of DeleteBoundaryUsageStatsByReplicaID.
func (mr *MockStoreMockRecorder) DeleteBoundaryUsageStatsByReplicaID(ctx, replicaID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteBoundaryUsageStatsByReplicaID", reflect.TypeOf((*MockStore)(nil).DeleteBoundaryUsageStatsByReplicaID), ctx, replicaID)
}
// DeleteCryptoKey mocks base method.
func (m *MockStore) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
m.ctrl.T.Helper()
@@ -987,10 +941,10 @@ func (mr *MockStoreMockRecorder) DeleteTailnetTunnel(ctx, arg any) *gomock.Call
}
// DeleteTask mocks base method.
func (m *MockStore) DeleteTask(ctx context.Context, arg database.DeleteTaskParams) (uuid.UUID, error) {
func (m *MockStore) DeleteTask(ctx context.Context, arg database.DeleteTaskParams) (database.TaskTable, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteTask", ctx, arg)
ret0, _ := ret[0].(uuid.UUID)
ret0, _ := ret[0].(database.TaskTable)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -1259,21 +1213,6 @@ func (mr *MockStoreMockRecorder) FindMatchingPresetID(ctx, arg any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindMatchingPresetID", reflect.TypeOf((*MockStore)(nil).FindMatchingPresetID), ctx, arg)
}
// FindOrCreateSessionForDisconnect mocks base method.
func (m *MockStore) FindOrCreateSessionForDisconnect(ctx context.Context, arg database.FindOrCreateSessionForDisconnectParams) (any, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FindOrCreateSessionForDisconnect", ctx, arg)
ret0, _ := ret[0].(any)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// FindOrCreateSessionForDisconnect indicates an expected call of FindOrCreateSessionForDisconnect.
func (mr *MockStoreMockRecorder) FindOrCreateSessionForDisconnect(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindOrCreateSessionForDisconnect", reflect.TypeOf((*MockStore)(nil).FindOrCreateSessionForDisconnect), ctx, arg)
}
// GetAIBridgeInterceptionByID mocks base method.
func (m *MockStore) GetAIBridgeInterceptionByID(ctx context.Context, id uuid.UUID) (database.AIBridgeInterception, error) {
m.ctrl.T.Helper()
@@ -1484,21 +1423,6 @@ func (mr *MockStoreMockRecorder) GetAllTailnetCoordinators(ctx any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllTailnetCoordinators", reflect.TypeOf((*MockStore)(nil).GetAllTailnetCoordinators), ctx)
}
// GetAllTailnetPeeringEventsByPeerID mocks base method.
func (m *MockStore) GetAllTailnetPeeringEventsByPeerID(ctx context.Context, srcPeerID uuid.NullUUID) ([]database.TailnetPeeringEvent, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAllTailnetPeeringEventsByPeerID", ctx, srcPeerID)
ret0, _ := ret[0].([]database.TailnetPeeringEvent)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAllTailnetPeeringEventsByPeerID indicates an expected call of GetAllTailnetPeeringEventsByPeerID.
func (mr *MockStoreMockRecorder) GetAllTailnetPeeringEventsByPeerID(ctx, srcPeerID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllTailnetPeeringEventsByPeerID", reflect.TypeOf((*MockStore)(nil).GetAllTailnetPeeringEventsByPeerID), ctx, srcPeerID)
}
// GetAllTailnetPeers mocks base method.
func (m *MockStore) GetAllTailnetPeers(ctx context.Context) ([]database.TailnetPeer, error) {
m.ctrl.T.Helper()
@@ -1529,21 +1453,6 @@ func (mr *MockStoreMockRecorder) GetAllTailnetTunnels(ctx any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllTailnetTunnels", reflect.TypeOf((*MockStore)(nil).GetAllTailnetTunnels), ctx)
}
// GetAndResetBoundaryUsageSummary mocks base method.
func (m *MockStore) GetAndResetBoundaryUsageSummary(ctx context.Context, maxStalenessMs int64) (database.GetAndResetBoundaryUsageSummaryRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAndResetBoundaryUsageSummary", ctx, maxStalenessMs)
ret0, _ := ret[0].(database.GetAndResetBoundaryUsageSummaryRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAndResetBoundaryUsageSummary indicates an expected call of GetAndResetBoundaryUsageSummary.
func (mr *MockStoreMockRecorder) GetAndResetBoundaryUsageSummary(ctx, maxStalenessMs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAndResetBoundaryUsageSummary", reflect.TypeOf((*MockStore)(nil).GetAndResetBoundaryUsageSummary), ctx, maxStalenessMs)
}
// GetAnnouncementBanners mocks base method.
func (m *MockStore) GetAnnouncementBanners(ctx context.Context) (string, error) {
m.ctrl.T.Helper()
@@ -1739,34 +1648,19 @@ func (mr *MockStoreMockRecorder) GetAuthorizedWorkspacesAndAgentsByOwnerID(ctx,
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedWorkspacesAndAgentsByOwnerID", reflect.TypeOf((*MockStore)(nil).GetAuthorizedWorkspacesAndAgentsByOwnerID), ctx, ownerID, prepared)
}
// GetConnectionLogByConnectionID mocks base method.
func (m *MockStore) GetConnectionLogByConnectionID(ctx context.Context, arg database.GetConnectionLogByConnectionIDParams) (database.ConnectionLog, error) {
// GetBoundaryUsageSummary mocks base method.
func (m *MockStore) GetBoundaryUsageSummary(ctx context.Context, maxStalenessMs int64) (database.GetBoundaryUsageSummaryRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetConnectionLogByConnectionID", ctx, arg)
ret0, _ := ret[0].(database.ConnectionLog)
ret := m.ctrl.Call(m, "GetBoundaryUsageSummary", ctx, maxStalenessMs)
ret0, _ := ret[0].(database.GetBoundaryUsageSummaryRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetConnectionLogByConnectionID indicates an expected call of GetConnectionLogByConnectionID.
func (mr *MockStoreMockRecorder) GetConnectionLogByConnectionID(ctx, arg any) *gomock.Call {
// GetBoundaryUsageSummary indicates an expected call of GetBoundaryUsageSummary.
func (mr *MockStoreMockRecorder) GetBoundaryUsageSummary(ctx, maxStalenessMs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConnectionLogByConnectionID", reflect.TypeOf((*MockStore)(nil).GetConnectionLogByConnectionID), ctx, arg)
}
// GetConnectionLogsBySessionIDs mocks base method.
func (m *MockStore) GetConnectionLogsBySessionIDs(ctx context.Context, sessionIds []uuid.UUID) ([]database.ConnectionLog, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetConnectionLogsBySessionIDs", ctx, sessionIds)
ret0, _ := ret[0].([]database.ConnectionLog)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetConnectionLogsBySessionIDs indicates an expected call of GetConnectionLogsBySessionIDs.
func (mr *MockStoreMockRecorder) GetConnectionLogsBySessionIDs(ctx, sessionIds any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConnectionLogsBySessionIDs", reflect.TypeOf((*MockStore)(nil).GetConnectionLogsBySessionIDs), ctx, sessionIds)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBoundaryUsageSummary", reflect.TypeOf((*MockStore)(nil).GetBoundaryUsageSummary), ctx, maxStalenessMs)
}
// GetConnectionLogsOffset mocks base method.
@@ -2129,21 +2023,6 @@ func (mr *MockStoreMockRecorder) GetGitSSHKey(ctx, userID any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGitSSHKey", reflect.TypeOf((*MockStore)(nil).GetGitSSHKey), ctx, userID)
}
// GetGlobalWorkspaceSessionsOffset mocks base method.
func (m *MockStore) GetGlobalWorkspaceSessionsOffset(ctx context.Context, arg database.GetGlobalWorkspaceSessionsOffsetParams) ([]database.GetGlobalWorkspaceSessionsOffsetRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGlobalWorkspaceSessionsOffset", ctx, arg)
ret0, _ := ret[0].([]database.GetGlobalWorkspaceSessionsOffsetRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGlobalWorkspaceSessionsOffset indicates an expected call of GetGlobalWorkspaceSessionsOffset.
func (mr *MockStoreMockRecorder) GetGlobalWorkspaceSessionsOffset(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGlobalWorkspaceSessionsOffset", reflect.TypeOf((*MockStore)(nil).GetGlobalWorkspaceSessionsOffset), ctx, arg)
}
// GetGroupByID mocks base method.
func (m *MockStore) GetGroupByID(ctx context.Context, id uuid.UUID) (database.Group, error) {
m.ctrl.T.Helper()
@@ -2699,21 +2578,6 @@ func (mr *MockStoreMockRecorder) GetOAuthSigningKey(ctx any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOAuthSigningKey", reflect.TypeOf((*MockStore)(nil).GetOAuthSigningKey), ctx)
}
// GetOngoingAgentConnectionsLast24h mocks base method.
func (m *MockStore) GetOngoingAgentConnectionsLast24h(ctx context.Context, arg database.GetOngoingAgentConnectionsLast24hParams) ([]database.GetOngoingAgentConnectionsLast24hRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOngoingAgentConnectionsLast24h", ctx, arg)
ret0, _ := ret[0].([]database.GetOngoingAgentConnectionsLast24hRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOngoingAgentConnectionsLast24h indicates an expected call of GetOngoingAgentConnectionsLast24h.
func (mr *MockStoreMockRecorder) GetOngoingAgentConnectionsLast24h(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOngoingAgentConnectionsLast24h", reflect.TypeOf((*MockStore)(nil).GetOngoingAgentConnectionsLast24h), ctx, arg)
}
// GetOrganizationByID mocks base method.
func (m *MockStore) GetOrganizationByID(ctx context.Context, id uuid.UUID) (database.Organization, error) {
m.ctrl.T.Helper()
@@ -3359,21 +3223,6 @@ func (mr *MockStoreMockRecorder) GetTailnetTunnelPeerBindings(ctx, srcID any) *g
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTailnetTunnelPeerBindings", reflect.TypeOf((*MockStore)(nil).GetTailnetTunnelPeerBindings), ctx, srcID)
}
// GetTailnetTunnelPeerBindingsByDstID mocks base method.
func (m *MockStore) GetTailnetTunnelPeerBindingsByDstID(ctx context.Context, dstID uuid.UUID) ([]database.GetTailnetTunnelPeerBindingsByDstIDRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTailnetTunnelPeerBindingsByDstID", ctx, dstID)
ret0, _ := ret[0].([]database.GetTailnetTunnelPeerBindingsByDstIDRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetTailnetTunnelPeerBindingsByDstID indicates an expected call of GetTailnetTunnelPeerBindingsByDstID.
func (mr *MockStoreMockRecorder) GetTailnetTunnelPeerBindingsByDstID(ctx, dstID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTailnetTunnelPeerBindingsByDstID", reflect.TypeOf((*MockStore)(nil).GetTailnetTunnelPeerBindingsByDstID), ctx, dstID)
}
// GetTailnetTunnelPeerIDs mocks base method.
func (m *MockStore) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) ([]database.GetTailnetTunnelPeerIDsRow, error) {
m.ctrl.T.Helper()
@@ -4664,21 +4513,6 @@ func (mr *MockStoreMockRecorder) GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ct
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceBuildByWorkspaceIDAndBuildNumber", reflect.TypeOf((*MockStore)(nil).GetWorkspaceBuildByWorkspaceIDAndBuildNumber), ctx, arg)
}
// GetWorkspaceBuildMetricsByResourceID mocks base method.
func (m *MockStore) GetWorkspaceBuildMetricsByResourceID(ctx context.Context, id uuid.UUID) (database.GetWorkspaceBuildMetricsByResourceIDRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetWorkspaceBuildMetricsByResourceID", ctx, id)
ret0, _ := ret[0].(database.GetWorkspaceBuildMetricsByResourceIDRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetWorkspaceBuildMetricsByResourceID indicates an expected call of GetWorkspaceBuildMetricsByResourceID.
func (mr *MockStoreMockRecorder) GetWorkspaceBuildMetricsByResourceID(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceBuildMetricsByResourceID", reflect.TypeOf((*MockStore)(nil).GetWorkspaceBuildMetricsByResourceID), ctx, id)
}
// GetWorkspaceBuildParameters mocks base method.
func (m *MockStore) GetWorkspaceBuildParameters(ctx context.Context, workspaceBuildID uuid.UUID) ([]database.WorkspaceBuildParameter, error) {
m.ctrl.T.Helper()
@@ -5009,21 +4843,6 @@ func (mr *MockStoreMockRecorder) GetWorkspaceResourcesCreatedAfter(ctx, createdA
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceResourcesCreatedAfter", reflect.TypeOf((*MockStore)(nil).GetWorkspaceResourcesCreatedAfter), ctx, createdAt)
}
// GetWorkspaceSessionsOffset mocks base method.
func (m *MockStore) GetWorkspaceSessionsOffset(ctx context.Context, arg database.GetWorkspaceSessionsOffsetParams) ([]database.GetWorkspaceSessionsOffsetRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetWorkspaceSessionsOffset", ctx, arg)
ret0, _ := ret[0].([]database.GetWorkspaceSessionsOffsetRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetWorkspaceSessionsOffset indicates an expected call of GetWorkspaceSessionsOffset.
func (mr *MockStoreMockRecorder) GetWorkspaceSessionsOffset(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceSessionsOffset", reflect.TypeOf((*MockStore)(nil).GetWorkspaceSessionsOffset), ctx, arg)
}
// GetWorkspaceUniqueOwnerCountByTemplateIDs mocks base method.
func (m *MockStore) GetWorkspaceUniqueOwnerCountByTemplateIDs(ctx context.Context, templateIds []uuid.UUID) ([]database.GetWorkspaceUniqueOwnerCountByTemplateIDsRow, error) {
m.ctrl.T.Helper()
@@ -5649,20 +5468,6 @@ func (mr *MockStoreMockRecorder) InsertReplica(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertReplica", reflect.TypeOf((*MockStore)(nil).InsertReplica), ctx, arg)
}
// InsertTailnetPeeringEvent mocks base method.
func (m *MockStore) InsertTailnetPeeringEvent(ctx context.Context, arg database.InsertTailnetPeeringEventParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertTailnetPeeringEvent", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// InsertTailnetPeeringEvent indicates an expected call of InsertTailnetPeeringEvent.
func (mr *MockStoreMockRecorder) InsertTailnetPeeringEvent(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertTailnetPeeringEvent", reflect.TypeOf((*MockStore)(nil).InsertTailnetPeeringEvent), ctx, arg)
}
// InsertTask mocks base method.
func (m *MockStore) InsertTask(ctx context.Context, arg database.InsertTaskParams) (database.TaskTable, error) {
m.ctrl.T.Helper()
@@ -6473,6 +6278,20 @@ func (mr *MockStoreMockRecorder) RemoveUserFromGroups(ctx, arg any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromGroups), ctx, arg)
}
// ResetBoundaryUsageStats mocks base method.
func (m *MockStore) ResetBoundaryUsageStats(ctx context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ResetBoundaryUsageStats", ctx)
ret0, _ := ret[0].(error)
return ret0
}
// ResetBoundaryUsageStats indicates an expected call of ResetBoundaryUsageStats.
func (mr *MockStoreMockRecorder) ResetBoundaryUsageStats(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetBoundaryUsageStats", reflect.TypeOf((*MockStore)(nil).ResetBoundaryUsageStats), ctx)
}
// RevokeDBCryptKey mocks base method.
func (m *MockStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
m.ctrl.T.Helper()
@@ -6574,20 +6393,6 @@ func (mr *MockStoreMockRecorder) UpdateAPIKeyByID(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAPIKeyByID", reflect.TypeOf((*MockStore)(nil).UpdateAPIKeyByID), ctx, arg)
}
// UpdateConnectionLogSessionID mocks base method.
func (m *MockStore) UpdateConnectionLogSessionID(ctx context.Context, arg database.UpdateConnectionLogSessionIDParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateConnectionLogSessionID", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateConnectionLogSessionID indicates an expected call of UpdateConnectionLogSessionID.
func (mr *MockStoreMockRecorder) UpdateConnectionLogSessionID(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateConnectionLogSessionID", reflect.TypeOf((*MockStore)(nil).UpdateConnectionLogSessionID), ctx, arg)
}
// UpdateCryptoKeyDeletesAt mocks base method.
func (m *MockStore) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {
m.ctrl.T.Helper()
@@ -8201,10 +8006,10 @@ func (mr *MockStoreMockRecorder) UpsertWorkspaceApp(ctx, arg any) *gomock.Call {
}
// UpsertWorkspaceAppAuditSession mocks base method.
func (m *MockStore) UpsertWorkspaceAppAuditSession(ctx context.Context, arg database.UpsertWorkspaceAppAuditSessionParams) (database.UpsertWorkspaceAppAuditSessionRow, error) {
func (m *MockStore) UpsertWorkspaceAppAuditSession(ctx context.Context, arg database.UpsertWorkspaceAppAuditSessionParams) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertWorkspaceAppAuditSession", ctx, arg)
ret0, _ := ret[0].(database.UpsertWorkspaceAppAuditSessionRow)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
+4 -64
View File
@@ -271,8 +271,7 @@ CREATE TYPE connection_type AS ENUM (
'jetbrains',
'reconnecting_pty',
'workspace_app',
'port_forwarding',
'system'
'port_forwarding'
);
CREATE TYPE cors_behavior AS ENUM (
@@ -1016,11 +1015,6 @@ BEGIN
END;
$$;
CREATE TABLE agent_peering_ids (
agent_id uuid NOT NULL,
peering_id bytea NOT NULL
);
CREATE TABLE aibridge_interceptions (
id uuid NOT NULL,
initiator_id uuid NOT NULL,
@@ -1165,12 +1159,7 @@ CREATE TABLE connection_logs (
slug_or_port text,
connection_id uuid,
disconnect_time timestamp with time zone,
disconnect_reason text,
agent_id uuid,
updated_at timestamp with time zone DEFAULT now() NOT NULL,
session_id uuid,
client_hostname text,
short_description text
disconnect_reason text
);
COMMENT ON COLUMN connection_logs.code IS 'Either the HTTP status code of the web request, or the exit code of an SSH connection. For non-web connections, this is Null until we receive a disconnect event for the same connection_id.';
@@ -1187,8 +1176,6 @@ COMMENT ON COLUMN connection_logs.disconnect_time IS 'The time the connection wa
COMMENT ON COLUMN connection_logs.disconnect_reason IS 'The reason the connection was closed. Null for web connections. For other connections, this is null until we receive a disconnect event for the same connection_id.';
COMMENT ON COLUMN connection_logs.updated_at IS 'Last time this connection log was confirmed active. For agent connections, equals connect_time. For web connections, bumped while the session is active.';
CREATE TABLE crypto_keys (
feature crypto_key_feature NOT NULL,
sequence integer NOT NULL,
@@ -1784,15 +1771,6 @@ CREATE UNLOGGED TABLE tailnet_coordinators (
COMMENT ON TABLE tailnet_coordinators IS 'We keep this separate from replicas in case we need to break the coordinator out into its own service';
CREATE TABLE tailnet_peering_events (
peering_id bytea NOT NULL,
event_type text NOT NULL,
src_peer_id uuid,
dst_peer_id uuid,
node bytea,
occurred_at timestamp with time zone NOT NULL
);
CREATE UNLOGGED TABLE tailnet_peers (
id uuid NOT NULL,
coordinator_id uuid NOT NULL,
@@ -2312,8 +2290,7 @@ CREATE TABLE templates (
activity_bump bigint DEFAULT '3600000000000'::bigint NOT NULL,
max_port_sharing_level app_sharing_level DEFAULT 'owner'::app_sharing_level NOT NULL,
use_classic_parameter_flow boolean DEFAULT false NOT NULL,
cors_behavior cors_behavior DEFAULT 'simple'::cors_behavior NOT NULL,
disable_module_cache boolean DEFAULT false NOT NULL
cors_behavior cors_behavior DEFAULT 'simple'::cors_behavior NOT NULL
);
COMMENT ON COLUMN templates.default_ttl IS 'The default duration for autostop for workspaces created from this template.';
@@ -2367,7 +2344,6 @@ CREATE VIEW template_with_names AS
templates.max_port_sharing_level,
templates.use_classic_parameter_flow,
templates.cors_behavior,
templates.disable_module_cache,
COALESCE(visible_users.avatar_url, ''::text) AS created_by_avatar_url,
COALESCE(visible_users.username, ''::text) AS created_by_username,
COALESCE(visible_users.name, ''::text) AS created_by_name,
@@ -2626,8 +2602,7 @@ CREATE UNLOGGED TABLE workspace_app_audit_sessions (
status_code integer NOT NULL,
started_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone NOT NULL,
id uuid NOT NULL,
connection_id uuid
id uuid NOT NULL
);
COMMENT ON TABLE workspace_app_audit_sessions IS 'Audit sessions for workspace apps, the data in this table is ephemeral and is used to deduplicate audit log entries for workspace apps. While a session is active, the same data will not be logged again. This table does not store historical data.';
@@ -2921,18 +2896,6 @@ CREATE SEQUENCE workspace_resource_metadata_id_seq
ALTER SEQUENCE workspace_resource_metadata_id_seq OWNED BY workspace_resource_metadata.id;
CREATE TABLE workspace_sessions (
id uuid DEFAULT gen_random_uuid() NOT NULL,
workspace_id uuid NOT NULL,
agent_id uuid,
ip inet,
client_hostname text,
short_description text,
started_at timestamp with time zone NOT NULL,
ended_at timestamp with time zone NOT NULL,
created_at timestamp with time zone DEFAULT now() NOT NULL
);
CREATE VIEW workspaces_expanded AS
SELECT workspaces.id,
workspaces.created_at,
@@ -2990,9 +2953,6 @@ ALTER TABLE ONLY workspace_proxies ALTER COLUMN region_id SET DEFAULT nextval('w
ALTER TABLE ONLY workspace_resource_metadata ALTER COLUMN id SET DEFAULT nextval('workspace_resource_metadata_id_seq'::regclass);
ALTER TABLE ONLY agent_peering_ids
ADD CONSTRAINT agent_peering_ids_pkey PRIMARY KEY (agent_id, peering_id);
ALTER TABLE ONLY workspace_agent_stats
ADD CONSTRAINT agent_stats_pkey PRIMARY KEY (id);
@@ -3299,9 +3259,6 @@ ALTER TABLE ONLY workspace_resource_metadata
ALTER TABLE ONLY workspace_resources
ADD CONSTRAINT workspace_resources_pkey PRIMARY KEY (id);
ALTER TABLE ONLY workspace_sessions
ADD CONSTRAINT workspace_sessions_pkey PRIMARY KEY (id);
ALTER TABLE ONLY workspaces
ADD CONSTRAINT workspaces_pkey PRIMARY KEY (id);
@@ -3353,8 +3310,6 @@ COMMENT ON INDEX idx_connection_logs_connection_id_workspace_id_agent_name IS 'C
CREATE INDEX idx_connection_logs_organization_id ON connection_logs USING btree (organization_id);
CREATE INDEX idx_connection_logs_session ON connection_logs USING btree (session_id) WHERE (session_id IS NOT NULL);
CREATE INDEX idx_connection_logs_workspace_id ON connection_logs USING btree (workspace_id);
CREATE INDEX idx_connection_logs_workspace_owner_id ON connection_logs USING btree (workspace_owner_id);
@@ -3409,12 +3364,6 @@ CREATE INDEX idx_workspace_app_statuses_workspace_id_created_at ON workspace_app
CREATE INDEX idx_workspace_builds_initiator_id ON workspace_builds USING btree (initiator_id);
CREATE INDEX idx_workspace_sessions_hostname_lookup ON workspace_sessions USING btree (workspace_id, client_hostname, started_at) WHERE (client_hostname IS NOT NULL);
CREATE INDEX idx_workspace_sessions_ip_lookup ON workspace_sessions USING btree (workspace_id, ip, started_at) WHERE ((ip IS NOT NULL) AND (client_hostname IS NULL));
CREATE INDEX idx_workspace_sessions_workspace ON workspace_sessions USING btree (workspace_id, started_at DESC);
CREATE UNIQUE INDEX notification_messages_dedupe_hash_idx ON notification_messages USING btree (dedupe_hash);
CREATE UNIQUE INDEX organizations_single_default_org ON organizations USING btree (is_default) WHERE (is_default = true);
@@ -3602,9 +3551,6 @@ ALTER TABLE ONLY api_keys
ALTER TABLE ONLY connection_logs
ADD CONSTRAINT connection_logs_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
ALTER TABLE ONLY connection_logs
ADD CONSTRAINT connection_logs_session_id_fkey FOREIGN KEY (session_id) REFERENCES workspace_sessions(id) ON DELETE SET NULL;
ALTER TABLE ONLY connection_logs
ADD CONSTRAINT connection_logs_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE;
@@ -3878,12 +3824,6 @@ ALTER TABLE ONLY workspace_resource_metadata
ALTER TABLE ONLY workspace_resources
ADD CONSTRAINT workspace_resources_job_id_fkey FOREIGN KEY (job_id) REFERENCES provisioner_jobs(id) ON DELETE CASCADE;
ALTER TABLE ONLY workspace_sessions
ADD CONSTRAINT workspace_sessions_agent_id_fkey FOREIGN KEY (agent_id) REFERENCES workspace_agents(id) ON DELETE SET NULL;
ALTER TABLE ONLY workspace_sessions
ADD CONSTRAINT workspace_sessions_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE;
ALTER TABLE ONLY workspaces
ADD CONSTRAINT workspaces_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE RESTRICT;
@@ -9,7 +9,6 @@ const (
ForeignKeyAibridgeInterceptionsInitiatorID ForeignKeyConstraint = "aibridge_interceptions_initiator_id_fkey" // ALTER TABLE ONLY aibridge_interceptions ADD CONSTRAINT aibridge_interceptions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id);
ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyConnectionLogsOrganizationID ForeignKeyConstraint = "connection_logs_organization_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
ForeignKeyConnectionLogsSessionID ForeignKeyConstraint = "connection_logs_session_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_session_id_fkey FOREIGN KEY (session_id) REFERENCES workspace_sessions(id) ON DELETE SET NULL;
ForeignKeyConnectionLogsWorkspaceID ForeignKeyConstraint = "connection_logs_workspace_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE;
ForeignKeyConnectionLogsWorkspaceOwnerID ForeignKeyConstraint = "connection_logs_workspace_owner_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_workspace_owner_id_fkey FOREIGN KEY (workspace_owner_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyCryptoKeysSecretKeyID ForeignKeyConstraint = "crypto_keys_secret_key_id_fkey" // ALTER TABLE ONLY crypto_keys ADD CONSTRAINT crypto_keys_secret_key_id_fkey FOREIGN KEY (secret_key_id) REFERENCES dbcrypt_keys(active_key_digest);
@@ -101,8 +100,6 @@ const (
ForeignKeyWorkspaceModulesJobID ForeignKeyConstraint = "workspace_modules_job_id_fkey" // ALTER TABLE ONLY workspace_modules ADD CONSTRAINT workspace_modules_job_id_fkey FOREIGN KEY (job_id) REFERENCES provisioner_jobs(id) ON DELETE CASCADE;
ForeignKeyWorkspaceResourceMetadataWorkspaceResourceID ForeignKeyConstraint = "workspace_resource_metadata_workspace_resource_id_fkey" // ALTER TABLE ONLY workspace_resource_metadata ADD CONSTRAINT workspace_resource_metadata_workspace_resource_id_fkey FOREIGN KEY (workspace_resource_id) REFERENCES workspace_resources(id) ON DELETE CASCADE;
ForeignKeyWorkspaceResourcesJobID ForeignKeyConstraint = "workspace_resources_job_id_fkey" // ALTER TABLE ONLY workspace_resources ADD CONSTRAINT workspace_resources_job_id_fkey FOREIGN KEY (job_id) REFERENCES provisioner_jobs(id) ON DELETE CASCADE;
ForeignKeyWorkspaceSessionsAgentID ForeignKeyConstraint = "workspace_sessions_agent_id_fkey" // ALTER TABLE ONLY workspace_sessions ADD CONSTRAINT workspace_sessions_agent_id_fkey FOREIGN KEY (agent_id) REFERENCES workspace_agents(id) ON DELETE SET NULL;
ForeignKeyWorkspaceSessionsWorkspaceID ForeignKeyConstraint = "workspace_sessions_workspace_id_fkey" // ALTER TABLE ONLY workspace_sessions ADD CONSTRAINT workspace_sessions_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE;
ForeignKeyWorkspacesOrganizationID ForeignKeyConstraint = "workspaces_organization_id_fkey" // ALTER TABLE ONLY workspaces ADD CONSTRAINT workspaces_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE RESTRICT;
ForeignKeyWorkspacesOwnerID ForeignKeyConstraint = "workspaces_owner_id_fkey" // ALTER TABLE ONLY workspaces ADD CONSTRAINT workspaces_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE RESTRICT;
ForeignKeyWorkspacesTemplateID ForeignKeyConstraint = "workspaces_template_id_fkey" // ALTER TABLE ONLY workspaces ADD CONSTRAINT workspaces_template_id_fkey FOREIGN KEY (template_id) REFERENCES templates(id) ON DELETE RESTRICT;
-1
View File
@@ -14,7 +14,6 @@ const (
LockIDCryptoKeyRotation
LockIDReconcilePrebuilds
LockIDReconcileSystemRoles
LockIDBoundaryUsageStats
)
// GenLockID generates a unique and consistent lock ID from a given string.
@@ -1,16 +0,0 @@
DROP VIEW template_with_names;
ALTER TABLE templates DROP COLUMN disable_module_cache;
CREATE VIEW template_with_names AS
SELECT templates.*,
COALESCE(visible_users.avatar_url, ''::text) AS created_by_avatar_url,
COALESCE(visible_users.username, ''::text) AS created_by_username,
COALESCE(visible_users.name, ''::text) AS created_by_name,
COALESCE(organizations.name, ''::text) AS organization_name,
COALESCE(organizations.display_name, ''::text) AS organization_display_name,
COALESCE(organizations.icon, ''::text) AS organization_icon
FROM ((templates
LEFT JOIN visible_users ON ((templates.created_by = visible_users.id)))
LEFT JOIN organizations ON ((templates.organization_id = organizations.id)));
COMMENT ON VIEW template_with_names IS 'Joins in the display name information such as username, avatar, and organization name.';
@@ -1,16 +0,0 @@
DROP VIEW template_with_names;
ALTER TABLE templates ADD COLUMN disable_module_cache BOOL NOT NULL DEFAULT false;
CREATE VIEW template_with_names AS
SELECT templates.*,
COALESCE(visible_users.avatar_url, ''::text) AS created_by_avatar_url,
COALESCE(visible_users.username, ''::text) AS created_by_username,
COALESCE(visible_users.name, ''::text) AS created_by_name,
COALESCE(organizations.name, ''::text) AS organization_name,
COALESCE(organizations.display_name, ''::text) AS organization_display_name,
COALESCE(organizations.icon, ''::text) AS organization_icon
FROM ((templates
LEFT JOIN visible_users ON ((templates.created_by = visible_users.id)))
LEFT JOIN organizations ON ((templates.organization_id = organizations.id)));
COMMENT ON VIEW template_with_names IS 'Joins in the display name information such as username, avatar, and organization name.';
@@ -1,2 +0,0 @@
ALTER TABLE connection_logs
DROP COLUMN agent_id;
@@ -1,2 +0,0 @@
ALTER TABLE connection_logs
ADD COLUMN agent_id uuid;
@@ -1,2 +0,0 @@
ALTER TABLE workspace_app_audit_sessions DROP COLUMN IF EXISTS connection_id;
ALTER TABLE connection_logs DROP COLUMN IF EXISTS updated_at;
@@ -1,14 +0,0 @@
ALTER TABLE workspace_app_audit_sessions
ADD COLUMN connection_id uuid;
ALTER TABLE connection_logs
ADD COLUMN updated_at timestamp with time zone;
UPDATE connection_logs SET updated_at = connect_time WHERE updated_at IS NULL;
ALTER TABLE connection_logs
ALTER COLUMN updated_at SET NOT NULL,
ALTER COLUMN updated_at SET DEFAULT now();
COMMENT ON COLUMN connection_logs.updated_at IS
'Last time this connection log was confirmed active. For agent connections, equals connect_time. For web connections, bumped while the session is active.';
@@ -1,2 +0,0 @@
DROP TABLE IF EXISTS tailnet_peering_events;
DROP TABLE IF EXISTS agent_peering_ids;
@@ -1,14 +0,0 @@
CREATE TABLE agent_peering_ids (
agent_id uuid NOT NULL,
peering_id bytea NOT NULL,
PRIMARY KEY (agent_id, peering_id)
);
CREATE TABLE tailnet_peering_events (
peering_id bytea NOT NULL,
event_type text NOT NULL,
src_peer_id uuid,
dst_peer_id uuid,
node bytea,
occurred_at timestamp with time zone NOT NULL
);
@@ -1,8 +0,0 @@
DROP INDEX IF EXISTS idx_connection_logs_session;
ALTER TABLE connection_logs
DROP COLUMN IF EXISTS short_description,
DROP COLUMN IF EXISTS client_hostname,
DROP COLUMN IF EXISTS session_id;
DROP TABLE IF EXISTS workspace_sessions;
@@ -1,21 +0,0 @@
CREATE TABLE workspace_sessions (
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
workspace_id uuid NOT NULL REFERENCES workspaces(id) ON DELETE CASCADE,
agent_id uuid REFERENCES workspace_agents(id) ON DELETE SET NULL,
ip inet NOT NULL,
client_hostname text,
short_description text,
started_at timestamp with time zone NOT NULL,
ended_at timestamp with time zone NOT NULL,
created_at timestamp with time zone NOT NULL DEFAULT now()
);
CREATE INDEX idx_workspace_sessions_workspace ON workspace_sessions (workspace_id, started_at DESC);
CREATE INDEX idx_workspace_sessions_lookup ON workspace_sessions (workspace_id, ip, started_at);
ALTER TABLE connection_logs
ADD COLUMN session_id uuid REFERENCES workspace_sessions(id) ON DELETE SET NULL,
ADD COLUMN client_hostname text,
ADD COLUMN short_description text;
CREATE INDEX idx_connection_logs_session ON connection_logs (session_id) WHERE session_id IS NOT NULL;
@@ -1 +0,0 @@
-- No-op: PostgreSQL does not support removing enum values.
@@ -1 +0,0 @@
ALTER TYPE connection_type ADD VALUE IF NOT EXISTS 'system';
@@ -1,6 +0,0 @@
UPDATE workspace_sessions SET ip = '0.0.0.0'::inet WHERE ip IS NULL;
ALTER TABLE workspace_sessions ALTER COLUMN ip SET NOT NULL;
DROP INDEX IF EXISTS idx_workspace_sessions_hostname_lookup;
DROP INDEX IF EXISTS idx_workspace_sessions_ip_lookup;
CREATE INDEX idx_workspace_sessions_lookup ON workspace_sessions (workspace_id, ip, started_at);
@@ -1,13 +0,0 @@
-- Make workspace_sessions.ip nullable since sessions now group by
-- hostname (with IP fallback), and a session may span multiple IPs.
ALTER TABLE workspace_sessions ALTER COLUMN ip DROP NOT NULL;
-- Replace the IP-based lookup index with hostname-based indexes
-- to support the new grouping logic.
DROP INDEX IF EXISTS idx_workspace_sessions_lookup;
CREATE INDEX idx_workspace_sessions_hostname_lookup
ON workspace_sessions (workspace_id, client_hostname, started_at)
WHERE client_hostname IS NOT NULL;
CREATE INDEX idx_workspace_sessions_ip_lookup
ON workspace_sessions (workspace_id, ip, started_at)
WHERE ip IS NOT NULL AND client_hostname IS NULL;
@@ -1,17 +0,0 @@
INSERT INTO agent_peering_ids
(agent_id, peering_id)
VALUES (
'c0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11',
'\xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef'
);
INSERT INTO tailnet_peering_events
(peering_id, event_type, src_peer_id, dst_peer_id, node, occurred_at)
VALUES (
'\xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef',
'added_tunnel',
'c0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11',
'b0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11',
'a fake protobuf byte string',
'2025-01-15 10:23:54+00'
);
@@ -1,11 +0,0 @@
INSERT INTO workspace_sessions
(id, workspace_id, agent_id, ip, started_at, ended_at, created_at)
VALUES (
'a1b2c3d4-e5f6-7890-abcd-ef1234567890',
'3a9a1feb-e89d-457c-9d53-ac751b198ebe',
'5f8e48e4-1304-45bd-b91a-ab12c8bfc20f',
'127.0.0.1',
'2025-01-01 10:00:00+00',
'2025-01-01 11:00:00+00',
'2025-01-01 11:00:00+00'
);
-6
View File
@@ -127,7 +127,6 @@ func (q *sqlQuerier) GetAuthorizedTemplates(ctx context.Context, arg GetTemplate
&i.MaxPortSharingLevel,
&i.UseClassicParameterFlow,
&i.CorsBehavior,
&i.DisableModuleCache,
&i.CreatedByAvatarURL,
&i.CreatedByUsername,
&i.CreatedByName,
@@ -689,11 +688,6 @@ func (q *sqlQuerier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg
&i.ConnectionLog.ConnectionID,
&i.ConnectionLog.DisconnectTime,
&i.ConnectionLog.DisconnectReason,
&i.ConnectionLog.AgentID,
&i.ConnectionLog.UpdatedAt,
&i.ConnectionLog.SessionID,
&i.ConnectionLog.ClientHostname,
&i.ConnectionLog.ShortDescription,
&i.UserUsername,
&i.UserName,
&i.UserEmail,
+3 -41
View File
@@ -1101,7 +1101,6 @@ const (
ConnectionTypeReconnectingPty ConnectionType = "reconnecting_pty"
ConnectionTypeWorkspaceApp ConnectionType = "workspace_app"
ConnectionTypePortForwarding ConnectionType = "port_forwarding"
ConnectionTypeSystem ConnectionType = "system"
)
func (e *ConnectionType) Scan(src interface{}) error {
@@ -1146,8 +1145,7 @@ func (e ConnectionType) Valid() bool {
ConnectionTypeJetbrains,
ConnectionTypeReconnectingPty,
ConnectionTypeWorkspaceApp,
ConnectionTypePortForwarding,
ConnectionTypeSystem:
ConnectionTypePortForwarding:
return true
}
return false
@@ -1161,7 +1159,6 @@ func AllConnectionTypeValues() []ConnectionType {
ConnectionTypeReconnectingPty,
ConnectionTypeWorkspaceApp,
ConnectionTypePortForwarding,
ConnectionTypeSystem,
}
}
@@ -3705,11 +3702,6 @@ type APIKey struct {
AllowList AllowList `db:"allow_list" json:"allow_list"`
}
type AgentPeeringID struct {
AgentID uuid.UUID `db:"agent_id" json:"agent_id"`
PeeringID []byte `db:"peering_id" json:"peering_id"`
}
type AuditLog struct {
ID uuid.UUID `db:"id" json:"id"`
Time time.Time `db:"time" json:"time"`
@@ -3770,12 +3762,6 @@ type ConnectionLog struct {
DisconnectTime sql.NullTime `db:"disconnect_time" json:"disconnect_time"`
// The reason the connection was closed. Null for web connections. For other connections, this is null until we receive a disconnect event for the same connection_id.
DisconnectReason sql.NullString `db:"disconnect_reason" json:"disconnect_reason"`
AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"`
// Last time this connection log was confirmed active. For agent connections, equals connect_time. For web connections, bumped while the session is active.
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
SessionID uuid.NullUUID `db:"session_id" json:"session_id"`
ClientHostname sql.NullString `db:"client_hostname" json:"client_hostname"`
ShortDescription sql.NullString `db:"short_description" json:"short_description"`
}
type CryptoKey struct {
@@ -4242,15 +4228,6 @@ type TailnetPeer struct {
Status TailnetStatus `db:"status" json:"status"`
}
type TailnetPeeringEvent struct {
PeeringID []byte `db:"peering_id" json:"peering_id"`
EventType string `db:"event_type" json:"event_type"`
SrcPeerID uuid.NullUUID `db:"src_peer_id" json:"src_peer_id"`
DstPeerID uuid.NullUUID `db:"dst_peer_id" json:"dst_peer_id"`
Node []byte `db:"node" json:"node"`
OccurredAt time.Time `db:"occurred_at" json:"occurred_at"`
}
type TailnetTunnel struct {
CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"`
SrcID uuid.UUID `db:"src_id" json:"src_id"`
@@ -4361,7 +4338,6 @@ type Template struct {
MaxPortSharingLevel AppSharingLevel `db:"max_port_sharing_level" json:"max_port_sharing_level"`
UseClassicParameterFlow bool `db:"use_classic_parameter_flow" json:"use_classic_parameter_flow"`
CorsBehavior CorsBehavior `db:"cors_behavior" json:"cors_behavior"`
DisableModuleCache bool `db:"disable_module_cache" json:"disable_module_cache"`
CreatedByAvatarURL string `db:"created_by_avatar_url" json:"created_by_avatar_url"`
CreatedByUsername string `db:"created_by_username" json:"created_by_username"`
CreatedByName string `db:"created_by_name" json:"created_by_name"`
@@ -4411,7 +4387,6 @@ type TemplateTable struct {
// Determines whether to default to the dynamic parameter creation flow for this template or continue using the legacy classic parameter creation flow.This is a template wide setting, the template admin can revert to the classic flow if there are any issues. An escape hatch is required, as workspace creation is a core workflow and cannot break. This column will be removed when the dynamic parameter creation flow is stable.
UseClassicParameterFlow bool `db:"use_classic_parameter_flow" json:"use_classic_parameter_flow"`
CorsBehavior CorsBehavior `db:"cors_behavior" json:"cors_behavior"`
DisableModuleCache bool `db:"disable_module_cache" json:"disable_module_cache"`
}
// Records aggregated usage statistics for templates/users. All usage is rounded up to the nearest minute.
@@ -4956,9 +4931,8 @@ type WorkspaceAppAuditSession struct {
// The time the user started the session.
StartedAt time.Time `db:"started_at" json:"started_at"`
// The time the session was last updated.
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
ID uuid.UUID `db:"id" json:"id"`
ConnectionID uuid.NullUUID `db:"connection_id" json:"connection_id"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
ID uuid.UUID `db:"id" json:"id"`
}
// A record of workspace app usage statistics
@@ -5133,18 +5107,6 @@ type WorkspaceResourceMetadatum struct {
ID int64 `db:"id" json:"id"`
}
type WorkspaceSession struct {
ID uuid.UUID `db:"id" json:"id"`
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"`
Ip pqtype.Inet `db:"ip" json:"ip"`
ClientHostname sql.NullString `db:"client_hostname" json:"client_hostname"`
ShortDescription sql.NullString `db:"short_description" json:"short_description"`
StartedAt time.Time `db:"started_at" json:"started_at"`
EndedAt time.Time `db:"ended_at" json:"ended_at"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
}
type WorkspaceTable struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
+14 -57
View File
@@ -68,43 +68,15 @@ type sqlcQuerier interface {
CleanTailnetCoordinators(ctx context.Context) error
CleanTailnetLostPeers(ctx context.Context) error
CleanTailnetTunnels(ctx context.Context) error
// Atomically closes open connections and creates sessions grouped by
// client_hostname (with IP fallback) and time overlap. Non-system
// connections drive session boundaries; system connections attach to
// the first overlapping session or get their own if orphaned.
//
// Processes connections that are still open (disconnect_time IS NULL) OR
// already disconnected but not yet assigned to a session (session_id IS
// NULL). The latter covers system/tunnel connections whose disconnect is
// recorded by dbsink but which have no session-assignment code path.
// Phase 1: Group non-system connections by hostname+time overlap.
// System connections persist for the entire workspace lifetime and
// would create mega-sessions if included in boundary computation.
// Check for pre-existing sessions that match by hostname (or IP
// fallback) and overlap in time, to avoid duplicates from the race
// with FindOrCreateSessionForDisconnect.
// Combine existing and newly created sessions.
// Phase 2: Assign system connections to the earliest overlapping
// primary session. First check sessions from this batch, then fall
// back to pre-existing workspace_sessions.
// Also match system connections to pre-existing sessions (created
// by FindOrCreateSessionForDisconnect) that aren't in this batch.
// Create sessions for orphaned system connections (no overlapping
// primary session) that have an IP.
// Combine all session sources for the final UPDATE.
CloseConnectionLogsAndCreateSessions(ctx context.Context, arg CloseConnectionLogsAndCreateSessionsParams) (int64, error)
CloseOpenAgentConnectionLogsForWorkspace(ctx context.Context, arg CloseOpenAgentConnectionLogsForWorkspaceParams) (int64, error)
CountAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams) (int64, error)
CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error)
CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error)
CountGlobalWorkspaceSessions(ctx context.Context, arg CountGlobalWorkspaceSessionsParams) (int64, error)
// CountInProgressPrebuilds returns the number of in-progress prebuilds, grouped by preset ID and transition.
// Prebuild considered in-progress if it's in the "pending", "starting", "stopping", or "deleting" state.
CountInProgressPrebuilds(ctx context.Context) ([]CountInProgressPrebuildsRow, error)
// CountPendingNonActivePrebuilds returns the number of pending prebuilds for non-active template versions
CountPendingNonActivePrebuilds(ctx context.Context) ([]CountPendingNonActivePrebuildsRow, error)
CountUnreadInboxNotificationsByUserID(ctx context.Context, userID uuid.UUID) (int64, error)
CountWorkspaceSessions(ctx context.Context, arg CountWorkspaceSessionsParams) (int64, error)
CreateUserSecret(ctx context.Context, arg CreateUserSecretParams) (UserSecret, error)
CustomRoles(ctx context.Context, arg CustomRolesParams) ([]CustomRole, error)
DeleteAPIKeyByID(ctx context.Context, id string) error
@@ -116,6 +88,8 @@ type sqlcQuerier interface {
// be recreated.
DeleteAllWebpushSubscriptions(ctx context.Context) error
DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error
// Deletes boundary usage statistics for a specific replica.
DeleteBoundaryUsageStatsByReplicaID(ctx context.Context, replicaID uuid.UUID) error
DeleteCryptoKey(ctx context.Context, arg DeleteCryptoKeyParams) (CryptoKey, error)
DeleteCustomRole(ctx context.Context, arg DeleteCustomRoleParams) error
DeleteExpiredAPIKeys(ctx context.Context, arg DeleteExpiredAPIKeysParams) (int64, error)
@@ -158,7 +132,7 @@ type sqlcQuerier interface {
DeleteRuntimeConfig(ctx context.Context, key string) error
DeleteTailnetPeer(ctx context.Context, arg DeleteTailnetPeerParams) (DeleteTailnetPeerRow, error)
DeleteTailnetTunnel(ctx context.Context, arg DeleteTailnetTunnelParams) (DeleteTailnetTunnelRow, error)
DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error)
DeleteTask(ctx context.Context, arg DeleteTaskParams) (TaskTable, error)
DeleteUserSecret(ctx context.Context, id uuid.UUID) error
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
@@ -189,11 +163,6 @@ type sqlcQuerier interface {
// The query finds presets where all preset parameters are present in the provided parameters,
// and returns the preset with the most parameters (largest subset).
FindMatchingPresetID(ctx context.Context, arg FindMatchingPresetIDParams) (uuid.UUID, error)
// Find existing session within time window, or create new one.
// Uses advisory lock to prevent duplicate sessions from concurrent disconnects.
// Groups by client_hostname (with IP fallback) to match the live session
// grouping in mergeWorkspaceConnectionsIntoSessions.
FindOrCreateSessionForDisconnect(ctx context.Context, arg FindOrCreateSessionForDisconnectParams) (interface{}, error)
GetAIBridgeInterceptionByID(ctx context.Context, id uuid.UUID) (AIBridgeInterception, error)
GetAIBridgeInterceptions(ctx context.Context) ([]AIBridgeInterception, error)
GetAIBridgeTokenUsagesByInterceptionID(ctx context.Context, interceptionID uuid.UUID) ([]AIBridgeTokenUsage, error)
@@ -210,14 +179,8 @@ type sqlcQuerier interface {
GetActiveWorkspaceBuildsByTemplateID(ctx context.Context, templateID uuid.UUID) ([]WorkspaceBuild, error)
// For PG Coordinator HTMLDebug
GetAllTailnetCoordinators(ctx context.Context) ([]TailnetCoordinator, error)
GetAllTailnetPeeringEventsByPeerID(ctx context.Context, srcPeerID uuid.NullUUID) ([]TailnetPeeringEvent, error)
GetAllTailnetPeers(ctx context.Context) ([]TailnetPeer, error)
GetAllTailnetTunnels(ctx context.Context) ([]TailnetTunnel, error)
// Atomic read+delete prevents replicas that flush between a separate read and
// reset from having their data deleted before the next snapshot. Uses a common
// table expression with DELETE...RETURNING so the rows we sum are exactly the
// rows we delete. Stale rows are excluded from the sum but still deleted.
GetAndResetBoundaryUsageSummary(ctx context.Context, maxStalenessMs int64) (GetAndResetBoundaryUsageSummaryRow, error)
GetAnnouncementBanners(ctx context.Context) (string, error)
GetAppSecurityKey(ctx context.Context) (string, error)
GetApplicationName(ctx context.Context) (string, error)
@@ -233,8 +196,10 @@ type sqlcQuerier interface {
// This function returns roles for authorization purposes. Implied member roles
// are included.
GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (GetAuthorizationUserRolesRow, error)
GetConnectionLogByConnectionID(ctx context.Context, arg GetConnectionLogByConnectionIDParams) (ConnectionLog, error)
GetConnectionLogsBySessionIDs(ctx context.Context, sessionIds []uuid.UUID) ([]ConnectionLog, error)
// Aggregates boundary usage statistics across all replicas. Filters to only
// include data where window_start is within the given interval to exclude
// stale data.
GetBoundaryUsageSummary(ctx context.Context, maxStalenessMs int64) (GetBoundaryUsageSummaryRow, error)
GetConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams) ([]GetConnectionLogsOffsetRow, error)
GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error)
GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error)
@@ -267,7 +232,6 @@ type sqlcQuerier interface {
// param limit_opt: The limit of notifications to fetch. If the limit is not specified, it defaults to 25
GetFilteredInboxNotificationsByUserID(ctx context.Context, arg GetFilteredInboxNotificationsByUserIDParams) ([]InboxNotification, error)
GetGitSSHKey(ctx context.Context, userID uuid.UUID) (GitSSHKey, error)
GetGlobalWorkspaceSessionsOffset(ctx context.Context, arg GetGlobalWorkspaceSessionsOffsetParams) ([]GetGlobalWorkspaceSessionsOffsetRow, error)
GetGroupByID(ctx context.Context, id uuid.UUID) (Group, error)
GetGroupByOrgAndName(ctx context.Context, arg GetGroupByOrgAndNameParams) (Group, error)
GetGroupMembers(ctx context.Context, includeSystem bool) ([]GroupMember, error)
@@ -315,7 +279,6 @@ type sqlcQuerier interface {
GetOAuth2ProviderApps(ctx context.Context) ([]OAuth2ProviderApp, error)
GetOAuth2ProviderAppsByUserID(ctx context.Context, userID uuid.UUID) ([]GetOAuth2ProviderAppsByUserIDRow, error)
GetOAuthSigningKey(ctx context.Context) (string, error)
GetOngoingAgentConnectionsLast24h(ctx context.Context, arg GetOngoingAgentConnectionsLast24hParams) ([]GetOngoingAgentConnectionsLast24hRow, error)
GetOrganizationByID(ctx context.Context, id uuid.UUID) (Organization, error)
GetOrganizationByName(ctx context.Context, arg GetOrganizationByNameParams) (Organization, error)
GetOrganizationIDsByMemberIDs(ctx context.Context, ids []uuid.UUID) ([]GetOrganizationIDsByMemberIDsRow, error)
@@ -392,7 +355,6 @@ type sqlcQuerier interface {
GetRuntimeConfig(ctx context.Context, key string) (string, error)
GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]TailnetPeer, error)
GetTailnetTunnelPeerBindings(ctx context.Context, srcID uuid.UUID) ([]GetTailnetTunnelPeerBindingsRow, error)
GetTailnetTunnelPeerBindingsByDstID(ctx context.Context, dstID uuid.UUID) ([]GetTailnetTunnelPeerBindingsByDstIDRow, error)
GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) ([]GetTailnetTunnelPeerIDsRow, error)
GetTaskByID(ctx context.Context, id uuid.UUID) (Task, error)
GetTaskByOwnerIDAndName(ctx context.Context, arg GetTaskByOwnerIDAndNameParams) (Task, error)
@@ -540,9 +502,6 @@ type sqlcQuerier interface {
GetWorkspaceBuildByID(ctx context.Context, id uuid.UUID) (WorkspaceBuild, error)
GetWorkspaceBuildByJobID(ctx context.Context, jobID uuid.UUID) (WorkspaceBuild, error)
GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx context.Context, arg GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams) (WorkspaceBuild, error)
// Returns build metadata for e2e workspace build duration metrics.
// Also checks if all agents are ready and returns the worst status.
GetWorkspaceBuildMetricsByResourceID(ctx context.Context, id uuid.UUID) (GetWorkspaceBuildMetricsByResourceIDRow, error)
GetWorkspaceBuildParameters(ctx context.Context, workspaceBuildID uuid.UUID) ([]WorkspaceBuildParameter, error)
GetWorkspaceBuildParametersByBuildIDs(ctx context.Context, workspaceBuildIds []uuid.UUID) ([]WorkspaceBuildParameter, error)
GetWorkspaceBuildStatsByTemplates(ctx context.Context, since time.Time) ([]GetWorkspaceBuildStatsByTemplatesRow, error)
@@ -572,7 +531,6 @@ type sqlcQuerier interface {
GetWorkspaceResourcesByJobID(ctx context.Context, jobID uuid.UUID) ([]WorkspaceResource, error)
GetWorkspaceResourcesByJobIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceResource, error)
GetWorkspaceResourcesCreatedAfter(ctx context.Context, createdAt time.Time) ([]WorkspaceResource, error)
GetWorkspaceSessionsOffset(ctx context.Context, arg GetWorkspaceSessionsOffsetParams) ([]GetWorkspaceSessionsOffsetRow, error)
GetWorkspaceUniqueOwnerCountByTemplateIDs(ctx context.Context, templateIds []uuid.UUID) ([]GetWorkspaceUniqueOwnerCountByTemplateIDsRow, error)
// build_params is used to filter by build parameters if present.
// It has to be a CTE because the set returning function 'unnest' cannot
@@ -624,7 +582,6 @@ type sqlcQuerier interface {
InsertProvisionerJobTimings(ctx context.Context, arg InsertProvisionerJobTimingsParams) ([]ProvisionerJobTiming, error)
InsertProvisionerKey(ctx context.Context, arg InsertProvisionerKeyParams) (ProvisionerKey, error)
InsertReplica(ctx context.Context, arg InsertReplicaParams) (Replica, error)
InsertTailnetPeeringEvent(ctx context.Context, arg InsertTailnetPeeringEventParams) error
InsertTask(ctx context.Context, arg InsertTaskParams) (TaskTable, error)
InsertTelemetryItemIfNotExists(ctx context.Context, arg InsertTelemetryItemIfNotExistsParams) error
// Inserts a new lock row into the telemetry_locks table. Replicas should call
@@ -695,6 +652,9 @@ type sqlcQuerier interface {
RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error)
RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error
RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error)
// Deletes all boundary usage statistics. Called after telemetry reports the
// aggregated stats. Each replica will insert a fresh row on its next flush.
ResetBoundaryUsageStats(ctx context.Context) error
RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error
// Note that this selects from the CTE, not the original table. The CTE is named
// the same as the original table to trick sqlc into reusing the existing struct
@@ -711,8 +671,6 @@ type sqlcQuerier interface {
UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error
UpdateAIBridgeInterceptionEnded(ctx context.Context, arg UpdateAIBridgeInterceptionEndedParams) (AIBridgeInterception, error)
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
// Links a connection log row to its workspace session.
UpdateConnectionLogSessionID(ctx context.Context, arg UpdateConnectionLogSessionIDParams) error
UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error)
UpdateCustomRole(ctx context.Context, arg UpdateCustomRoleParams) (CustomRole, error)
UpdateExternalAuthLink(ctx context.Context, arg UpdateExternalAuthLinkParams) (ExternalAuthLink, error)
@@ -842,11 +800,10 @@ type sqlcQuerier interface {
UpsertWorkspaceAgentPortShare(ctx context.Context, arg UpsertWorkspaceAgentPortShareParams) (WorkspaceAgentPortShare, error)
UpsertWorkspaceApp(ctx context.Context, arg UpsertWorkspaceAppParams) (WorkspaceApp, error)
//
// The returned columns, new_or_stale and connection_id, can be used to deduce
// if a new session was started and which connection_id to use. new_or_stale is
// true when a new row was inserted (no previous session) or the updated_at is
// older than the stale interval.
UpsertWorkspaceAppAuditSession(ctx context.Context, arg UpsertWorkspaceAppAuditSessionParams) (UpsertWorkspaceAppAuditSessionRow, error)
// The returned boolean, new_or_stale, can be used to deduce if a new session
// was started. This means that a new row was inserted (no previous session) or
// the updated_at is older than stale interval.
UpsertWorkspaceAppAuditSession(ctx context.Context, arg UpsertWorkspaceAppAuditSessionParams) (bool, error)
ValidateGroupIDs(ctx context.Context, groupIds []uuid.UUID) (ValidateGroupIDsRow, error)
ValidateUserIDs(ctx context.Context, userIds []uuid.UUID) (ValidateUserIDsRow, error)
}
+2 -8
View File
@@ -3081,7 +3081,7 @@ func TestConnectionLogsOffsetFilters(t *testing.T) {
params: database.GetConnectionLogsOffsetParams{
Status: string(codersdk.ConnectionLogStatusOngoing),
},
expectedLogIDs: []uuid.UUID{log1.ID, log4.ID},
expectedLogIDs: []uuid.UUID{log4.ID},
},
{
name: "StatusCompleted",
@@ -3308,16 +3308,12 @@ func TestUpsertConnectionLog(t *testing.T) {
origLog, err := db.UpsertConnectionLog(ctx, connectParams2)
require.NoError(t, err)
// updated_at is always bumped on conflict to track activity.
require.True(t, connectTime2.Equal(origLog.UpdatedAt), "expected updated_at %s, got %s", connectTime2, origLog.UpdatedAt)
origLog.UpdatedAt = log.UpdatedAt
require.Equal(t, log, origLog, "connect update should be a no-op except updated_at")
require.Equal(t, log, origLog, "connect update should be a no-op")
// Check that still only one row exists.
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{})
require.NoError(t, err)
require.Len(t, rows, 1)
rows[0].ConnectionLog.UpdatedAt = log.UpdatedAt
require.Equal(t, log, rows[0].ConnectionLog)
})
@@ -3399,8 +3395,6 @@ func TestUpsertConnectionLog(t *testing.T) {
secondRows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{})
require.NoError(t, err)
require.Len(t, secondRows, 1)
// updated_at is always bumped on conflict to track activity.
secondRows[0].ConnectionLog.UpdatedAt = firstRows[0].ConnectionLog.UpdatedAt
require.Equal(t, firstRows, secondRows)
// Upsert a disconnection, which should also be a no op
File diff suppressed because it is too large Load Diff
+19 -22
View File
@@ -27,26 +27,23 @@ INSERT INTO boundary_usage_stats (
updated_at = NOW()
RETURNING (xmax = 0) AS new_period;
-- name: GetAndResetBoundaryUsageSummary :one
-- Atomic read+delete prevents replicas that flush between a separate read and
-- reset from having their data deleted before the next snapshot. Uses a common
-- table expression with DELETE...RETURNING so the rows we sum are exactly the
-- rows we delete. Stale rows are excluded from the sum but still deleted.
WITH deleted AS (
DELETE FROM boundary_usage_stats
RETURNING *
)
-- name: GetBoundaryUsageSummary :one
-- Aggregates boundary usage statistics across all replicas. Filters to only
-- include data where window_start is within the given interval to exclude
-- stale data.
SELECT
COALESCE(SUM(unique_workspaces_count) FILTER (
WHERE window_start >= NOW() - (@max_staleness_ms::bigint || ' ms')::interval
), 0)::bigint AS unique_workspaces,
COALESCE(SUM(unique_users_count) FILTER (
WHERE window_start >= NOW() - (@max_staleness_ms::bigint || ' ms')::interval
), 0)::bigint AS unique_users,
COALESCE(SUM(allowed_requests) FILTER (
WHERE window_start >= NOW() - (@max_staleness_ms::bigint || ' ms')::interval
), 0)::bigint AS allowed_requests,
COALESCE(SUM(denied_requests) FILTER (
WHERE window_start >= NOW() - (@max_staleness_ms::bigint || ' ms')::interval
), 0)::bigint AS denied_requests
FROM deleted;
COALESCE(SUM(unique_workspaces_count), 0)::bigint AS unique_workspaces,
COALESCE(SUM(unique_users_count), 0)::bigint AS unique_users,
COALESCE(SUM(allowed_requests), 0)::bigint AS allowed_requests,
COALESCE(SUM(denied_requests), 0)::bigint AS denied_requests
FROM boundary_usage_stats
WHERE window_start >= NOW() - (@max_staleness_ms::bigint || ' ms')::interval;
-- name: ResetBoundaryUsageStats :exec
-- Deletes all boundary usage statistics. Called after telemetry reports the
-- aggregated stats. Each replica will insert a fresh row on its next flush.
DELETE FROM boundary_usage_stats;
-- name: DeleteBoundaryUsageStatsByReplicaID :exec
-- Deletes boundary usage statistics for a specific replica.
DELETE FROM boundary_usage_stats WHERE replica_id = @replica_id;
+10 -282
View File
@@ -114,7 +114,9 @@ WHERE
AND CASE
WHEN @status :: text != '' THEN
((@status = 'ongoing' AND disconnect_time IS NULL) OR
(@status = 'completed' AND disconnect_time IS NOT NULL))
(@status = 'completed' AND disconnect_time IS NOT NULL)) AND
-- Exclude web events, since we don't know their close time.
"type" NOT IN ('workspace_app', 'port_forwarding')
ELSE true
END
-- Authorize Filter clause will be injected below in
@@ -227,7 +229,9 @@ WHERE
AND CASE
WHEN @status :: text != '' THEN
((@status = 'ongoing' AND disconnect_time IS NULL) OR
(@status = 'completed' AND disconnect_time IS NOT NULL))
(@status = 'completed' AND disconnect_time IS NOT NULL)) AND
-- Exclude web events, since we don't know their close time.
"type" NOT IN ('workspace_app', 'port_forwarding')
ELSE true
END
-- Authorize Filter clause will be injected below in
@@ -256,7 +260,6 @@ INSERT INTO connection_logs (
workspace_id,
workspace_name,
agent_name,
agent_id,
type,
code,
ip,
@@ -265,24 +268,18 @@ INSERT INTO connection_logs (
slug_or_port,
connection_id,
disconnect_reason,
disconnect_time,
updated_at,
session_id,
client_hostname,
short_description
disconnect_time
) VALUES
($1, @time, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15,
($1, @time, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14,
-- If we've only received a disconnect event, mark the event as immediately
-- closed.
CASE
WHEN @connection_status::connection_status = 'disconnected'
THEN @time :: timestamp with time zone
ELSE NULL
END,
@time, $16, $17, $18)
END)
ON CONFLICT (connection_id, workspace_id, agent_name)
DO UPDATE SET
updated_at = @time,
-- No-op if the connection is still open.
disconnect_time = CASE
WHEN @connection_status::connection_status = 'disconnected'
@@ -304,274 +301,5 @@ DO UPDATE SET
AND connection_logs.code IS NULL
THEN EXCLUDED.code
ELSE connection_logs.code
END,
agent_id = COALESCE(connection_logs.agent_id, EXCLUDED.agent_id)
END
RETURNING *;
-- name: CloseOpenAgentConnectionLogsForWorkspace :execrows
UPDATE connection_logs
SET
disconnect_time = GREATEST(connect_time, @closed_at :: timestamp with time zone),
-- Do not overwrite any existing reason.
disconnect_reason = COALESCE(disconnect_reason, @reason :: text)
WHERE disconnect_time IS NULL
AND workspace_id = @workspace_id :: uuid
AND type = ANY(@types :: connection_type[]);
-- name: GetOngoingAgentConnectionsLast24h :many
WITH ranked AS (
SELECT
id,
connect_time,
organization_id,
workspace_owner_id,
workspace_id,
workspace_name,
agent_name,
type,
ip,
code,
user_agent,
user_id,
slug_or_port,
connection_id,
disconnect_time,
disconnect_reason,
agent_id,
updated_at,
session_id,
client_hostname,
short_description,
row_number() OVER (
PARTITION BY workspace_id, agent_name
ORDER BY connect_time DESC
) AS rn
FROM
connection_logs
WHERE
workspace_id = ANY(@workspace_ids :: uuid[])
AND agent_name = ANY(@agent_names :: text[])
AND type = ANY(@types :: connection_type[])
AND disconnect_time IS NULL
AND (
-- Non-web types always included while connected.
type NOT IN ('workspace_app', 'port_forwarding')
-- Agent-reported web connections have NULL user_agent
-- and carry proper disconnect lifecycle tracking.
OR user_agent IS NULL
-- Proxy-reported web connections (non-NULL user_agent)
-- rely on updated_at being bumped on each token refresh.
OR updated_at >= @app_active_since :: timestamp with time zone
)
AND connect_time >= @since :: timestamp with time zone
)
SELECT
id,
connect_time,
organization_id,
workspace_owner_id,
workspace_id,
workspace_name,
agent_name,
type,
ip,
code,
user_agent,
user_id,
slug_or_port,
connection_id,
disconnect_time,
disconnect_reason,
updated_at,
session_id,
client_hostname,
short_description
FROM
ranked
WHERE
rn <= @per_agent_limit
ORDER BY
workspace_id,
agent_name,
connect_time DESC;
-- name: UpdateConnectionLogSessionID :exec
-- Links a connection log row to its workspace session.
UPDATE connection_logs SET session_id = @session_id WHERE id = @id;
-- name: CloseConnectionLogsAndCreateSessions :execrows
-- Atomically closes open connections and creates sessions grouped by
-- client_hostname (with IP fallback) and time overlap. Non-system
-- connections drive session boundaries; system connections attach to
-- the first overlapping session or get their own if orphaned.
--
-- Processes connections that are still open (disconnect_time IS NULL) OR
-- already disconnected but not yet assigned to a session (session_id IS
-- NULL). The latter covers system/tunnel connections whose disconnect is
-- recorded by dbsink but which have no session-assignment code path.
WITH connections_to_close AS (
SELECT id, ip, connect_time, disconnect_time, agent_id,
client_hostname, short_description, type
FROM connection_logs
WHERE (disconnect_time IS NULL OR session_id IS NULL)
AND workspace_id = @workspace_id
AND type = ANY(@types::connection_type[])
),
-- Phase 1: Group non-system connections by hostname+time overlap.
-- System connections persist for the entire workspace lifetime and
-- would create mega-sessions if included in boundary computation.
primary_connections AS (
SELECT *,
COALESCE(client_hostname, host(ip), 'unknown') AS group_key
FROM connections_to_close
WHERE type != 'system'
),
ordered AS (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY group_key ORDER BY connect_time) AS rn,
MAX(COALESCE(disconnect_time, @closed_at::timestamptz))
OVER (PARTITION BY group_key ORDER BY connect_time
ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING) AS running_max_end
FROM primary_connections
),
with_boundaries AS (
SELECT *,
SUM(CASE
WHEN rn = 1 THEN 1
WHEN connect_time > running_max_end + INTERVAL '30 minutes' THEN 1
ELSE 0
END) OVER (PARTITION BY group_key ORDER BY connect_time) AS group_id
FROM ordered
),
session_groups AS (
SELECT
group_key,
group_id,
MIN(connect_time) AS started_at,
MAX(COALESCE(disconnect_time, @closed_at::timestamptz)) AS ended_at,
(array_agg(agent_id ORDER BY connect_time) FILTER (WHERE agent_id IS NOT NULL))[1] AS agent_id,
(array_agg(ip ORDER BY connect_time) FILTER (WHERE ip IS NOT NULL))[1] AS ip,
(array_agg(client_hostname ORDER BY connect_time) FILTER (WHERE client_hostname IS NOT NULL))[1] AS client_hostname,
(array_agg(short_description ORDER BY connect_time) FILTER (WHERE short_description IS NOT NULL))[1] AS short_description
FROM with_boundaries
GROUP BY group_key, group_id
),
-- Check for pre-existing sessions that match by hostname (or IP
-- fallback) and overlap in time, to avoid duplicates from the race
-- with FindOrCreateSessionForDisconnect.
existing_sessions AS (
SELECT DISTINCT ON (sg.group_key, sg.group_id)
sg.group_key, sg.group_id, ws.id AS session_id
FROM session_groups sg
JOIN workspace_sessions ws
ON ws.workspace_id = @workspace_id
AND (
(sg.client_hostname IS NOT NULL AND ws.client_hostname = sg.client_hostname)
OR (sg.client_hostname IS NULL AND sg.ip IS NOT NULL AND ws.ip = sg.ip AND ws.client_hostname IS NULL)
)
AND sg.started_at <= ws.ended_at + INTERVAL '30 minutes'
AND sg.ended_at >= ws.started_at - INTERVAL '30 minutes'
ORDER BY sg.group_key, sg.group_id, ws.started_at DESC
),
new_sessions AS (
INSERT INTO workspace_sessions (workspace_id, agent_id, ip, client_hostname, short_description, started_at, ended_at)
SELECT @workspace_id, sg.agent_id, sg.ip, sg.client_hostname, sg.short_description, sg.started_at, sg.ended_at
FROM session_groups sg
WHERE NOT EXISTS (
SELECT 1 FROM existing_sessions es
WHERE es.group_key = sg.group_key AND es.group_id = sg.group_id
)
RETURNING id, ip, started_at
),
-- Combine existing and newly created sessions.
all_sessions AS (
SELECT ns.id, sg.group_key, sg.started_at
FROM new_sessions ns
JOIN session_groups sg
ON sg.started_at = ns.started_at
AND (sg.ip IS NOT DISTINCT FROM ns.ip)
UNION ALL
SELECT es.session_id AS id, es.group_key, sg.started_at
FROM existing_sessions es
JOIN session_groups sg ON es.group_key = sg.group_key AND es.group_id = sg.group_id
),
-- Phase 2: Assign system connections to the earliest overlapping
-- primary session. First check sessions from this batch, then fall
-- back to pre-existing workspace_sessions.
system_batch_match AS (
SELECT DISTINCT ON (c.id)
c.id AS connection_id,
alls.id AS session_id,
sg.started_at AS session_start
FROM connections_to_close c
JOIN all_sessions alls ON true
JOIN session_groups sg ON alls.group_key = sg.group_key AND alls.started_at = sg.started_at
WHERE c.type = 'system'
AND COALESCE(c.disconnect_time, @closed_at::timestamptz) >= sg.started_at
AND c.connect_time <= sg.ended_at
ORDER BY c.id, sg.started_at
),
-- Also match system connections to pre-existing sessions (created
-- by FindOrCreateSessionForDisconnect) that aren't in this batch.
system_existing_match AS (
SELECT DISTINCT ON (c.id)
c.id AS connection_id,
ws.id AS session_id
FROM connections_to_close c
JOIN workspace_sessions ws
ON ws.workspace_id = @workspace_id
AND COALESCE(c.disconnect_time, @closed_at::timestamptz) >= ws.started_at
AND c.connect_time <= ws.ended_at
WHERE c.type = 'system'
AND NOT EXISTS (SELECT 1 FROM system_batch_match sbm WHERE sbm.connection_id = c.id)
ORDER BY c.id, ws.started_at
),
system_session_match AS (
SELECT connection_id, session_id FROM system_batch_match
UNION ALL
SELECT connection_id, session_id FROM system_existing_match
),
-- Create sessions for orphaned system connections (no overlapping
-- primary session) that have an IP.
orphan_system AS (
SELECT c.*
FROM connections_to_close c
LEFT JOIN system_session_match ssm ON ssm.connection_id = c.id
WHERE c.type = 'system'
AND ssm.connection_id IS NULL
AND c.ip IS NOT NULL
),
orphan_system_sessions AS (
INSERT INTO workspace_sessions (workspace_id, agent_id, ip, client_hostname, short_description, started_at, ended_at)
SELECT @workspace_id, os.agent_id, os.ip, os.client_hostname, os.short_description,
os.connect_time, COALESCE(os.disconnect_time, @closed_at::timestamptz)
FROM orphan_system os
RETURNING id, ip, started_at
),
-- Combine all session sources for the final UPDATE.
final_sessions AS (
-- Primary sessions matched to non-system connections.
SELECT wb.id AS connection_id, alls.id AS session_id
FROM with_boundaries wb
JOIN session_groups sg ON wb.group_key = sg.group_key AND wb.group_id = sg.group_id
JOIN all_sessions alls ON sg.group_key = alls.group_key AND sg.started_at = alls.started_at
UNION ALL
-- System connections matched to primary sessions.
SELECT ssm.connection_id, ssm.session_id
FROM system_session_match ssm
UNION ALL
-- Orphaned system connections with their own sessions.
SELECT os.id, oss.id
FROM orphan_system os
JOIN orphan_system_sessions oss ON os.ip = oss.ip AND os.connect_time = oss.started_at
)
UPDATE connection_logs cl
SET
disconnect_time = COALESCE(cl.disconnect_time, @closed_at),
disconnect_reason = COALESCE(cl.disconnect_reason, @reason),
session_id = COALESCE(cl.session_id, fs.session_id)
FROM connections_to_close ctc
LEFT JOIN final_sessions fs ON ctc.id = fs.connection_id
WHERE cl.id = ctc.id;
-24
View File
@@ -126,29 +126,5 @@ SELECT * FROM tailnet_coordinators;
-- name: GetAllTailnetPeers :many
SELECT * FROM tailnet_peers;
-- name: GetTailnetTunnelPeerBindingsByDstID :many
SELECT tp.id AS peer_id, tp.coordinator_id, tp.updated_at, tp.node, tp.status
FROM tailnet_peers tp
INNER JOIN tailnet_tunnels tt ON tp.id = tt.src_id
WHERE tt.dst_id = @dst_id;
-- name: GetAllTailnetTunnels :many
SELECT * FROM tailnet_tunnels;
-- name: InsertTailnetPeeringEvent :exec
INSERT INTO tailnet_peering_events (
peering_id,
event_type,
src_peer_id,
dst_peer_id,
node,
occurred_at
)
VALUES
($1, $2, $3, $4, $5, $6);
-- name: GetAllTailnetPeeringEventsByPeerID :many
SELECT *
FROM tailnet_peering_events
WHERE src_peer_id = $1 OR dst_peer_id = $1
ORDER BY peering_id, occurred_at;
+7 -13
View File
@@ -57,19 +57,13 @@ AND CASE WHEN @status::text != '' THEN tws.status = @status::task_status ELSE TR
ORDER BY tws.created_at DESC;
-- name: DeleteTask :one
WITH deleted_task AS (
UPDATE tasks
SET
deleted_at = @deleted_at::timestamptz
WHERE
id = @id::uuid
AND deleted_at IS NULL
RETURNING id
), deleted_snapshot AS (
DELETE FROM task_snapshots
WHERE task_id = @id::uuid
)
SELECT id FROM deleted_task;
UPDATE tasks
SET
deleted_at = @deleted_at::timestamptz
WHERE
id = @id::uuid
AND deleted_at IS NULL
RETURNING *;
-- name: UpdateTaskPrompt :one
+1 -2
View File
@@ -173,8 +173,7 @@ SET
group_acl = $8,
max_port_sharing_level = $9,
use_classic_parameter_flow = $10,
cors_behavior = $11,
disable_module_cache = $12
cors_behavior = $11
WHERE
id = $1
;
+6 -15
View File
@@ -1,9 +1,8 @@
-- name: UpsertWorkspaceAppAuditSession :one
--
-- The returned columns, new_or_stale and connection_id, can be used to deduce
-- if a new session was started and which connection_id to use. new_or_stale is
-- true when a new row was inserted (no previous session) or the updated_at is
-- older than the stale interval.
-- The returned boolean, new_or_stale, can be used to deduce if a new session
-- was started. This means that a new row was inserted (no previous session) or
-- the updated_at is older than stale interval.
INSERT INTO
workspace_app_audit_sessions (
id,
@@ -15,8 +14,7 @@ INSERT INTO
slug_or_port,
status_code,
started_at,
updated_at,
connection_id
updated_at
)
VALUES
(
@@ -29,8 +27,7 @@ VALUES
$7,
$8,
$9,
$10,
$11
$10
)
ON CONFLICT
(agent_id, app_id, user_id, ip, user_agent, slug_or_port, status_code)
@@ -48,12 +45,6 @@ DO
THEN workspace_app_audit_sessions.started_at
ELSE EXCLUDED.started_at
END,
connection_id = CASE
WHEN workspace_app_audit_sessions.updated_at > NOW() - (@stale_interval_ms::bigint || ' ms')::interval
THEN workspace_app_audit_sessions.connection_id
ELSE EXCLUDED.connection_id
END,
updated_at = EXCLUDED.updated_at
RETURNING
id = $1 AS new_or_stale,
connection_id;
id = $1 AS new_or_stale;
@@ -243,31 +243,3 @@ SET
has_external_agent = @has_external_agent,
updated_at = @updated_at::timestamptz
WHERE id = @id::uuid;
-- name: GetWorkspaceBuildMetricsByResourceID :one
-- Returns build metadata for e2e workspace build duration metrics.
-- Also checks if all agents are ready and returns the worst status.
SELECT
wb.created_at,
wb.transition,
t.name AS template_name,
o.name AS organization_name,
(w.owner_id = 'c42fdf75-3097-471c-8c33-fb52454d81c0') AS is_prebuild,
-- All agents must have ready_at set (terminal startup state)
COUNT(*) FILTER (WHERE wa.ready_at IS NULL) = 0 AS all_agents_ready,
-- Latest ready_at across all agents (for duration calculation)
MAX(wa.ready_at)::timestamptz AS last_agent_ready_at,
-- Worst status: error > timeout > ready
CASE
WHEN bool_or(wa.lifecycle_state = 'start_error') THEN 'error'
WHEN bool_or(wa.lifecycle_state = 'start_timeout') THEN 'timeout'
ELSE 'success'
END AS worst_status
FROM workspace_builds wb
JOIN workspaces w ON wb.workspace_id = w.id
JOIN templates t ON w.template_id = t.id
JOIN organizations o ON t.organization_id = o.id
JOIN workspace_resources wr ON wr.job_id = wb.job_id
JOIN workspace_agents wa ON wa.resource_id = wr.id
WHERE wb.job_id = (SELECT job_id FROM workspace_resources WHERE workspace_resources.id = $1)
GROUP BY wb.created_at, wb.transition, t.name, o.name, w.owner_id;
+7 -7
View File
@@ -399,13 +399,13 @@ WHERE
filtered_workspaces fw
ORDER BY
-- To ensure that 'favorite' workspaces show up first in the list only for their owner.
CASE WHEN fw.owner_id = @requester_id AND fw.favorite THEN 0 ELSE 1 END ASC,
(fw.latest_build_completed_at IS NOT NULL AND
fw.latest_build_canceled_at IS NULL AND
fw.latest_build_error IS NULL AND
fw.latest_build_transition = 'start'::workspace_transition) DESC,
LOWER(fw.owner_username) ASC,
LOWER(fw.name) ASC
CASE WHEN owner_id = @requester_id AND favorite THEN 0 ELSE 1 END ASC,
(latest_build_completed_at IS NOT NULL AND
latest_build_canceled_at IS NULL AND
latest_build_error IS NULL AND
latest_build_transition = 'start'::workspace_transition) DESC,
LOWER(owner_username) ASC,
LOWER(name) ASC
LIMIT
CASE
WHEN @limit_ :: integer > 0 THEN
@@ -1,110 +0,0 @@
-- name: FindOrCreateSessionForDisconnect :one
-- Find existing session within time window, or create new one.
-- Uses advisory lock to prevent duplicate sessions from concurrent disconnects.
-- Groups by client_hostname (with IP fallback) to match the live session
-- grouping in mergeWorkspaceConnectionsIntoSessions.
WITH lock AS (
SELECT pg_advisory_xact_lock(
hashtext(@workspace_id::text || COALESCE(@client_hostname, host(@ip::inet), 'unknown'))
)
),
existing AS (
SELECT id FROM workspace_sessions
WHERE workspace_id = @workspace_id::uuid
AND (
(@client_hostname IS NOT NULL AND client_hostname = @client_hostname)
OR
(@client_hostname IS NULL AND client_hostname IS NULL AND ip = @ip::inet)
)
AND @connect_time BETWEEN started_at - INTERVAL '30 minutes' AND ended_at + INTERVAL '30 minutes'
ORDER BY started_at DESC
LIMIT 1
),
new_session AS (
INSERT INTO workspace_sessions (workspace_id, agent_id, ip, client_hostname, short_description, started_at, ended_at)
SELECT @workspace_id::uuid, @agent_id, @ip::inet, @client_hostname, @short_description, @connect_time, @disconnect_time
WHERE NOT EXISTS (SELECT 1 FROM existing)
RETURNING id
),
updated_session AS (
UPDATE workspace_sessions
SET started_at = LEAST(started_at, @connect_time),
ended_at = GREATEST(ended_at, @disconnect_time)
WHERE id = (SELECT id FROM existing)
RETURNING id
)
SELECT COALESCE(
(SELECT id FROM updated_session),
(SELECT id FROM new_session)
) AS id;
-- name: GetWorkspaceSessionsOffset :many
SELECT
ws.*,
(SELECT COUNT(*) FROM connection_logs cl WHERE cl.session_id = ws.id) AS connection_count
FROM workspace_sessions ws
WHERE ws.workspace_id = @workspace_id
AND CASE WHEN @started_after::timestamptz != '0001-01-01 00:00:00Z'::timestamptz
THEN ws.started_at >= @started_after ELSE true END
AND CASE WHEN @started_before::timestamptz != '0001-01-01 00:00:00Z'::timestamptz
THEN ws.started_at <= @started_before ELSE true END
ORDER BY ws.started_at DESC
LIMIT @limit_count
OFFSET @offset_count;
-- name: CountWorkspaceSessions :one
SELECT COUNT(*) FROM workspace_sessions ws
WHERE ws.workspace_id = @workspace_id
AND CASE WHEN @started_after::timestamptz != '0001-01-01 00:00:00Z'::timestamptz
THEN ws.started_at >= @started_after ELSE true END
AND CASE WHEN @started_before::timestamptz != '0001-01-01 00:00:00Z'::timestamptz
THEN ws.started_at <= @started_before ELSE true END;
-- name: GetConnectionLogsBySessionIDs :many
SELECT * FROM connection_logs
WHERE session_id = ANY(@session_ids::uuid[])
ORDER BY session_id, connect_time DESC;
-- name: GetConnectionLogByConnectionID :one
SELECT * FROM connection_logs
WHERE connection_id = @connection_id
AND workspace_id = @workspace_id
AND agent_name = @agent_name
LIMIT 1;
-- name: GetGlobalWorkspaceSessionsOffset :many
SELECT
ws.*,
w.name AS workspace_name,
workspace_owner.username AS workspace_owner_username,
(SELECT COUNT(*) FROM connection_logs cl WHERE cl.session_id = ws.id) AS connection_count
FROM workspace_sessions ws
JOIN workspaces w ON w.id = ws.workspace_id
JOIN users workspace_owner ON workspace_owner.id = w.owner_id
WHERE
CASE WHEN @workspace_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid
THEN ws.workspace_id = @workspace_id ELSE true END
AND CASE WHEN @workspace_owner::text != ''
THEN workspace_owner.username = @workspace_owner ELSE true END
AND CASE WHEN @started_after::timestamptz != '0001-01-01 00:00:00Z'::timestamptz
THEN ws.started_at >= @started_after ELSE true END
AND CASE WHEN @started_before::timestamptz != '0001-01-01 00:00:00Z'::timestamptz
THEN ws.started_at <= @started_before ELSE true END
ORDER BY ws.started_at DESC
LIMIT @limit_count
OFFSET @offset_count;
-- name: CountGlobalWorkspaceSessions :one
SELECT COUNT(*) FROM workspace_sessions ws
JOIN workspaces w ON w.id = ws.workspace_id
JOIN users workspace_owner ON workspace_owner.id = w.owner_id
WHERE
CASE WHEN @workspace_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid
THEN ws.workspace_id = @workspace_id ELSE true END
AND CASE WHEN @workspace_owner::text != ''
THEN workspace_owner.username = @workspace_owner ELSE true END
AND CASE WHEN @started_after::timestamptz != '0001-01-01 00:00:00Z'::timestamptz
THEN ws.started_at >= @started_after ELSE true END
AND CASE WHEN @started_before::timestamptz != '0001-01-01 00:00:00Z'::timestamptz
THEN ws.started_at <= @started_before ELSE true END;
-2
View File
@@ -6,7 +6,6 @@ type UniqueConstraint string
// UniqueConstraint enums.
const (
UniqueAgentPeeringIDsPkey UniqueConstraint = "agent_peering_ids_pkey" // ALTER TABLE ONLY agent_peering_ids ADD CONSTRAINT agent_peering_ids_pkey PRIMARY KEY (agent_id, peering_id);
UniqueAgentStatsPkey UniqueConstraint = "agent_stats_pkey" // ALTER TABLE ONLY workspace_agent_stats ADD CONSTRAINT agent_stats_pkey PRIMARY KEY (id);
UniqueAibridgeInterceptionsPkey UniqueConstraint = "aibridge_interceptions_pkey" // ALTER TABLE ONLY aibridge_interceptions ADD CONSTRAINT aibridge_interceptions_pkey PRIMARY KEY (id);
UniqueAibridgeTokenUsagesPkey UniqueConstraint = "aibridge_token_usages_pkey" // ALTER TABLE ONLY aibridge_token_usages ADD CONSTRAINT aibridge_token_usages_pkey PRIMARY KEY (id);
@@ -109,7 +108,6 @@ const (
UniqueWorkspaceResourceMetadataName UniqueConstraint = "workspace_resource_metadata_name" // ALTER TABLE ONLY workspace_resource_metadata ADD CONSTRAINT workspace_resource_metadata_name UNIQUE (workspace_resource_id, key);
UniqueWorkspaceResourceMetadataPkey UniqueConstraint = "workspace_resource_metadata_pkey" // ALTER TABLE ONLY workspace_resource_metadata ADD CONSTRAINT workspace_resource_metadata_pkey PRIMARY KEY (id);
UniqueWorkspaceResourcesPkey UniqueConstraint = "workspace_resources_pkey" // ALTER TABLE ONLY workspace_resources ADD CONSTRAINT workspace_resources_pkey PRIMARY KEY (id);
UniqueWorkspaceSessionsPkey UniqueConstraint = "workspace_sessions_pkey" // ALTER TABLE ONLY workspace_sessions ADD CONSTRAINT workspace_sessions_pkey PRIMARY KEY (id);
UniqueWorkspacesPkey UniqueConstraint = "workspaces_pkey" // ALTER TABLE ONLY workspaces ADD CONSTRAINT workspaces_pkey PRIMARY KEY (id);
UniqueIndexAPIKeyName UniqueConstraint = "idx_api_key_name" // CREATE UNIQUE INDEX idx_api_key_name ON api_keys USING btree (user_id, token_name) WHERE (login_type = 'token'::login_type);
UniqueIndexConnectionLogsConnectionIDWorkspaceIDAgentName UniqueConstraint = "idx_connection_logs_connection_id_workspace_id_agent_name" // CREATE UNIQUE INDEX idx_connection_logs_connection_id_workspace_id_agent_name ON connection_logs USING btree (connection_id, workspace_id, agent_name);

Some files were not shown because too many files have changed in this diff Show More