Compare commits

..

18 Commits

Author SHA1 Message Date
Jon Ayers cfd7730194 chore(enterprise/tailnet): add debug logging for LOST and DISCONNECTED peer updates 2026-04-10 00:40:57 +00:00
Jon Ayers 1937ada0cd fix: use enriched logger for HeartbeatClose, reduce AwaitReachable backoff to 5s 2026-04-09 23:45:59 +00:00
Jon Ayers d64cd6415d revert: move HeartbeatClose back before agent dial 2026-04-09 22:37:31 +00:00
Jon Ayers c1851d9453 chore(coderd/workspaceapps): add workspace_id and elapsed time to PTY dial logs 2026-04-09 22:35:44 +00:00
Jon Ayers 8f73453681 fix(coderd/workspaceapps): move HeartbeatClose after agent dial, add 1m setup timeout 2026-04-09 21:39:28 +00:00
Jon Ayers 165db3d31c perf(enterprise/tailnet): increase coordinator worker counts and batch size for 10k scale 2026-04-09 21:26:52 +00:00
Jon Ayers 1bd1516fd1 perf(tailnet): singleflight AwaitReachable to deduplicate concurrent ping storms 2026-04-09 19:34:22 +00:00
Jon Ayers 81ba35a987 fix(coderd/tailnet): move ensureAgent Send outside mutex using singleflight 2026-04-09 19:12:27 +00:00
Jon Ayers 53d63cf8e9 perf(coderd/database/pubsub): batch-drain msgQueue to amortize lock overhead
Replace the one-at-a-time dequeue loop in msgQueue.run() with a batch
drain that copies up to 256 messages per lock acquisition. This
amortizes mutex acquire/release and cond.Wait costs across many
messages, improving drain throughput during bursts and reducing the
likelihood of ring buffer overflow.
2026-04-08 00:02:29 +00:00
Jon Ayers 4213a43b53 fix(enterprise/tailnet): async singleflight-coalesced resyncPeerMappings in pubsub callbacks
Replace synchronous resyncPeerMappings() calls in listenPeer and
listenTunnel with async goroutines using singleflight.Do. This
prevents blocking the pubsub drain goroutine when ErrDroppedMessages
arrives, avoiding cascading buffer overflows.
2026-04-08 00:02:21 +00:00
Garrett Delfosse 5453a6c6d6 fix(scripts/releaser): simplify branch regex and fix changelog range (#23947)
Two fixes for the release script:

**1. Branch regex cleanup** — Simplified to only match `release/X.Y`.
Removed
support for `release/X.Y.Z` and `release/X.Y-rc.N` branch formats. RCs
are
now tagged from main (not from release branches), and the three-segment
`release/X.Y.Z` format will not be used going forward.

**2. Changelog range for first release on a new minor** — When no tags
match
the branch's major.minor, the commit range fell back to `HEAD` (entire
git
history, ~13k lines of changelog). Now computes `git merge-base` with
the
previous minor's release branch (e.g. `origin/release/2.32`) as the
changelog
starting point. This works even when that branch has no tags pushed yet.
Falls
back to the latest reachable tag from a previous minor if the branch
doesn't
exist.
2026-04-07 17:07:21 +00:00
Jake Howell 21c08a37d7 feat: de-mui <LogLine /> and <Logs /> (#24043)
Migrated LogLine and Logs components from Emotion CSS-in-JS to Tailwind
CSS classes.

- Replaced Emotion `css` prop and theme-based styling with Tailwind
utility classes in `LogLine` and `LogLinePrefix` components
- Converted CSS-in-JS styles object to conditional Tailwind classes
using the `cn` utility function
- Updated log level styling (error, debug, warn) to use Tailwind classes
with design token references
- Migrated the Logs container component styling from Emotion to Tailwind
classes
- Removed Emotion imports and theme dependencies
2026-04-07 16:35:10 +00:00
Jake Howell 2bd261fbbf fix: cleanup useKebabMenu code (#24042)
Refactored the tab overflow hook by renaming `useTabOverflowKebabMenu`
to `useKebabMenu` and removing the configurable `alwaysVisibleTabsCount`
parameter.

- Renamed `useTabOverflowKebabMenu` to `useKebabMenu` and moved it to a
new file
- Removed the `alwaysVisibleTabsCount` parameter and hardcoded it to 1
tab as `ALWAYS_VISIBLE_TABS_COUNT`
- Removed the `utils/index.ts` export file for the Tabs component
- Updated the import in `AgentRow.tsx` to use the new hook name and
removed the `alwaysVisibleTabsCount` prop
- Refactored the internal logic to use a more functional approach with
`reduce` instead of imperative loops
- Added better performance optimizations to prevent unnecessary
re-renders
2026-04-08 02:25:18 +10:00
Kyle Carberry cffc68df58 feat(site): render read_skill body as markdown (#24069) 2026-04-07 11:50:21 -04:00
Jake Howell 6e5335df1e feat: implement new workspace download logs dropdown (#23963)
This PR improves the agent log download functionality by replacing the
single download button with a comprehensive dropdown menu system.

- Replaced single download button with a dropdown menu offering multiple
download options
- Added ability to download all logs or individual log sources
separately
- Updated download button to show chevron icon indicating dropdown
functionality
- Enhanced download options with appropriate icons for each log source

<img width="370" height="305" alt="image"
src="https://github.com/user-attachments/assets/ddf025f5-f936-499a-9165-6e81b62d6860"
/>
2026-04-07 15:27:43 +00:00
Kyle Carberry 16265e834e chore: update fantasy fork to use github.com/coder/fantasy (#24100)
Moves the `charm.land/fantasy` replace directive from
`github.com/kylecarbs/fantasy` to `github.com/coder/fantasy`, pointing
at the same `cj/go1.25` branch and commit (`112927d9b6d8`).

> Generated by Coder Agents
2026-04-07 16:11:49 +01:00
Zach 565a15bc9b feat: update user secrets queries for REST API and injection (#23998)
Update queries as prep work for user secrets API development:
- Switch all lookups and mutations from ID-based to user_id + name
- Split list query into metadata-only (for API responses) and
with-values (for provisioner/agent)
- Add partial update support using CASE WHEN pattern for write-only
value fields
- Include value_key_id in create for dbcrypt encryption support
- Update dbauthz wrappers and remove stale methods from dbmetrics
2026-04-07 09:03:28 -06:00
Ethan 76a2cb1af5 fix(site/src/pages/AgentsPage): reset provider form after create (#23975)
Previously, after creating a provider config in the agents provider
editor, the Save changes button stayed enabled for the lifetime of the
mounted form. The form kept the pre-create local baseline, so the
freshly-saved values still looked dirty.

Key `ProviderForm` by provider config identity so React remounts the
form when a config is created and re-establishes the pristine state from
the saved provider values.
2026-04-08 00:32:36 +10:00
32 changed files with 1078 additions and 1298 deletions
+3 -4
View File
@@ -31,8 +31,7 @@ updates:
patterns:
- "golang.org/x/*"
ignore:
# Patch updates are handled by the security-patch-prs workflow so this
# lane stays focused on broader dependency updates.
# Ignore patch updates for all dependencies
- dependency-name: "*"
update-types:
- version-update:semver-patch
@@ -57,7 +56,7 @@ updates:
labels: []
ignore:
# We need to coordinate terraform updates with the version hardcoded in
# our Go code. These are handled by the security-patch-prs workflow.
# our Go code.
- dependency-name: "terraform"
- package-ecosystem: "npm"
@@ -118,11 +117,11 @@ updates:
interval: "weekly"
commit-message:
prefix: "chore"
labels: []
groups:
coder-modules:
patterns:
- "coder/*/coder"
labels: []
ignore:
- dependency-name: "*"
update-types:
-354
View File
@@ -1,354 +0,0 @@
name: security-backport
on:
pull_request_target:
types:
- labeled
- unlabeled
- closed
workflow_dispatch:
inputs:
pull_request:
description: Pull request number to backport.
required: true
type: string
permissions:
contents: write
pull-requests: write
issues: write
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || inputs.pull_request }}
cancel-in-progress: false
env:
LATEST_BRANCH: release/2.31
STABLE_BRANCH: release/2.30
STABLE_1_BRANCH: release/2.29
jobs:
label-policy:
if: github.event_name == 'pull_request_target'
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
- name: Apply security backport label policy
shell: bash
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
set -euo pipefail
pr_number="$(jq -r '.pull_request.number' "${GITHUB_EVENT_PATH}")"
pr_json="$(gh pr view "${pr_number}" --json number,title,url,baseRefName,labels)"
PR_JSON="${pr_json}" \
python3 - <<'PY'
import json
import os
import subprocess
import sys
pr = json.loads(os.environ["PR_JSON"])
pr_number = pr["number"]
labels = [label["name"] for label in pr.get("labels", [])]
def has(label: str) -> bool:
return label in labels
def ensure_label(label: str) -> None:
if not has(label):
subprocess.run(
["gh", "pr", "edit", str(pr_number), "--add-label", label],
check=False,
)
def remove_label(label: str) -> None:
if has(label):
subprocess.run(
["gh", "pr", "edit", str(pr_number), "--remove-label", label],
check=False,
)
def comment(body: str) -> None:
subprocess.run(
["gh", "pr", "comment", str(pr_number), "--body", body],
check=True,
)
if not has("security:patch"):
remove_label("status:needs-severity")
sys.exit(0)
severity_labels = [
label
for label in ("severity:medium", "severity:high", "severity:critical")
if has(label)
]
if len(severity_labels) == 0:
ensure_label("status:needs-severity")
comment(
"This PR is labeled `security:patch` but is missing a severity "
"label. Add one of `severity:medium`, `severity:high`, or "
"`severity:critical` before backport automation can proceed."
)
sys.exit(0)
if len(severity_labels) > 1:
comment(
"This PR has multiple severity labels. Keep exactly one of "
"`severity:medium`, `severity:high`, or `severity:critical`."
)
sys.exit(1)
remove_label("status:needs-severity")
target_labels = [
label
for label in ("backport:stable", "backport:stable-1")
if has(label)
]
has_none = has("backport:none")
if has_none and target_labels:
comment(
"`backport:none` cannot be combined with other backport labels. "
"Remove `backport:none` or remove the explicit backport targets."
)
sys.exit(1)
if not has_none and not target_labels:
ensure_label("backport:stable")
ensure_label("backport:stable-1")
comment(
"Applied default backport labels `backport:stable` and "
"`backport:stable-1` for a qualifying security patch."
)
PY
backport:
if: >
github.event_name == 'workflow_dispatch' ||
(
github.event_name == 'pull_request_target' &&
github.event.pull_request.merged == true
)
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
- name: Resolve PR metadata
id: metadata
shell: bash
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
INPUT_PR_NUMBER: ${{ inputs.pull_request }}
LATEST_BRANCH: ${{ env.LATEST_BRANCH }}
STABLE_BRANCH: ${{ env.STABLE_BRANCH }}
STABLE_1_BRANCH: ${{ env.STABLE_1_BRANCH }}
run: |
set -euo pipefail
if [ "${GITHUB_EVENT_NAME}" = "workflow_dispatch" ]; then
pr_number="${INPUT_PR_NUMBER}"
else
pr_number="$(jq -r '.pull_request.number' "${GITHUB_EVENT_PATH}")"
fi
case "${pr_number}" in
''|*[!0-9]*)
echo "A valid pull request number is required."
exit 1
;;
esac
pr_json="$(gh pr view "${pr_number}" --json number,title,url,mergeCommit,baseRefName,labels,mergedAt,author)"
PR_JSON="${pr_json}" \
python3 - <<'PY'
import json
import os
import sys
pr = json.loads(os.environ["PR_JSON"])
github_output = os.environ["GITHUB_OUTPUT"]
labels = [label["name"] for label in pr.get("labels", [])]
if "security:patch" not in labels:
print("Not a security patch PR; skipping.")
sys.exit(0)
severity_labels = [
label
for label in ("severity:medium", "severity:high", "severity:critical")
if label in labels
]
if len(severity_labels) != 1:
raise SystemExit(
"Merged security patch PR must have exactly one severity label."
)
if not pr.get("mergedAt"):
raise SystemExit(f"PR #{pr['number']} is not merged.")
if "backport:none" in labels:
target_pairs = []
else:
mapping = {
"backport:stable": os.environ["STABLE_BRANCH"],
"backport:stable-1": os.environ["STABLE_1_BRANCH"],
}
target_pairs = []
for label_name, branch in mapping.items():
if label_name in labels and branch and branch != pr["baseRefName"]:
target_pairs.append({"label": label_name, "branch": branch})
with open(github_output, "a", encoding="utf-8") as f:
f.write(f"pr_number={pr['number']}\n")
f.write(f"merge_sha={pr['mergeCommit']['oid']}\n")
f.write(f"title={pr['title']}\n")
f.write(f"url={pr['url']}\n")
f.write(f"author={pr['author']['login']}\n")
f.write(f"severity_label={severity_labels[0]}\n")
f.write(f"target_pairs={json.dumps(target_pairs)}\n")
PY
- name: Backport to release branches
if: ${{ steps.metadata.outputs.target_pairs != '[]' }}
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ steps.metadata.outputs.pr_number }}
MERGE_SHA: ${{ steps.metadata.outputs.merge_sha }}
PR_TITLE: ${{ steps.metadata.outputs.title }}
PR_URL: ${{ steps.metadata.outputs.url }}
PR_AUTHOR: ${{ steps.metadata.outputs.author }}
SEVERITY_LABEL: ${{ steps.metadata.outputs.severity_label }}
TARGET_PAIRS: ${{ steps.metadata.outputs.target_pairs }}
run: |
set -euo pipefail
git config user.name "github-actions[bot]"
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
git remote set-url origin "https://x-access-token:${GH_TOKEN}@github.com/${GITHUB_REPOSITORY}.git"
git fetch origin --prune
merge_parent_count="$(git rev-list --parents -n 1 "${MERGE_SHA}" | awk '{print NF-1}')"
failures=()
successes=()
while IFS=$'\t' read -r backport_label target_branch; do
[ -n "${target_branch}" ] || continue
safe_branch_name="${target_branch//\//-}"
head_branch="backport/${safe_branch_name}/pr-${PR_NUMBER}"
existing_pr="$(gh pr list \
--base "${target_branch}" \
--head "${head_branch}" \
--state all \
--json number,url \
--jq '.[0]')"
if [ -n "${existing_pr}" ] && [ "${existing_pr}" != "null" ]; then
pr_url="$(printf '%s' "${existing_pr}" | jq -r '.url')"
successes+=("${target_branch}:existing:${pr_url}")
continue
fi
git checkout -B "${head_branch}" "origin/${target_branch}"
if [ "${merge_parent_count}" -gt 1 ]; then
cherry_pick_args=(-m 1 "${MERGE_SHA}")
else
cherry_pick_args=("${MERGE_SHA}")
fi
if ! git cherry-pick -x "${cherry_pick_args[@]}"; then
git cherry-pick --abort || true
gh pr edit "${PR_NUMBER}" --add-label "backport:conflict" || true
gh pr comment "${PR_NUMBER}" --body \
"Automatic backport to \`${target_branch}\` conflicted. The original author or release manager should resolve it manually."
failures+=("${target_branch}:cherry-pick failed")
continue
fi
git push --force-with-lease origin "${head_branch}"
body_file="$(mktemp)"
printf '%s\n' \
"Automated backport of [#${PR_NUMBER}](${PR_URL})." \
"" \
"- Source PR: #${PR_NUMBER}" \
"- Source commit: ${MERGE_SHA}" \
"- Target branch: ${target_branch}" \
"- Severity: ${SEVERITY_LABEL}" \
> "${body_file}"
pr_url="$(gh pr create \
--base "${target_branch}" \
--head "${head_branch}" \
--title "${PR_TITLE} (backport to ${target_branch})" \
--body-file "${body_file}")"
backport_pr_number="$(gh pr list \
--base "${target_branch}" \
--head "${head_branch}" \
--state open \
--json number \
--jq '.[0].number')"
gh pr edit "${backport_pr_number}" \
--add-label "security:patch" \
--add-label "${SEVERITY_LABEL}" \
--add-label "${backport_label}" || true
successes+=("${target_branch}:created:${pr_url}")
done < <(
python3 - <<'PY'
import json
import os
for pair in json.loads(os.environ["TARGET_PAIRS"]):
print(f"{pair['label']}\t{pair['branch']}")
PY
)
summary_file="$(mktemp)"
{
echo "## Security backport summary"
echo
if [ "${#successes[@]}" -gt 0 ]; then
echo "### Created or existing"
for entry in "${successes[@]}"; do
echo "- ${entry}"
done
echo
fi
if [ "${#failures[@]}" -gt 0 ]; then
echo "### Failures"
for entry in "${failures[@]}"; do
echo "- ${entry}"
done
fi
} | tee -a "${GITHUB_STEP_SUMMARY}" > "${summary_file}"
gh pr comment "${PR_NUMBER}" --body-file "${summary_file}"
if [ "${#failures[@]}" -gt 0 ]; then
printf 'Backport failures:\n%s\n' "${failures[@]}" >&2
exit 1
fi
-214
View File
@@ -1,214 +0,0 @@
name: security-patch-prs
on:
workflow_dispatch:
schedule:
- cron: "0 3 * * 1-5"
permissions:
contents: write
pull-requests: write
jobs:
patch:
strategy:
fail-fast: false
matrix:
lane:
- gomod
- terraform
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Setup Go
uses: ./.github/actions/setup-go
- name: Patch Go dependencies
if: matrix.lane == 'gomod'
shell: bash
run: |
set -euo pipefail
go get -u=patch ./...
go mod tidy
# Guardrail: do not auto-edit replace directives.
if git diff --unified=0 -- go.mod | grep -E '^[+-]replace '; then
echo "Refusing to auto-edit go.mod replace directives"
exit 1
fi
# Guardrail: only go.mod / go.sum may change.
extra="$(git diff --name-only | grep -Ev '^(go\.mod|go\.sum)$' || true)"
test -z "$extra" || { echo "Unexpected files changed:"; echo "$extra"; exit 1; }
- name: Patch bundled Terraform
if: matrix.lane == 'terraform'
shell: bash
run: |
set -euo pipefail
current="$(
grep -oE 'NewVersion\("[0-9]+\.[0-9]+\.[0-9]+"\)' \
provisioner/terraform/install.go \
| head -1 \
| grep -oE '[0-9]+\.[0-9]+\.[0-9]+'
)"
series="$(echo "$current" | cut -d. -f1,2)"
latest="$(
curl -fsSL https://releases.hashicorp.com/terraform/index.json \
| jq -r --arg series "$series" '
.versions
| keys[]
| select(startswith($series + "."))
' \
| sort -V \
| tail -1
)"
test -n "$latest"
[ "$latest" != "$current" ] || exit 0
CURRENT_TERRAFORM_VERSION="$current" \
LATEST_TERRAFORM_VERSION="$latest" \
python3 - <<'PY'
from pathlib import Path
import os
current = os.environ["CURRENT_TERRAFORM_VERSION"]
latest = os.environ["LATEST_TERRAFORM_VERSION"]
updates = {
"scripts/Dockerfile.base": (
f"terraform/{current}/",
f"terraform/{latest}/",
),
"provisioner/terraform/install.go": (
f'NewVersion("{current}")',
f'NewVersion("{latest}")',
),
"install.sh": (
f'TERRAFORM_VERSION="{current}"',
f'TERRAFORM_VERSION="{latest}"',
),
}
for path_str, (before, after) in updates.items():
path = Path(path_str)
content = path.read_text()
if before not in content:
raise SystemExit(f"did not find expected text in {path_str}: {before}")
path.write_text(content.replace(before, after))
PY
# Guardrail: only the Terraform-version files may change.
extra="$(git diff --name-only | grep -Ev '^(scripts/Dockerfile.base|provisioner/terraform/install.go|install.sh)$' || true)"
test -z "$extra" || { echo "Unexpected files changed:"; echo "$extra"; exit 1; }
- name: Validate Go dependency patch
if: matrix.lane == 'gomod'
shell: bash
run: |
set -euo pipefail
go test ./...
- name: Validate Terraform patch
if: matrix.lane == 'terraform'
shell: bash
run: |
set -euo pipefail
go test ./provisioner/terraform/...
docker build -f scripts/Dockerfile.base .
- name: Skip PR creation when there are no changes
id: changes
shell: bash
run: |
set -euo pipefail
if git diff --quiet; then
echo "has_changes=false" >> "${GITHUB_OUTPUT}"
else
echo "has_changes=true" >> "${GITHUB_OUTPUT}"
fi
- name: Commit changes
if: steps.changes.outputs.has_changes == 'true'
shell: bash
run: |
set -euo pipefail
git config user.name "github-actions[bot]"
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
git checkout -B "secpatch/${{ matrix.lane }}"
git add -A
git commit -m "security: patch ${{ matrix.lane }}"
- name: Push branch
if: steps.changes.outputs.has_changes == 'true'
shell: bash
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
set -euo pipefail
git push --force-with-lease \
"https://x-access-token:${GH_TOKEN}@github.com/${{ github.repository }}.git" \
"HEAD:refs/heads/secpatch/${{ matrix.lane }}"
- name: Create or update PR
if: steps.changes.outputs.has_changes == 'true'
shell: bash
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
set -euo pipefail
branch="secpatch/${{ matrix.lane }}"
title="security: patch ${{ matrix.lane }}"
body="$(cat <<'EOF'
Automated security patch PR for `${{ matrix.lane }}`.
Scope:
- gomod: patch-level Go dependency updates only
- terraform: bundled Terraform patch updates only
Guardrails:
- no application-code edits
- no auto-editing of go.mod replace directives
- CI must pass
EOF
)"
existing_pr="$(gh pr list --head "${branch}" --base main --json number --jq '.[0].number // empty')"
if [[ -n "${existing_pr}" ]]; then
gh pr edit "${existing_pr}" \
--title "${title}" \
--body "${body}"
pr_number="${existing_pr}"
else
gh pr create \
--base main \
--head "${branch}" \
--title "${title}" \
--body "${body}"
pr_number="$(gh pr list --head "${branch}" --base main --json number --jq '.[0].number // empty')"
fi
for label in security dependencies automated-pr; do
if gh label list --json name --jq '.[].name' | grep -Fxq "${label}"; then
gh pr edit "${pr_number}" --add-label "${label}"
fi
done
+19 -32
View File
@@ -2155,17 +2155,12 @@ func (q *querier) DeleteUserChatProviderKey(ctx context.Context, arg database.De
return q.db.DeleteUserChatProviderKey(ctx, arg)
}
func (q *querier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
// First get the secret to check ownership
secret, err := q.GetUserSecret(ctx, id)
if err != nil {
func (q *querier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
if err := q.authorizeContext(ctx, policy.ActionDelete, obj); err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionDelete, secret); err != nil {
return err
}
return q.db.DeleteUserSecret(ctx, id)
return q.db.DeleteUserSecretByUserIDAndName(ctx, arg)
}
func (q *querier) DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg database.DeleteWebpushSubscriptionByUserIDAndEndpointParams) error {
@@ -4128,19 +4123,6 @@ func (q *querier) GetUserNotificationPreferences(ctx context.Context, userID uui
return q.db.GetUserNotificationPreferences(ctx, userID)
}
func (q *querier) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
// First get the secret to check ownership
secret, err := q.db.GetUserSecret(ctx, id)
if err != nil {
return database.UserSecret{}, err
}
if err := q.authorizeContext(ctx, policy.ActionRead, secret); err != nil {
return database.UserSecret{}, err
}
return secret, nil
}
func (q *querier) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
if err := q.authorizeContext(ctx, policy.ActionRead, obj); err != nil {
@@ -5524,7 +5506,7 @@ func (q *querier) ListUserChatCompactionThresholds(ctx context.Context, userID u
return q.db.ListUserChatCompactionThresholds(ctx, userID)
}
func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
obj := rbac.ResourceUserSecret.WithOwner(userID.String())
if err := q.authorizeContext(ctx, policy.ActionRead, obj); err != nil {
return nil, err
@@ -5532,6 +5514,16 @@ func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]data
return q.db.ListUserSecrets(ctx, userID)
}
func (q *querier) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
// This query returns decrypted secret values and must only be called
// from system contexts (provisioner, agent manifest). REST API
// handlers should use ListUserSecrets (metadata only).
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.ListUserSecretsWithValues(ctx, userID)
}
func (q *querier) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
workspace, err := q.db.GetWorkspaceByID(ctx, workspaceID)
if err != nil {
@@ -6632,17 +6624,12 @@ func (q *querier) UpdateUserRoles(ctx context.Context, arg database.UpdateUserRo
return q.db.UpdateUserRoles(ctx, arg)
}
func (q *querier) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
// First get the secret to check ownership
secret, err := q.db.GetUserSecret(ctx, arg.ID)
if err != nil {
func (q *querier) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
if err := q.authorizeContext(ctx, policy.ActionUpdate, obj); err != nil {
return database.UserSecret{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, secret); err != nil {
return database.UserSecret{}, err
}
return q.db.UpdateUserSecret(ctx, arg)
return q.db.UpdateUserSecretByUserIDAndName(ctx, arg)
}
func (q *querier) UpdateUserStatus(ctx context.Context, arg database.UpdateUserStatusParams) (database.User, error) {
+22 -22
View File
@@ -5346,19 +5346,20 @@ func (s *MethodTestSuite) TestUserSecrets() {
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionRead).
Returns(secret)
}))
s.Run("GetUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
check.Args(secret.ID).
Asserts(secret, policy.ActionRead).
Returns(secret)
}))
s.Run("ListUserSecrets", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
user := testutil.Fake(s.T(), faker, database.User{})
secret := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
dbm.EXPECT().ListUserSecrets(gomock.Any(), user.ID).Return([]database.UserSecret{secret}, nil).AnyTimes()
row := testutil.Fake(s.T(), faker, database.ListUserSecretsRow{UserID: user.ID})
dbm.EXPECT().ListUserSecrets(gomock.Any(), user.ID).Return([]database.ListUserSecretsRow{row}, nil).AnyTimes()
check.Args(user.ID).
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionRead).
Returns([]database.ListUserSecretsRow{row})
}))
s.Run("ListUserSecretsWithValues", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
user := testutil.Fake(s.T(), faker, database.User{})
secret := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
dbm.EXPECT().ListUserSecretsWithValues(gomock.Any(), user.ID).Return([]database.UserSecret{secret}, nil).AnyTimes()
check.Args(user.ID).
Asserts(rbac.ResourceSystem, policy.ActionRead).
Returns([]database.UserSecret{secret})
}))
s.Run("CreateUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
@@ -5370,22 +5371,21 @@ func (s *MethodTestSuite) TestUserSecrets() {
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionCreate).
Returns(ret)
}))
s.Run("UpdateUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
updated := testutil.Fake(s.T(), faker, database.UserSecret{ID: secret.ID})
arg := database.UpdateUserSecretParams{ID: secret.ID}
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
dbm.EXPECT().UpdateUserSecret(gomock.Any(), arg).Return(updated, nil).AnyTimes()
s.Run("UpdateUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
user := testutil.Fake(s.T(), faker, database.User{})
updated := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
arg := database.UpdateUserSecretByUserIDAndNameParams{UserID: user.ID, Name: "test"}
dbm.EXPECT().UpdateUserSecretByUserIDAndName(gomock.Any(), arg).Return(updated, nil).AnyTimes()
check.Args(arg).
Asserts(secret, policy.ActionUpdate).
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionUpdate).
Returns(updated)
}))
s.Run("DeleteUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
dbm.EXPECT().DeleteUserSecret(gomock.Any(), secret.ID).Return(nil).AnyTimes()
check.Args(secret.ID).
Asserts(secret, policy.ActionRead, secret, policy.ActionDelete).
s.Run("DeleteUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
user := testutil.Fake(s.T(), faker, database.User{})
arg := database.DeleteUserSecretByUserIDAndNameParams{UserID: user.ID, Name: "test"}
dbm.EXPECT().DeleteUserSecretByUserIDAndName(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionDelete).
Returns()
}))
}
+1
View File
@@ -1597,6 +1597,7 @@ func UserSecret(t testing.TB, db database.Store, seed database.UserSecret) datab
Name: takeFirst(seed.Name, "secret-name"),
Description: takeFirst(seed.Description, "secret description"),
Value: takeFirst(seed.Value, "secret value"),
ValueKeyID: seed.ValueKeyID,
EnvName: takeFirst(seed.EnvName, "SECRET_ENV_NAME"),
FilePath: takeFirst(seed.FilePath, "~/secret/file/path"),
})
+17 -17
View File
@@ -712,11 +712,11 @@ func (m queryMetricsStore) DeleteUserChatProviderKey(ctx context.Context, arg da
return r0
}
func (m queryMetricsStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
func (m queryMetricsStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
start := time.Now()
r0 := m.s.DeleteUserSecret(ctx, id)
m.queryLatencies.WithLabelValues("DeleteUserSecret").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserSecret").Inc()
r0 := m.s.DeleteUserSecretByUserIDAndName(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteUserSecretByUserIDAndName").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserSecretByUserIDAndName").Inc()
return r0
}
@@ -2624,14 +2624,6 @@ func (m queryMetricsStore) GetUserNotificationPreferences(ctx context.Context, u
return r0, r1
}
func (m queryMetricsStore) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
start := time.Now()
r0, r1 := m.s.GetUserSecret(ctx, id)
m.queryLatencies.WithLabelValues("GetUserSecret").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserSecret").Inc()
return r0, r1
}
func (m queryMetricsStore) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
start := time.Now()
r0, r1 := m.s.GetUserSecretByUserIDAndName(ctx, arg)
@@ -3920,7 +3912,7 @@ func (m queryMetricsStore) ListUserChatCompactionThresholds(ctx context.Context,
return r0, r1
}
func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
start := time.Now()
r0, r1 := m.s.ListUserSecrets(ctx, userID)
m.queryLatencies.WithLabelValues("ListUserSecrets").Observe(time.Since(start).Seconds())
@@ -3928,6 +3920,14 @@ func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID
return r0, r1
}
func (m queryMetricsStore) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
start := time.Now()
r0, r1 := m.s.ListUserSecretsWithValues(ctx, userID)
m.queryLatencies.WithLabelValues("ListUserSecretsWithValues").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListUserSecretsWithValues").Inc()
return r0, r1
}
func (m queryMetricsStore) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
start := time.Now()
r0, r1 := m.s.ListWorkspaceAgentPortShares(ctx, workspaceID)
@@ -4696,11 +4696,11 @@ func (m queryMetricsStore) UpdateUserRoles(ctx context.Context, arg database.Upd
return r0, r1
}
func (m queryMetricsStore) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
func (m queryMetricsStore) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
start := time.Now()
r0, r1 := m.s.UpdateUserSecret(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateUserSecret").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserSecret").Inc()
r0, r1 := m.s.UpdateUserSecretByUserIDAndName(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateUserSecretByUserIDAndName").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserSecretByUserIDAndName").Inc()
return r0, r1
}
+29 -29
View File
@@ -1199,18 +1199,18 @@ func (mr *MockStoreMockRecorder) DeleteUserChatProviderKey(ctx, arg any) *gomock
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).DeleteUserChatProviderKey), ctx, arg)
}
// DeleteUserSecret mocks base method.
func (m *MockStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
// DeleteUserSecretByUserIDAndName mocks base method.
func (m *MockStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteUserSecret", ctx, id)
ret := m.ctrl.Call(m, "DeleteUserSecretByUserIDAndName", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteUserSecret indicates an expected call of DeleteUserSecret.
func (mr *MockStoreMockRecorder) DeleteUserSecret(ctx, id any) *gomock.Call {
// DeleteUserSecretByUserIDAndName indicates an expected call of DeleteUserSecretByUserIDAndName.
func (mr *MockStoreMockRecorder) DeleteUserSecretByUserIDAndName(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserSecret", reflect.TypeOf((*MockStore)(nil).DeleteUserSecret), ctx, id)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserSecretByUserIDAndName", reflect.TypeOf((*MockStore)(nil).DeleteUserSecretByUserIDAndName), ctx, arg)
}
// DeleteWebpushSubscriptionByUserIDAndEndpoint mocks base method.
@@ -4907,21 +4907,6 @@ func (mr *MockStoreMockRecorder) GetUserNotificationPreferences(ctx, userID any)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserNotificationPreferences", reflect.TypeOf((*MockStore)(nil).GetUserNotificationPreferences), ctx, userID)
}
// GetUserSecret mocks base method.
func (m *MockStore) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserSecret", ctx, id)
ret0, _ := ret[0].(database.UserSecret)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUserSecret indicates an expected call of GetUserSecret.
func (mr *MockStoreMockRecorder) GetUserSecret(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserSecret", reflect.TypeOf((*MockStore)(nil).GetUserSecret), ctx, id)
}
// GetUserSecretByUserIDAndName mocks base method.
func (m *MockStore) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
m.ctrl.T.Helper()
@@ -7412,10 +7397,10 @@ func (mr *MockStoreMockRecorder) ListUserChatCompactionThresholds(ctx, userID an
}
// ListUserSecrets mocks base method.
func (m *MockStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
func (m *MockStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListUserSecrets", ctx, userID)
ret0, _ := ret[0].([]database.UserSecret)
ret0, _ := ret[0].([]database.ListUserSecretsRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -7426,6 +7411,21 @@ func (mr *MockStoreMockRecorder) ListUserSecrets(ctx, userID any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUserSecrets", reflect.TypeOf((*MockStore)(nil).ListUserSecrets), ctx, userID)
}
// ListUserSecretsWithValues mocks base method.
func (m *MockStore) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListUserSecretsWithValues", ctx, userID)
ret0, _ := ret[0].([]database.UserSecret)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListUserSecretsWithValues indicates an expected call of ListUserSecretsWithValues.
func (mr *MockStoreMockRecorder) ListUserSecretsWithValues(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUserSecretsWithValues", reflect.TypeOf((*MockStore)(nil).ListUserSecretsWithValues), ctx, userID)
}
// ListWorkspaceAgentPortShares mocks base method.
func (m *MockStore) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
m.ctrl.T.Helper()
@@ -8854,19 +8854,19 @@ func (mr *MockStoreMockRecorder) UpdateUserRoles(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserRoles", reflect.TypeOf((*MockStore)(nil).UpdateUserRoles), ctx, arg)
}
// UpdateUserSecret mocks base method.
func (m *MockStore) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
// UpdateUserSecretByUserIDAndName mocks base method.
func (m *MockStore) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateUserSecret", ctx, arg)
ret := m.ctrl.Call(m, "UpdateUserSecretByUserIDAndName", ctx, arg)
ret0, _ := ret[0].(database.UserSecret)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateUserSecret indicates an expected call of UpdateUserSecret.
func (mr *MockStoreMockRecorder) UpdateUserSecret(ctx, arg any) *gomock.Call {
// UpdateUserSecretByUserIDAndName indicates an expected call of UpdateUserSecretByUserIDAndName.
func (mr *MockStoreMockRecorder) UpdateUserSecretByUserIDAndName(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserSecret", reflect.TypeOf((*MockStore)(nil).UpdateUserSecret), ctx, arg)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserSecretByUserIDAndName", reflect.TypeOf((*MockStore)(nil).UpdateUserSecretByUserIDAndName), ctx, arg)
}
// UpdateUserStatus mocks base method.
+27 -17
View File
@@ -81,8 +81,8 @@ func newMsgQueue(ctx context.Context, l Listener, le ListenerWithErr) *msgQueue
}
func (q *msgQueue) run() {
var batch [maxDrainBatch]msgOrErr
for {
// wait until there is something on the queue or we are closed
q.cond.L.Lock()
for q.size == 0 && !q.closed {
q.cond.Wait()
@@ -91,28 +91,32 @@ func (q *msgQueue) run() {
q.cond.L.Unlock()
return
}
item := q.q[q.front]
q.front = (q.front + 1) % BufferSize
q.size--
// Drain up to maxDrainBatch items while holding the lock.
n := min(q.size, maxDrainBatch)
for i := range n {
batch[i] = q.q[q.front]
q.front = (q.front + 1) % BufferSize
}
q.size -= n
q.cond.L.Unlock()
// process item without holding lock
if item.err == nil {
// real message
if q.l != nil {
q.l(q.ctx, item.msg)
// Dispatch each message individually without holding the lock.
for i := range n {
item := batch[i]
if item.err == nil {
if q.l != nil {
q.l(q.ctx, item.msg)
continue
}
if q.le != nil {
q.le(q.ctx, item.msg, nil)
continue
}
continue
}
if q.le != nil {
q.le(q.ctx, item.msg, nil)
continue
q.le(q.ctx, nil, item.err)
}
// unhittable
continue
}
// if the listener wants errors, send it.
if q.le != nil {
q.le(q.ctx, nil, item.err)
}
}
}
@@ -233,6 +237,12 @@ type PGPubsub struct {
// for a subscriber before dropping messages.
const BufferSize = 2048
// maxDrainBatch is the maximum number of messages to drain from the ring
// buffer per iteration. Batching amortizes the cost of mutex
// acquire/release and cond.Wait across many messages, improving drain
// throughput during bursts.
const maxDrainBatch = 256
// Subscribe calls the listener when an event matching the name is received.
func (p *PGPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
return p.subscribeQueue(event, newMsgQueue(context.Background(), listener, nil))
+9 -4
View File
@@ -152,7 +152,7 @@ type sqlcQuerier interface {
DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error)
DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error
DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error
DeleteUserSecret(ctx context.Context, id uuid.UUID) error
DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) error
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) error
@@ -598,7 +598,6 @@ type sqlcQuerier interface {
GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUserLinkByUserIDLoginTypeParams) (UserLink, error)
GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]UserLink, error)
GetUserNotificationPreferences(ctx context.Context, userID uuid.UUID) ([]NotificationPreference, error)
GetUserSecret(ctx context.Context, id uuid.UUID) (UserSecret, error)
GetUserSecretByUserIDAndName(ctx context.Context, arg GetUserSecretByUserIDAndNameParams) (UserSecret, error)
// GetUserStatusCounts returns the count of users in each status over time.
// The time range is inclusively defined by the start_time and end_time parameters.
@@ -818,7 +817,13 @@ type sqlcQuerier interface {
ListProvisionerKeysByOrganizationExcludeReserved(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
ListTasks(ctx context.Context, arg ListTasksParams) ([]Task, error)
ListUserChatCompactionThresholds(ctx context.Context, userID uuid.UUID) ([]UserConfig, error)
ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]UserSecret, error)
// Returns metadata only (no value or value_key_id) for the
// REST API list and get endpoints.
ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]ListUserSecretsRow, error)
// Returns all columns including the secret value. Used by the
// provisioner (build-time injection) and the agent manifest
// (runtime injection).
ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]UserSecret, error)
ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceAgentPortShare, error)
MarkAllInboxNotificationsAsRead(ctx context.Context, arg MarkAllInboxNotificationsAsReadParams) error
OIDCClaimFieldValues(ctx context.Context, arg OIDCClaimFieldValuesParams) ([]string, error)
@@ -957,7 +962,7 @@ type sqlcQuerier interface {
UpdateUserProfile(ctx context.Context, arg UpdateUserProfileParams) (User, error)
UpdateUserQuietHoursSchedule(ctx context.Context, arg UpdateUserQuietHoursScheduleParams) (User, error)
UpdateUserRoles(ctx context.Context, arg UpdateUserRolesParams) (User, error)
UpdateUserSecret(ctx context.Context, arg UpdateUserSecretParams) (UserSecret, error)
UpdateUserSecretByUserIDAndName(ctx context.Context, arg UpdateUserSecretByUserIDAndNameParams) (UserSecret, error)
UpdateUserStatus(ctx context.Context, arg UpdateUserStatusParams) (User, error)
UpdateUserTaskNotificationAlertDismissed(ctx context.Context, arg UpdateUserTaskNotificationAlertDismissedParams) (bool, error)
UpdateUserTerminalFont(ctx context.Context, arg UpdateUserTerminalFontParams) (UserConfig, error)
+45 -30
View File
@@ -7339,13 +7339,7 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, secretID, createdSecret.ID)
// 2. READ by ID
readSecret, err := db.GetUserSecret(ctx, createdSecret.ID)
require.NoError(t, err)
assert.Equal(t, createdSecret.ID, readSecret.ID)
assert.Equal(t, "workflow-secret", readSecret.Name)
// 3. READ by UserID and Name
// 2. READ by UserID and Name
readByNameParams := database.GetUserSecretByUserIDAndNameParams{
UserID: testUser.ID,
Name: "workflow-secret",
@@ -7353,33 +7347,43 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
readByNameSecret, err := db.GetUserSecretByUserIDAndName(ctx, readByNameParams)
require.NoError(t, err)
assert.Equal(t, createdSecret.ID, readByNameSecret.ID)
assert.Equal(t, "workflow-secret", readByNameSecret.Name)
// 4. LIST
// 3. LIST (metadata only)
secrets, err := db.ListUserSecrets(ctx, testUser.ID)
require.NoError(t, err)
require.Len(t, secrets, 1)
assert.Equal(t, createdSecret.ID, secrets[0].ID)
// 5. UPDATE
updateParams := database.UpdateUserSecretParams{
ID: createdSecret.ID,
Description: "Updated workflow description",
Value: "updated-workflow-value",
EnvName: "UPDATED_WORKFLOW_ENV",
FilePath: "/updated/workflow/path",
// 4. LIST with values
secretsWithValues, err := db.ListUserSecretsWithValues(ctx, testUser.ID)
require.NoError(t, err)
require.Len(t, secretsWithValues, 1)
assert.Equal(t, "workflow-value", secretsWithValues[0].Value)
// 5. UPDATE (partial - only description)
updateParams := database.UpdateUserSecretByUserIDAndNameParams{
UserID: testUser.ID,
Name: "workflow-secret",
UpdateDescription: true,
Description: "Updated workflow description",
}
updatedSecret, err := db.UpdateUserSecret(ctx, updateParams)
updatedSecret, err := db.UpdateUserSecretByUserIDAndName(ctx, updateParams)
require.NoError(t, err)
assert.Equal(t, "Updated workflow description", updatedSecret.Description)
assert.Equal(t, "updated-workflow-value", updatedSecret.Value)
assert.Equal(t, "workflow-value", updatedSecret.Value) // Value unchanged
assert.Equal(t, "WORKFLOW_ENV", updatedSecret.EnvName) // EnvName unchanged
// 6. DELETE
err = db.DeleteUserSecret(ctx, createdSecret.ID)
err = db.DeleteUserSecretByUserIDAndName(ctx, database.DeleteUserSecretByUserIDAndNameParams{
UserID: testUser.ID,
Name: "workflow-secret",
})
require.NoError(t, err)
// Verify deletion
_, err = db.GetUserSecret(ctx, createdSecret.ID)
_, err = db.GetUserSecretByUserIDAndName(ctx, readByNameParams)
require.Error(t, err)
assert.Contains(t, err.Error(), "no rows in result set")
@@ -7449,9 +7453,13 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
})
// Verify both secrets exist
_, err = db.GetUserSecret(ctx, secret1.ID)
_, err = db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
UserID: testUser.ID, Name: secret1.Name,
})
require.NoError(t, err)
_, err = db.GetUserSecret(ctx, secret2.ID)
_, err = db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
UserID: testUser.ID, Name: secret2.Name,
})
require.NoError(t, err)
})
}
@@ -7474,14 +7482,14 @@ func TestUserSecretsAuthorization(t *testing.T) {
org := dbgen.Organization(t, db, database.Organization{})
// Create secrets for users
user1Secret := dbgen.UserSecret(t, db, database.UserSecret{
_ = dbgen.UserSecret(t, db, database.UserSecret{
UserID: user1.ID,
Name: "user1-secret",
Description: "User 1's secret",
Value: "user1-value",
})
user2Secret := dbgen.UserSecret(t, db, database.UserSecret{
_ = dbgen.UserSecret(t, db, database.UserSecret{
UserID: user2.ID,
Name: "user2-secret",
Description: "User 2's secret",
@@ -7491,7 +7499,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
testCases := []struct {
name string
subject rbac.Subject
secretID uuid.UUID
lookupUserID uuid.UUID
lookupName string
expectedAccess bool
}{
{
@@ -7501,7 +7510,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
Roles: rbac.RoleIdentifiers{rbac.RoleMember()},
Scope: rbac.ScopeAll,
},
secretID: user1Secret.ID,
lookupUserID: user1.ID,
lookupName: "user1-secret",
expectedAccess: true,
},
{
@@ -7511,7 +7521,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
Roles: rbac.RoleIdentifiers{rbac.RoleMember()},
Scope: rbac.ScopeAll,
},
secretID: user2Secret.ID,
lookupUserID: user2.ID,
lookupName: "user2-secret",
expectedAccess: false,
},
{
@@ -7521,7 +7532,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
Roles: rbac.RoleIdentifiers{rbac.RoleOwner()},
Scope: rbac.ScopeAll,
},
secretID: user1Secret.ID,
lookupUserID: user1.ID,
lookupName: "user1-secret",
expectedAccess: false,
},
{
@@ -7531,7 +7543,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
Roles: rbac.RoleIdentifiers{rbac.ScopedRoleOrgAdmin(org.ID)},
Scope: rbac.ScopeAll,
},
secretID: user1Secret.ID,
lookupUserID: user1.ID,
lookupName: "user1-secret",
expectedAccess: false,
},
}
@@ -7543,8 +7556,10 @@ func TestUserSecretsAuthorization(t *testing.T) {
authCtx := dbauthz.As(ctx, tc.subject)
// Test GetUserSecret
_, err := authDB.GetUserSecret(authCtx, tc.secretID)
_, err := authDB.GetUserSecretByUserIDAndName(authCtx, database.GetUserSecretByUserIDAndNameParams{
UserID: tc.lookupUserID,
Name: tc.lookupName,
})
if tc.expectedAccess {
require.NoError(t, err, "expected access to be granted")
+120 -55
View File
@@ -22639,21 +22639,30 @@ INSERT INTO user_secrets (
name,
description,
value,
value_key_id,
env_name,
file_path
) VALUES (
$1, $2, $3, $4, $5, $6, $7
$1,
$2,
$3,
$4,
$5,
$6,
$7,
$8
) RETURNING id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id
`
type CreateUserSecretParams struct {
ID uuid.UUID `db:"id" json:"id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
Name string `db:"name" json:"name"`
Description string `db:"description" json:"description"`
Value string `db:"value" json:"value"`
EnvName string `db:"env_name" json:"env_name"`
FilePath string `db:"file_path" json:"file_path"`
ID uuid.UUID `db:"id" json:"id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
Name string `db:"name" json:"name"`
Description string `db:"description" json:"description"`
Value string `db:"value" json:"value"`
ValueKeyID sql.NullString `db:"value_key_id" json:"value_key_id"`
EnvName string `db:"env_name" json:"env_name"`
FilePath string `db:"file_path" json:"file_path"`
}
func (q *sqlQuerier) CreateUserSecret(ctx context.Context, arg CreateUserSecretParams) (UserSecret, error) {
@@ -22663,6 +22672,7 @@ func (q *sqlQuerier) CreateUserSecret(ctx context.Context, arg CreateUserSecretP
arg.Name,
arg.Description,
arg.Value,
arg.ValueKeyID,
arg.EnvName,
arg.FilePath,
)
@@ -22682,41 +22692,24 @@ func (q *sqlQuerier) CreateUserSecret(ctx context.Context, arg CreateUserSecretP
return i, err
}
const deleteUserSecret = `-- name: DeleteUserSecret :exec
const deleteUserSecretByUserIDAndName = `-- name: DeleteUserSecretByUserIDAndName :exec
DELETE FROM user_secrets
WHERE id = $1
WHERE user_id = $1 AND name = $2
`
func (q *sqlQuerier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
_, err := q.db.ExecContext(ctx, deleteUserSecret, id)
type DeleteUserSecretByUserIDAndNameParams struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
Name string `db:"name" json:"name"`
}
func (q *sqlQuerier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) error {
_, err := q.db.ExecContext(ctx, deleteUserSecretByUserIDAndName, arg.UserID, arg.Name)
return err
}
const getUserSecret = `-- name: GetUserSecret :one
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id FROM user_secrets
WHERE id = $1
`
func (q *sqlQuerier) GetUserSecret(ctx context.Context, id uuid.UUID) (UserSecret, error) {
row := q.db.QueryRowContext(ctx, getUserSecret, id)
var i UserSecret
err := row.Scan(
&i.ID,
&i.UserID,
&i.Name,
&i.Description,
&i.Value,
&i.EnvName,
&i.FilePath,
&i.CreatedAt,
&i.UpdatedAt,
&i.ValueKeyID,
)
return i, err
}
const getUserSecretByUserIDAndName = `-- name: GetUserSecretByUserIDAndName :one
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id FROM user_secrets
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id
FROM user_secrets
WHERE user_id = $1 AND name = $2
`
@@ -22744,17 +22737,76 @@ func (q *sqlQuerier) GetUserSecretByUserIDAndName(ctx context.Context, arg GetUs
}
const listUserSecrets = `-- name: ListUserSecrets :many
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id FROM user_secrets
SELECT
id, user_id, name, description,
env_name, file_path,
created_at, updated_at
FROM user_secrets
WHERE user_id = $1
ORDER BY name ASC
`
func (q *sqlQuerier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]UserSecret, error) {
type ListUserSecretsRow struct {
ID uuid.UUID `db:"id" json:"id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
Name string `db:"name" json:"name"`
Description string `db:"description" json:"description"`
EnvName string `db:"env_name" json:"env_name"`
FilePath string `db:"file_path" json:"file_path"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
// Returns metadata only (no value or value_key_id) for the
// REST API list and get endpoints.
func (q *sqlQuerier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]ListUserSecretsRow, error) {
rows, err := q.db.QueryContext(ctx, listUserSecrets, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ListUserSecretsRow
for rows.Next() {
var i ListUserSecretsRow
if err := rows.Scan(
&i.ID,
&i.UserID,
&i.Name,
&i.Description,
&i.EnvName,
&i.FilePath,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listUserSecretsWithValues = `-- name: ListUserSecretsWithValues :many
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id
FROM user_secrets
WHERE user_id = $1
ORDER BY name ASC
`
// Returns all columns including the secret value. Used by the
// provisioner (build-time injection) and the agent manifest
// (runtime injection).
func (q *sqlQuerier) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]UserSecret, error) {
rows, err := q.db.QueryContext(ctx, listUserSecretsWithValues, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []UserSecret
for rows.Next() {
var i UserSecret
@@ -22783,33 +22835,46 @@ func (q *sqlQuerier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]U
return items, nil
}
const updateUserSecret = `-- name: UpdateUserSecret :one
const updateUserSecretByUserIDAndName = `-- name: UpdateUserSecretByUserIDAndName :one
UPDATE user_secrets
SET
description = $2,
value = $3,
env_name = $4,
file_path = $5,
updated_at = CURRENT_TIMESTAMP
WHERE id = $1
value = CASE WHEN $1::bool THEN $2 ELSE value END,
value_key_id = CASE WHEN $1::bool THEN $3 ELSE value_key_id END,
description = CASE WHEN $4::bool THEN $5 ELSE description END,
env_name = CASE WHEN $6::bool THEN $7 ELSE env_name END,
file_path = CASE WHEN $8::bool THEN $9 ELSE file_path END,
updated_at = CURRENT_TIMESTAMP
WHERE user_id = $10 AND name = $11
RETURNING id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id
`
type UpdateUserSecretParams struct {
ID uuid.UUID `db:"id" json:"id"`
Description string `db:"description" json:"description"`
Value string `db:"value" json:"value"`
EnvName string `db:"env_name" json:"env_name"`
FilePath string `db:"file_path" json:"file_path"`
type UpdateUserSecretByUserIDAndNameParams struct {
UpdateValue bool `db:"update_value" json:"update_value"`
Value string `db:"value" json:"value"`
ValueKeyID sql.NullString `db:"value_key_id" json:"value_key_id"`
UpdateDescription bool `db:"update_description" json:"update_description"`
Description string `db:"description" json:"description"`
UpdateEnvName bool `db:"update_env_name" json:"update_env_name"`
EnvName string `db:"env_name" json:"env_name"`
UpdateFilePath bool `db:"update_file_path" json:"update_file_path"`
FilePath string `db:"file_path" json:"file_path"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
Name string `db:"name" json:"name"`
}
func (q *sqlQuerier) UpdateUserSecret(ctx context.Context, arg UpdateUserSecretParams) (UserSecret, error) {
row := q.db.QueryRowContext(ctx, updateUserSecret,
arg.ID,
arg.Description,
func (q *sqlQuerier) UpdateUserSecretByUserIDAndName(ctx context.Context, arg UpdateUserSecretByUserIDAndNameParams) (UserSecret, error) {
row := q.db.QueryRowContext(ctx, updateUserSecretByUserIDAndName,
arg.UpdateValue,
arg.Value,
arg.ValueKeyID,
arg.UpdateDescription,
arg.Description,
arg.UpdateEnvName,
arg.EnvName,
arg.UpdateFilePath,
arg.FilePath,
arg.UserID,
arg.Name,
)
var i UserSecret
err := row.Scan(
+39 -18
View File
@@ -1,14 +1,26 @@
-- name: GetUserSecretByUserIDAndName :one
SELECT * FROM user_secrets
WHERE user_id = $1 AND name = $2;
-- name: GetUserSecret :one
SELECT * FROM user_secrets
WHERE id = $1;
SELECT *
FROM user_secrets
WHERE user_id = @user_id AND name = @name;
-- name: ListUserSecrets :many
SELECT * FROM user_secrets
WHERE user_id = $1
-- Returns metadata only (no value or value_key_id) for the
-- REST API list and get endpoints.
SELECT
id, user_id, name, description,
env_name, file_path,
created_at, updated_at
FROM user_secrets
WHERE user_id = @user_id
ORDER BY name ASC;
-- name: ListUserSecretsWithValues :many
-- Returns all columns including the secret value. Used by the
-- provisioner (build-time injection) and the agent manifest
-- (runtime injection).
SELECT *
FROM user_secrets
WHERE user_id = @user_id
ORDER BY name ASC;
-- name: CreateUserSecret :one
@@ -18,23 +30,32 @@ INSERT INTO user_secrets (
name,
description,
value,
value_key_id,
env_name,
file_path
) VALUES (
$1, $2, $3, $4, $5, $6, $7
@id,
@user_id,
@name,
@description,
@value,
@value_key_id,
@env_name,
@file_path
) RETURNING *;
-- name: UpdateUserSecret :one
-- name: UpdateUserSecretByUserIDAndName :one
UPDATE user_secrets
SET
description = $2,
value = $3,
env_name = $4,
file_path = $5,
updated_at = CURRENT_TIMESTAMP
WHERE id = $1
value = CASE WHEN @update_value::bool THEN @value ELSE value END,
value_key_id = CASE WHEN @update_value::bool THEN @value_key_id ELSE value_key_id END,
description = CASE WHEN @update_description::bool THEN @description ELSE description END,
env_name = CASE WHEN @update_env_name::bool THEN @env_name ELSE env_name END,
file_path = CASE WHEN @update_file_path::bool THEN @file_path ELSE file_path END,
updated_at = CURRENT_TIMESTAMP
WHERE user_id = @user_id AND name = @name
RETURNING *;
-- name: DeleteUserSecret :exec
-- name: DeleteUserSecretByUserIDAndName :exec
DELETE FROM user_secrets
WHERE id = $1;
WHERE user_id = @user_id AND name = @name;
+37 -19
View File
@@ -19,6 +19,7 @@ import (
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/singleflight"
"golang.org/x/xerrors"
"tailscale.com/derp"
"tailscale.com/tailcfg"
@@ -389,6 +390,7 @@ type MultiAgentController struct {
// connections to the destination
tickets map[uuid.UUID]map[uuid.UUID]struct{}
coordination *tailnet.BasicCoordination
sendGroup singleflight.Group
cancel context.CancelFunc
expireOldAgentsDone chan struct{}
@@ -418,28 +420,44 @@ func (m *MultiAgentController) New(client tailnet.CoordinatorClient) tailnet.Clo
func (m *MultiAgentController) ensureAgent(agentID uuid.UUID) error {
m.mu.Lock()
defer m.mu.Unlock()
_, ok := m.connectionTimes[agentID]
// If we don't have the agent, subscribe.
if !ok {
m.logger.Debug(context.Background(),
"subscribing to agent", slog.F("agent_id", agentID))
if m.coordination != nil {
err := m.coordination.Client.Send(&proto.CoordinateRequest{
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
})
if err != nil {
err = xerrors.Errorf("subscribe agent: %w", err)
m.coordination.SendErr(err)
_ = m.coordination.Client.Close()
m.coordination = nil
return err
}
}
m.tickets[agentID] = map[uuid.UUID]struct{}{}
if ok {
m.connectionTimes[agentID] = time.Now()
m.mu.Unlock()
return nil
}
m.mu.Unlock()
m.logger.Debug(context.Background(),
"subscribing to agent", slog.F("agent_id", agentID))
_, err, _ := m.sendGroup.Do(agentID.String(), func() (interface{}, error) {
m.mu.Lock()
coord := m.coordination
m.mu.Unlock()
if coord == nil {
return nil, xerrors.New("no active coordination")
}
err := coord.Client.Send(&proto.CoordinateRequest{
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
})
if err != nil {
return nil, err
}
m.mu.Lock()
m.tickets[agentID] = map[uuid.UUID]struct{}{}
m.mu.Unlock()
return nil, nil
})
if err != nil {
m.logger.Error(context.Background(), "ensureAgent send failed",
slog.F("agent_id", agentID), slog.Error(err))
return xerrors.Errorf("send AddTunnel: %w", err)
}
m.mu.Lock()
m.connectionTimes[agentID] = time.Now()
m.mu.Unlock()
return nil
}
+13 -8
View File
@@ -730,7 +730,10 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
if !ok {
return
}
log := s.Logger.With(slog.F("agent_id", appToken.AgentID))
log := s.Logger.With(
slog.F("agent_id", appToken.AgentID),
slog.F("workspace_id", appToken.WorkspaceID),
)
log.Debug(ctx, "resolved PTY request")
values := r.URL.Query()
@@ -765,19 +768,21 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
})
return
}
go httpapi.HeartbeatClose(ctx, s.Logger, cancel, conn)
ctx, wsNetConn := WebsocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close() // Also closes conn.
go httpapi.HeartbeatClose(ctx, log, cancel, conn)
dialStart := time.Now()
agentConn, release, err := s.AgentProvider.AgentConn(ctx, appToken.AgentID)
if err != nil {
log.Debug(ctx, "dial workspace agent", slog.Error(err))
log.Debug(ctx, "dial workspace agent", slog.Error(err), slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err))
return
}
defer release()
log.Debug(ctx, "dialed workspace agent")
log.Debug(ctx, "dialed workspace agent", slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
// #nosec G115 - Safe conversion for terminal height/width which are expected to be within uint16 range (0-65535)
ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect, uint16(height), uint16(width), r.URL.Query().Get("command"), func(arp *workspacesdk.AgentReconnectingPTYInit) {
arp.Container = container
@@ -785,12 +790,12 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
arp.BackendType = backendType
})
if err != nil {
log.Debug(ctx, "dial reconnecting pty server in workspace agent", slog.Error(err))
log.Debug(ctx, "dial reconnecting pty server in workspace agent", slog.Error(err), slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial: %s", err))
return
}
defer ptNetConn.Close()
log.Debug(ctx, "obtained PTY")
log.Debug(ctx, "obtained PTY", slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
report := newStatsReportFromSignedToken(*appToken)
s.collectStats(report)
@@ -800,7 +805,7 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
}()
agentssh.Bicopy(ctx, wsNetConn, ptNetConn)
log.Debug(ctx, "pty Bicopy finished")
log.Debug(ctx, "pty Bicopy finished", slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
}
func (s *Server) collectStats(stats StatsReport) {
+27 -7
View File
@@ -12,6 +12,7 @@ import (
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
"golang.org/x/sync/singleflight"
"golang.org/x/xerrors"
gProto "google.golang.org/protobuf/proto"
@@ -33,9 +34,9 @@ const (
eventReadyForHandshake = "tailnet_ready_for_handshake"
HeartbeatPeriod = time.Second * 2
MissedHeartbeats = 3
numQuerierWorkers = 10
numQuerierWorkers = 40
numBinderWorkers = 10
numTunnelerWorkers = 10
numTunnelerWorkers = 20
numHandshakerWorkers = 5
dbMaxBackoff = 10 * time.Second
cleanupPeriod = time.Hour
@@ -770,6 +771,9 @@ func (m *mapper) bestToUpdate(best map[uuid.UUID]mapping) *proto.CoordinateRespo
for k := range m.sent {
if _, ok := best[k]; !ok {
m.logger.Debug(m.ctx, "peer no longer in best mappings, sending DISCONNECTED",
slog.F("peer_id", k),
)
resp.PeerUpdates = append(resp.PeerUpdates, &proto.CoordinateResponse_PeerUpdate{
Id: agpl.UUIDToByteSlice(k),
Kind: proto.CoordinateResponse_PeerUpdate_DISCONNECTED,
@@ -820,6 +824,8 @@ type querier struct {
mu sync.Mutex
mappers map[mKey]*mapper
healthy bool
resyncGroup singleflight.Group
}
func newQuerier(ctx context.Context,
@@ -958,7 +964,7 @@ func (q *querier) cleanupConn(c *connIO) {
// maxBatchSize is the maximum number of keys to process in a single batch
// query.
const maxBatchSize = 50
const maxBatchSize = 200
func (q *querier) peerUpdateWorker() {
defer q.wg.Done()
@@ -1207,8 +1213,13 @@ func (q *querier) subscribe() {
func (q *querier) listenPeer(_ context.Context, msg []byte, err error) {
if xerrors.Is(err, pubsub.ErrDroppedMessages) {
q.logger.Warn(q.ctx, "pubsub may have dropped peer updates")
// we need to schedule a full resync of peer mappings
q.resyncPeerMappings()
// Schedule a full resync asynchronously so we don't block the
// pubsub drain goroutine. Singleflight coalesces concurrent
// resync requests.
go q.resyncGroup.Do("resync", func() (any, error) {
q.resyncPeerMappings()
return nil, nil
})
return
}
if err != nil {
@@ -1234,8 +1245,13 @@ func (q *querier) listenPeer(_ context.Context, msg []byte, err error) {
func (q *querier) listenTunnel(_ context.Context, msg []byte, err error) {
if xerrors.Is(err, pubsub.ErrDroppedMessages) {
q.logger.Warn(q.ctx, "pubsub may have dropped tunnel updates")
// we need to schedule a full resync of peer mappings
q.resyncPeerMappings()
// Schedule a full resync asynchronously so we don't block the
// pubsub drain goroutine. Singleflight coalesces concurrent
// resync requests.
go q.resyncGroup.Do("resync", func() (any, error) {
q.resyncPeerMappings()
return nil, nil
})
return
}
if err != nil {
@@ -1601,6 +1617,10 @@ func (h *heartbeats) filter(mappings []mapping) []mapping {
// the only mapping available for it. Newer mappings will take
// precedence.
m.kind = proto.CoordinateResponse_PeerUpdate_LOST
h.logger.Debug(h.ctx, "mapping rewritten to LOST due to missed heartbeats",
slog.F("peer_id", m.peer),
slog.F("coordinator_id", m.coordinator),
)
}
}
+2 -2
View File
@@ -76,11 +76,11 @@ replace github.com/aquasecurity/trivy => github.com/coder/trivy v0.0.0-202603091
// https://github.com/spf13/afero/pull/487
replace github.com/spf13/afero => github.com/aslilac/afero v0.0.0-20250403163713-f06e86036696
// Forked from kylecarbs/fantasy (cj/go1.25 branch) which adds:
// Forked from coder/fantasy (cj/go1.25 branch) which adds:
// 1) Anthropic computer use + thinking effort
// 2) Go 1.25 downgrade for Windows CI compat
// 3) ibetitsmike/fantasy#4 — skip ephemeral replay items when store=false
replace charm.land/fantasy => github.com/kylecarbs/fantasy v0.0.0-20260325145725-112927d9b6d8
replace charm.land/fantasy => github.com/coder/fantasy v0.0.0-20260325145725-112927d9b6d8
replace github.com/charmbracelet/anthropic-sdk-go => github.com/kylecarbs/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab
+2 -2
View File
@@ -322,6 +322,8 @@ github.com/coder/bubbletea v1.2.2-0.20241212190825-007a1cdb2c41 h1:SBN/DA63+ZHwu
github.com/coder/bubbletea v1.2.2-0.20241212190825-007a1cdb2c41/go.mod h1:I9ULxr64UaOSUv7hcb3nX4kowodJCVS7vt7VVJk/kW4=
github.com/coder/clistat v1.2.1 h1:P9/10njXMyj5cWzIU5wkRsSy5LVQH49+tcGMsAgWX0w=
github.com/coder/clistat v1.2.1/go.mod h1:m7SC0uj88eEERgvF8Kn6+w6XF21BeSr+15f7GoLAw0A=
github.com/coder/fantasy v0.0.0-20260325145725-112927d9b6d8 h1:n+6v+yT1B6V4oSGPmXFh7mul1E+RzG9rnqp50Vb7M/w=
github.com/coder/fantasy v0.0.0-20260325145725-112927d9b6d8/go.mod h1:ktfNX0xDpIKeggZbP/j5IYJci6pyMOR3WmZSfz9XLYw=
github.com/coder/flog v1.1.0 h1:kbAes1ai8fIS5OeV+QAnKBQE22ty1jRF/mcAwHpLBa4=
github.com/coder/flog v1.1.0/go.mod h1:UQlQvrkJBvnRGo69Le8E24Tcl5SJleAAR7gYEHzAmdQ=
github.com/coder/go-httpstat v0.0.0-20230801153223-321c88088322 h1:m0lPZjlQ7vdVpRBPKfYIFlmgevoTkBxB10wv6l2gOaU=
@@ -813,8 +815,6 @@ github.com/kylecarbs/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab h1:5UMY
github.com/kylecarbs/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab/go.mod h1:hqlYqR7uPKOKfnNeicUbZp0Ps0GeYFlKYtwh5HGDCx8=
github.com/kylecarbs/chroma/v2 v2.0.0-20240401211003-9e036e0631f3 h1:Z9/bo5PSeMutpdiKYNt/TTSfGM1Ll0naj3QzYX9VxTc=
github.com/kylecarbs/chroma/v2 v2.0.0-20240401211003-9e036e0631f3/go.mod h1:BUGjjsD+ndS6eX37YgTchSEG+Jg9Jv1GiZs9sqPqztk=
github.com/kylecarbs/fantasy v0.0.0-20260325145725-112927d9b6d8 h1:fZ0208U3B438fDSHCc/GNioPIyaFqn6eBsQTO61QtrI=
github.com/kylecarbs/fantasy v0.0.0-20260325145725-112927d9b6d8/go.mod h1:ktfNX0xDpIKeggZbP/j5IYJci6pyMOR3WmZSfz9XLYw=
github.com/kylecarbs/openai-go/v3 v3.0.0-20260319113850-9477dcaedcae h1:xlFZNX4nnxpj9Cf6mTwD3pirXGNtBJ/6COsf9iZmsL0=
github.com/kylecarbs/openai-go/v3 v3.0.0-20260319113850-9477dcaedcae/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
github.com/kylecarbs/spinner v1.18.2-0.20220329160715-20702b5af89e h1:OP0ZMFeZkUnOzTFRfpuK3m7Kp4fNvC6qN+exwj7aI4M=
+44 -35
View File
@@ -68,17 +68,17 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
return xerrors.Errorf("detecting branch: %w", err)
}
// Match standard release branches (release/2.32) and RC
// branches (release/2.32-rc.0).
branchRe := regexp.MustCompile(`^release/(\d+)\.(\d+)(?:-rc\.(\d+))?$`)
// Match release branches (release/X.Y). RCs are tagged
// from main, not from release branches.
branchRe := regexp.MustCompile(`^release/(\d+)\.(\d+)$`)
m := branchRe.FindStringSubmatch(currentBranch)
if m == nil {
warnf(w, "Current branch %q is not a release branch (release/X.Y or release/X.Y-rc.N).", currentBranch)
warnf(w, "Current branch %q is not a release branch (release/X.Y).", currentBranch)
branchInput, err := cliui.Prompt(inv, cliui.PromptOptions{
Text: "Enter the release branch to use (e.g. release/2.21 or release/2.21-rc.0)",
Text: "Enter the release branch to use (e.g. release/2.21)",
Validate: func(s string) error {
if !branchRe.MatchString(s) {
return xerrors.New("must be in format release/X.Y or release/X.Y-rc.N (e.g. release/2.21 or release/2.21-rc.0)")
return xerrors.New("must be in format release/X.Y (e.g. release/2.21)")
}
return nil
},
@@ -91,10 +91,6 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
}
branchMajor, _ := strconv.Atoi(m[1])
branchMinor, _ := strconv.Atoi(m[2])
branchRC := -1 // -1 means not an RC branch.
if m[3] != "" {
branchRC, _ = strconv.Atoi(m[3])
}
successf(w, "Using release branch: %s", currentBranch)
// --- Fetch & sync check ---
@@ -138,31 +134,44 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
}
}
// changelogBaseRef is the git ref used as the starting point
// for release notes generation. When a tag already exists in
// this minor series we use it directly. For the first release
// on a new minor no matching tag exists, so we compute the
// merge-base with the previous minor's release branch instead.
// This works even when that branch has no tags yet (it was
// just cut and pushed). As a last resort we fall back to the
// latest reachable tag from a previous minor.
var changelogBaseRef string
if prevVersion != nil {
changelogBaseRef = prevVersion.String()
} else {
prevReleaseBranch := fmt.Sprintf("release/%d.%d", branchMajor, branchMinor-1)
if err := gitRun("fetch", "--quiet", "origin", prevReleaseBranch); err != nil {
warnf(w, "Could not fetch %s: %v", prevReleaseBranch, err)
}
if mb, mbErr := gitOutput("merge-base", "HEAD", "origin/"+prevReleaseBranch); mbErr == nil && mb != "" {
changelogBaseRef = mb
infof(w, "Using merge-base with %s as changelog base: %s", prevReleaseBranch, mb[:12])
} else {
// No previous release branch found; fall back to
// the latest reachable tag from a previous minor.
for _, t := range mergedTags {
if t.Major == branchMajor && t.Minor < branchMinor {
changelogBaseRef = t.String()
break
}
}
}
}
var suggested version
if prevVersion == nil {
infof(w, "No previous release tag found on this branch.")
suggested = version{Major: branchMajor, Minor: branchMinor, Patch: 0}
if branchRC >= 0 {
suggested.Pre = fmt.Sprintf("rc.%d", branchRC)
}
} else {
infof(w, "Previous release tag: %s", prevVersion.String())
if branchRC >= 0 {
// On an RC branch, suggest the next RC for
// the same base version.
nextRC := 0
if prevVersion.IsRC() {
nextRC = prevVersion.rcNumber() + 1
}
suggested = version{
Major: prevVersion.Major,
Minor: prevVersion.Minor,
Patch: prevVersion.Patch,
Pre: fmt.Sprintf("rc.%d", nextRC),
}
} else {
suggested = version{Major: prevVersion.Major, Minor: prevVersion.Minor, Patch: prevVersion.Patch + 1}
}
suggested = version{Major: prevVersion.Major, Minor: prevVersion.Minor, Patch: prevVersion.Patch + 1}
}
fmt.Fprintln(w)
@@ -366,8 +375,8 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
infof(w, "Generating release notes...")
commitRange := "HEAD"
if prevVersion != nil {
commitRange = prevVersion.String() + "..HEAD"
if changelogBaseRef != "" {
commitRange = changelogBaseRef + "..HEAD"
}
commits, err := commitLog(commitRange)
@@ -473,16 +482,16 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
}
if !hasContent {
prevStr := "the beginning of time"
if prevVersion != nil {
prevStr = prevVersion.String()
if changelogBaseRef != "" {
prevStr = changelogBaseRef
}
fmt.Fprintf(&notes, "\n_No changes since %s._\n", prevStr)
}
// Compare link.
if prevVersion != nil {
if changelogBaseRef != "" {
fmt.Fprintf(&notes, "\nCompare: [`%s...%s`](https://github.com/%s/%s/compare/%s...%s)\n",
prevVersion, newVersion, owner, repo, prevVersion, newVersion)
changelogBaseRef, newVersion, owner, repo, changelogBaseRef, newVersion)
}
// Container image.
+31 -57
View File
@@ -1,7 +1,6 @@
import type { Interpolation, Theme } from "@emotion/react";
import type { FC, HTMLAttributes } from "react";
import type { LogLevel } from "#/api/typesGenerated";
import { MONOSPACE_FONT_FAMILY } from "#/theme/constants";
import { cn } from "#/utils/cn";
const DEFAULT_LOG_LINE_SIDE_PADDING = 24;
@@ -17,65 +16,40 @@ type LogLineProps = {
level: LogLevel;
} & HTMLAttributes<HTMLPreElement>;
export const LogLine: FC<LogLineProps> = ({ level, ...divProps }) => {
export const LogLine: FC<LogLineProps> = ({ level, className, ...props }) => {
return (
<pre
css={styles.line}
className={`${level} ${divProps.className} logs-line`}
{...divProps}
{...props}
className={cn(
"logs-line",
"m-0 break-all flex items-center h-auto",
"text-[13px] text-content-primary font-mono",
level === "error" &&
"bg-surface-error text-content-error [&_.dashed-line]:bg-border-error",
level === "debug" &&
"bg-surface-sky text-content-sky [&_.dashed-line]:bg-border-sky",
level === "warn" &&
"bg-surface-warning text-content-warning [&_.dashed-line]:bg-border-warning",
className,
)}
style={{
padding: `0 var(--log-line-side-padding, ${DEFAULT_LOG_LINE_SIDE_PADDING}px)`,
}}
/>
);
};
export const LogLinePrefix: FC<HTMLAttributes<HTMLSpanElement>> = (props) => {
return <pre css={styles.prefix} {...props} />;
export const LogLinePrefix: FC<HTMLAttributes<HTMLSpanElement>> = ({
className,
...props
}) => {
return (
<pre
className={cn(
"select-none m-0 inline-block text-content-secondary mr-6",
className,
)}
{...props}
/>
);
};
const styles = {
line: (theme) => ({
margin: 0,
wordBreak: "break-all",
display: "flex",
alignItems: "center",
fontSize: 13,
color: theme.palette.text.primary,
fontFamily: MONOSPACE_FONT_FAMILY,
height: "auto",
padding: `0 var(--log-line-side-padding, ${DEFAULT_LOG_LINE_SIDE_PADDING}px)`,
"&.error": {
backgroundColor: theme.roles.error.background,
color: theme.roles.error.text,
"& .dashed-line": {
backgroundColor: theme.roles.error.outline,
},
},
"&.debug": {
backgroundColor: theme.roles.notice.background,
color: theme.roles.notice.text,
"& .dashed-line": {
backgroundColor: theme.roles.notice.outline,
},
},
"&.warn": {
backgroundColor: theme.roles.warning.background,
color: theme.roles.warning.text,
"& .dashed-line": {
backgroundColor: theme.roles.warning.outline,
},
},
}),
prefix: (theme) => ({
userSelect: "none",
margin: 0,
display: "inline-block",
color: theme.palette.text.secondary,
marginRight: 24,
}),
} satisfies Record<string, Interpolation<Theme>>;
+12 -17
View File
@@ -1,6 +1,6 @@
import type { Interpolation, Theme } from "@emotion/react";
import dayjs from "dayjs";
import type { FC } from "react";
import { cn } from "#/utils/cn";
import { type Line, LogLine, LogLinePrefix } from "./LogLine";
export const DEFAULT_LOG_LINE_SIDE_PADDING = 24;
@@ -17,7 +17,17 @@ export const Logs: FC<LogsProps> = ({
className = "",
}) => {
return (
<div css={styles.root} className={`${className} logs-container`}>
<div
className={cn(
"logs-container",
"min-h-40 py-2 rounded-lg overflow-x-auto bg-surface-primary",
"[&:not(:last-child)]:border-0",
"[&:not(:last-child)]:border-solid",
"[&:not(:last-child)]:border-b-border",
"[&:not(:last-child)]:rounded-none",
className,
)}
>
<div className="min-w-fit">
{lines.map((line) => (
<LogLine key={line.id} level={line.level}>
@@ -33,18 +43,3 @@ export const Logs: FC<LogsProps> = ({
</div>
);
};
const styles = {
root: (theme) => ({
minHeight: 156,
padding: "8px 0",
borderRadius: 8,
overflowX: "auto",
background: theme.palette.background.default,
"&:not(:last-child)": {
borderBottom: `1px solid ${theme.palette.divider}`,
borderRadius: 0,
},
}),
} satisfies Record<string, Interpolation<Theme>>;
-1
View File
@@ -1 +0,0 @@
export { useTabOverflowKebabMenu } from "./useTabOverflowKebabMenu";
@@ -0,0 +1,274 @@
import {
type RefObject,
useCallback,
useEffect,
useRef,
useState,
} from "react";
type TabValue = {
value: string;
};
type UseKebabMenuOptions<T extends TabValue> = {
tabs: readonly T[];
enabled: boolean;
isActive: boolean;
overflowTriggerWidth?: number;
};
type UseKebabMenuResult<T extends TabValue> = {
containerRef: RefObject<HTMLDivElement | null>;
visibleTabs: T[];
overflowTabs: T[];
getTabMeasureProps: (tabValue: string) => Record<string, string>;
};
const ALWAYS_VISIBLE_TABS_COUNT = 1;
const DATA_ATTR_TAB_VALUE = "data-tab-overflow-item-value";
/**
* Splits tabs into visible and overflow groups based on container width.
*
* Tabs must render with `getTabMeasureProps()` so this hook can measure
* trigger widths from the DOM.
*/
export const useKebabMenu = <T extends TabValue>({
tabs,
enabled,
isActive,
overflowTriggerWidth = 44,
}: UseKebabMenuOptions<T>): UseKebabMenuResult<T> => {
const containerRef = useRef<HTMLDivElement>(null);
const tabsRef = useRef<readonly T[]>(tabs);
tabsRef.current = tabs;
const previousTabsRef = useRef<readonly T[]>(tabs);
const availableWidthRef = useRef<number | null>(null);
// Width cache prevents oscillation when overflow tabs are not mounted.
const tabWidthByValueRef = useRef<Record<string, number>>({});
const [overflowTabValues, setTabValues] = useState<string[]>([]);
const recalculateOverflow = useCallback(
(availableWidth: number) => {
if (!enabled || !isActive) {
// Keep this update idempotent to avoid render loops.
setTabValues((currentValues) => {
if (currentValues.length === 0) {
return currentValues;
}
return [];
});
return;
}
const container = containerRef.current;
if (!container) {
return;
}
const currentTabs = tabsRef.current;
const tabWidthByValue = measureTabWidths({
tabs: currentTabs,
container,
previousTabWidthByValue: tabWidthByValueRef.current,
});
tabWidthByValueRef.current = tabWidthByValue;
const nextOverflowValues = calculateTabValues({
tabs: currentTabs,
availableWidth,
tabWidthByValue,
overflowTriggerWidth,
});
setTabValues((currentValues) => {
// Avoid state updates when the computed overflow did not change.
if (areStringArraysEqual(currentValues, nextOverflowValues)) {
return currentValues;
}
return nextOverflowValues;
});
},
[enabled, isActive, overflowTriggerWidth],
);
useEffect(() => {
if (previousTabsRef.current === tabs) {
// No change in tabs, no need to recalculate.
return;
}
previousTabsRef.current = tabs;
if (availableWidthRef.current === null) {
// First mount, no width available yet.
return;
}
recalculateOverflow(availableWidthRef.current);
}, [recalculateOverflow, tabs]);
useEffect(() => {
const container = containerRef.current;
if (!container) {
return;
}
// Recompute whenever ResizeObserver reports a container width change.
const observer = new ResizeObserver(([entry]) => {
if (!entry) {
return;
}
availableWidthRef.current = entry.contentRect.width;
recalculateOverflow(entry.contentRect.width);
});
observer.observe(container);
return () => observer.disconnect();
}, [recalculateOverflow]);
const overflowTabValuesSet = new Set(overflowTabValues);
const { visibleTabs, overflowTabs } = tabs.reduce<{
visibleTabs: T[];
overflowTabs: T[];
}>(
(tabGroups, tab) => {
if (overflowTabValuesSet.has(tab.value)) {
tabGroups.overflowTabs.push(tab);
} else {
tabGroups.visibleTabs.push(tab);
}
return tabGroups;
},
{ visibleTabs: [], overflowTabs: [] },
);
const getTabMeasureProps = (tabValue: string) => {
return { [DATA_ATTR_TAB_VALUE]: tabValue };
};
return {
containerRef,
visibleTabs,
overflowTabs,
getTabMeasureProps,
};
};
const calculateTabValues = <T extends TabValue>({
tabs,
availableWidth,
tabWidthByValue,
overflowTriggerWidth,
}: {
tabs: readonly T[];
availableWidth: number;
tabWidthByValue: Readonly<Record<string, number>>;
overflowTriggerWidth: number;
}): string[] => {
const tabWidthByValueMap = new Map<string, number>();
for (const tab of tabs) {
tabWidthByValueMap.set(tab.value, tabWidthByValue[tab.value] ?? 0);
}
const firstOptionalTabIndex = Math.min(
ALWAYS_VISIBLE_TABS_COUNT,
tabs.length,
);
if (firstOptionalTabIndex >= tabs.length) {
return [];
}
const alwaysVisibleTabs = tabs.slice(0, firstOptionalTabIndex);
const optionalTabs = tabs.slice(firstOptionalTabIndex);
const alwaysVisibleWidth = alwaysVisibleTabs.reduce((total, tab) => {
return total + (tabWidthByValueMap.get(tab.value) ?? 0);
}, 0);
const firstTabIndex = findFirstTabIndex({
optionalTabs,
optionalTabWidths: optionalTabs.map((tab) => {
return tabWidthByValueMap.get(tab.value) ?? 0;
}),
startingUsedWidth: alwaysVisibleWidth,
availableWidth,
overflowTriggerWidth,
});
if (firstTabIndex === -1) {
return [];
}
return optionalTabs
.slice(firstTabIndex)
.map((overflowTab) => overflowTab.value);
};
const measureTabWidths = <T extends TabValue>({
tabs,
container,
previousTabWidthByValue,
}: {
tabs: readonly T[];
container: HTMLDivElement;
previousTabWidthByValue: Readonly<Record<string, number>>;
}): Record<string, number> => {
const nextTabWidthByValue = { ...previousTabWidthByValue };
for (const tab of tabs) {
const tabElement = container.querySelector<HTMLElement>(
`[${DATA_ATTR_TAB_VALUE}="${tab.value}"]`,
);
if (tabElement) {
nextTabWidthByValue[tab.value] = tabElement.offsetWidth;
}
}
return nextTabWidthByValue;
};
const findFirstTabIndex = ({
optionalTabs,
optionalTabWidths,
startingUsedWidth,
availableWidth,
overflowTriggerWidth,
}: {
optionalTabs: readonly TabValue[];
optionalTabWidths: readonly number[];
startingUsedWidth: number;
availableWidth: number;
overflowTriggerWidth: number;
}): number => {
const result = optionalTabs.reduce(
(acc, _tab, index) => {
if (acc.firstTabIndex !== -1) {
return acc;
}
const tabWidth = optionalTabWidths[index] ?? 0;
const hasMoreTabs = index < optionalTabs.length - 1;
// Reserve kebab trigger width whenever additional tabs remain.
const widthNeeded =
acc.usedWidth + tabWidth + (hasMoreTabs ? overflowTriggerWidth : 0);
if (widthNeeded <= availableWidth) {
return {
usedWidth: acc.usedWidth + tabWidth,
firstTabIndex: -1,
};
}
return {
usedWidth: acc.usedWidth,
firstTabIndex: index,
};
},
{ usedWidth: startingUsedWidth, firstTabIndex: -1 },
);
return result.firstTabIndex;
};
const areStringArraysEqual = (
left: readonly string[],
right: readonly string[],
): boolean => {
return (
left.length === right.length &&
left.every((value, index) => value === right[index])
);
};
@@ -1,157 +0,0 @@
import {
type RefObject,
useCallback,
useEffect,
useLayoutEffect,
useMemo,
useRef,
useState,
} from "react";
type TabLike = {
value: string;
};
type UseTabOverflowKebabMenuOptions<TTab extends TabLike> = {
tabs: readonly TTab[];
enabled: boolean;
isActive: boolean;
alwaysVisibleTabsCount?: number;
overflowTriggerWidthPx?: number;
};
type UseTabOverflowKebabMenuResult<TTab extends TabLike> = {
containerRef: RefObject<HTMLDivElement | null>;
visibleTabs: TTab[];
overflowTabs: TTab[];
getTabMeasureProps: (tabValue: string) => Record<string, string>;
};
const DATA_ATTR_TAB_VALUE = "data-tab-overflow-item-value";
export const useTabOverflowKebabMenu = <TTab extends TabLike>({
tabs,
enabled,
isActive,
alwaysVisibleTabsCount = 1,
overflowTriggerWidthPx = 44,
}: UseTabOverflowKebabMenuOptions<TTab>): UseTabOverflowKebabMenuResult<TTab> => {
const containerRef = useRef<HTMLDivElement>(null);
const tabWidthByValueRef = useRef<Record<string, number>>({});
const [overflowTabValues, setOverflowTabValues] = useState<string[]>([]);
const recalculateOverflow = useCallback(() => {
if (!enabled) {
setOverflowTabValues([]);
return;
}
const container = containerRef.current;
if (!container) {
return;
}
for (const tab of tabs) {
const tabElement = container.querySelector<HTMLElement>(
`[${DATA_ATTR_TAB_VALUE}="${tab.value}"]`,
);
if (tabElement) {
tabWidthByValueRef.current[tab.value] = tabElement.offsetWidth;
}
}
const alwaysVisibleTabs = tabs.slice(0, alwaysVisibleTabsCount);
const optionalTabs = tabs.slice(alwaysVisibleTabsCount);
if (optionalTabs.length === 0) {
setOverflowTabValues([]);
return;
}
const alwaysVisibleWidth = alwaysVisibleTabs.reduce((total, tab) => {
return total + (tabWidthByValueRef.current[tab.value] ?? 0);
}, 0);
const availableWidth = container.clientWidth;
let usedWidth = alwaysVisibleWidth;
const nextOverflowValues: string[] = [];
for (let i = 0; i < optionalTabs.length; i++) {
const tab = optionalTabs[i];
const tabWidth = tabWidthByValueRef.current[tab.value] ?? 0;
const hasMoreTabsAfterCurrent = i < optionalTabs.length - 1;
const widthNeeded =
usedWidth +
tabWidth +
(hasMoreTabsAfterCurrent ? overflowTriggerWidthPx : 0);
if (widthNeeded <= availableWidth) {
usedWidth += tabWidth;
continue;
}
nextOverflowValues.push(
...optionalTabs.slice(i).map((overflowTab) => overflowTab.value),
);
break;
}
setOverflowTabValues((currentValues) => {
if (
currentValues.length === nextOverflowValues.length &&
currentValues.every(
(value, index) => value === nextOverflowValues[index],
)
) {
return currentValues;
}
return nextOverflowValues;
});
}, [alwaysVisibleTabsCount, enabled, overflowTriggerWidthPx, tabs]);
useLayoutEffect(() => {
if (!isActive) {
return;
}
recalculateOverflow();
}, [isActive, recalculateOverflow]);
useEffect(() => {
if (!isActive) {
return;
}
const container = containerRef.current;
if (!container) {
return;
}
const observer = new ResizeObserver(() => {
recalculateOverflow();
});
observer.observe(container);
return () => observer.disconnect();
}, [isActive, recalculateOverflow]);
const overflowTabValuesSet = useMemo(
() => new Set(overflowTabValues),
[overflowTabValues],
);
const visibleTabs = useMemo(
() => tabs.filter((tab) => !overflowTabValuesSet.has(tab.value)),
[tabs, overflowTabValuesSet],
);
const overflowTabs = useMemo(
() => tabs.filter((tab) => overflowTabValuesSet.has(tab.value)),
[tabs, overflowTabValuesSet],
);
const getTabMeasureProps = useCallback((tabValue: string) => {
return { [DATA_ATTR_TAB_VALUE]: tabValue };
}, []);
return {
containerRef,
visibleTabs,
overflowTabs,
getTabMeasureProps,
};
};
@@ -2,7 +2,7 @@ import type { Meta, StoryObj } from "@storybook/react-vite";
import { expect, spyOn, userEvent, waitFor, within } from "storybook/test";
import { API } from "#/api/api";
import { workspaceAgentContainersKey } from "#/api/queries/workspaces";
import type * as TypesGen from "#/api/typesGenerated";
import type { WorkspaceAgentLogSource } from "#/api/typesGenerated";
import { getPreferredProxy } from "#/contexts/ProxyContext";
import { chromatic } from "#/testHelpers/chromatic";
import * as M from "#/testHelpers/entities";
@@ -92,7 +92,7 @@ const logs = [
created_at: fixedLogTimestamp,
}));
const installScriptLogSource: TypesGen.WorkspaceAgentLogSource = {
const installScriptLogSource: WorkspaceAgentLogSource = {
...M.MockWorkspaceAgentLogSource,
id: "f2ee4b8d-b09d-4f4e-a1f1-5e4adf7d53bb",
display_name: "Install Script",
@@ -122,42 +122,6 @@ const tabbedLogs = [
},
];
const overflowLogSources: TypesGen.WorkspaceAgentLogSource[] = [
M.MockWorkspaceAgentLogSource,
{
...M.MockWorkspaceAgentLogSource,
id: "58f5db69-5f78-496f-bce1-0686f5525aa1",
display_name: "code-server",
icon: "/icon/code.svg",
},
{
...M.MockWorkspaceAgentLogSource,
id: "f39d758c-bce2-4f41-8d70-58fdb1f0f729",
display_name: "Install and start AgentAPI",
icon: "/icon/claude.svg",
},
{
...M.MockWorkspaceAgentLogSource,
id: "bf7529b8-1787-4a20-b54f-eb894680e48f",
display_name: "Mux",
icon: "/icon/mux.svg",
},
{
...M.MockWorkspaceAgentLogSource,
id: "0d6ebde6-c534-4551-9f91-bfd98bfb04f4",
display_name: "Portable Desktop",
icon: "/icon/portable-desktop.svg",
},
];
const overflowLogs = overflowLogSources.map((source, index) => ({
id: 200 + index,
level: "info",
output: `${source.display_name}: line`,
source_id: source.id,
created_at: fixedLogTimestamp,
}));
const meta: Meta<typeof AgentRow> = {
title: "components/AgentRow",
component: AgentRow,
@@ -440,44 +404,3 @@ export const LogsTabs: Story = {
await expect(canvas.getByText("install: pnpm install")).toBeVisible();
},
};
export const LogsTabsOverflow: Story = {
args: {
agent: {
...M.MockWorkspaceAgentReady,
logs_length: overflowLogs.length,
log_sources: overflowLogSources,
},
},
parameters: {
webSocket: [
{
event: "message",
data: JSON.stringify(overflowLogs),
},
],
},
render: (args) => (
<div className="max-w-[320px]">
<AgentRow {...args} />
</div>
),
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
const page = within(canvasElement.ownerDocument.body);
await userEvent.click(canvas.getByRole("button", { name: "Logs" }));
await userEvent.click(
canvas.getByRole("button", { name: "More log tabs" }),
);
const overflowItems = await page.findAllByRole("menuitemradio");
const selectedItem = overflowItems[0];
const selectedSource = selectedItem.textContent;
if (!selectedSource) {
throw new Error("Overflow menu item must have text content.");
}
await userEvent.click(selectedItem);
await waitFor(() =>
expect(canvas.getByText(`${selectedSource}: line`)).toBeVisible(),
);
},
};
+69 -60
View File
@@ -8,10 +8,9 @@ import {
} from "lucide-react";
import {
type FC,
useCallback,
type ReactNode,
useEffect,
useLayoutEffect,
useMemo,
useRef,
useState,
} from "react";
@@ -42,7 +41,7 @@ import {
TabsList,
TabsTrigger,
} from "#/components/Tabs/Tabs";
import { useTabOverflowKebabMenu } from "#/components/Tabs/utils";
import { useKebabMenu } from "#/components/Tabs/utils/useKebabMenu";
import { useProxy } from "#/contexts/ProxyContext";
import { useClipboard } from "#/hooks/useClipboard";
import { useFeatureVisibility } from "#/modules/dashboard/useFeatureVisibility";
@@ -162,7 +161,7 @@ export const AgentRow: FC<AgentRowProps> = ({
// This is a bit of a hack on the react-window API to get the scroll position.
// If we're scrolled to the bottom, we want to keep the list scrolled to the bottom.
// This makes it feel similar to a terminal that auto-scrolls downwards!
const handleLogScroll = useCallback((props: ListOnScrollProps) => {
const handleLogScroll = (props: ListOnScrollProps) => {
if (
props.scrollOffset === 0 ||
props.scrollUpdateWasRequested ||
@@ -179,7 +178,7 @@ export const AgentRow: FC<AgentRowProps> = ({
logListDivRef.current.scrollHeight -
(props.scrollOffset + parent.clientHeight);
setBottomOfLogs(distanceFromBottom < AGENT_LOG_LINE_HEIGHT);
}, []);
};
const devcontainers = useAgentContainers(agent);
@@ -211,59 +210,56 @@ export const AgentRow: FC<AgentRowProps> = ({
);
const [selectedLogTab, setSelectedLogTab] = useState("all");
const logTabs = useMemo(() => {
const sourceLogTabs = agent.log_sources
.filter((logSource) => {
// Remove the logSources that have no entries.
return agentLogs.some(
(log) =>
log.source_id === logSource.id && (log.output?.length ?? 0) > 0,
);
})
.map((logSource) => ({
// Show the icon for the log source if it has one.
// In the startup script case, we show a bespoke play icon.
startIcon: logSource.icon ? (
<ExternalImage
src={logSource.icon}
alt=""
className="size-icon-xs shrink-0"
/>
) : logSource.display_name === STARTUP_SCRIPT_DISPLAY_NAME ? (
<PlayIcon className="size-icon-xs shrink-0" />
) : null,
title: logSource.display_name,
value: logSource.id,
}));
const startupScriptLogTab = sourceLogTabs.find(
(tab) => tab.title === STARTUP_SCRIPT_DISPLAY_NAME,
);
const sortedSourceLogTabs = sourceLogTabs
.filter((tab) => tab !== startupScriptLogTab)
.sort((a, b) => a.title.localeCompare(b.title));
return [
{
title: "All Logs",
value: "all",
},
...(startupScriptLogTab ? [startupScriptLogTab] : []),
...sortedSourceLogTabs,
] as {
startIcon?: React.ReactNode;
title: string;
value: string;
}[];
}, [agent.log_sources, agentLogs]);
const sourceLogTabs = agent.log_sources
.filter((logSource) => {
// Remove the logSources that have no entries.
return agentLogs.some(
(log) =>
log.source_id === logSource.id && (log.output?.length ?? 0) > 0,
);
})
.map((logSource) => ({
// Show the icon for the log source if it has one.
// In the startup script case, we show a bespoke play icon.
startIcon: logSource.icon ? (
<ExternalImage
src={logSource.icon}
alt=""
className="size-icon-xs shrink-0"
/>
) : logSource.display_name === STARTUP_SCRIPT_DISPLAY_NAME ? (
<PlayIcon className="size-icon-xs shrink-0" />
) : null,
title: logSource.display_name,
value: logSource.id,
}));
const startupScriptLogTab = sourceLogTabs.find(
(tab) => tab.title === STARTUP_SCRIPT_DISPLAY_NAME,
);
const sortedSourceLogTabs = sourceLogTabs
.filter((tab) => tab !== startupScriptLogTab)
.sort((a, b) => a.title.localeCompare(b.title));
const logTabs: {
startIcon?: ReactNode;
title: string;
value: string;
}[] = [
{
title: "All Logs",
value: "all",
},
...(startupScriptLogTab ? [startupScriptLogTab] : []),
...sortedSourceLogTabs,
];
const {
containerRef: logTabsListContainerRef,
visibleTabs: visibleLogTabs,
overflowTabs: overflowLogTabs,
getTabMeasureProps,
} = useTabOverflowKebabMenu({
} = useKebabMenu({
tabs: logTabs,
enabled: true,
isActive: showLogs,
alwaysVisibleTabsCount: 1,
});
const overflowLogTabValuesSet = new Set(
overflowLogTabs.map((tab) => tab.value),
@@ -279,16 +275,29 @@ export const AgentRow: FC<AgentRowProps> = ({
level: log.level,
sourceId: log.source_id,
}));
const allLogsText = agentLogs.map((log) => log.output).join("\n");
const selectedLogsText = selectedLogs.map((log) => log.output).join("\n");
const hasSelectedLogs = selectedLogs.length > 0;
const hasAnyLogs = agentLogs.length > 0;
const { showCopiedSuccess, copyToClipboard } = useClipboard();
const selectedLogTabTitle =
logTabs.find((tab) => tab.value === selectedLogTab)?.title ?? "Logs";
const sanitizedTabTitle = selectedLogTabTitle
.toLowerCase()
.replaceAll(/[^a-z0-9]+/g, "-")
.replaceAll(/(^-|-$)/g, "");
const logFilenameSuffix = sanitizedTabTitle || "logs";
const downloadableLogSets = logTabs
.filter((tab) => tab.value !== "all")
.map((tab) => {
const logsText = agentLogs
.filter((log) => log.source_id === tab.value)
.map((log) => log.output)
.join("\n");
const filenameSuffix = tab.title
.toLowerCase()
.replaceAll(/[^a-z0-9]+/g, "-")
.replaceAll(/(^-|-$)/g, "");
return {
label: tab.title,
filenameSuffix: filenameSuffix || tab.value,
logsText,
startIcon: tab.startIcon,
};
});
return (
<div
@@ -547,9 +556,9 @@ export const AgentRow: FC<AgentRowProps> = ({
</Button>
<DownloadSelectedAgentLogsButton
agentName={agent.name}
filenameSuffix={logFilenameSuffix}
logsText={selectedLogsText}
disabled={!hasSelectedLogs}
logSets={downloadableLogSets}
allLogsText={allLogsText}
disabled={!hasAnyLogs}
/>
</div>
</div>
@@ -1,14 +1,27 @@
import { saveAs } from "file-saver";
import { DownloadIcon } from "lucide-react";
import { type FC, useState } from "react";
import { ChevronDownIcon, DownloadIcon, PackageIcon } from "lucide-react";
import { type FC, type ReactNode, useState } from "react";
import { toast } from "sonner";
import { getErrorDetail } from "#/api/errors";
import { Button } from "#/components/Button/Button";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuTrigger,
} from "#/components/DropdownMenu/DropdownMenu";
type DownloadableLogSet = {
label: string;
filenameSuffix: string;
logsText: string;
startIcon?: ReactNode;
};
type DownloadSelectedAgentLogsButtonProps = {
agentName: string;
filenameSuffix: string;
logsText: string;
logSets: readonly DownloadableLogSet[];
allLogsText: string;
disabled?: boolean;
download?: (file: Blob, filename: string) => void | Promise<void>;
};
@@ -17,13 +30,13 @@ export const DownloadSelectedAgentLogsButton: FC<
DownloadSelectedAgentLogsButtonProps
> = ({
agentName,
filenameSuffix,
logsText,
logSets,
allLogsText,
disabled = false,
download = saveAs,
}) => {
const [isDownloading, setIsDownloading] = useState(false);
const handleDownload = async () => {
const downloadLogs = async (logsText: string, filenameSuffix: string) => {
try {
setIsDownloading(true);
const file = new Blob([logsText], { type: "text/plain" });
@@ -37,15 +50,40 @@ export const DownloadSelectedAgentLogsButton: FC<
}
};
const hasAllLogs = allLogsText.length > 0;
return (
<Button
variant="subtle"
size="sm"
disabled={disabled || isDownloading}
onClick={handleDownload}
>
<DownloadIcon />
{isDownloading ? "Downloading..." : "Download logs"}
</Button>
<DropdownMenu>
<DropdownMenuTrigger asChild>
<Button variant="subtle" size="sm" disabled={disabled || isDownloading}>
<DownloadIcon />
{isDownloading ? "Downloading..." : "Download logs"}
<ChevronDownIcon className="size-icon-sm" />
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuItem
disabled={!hasAllLogs}
onSelect={() => {
downloadLogs(allLogsText, "all-logs");
}}
>
<PackageIcon />
Download all logs
</DropdownMenuItem>
{logSets.map((logSet) => (
<DropdownMenuItem
key={logSet.filenameSuffix}
disabled={logSet.logsText.length === 0}
onSelect={() => {
downloadLogs(logSet.logsText, logSet.filenameSuffix);
}}
>
{logSet.startIcon}
<span>Download {logSet.label}</span>
</DropdownMenuItem>
))}
</DropdownMenuContent>
</DropdownMenu>
);
};
@@ -0,0 +1,63 @@
import { BookOpenIcon, LoaderIcon, TriangleAlertIcon } from "lucide-react";
import type React from "react";
import { ScrollArea } from "#/components/ScrollArea/ScrollArea";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "#/components/Tooltip/Tooltip";
import { cn } from "#/utils/cn";
import { Response } from "../Response";
import { ToolCollapsible } from "./ToolCollapsible";
import type { ToolStatus } from "./utils";
export const ReadSkillTool: React.FC<{
label: string;
body: string;
status: ToolStatus;
isError: boolean;
errorMessage?: string;
}> = ({ label, body, status, isError, errorMessage }) => {
const hasContent = body.length > 0;
const isRunning = status === "running";
return (
<ToolCollapsible
className="w-full"
hasContent={hasContent}
header={
<>
<BookOpenIcon className="h-4 w-4 shrink-0 text-content-secondary" />
<span className={cn("text-sm", "text-content-secondary")}>
{isRunning ? `Reading ${label}` : `Read ${label}`}
</span>
{isError && (
<Tooltip>
<TooltipTrigger asChild>
<TriangleAlertIcon className="h-3.5 w-3.5 shrink-0 text-content-secondary" />
</TooltipTrigger>
<TooltipContent>
{errorMessage || "Failed to read skill"}
</TooltipContent>
</Tooltip>
)}
{isRunning && (
<LoaderIcon className="h-3.5 w-3.5 shrink-0 animate-spin motion-reduce:animate-none text-content-secondary" />
)}
</>
}
>
{body && (
<ScrollArea
className="mt-1.5 rounded-md border border-solid border-border-default"
viewportClassName="max-h-64"
scrollBarClassName="w-1.5"
>
<div className="px-3 py-2">
<Response>{body}</Response>
</div>
</ScrollArea>
)}
</ToolCollapsible>
);
};
@@ -1342,6 +1342,12 @@ export const ReadSkillCompleted: Story = {
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
expect(canvas.getByText(/Read skill deep-review/)).toBeInTheDocument();
// Expand the collapsible to verify markdown body renders.
const toggle = canvas.getByRole("button");
await userEvent.click(toggle);
await waitFor(() => {
expect(canvas.getByText("Deep Review Skill")).toBeInTheDocument();
});
},
};
@@ -1393,6 +1399,12 @@ export const ReadSkillFileCompleted: Story = {
expect(
canvas.getByText(/Read deep-review\/roles\/security-reviewer\.md/),
).toBeInTheDocument();
// Expand the collapsible to verify markdown content renders.
const toggle = canvas.getByRole("button");
await userEvent.click(toggle);
await waitFor(() => {
expect(canvas.getByText("Security Reviewer Role")).toBeInTheDocument();
});
},
};
@@ -23,6 +23,7 @@ import { ListTemplatesTool } from "./ListTemplatesTool";
import { ProcessOutputTool } from "./ProcessOutputTool";
import { ProposePlanTool } from "./ProposePlanTool";
import { ReadFileTool } from "./ReadFileTool";
import { ReadSkillTool } from "./ReadSkillTool";
import { ReadTemplateTool } from "./ReadTemplateTool";
import { SubagentTool } from "./SubagentTool";
import { ToolCollapsible } from "./ToolCollapsible";
@@ -210,6 +211,55 @@ const ReadFileRenderer: FC<ToolRendererProps> = ({
);
};
const ReadSkillRenderer: FC<ToolRendererProps> = ({
status,
args,
result,
isError,
}) => {
const parsedArgs = parseArgs(args);
const skillName = parsedArgs ? asString(parsedArgs.name) : "";
const rec = asRecord(result);
const body = rec ? asString(rec.body) : "";
return (
<ReadSkillTool
label={skillName ? `skill ${skillName}` : "skill"}
body={body}
status={status}
isError={isError}
errorMessage={rec ? asString(rec.error || rec.message) : undefined}
/>
);
};
const ReadSkillFileRenderer: FC<ToolRendererProps> = ({
status,
args,
result,
isError,
}) => {
const parsedArgs = parseArgs(args);
const skillName = parsedArgs ? asString(parsedArgs.name) : "";
const filePath = parsedArgs ? asString(parsedArgs.path) : "";
const label =
skillName && filePath
? `${skillName}/${filePath}`
: skillName || filePath || "skill file";
const rec = asRecord(result);
const content = rec ? asString(rec.content) : "";
return (
<ReadSkillTool
label={label}
body={content}
status={status}
isError={isError}
errorMessage={rec ? asString(rec.error || rec.message) : undefined}
/>
);
};
const WriteFileRenderer: FC<ToolRendererProps> = ({
status,
args,
@@ -667,6 +717,8 @@ const toolRenderers: Record<string, FC<ToolRendererProps>> = {
create_workspace: CreateWorkspaceRenderer,
list_templates: ListTemplatesRenderer,
read_template: ReadTemplateRenderer,
read_skill: ReadSkillRenderer,
read_skill_file: ReadSkillFileRenderer,
spawn_agent: SubagentRenderer,
wait_agent: SubagentRenderer,
message_agent: SubagentRenderer,
@@ -462,6 +462,9 @@ export const CreateAndUpdateProvider: Story = {
await waitFor(() => {
expect(args.onCreateProvider).toHaveBeenCalledTimes(1);
});
await waitFor(() => {
expect(body.getByRole("button", { name: "Save changes" })).toBeDisabled();
});
expect(args.onCreateProvider).toHaveBeenCalledWith(
expect.objectContaining({
provider: "openai",
+50 -42
View File
@@ -12,6 +12,8 @@ import (
"sync"
"time"
"golang.org/x/sync/singleflight"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
"github.com/tailscale/wireguard-go/tun"
@@ -463,6 +465,8 @@ type Conn struct {
trafficStats *connstats.Statistics
lastNetInfo *tailcfg.NetInfo
awaitReachableGroup singleflight.Group
}
func (c *Conn) GetNetInfo() *tailcfg.NetInfo {
@@ -599,56 +603,60 @@ func (c *Conn) DERPMap() *tailcfg.DERPMap {
// address is reachable. It's the callers responsibility to provide
// a timeout, otherwise this function will block forever.
func (c *Conn) AwaitReachable(ctx context.Context, ip netip.Addr) bool {
ctx, cancel := context.WithCancel(ctx)
defer cancel() // Cancel all pending pings on exit.
result, _, _ := c.awaitReachableGroup.Do(ip.String(), func() (interface{}, error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel() // Cancel all pending pings on exit.
completedCtx, completed := context.WithCancel(context.Background())
defer completed()
completedCtx, completed := context.WithCancel(context.Background())
defer completed()
run := func() {
// Safety timeout, initially we'll have around 10-20 goroutines
// running in parallel. The exponential backoff will converge
// around ~1 ping / 30s, this means we'll have around 10-20
// goroutines pending towards the end as well.
ctx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel()
run := func() {
// Safety timeout, initially we'll have around 10-20 goroutines
// running in parallel. The exponential backoff will converge
// around ~1 ping / 30s, this means we'll have around 10-20
// goroutines pending towards the end as well.
ctx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel()
// For reachability, we use TSMP ping, which pings at the IP layer, and
// therefore requires that wireguard and the netstack are up. If we
// don't wait for wireguard to be up, we could miss a handshake, and it
// might take 5 seconds for the handshake to be retried. A 5s initial
// round trip can set us up for poor TCP performance, since the initial
// round-trip-time sets the initial retransmit timeout.
_, _, _, err := c.pingWithType(ctx, ip, tailcfg.PingTSMP)
if err == nil {
completed()
// For reachability, we use TSMP ping, which pings at the IP layer,
// and therefore requires that wireguard and the netstack are up.
// If we don't wait for wireguard to be up, we could miss a
// handshake, and it might take 5 seconds for the handshake to be
// retried. A 5s initial round trip can set us up for poor TCP
// performance, since the initial round-trip-time sets the initial
// retransmit timeout.
_, _, _, err := c.pingWithType(ctx, ip, tailcfg.PingTSMP)
if err == nil {
completed()
}
}
}
eb := backoff.NewExponentialBackOff()
eb.MaxElapsedTime = 0
eb.InitialInterval = 50 * time.Millisecond
eb.MaxInterval = 30 * time.Second
// Consume the first interval since
// we'll fire off a ping immediately.
_ = eb.NextBackOff()
eb := backoff.NewExponentialBackOff()
eb.MaxElapsedTime = 0
eb.InitialInterval = 50 * time.Millisecond
eb.MaxInterval = 5 * time.Second
// Consume the first interval since
// we'll fire off a ping immediately.
_ = eb.NextBackOff()
t := backoff.NewTicker(eb)
defer t.Stop()
t := backoff.NewTicker(eb)
defer t.Stop()
go run()
for {
select {
case <-completedCtx.Done():
return true
case <-t.C:
// Pings can take a while, so we can run multiple
// in parallel to return ASAP.
go run()
case <-ctx.Done():
return false
go run()
for {
select {
case <-completedCtx.Done():
return true, nil
case <-t.C:
// Pings can take a while, so we can run multiple
// in parallel to return ASAP.
go run()
case <-ctx.Done():
return false, nil
}
}
}
})
return result.(bool)
}
// Closed is a channel that ends when the connection has