Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| cfd7730194 | |||
| 1937ada0cd | |||
| d64cd6415d | |||
| c1851d9453 | |||
| 8f73453681 | |||
| 165db3d31c | |||
| 1bd1516fd1 | |||
| 81ba35a987 | |||
| 53d63cf8e9 | |||
| 4213a43b53 | |||
| 5453a6c6d6 | |||
| 21c08a37d7 | |||
| 2bd261fbbf | |||
| cffc68df58 | |||
| 6e5335df1e | |||
| 16265e834e | |||
| 565a15bc9b | |||
| 76a2cb1af5 |
@@ -31,8 +31,7 @@ updates:
|
||||
patterns:
|
||||
- "golang.org/x/*"
|
||||
ignore:
|
||||
# Patch updates are handled by the security-patch-prs workflow so this
|
||||
# lane stays focused on broader dependency updates.
|
||||
# Ignore patch updates for all dependencies
|
||||
- dependency-name: "*"
|
||||
update-types:
|
||||
- version-update:semver-patch
|
||||
@@ -57,7 +56,7 @@ updates:
|
||||
labels: []
|
||||
ignore:
|
||||
# We need to coordinate terraform updates with the version hardcoded in
|
||||
# our Go code. These are handled by the security-patch-prs workflow.
|
||||
# our Go code.
|
||||
- dependency-name: "terraform"
|
||||
|
||||
- package-ecosystem: "npm"
|
||||
@@ -118,11 +117,11 @@ updates:
|
||||
interval: "weekly"
|
||||
commit-message:
|
||||
prefix: "chore"
|
||||
labels: []
|
||||
groups:
|
||||
coder-modules:
|
||||
patterns:
|
||||
- "coder/*/coder"
|
||||
labels: []
|
||||
ignore:
|
||||
- dependency-name: "*"
|
||||
update-types:
|
||||
|
||||
@@ -1,354 +0,0 @@
|
||||
name: security-backport
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types:
|
||||
- labeled
|
||||
- unlabeled
|
||||
- closed
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
pull_request:
|
||||
description: Pull request number to backport.
|
||||
required: true
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
issues: write
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || inputs.pull_request }}
|
||||
cancel-in-progress: false
|
||||
|
||||
env:
|
||||
LATEST_BRANCH: release/2.31
|
||||
STABLE_BRANCH: release/2.30
|
||||
STABLE_1_BRANCH: release/2.29
|
||||
|
||||
jobs:
|
||||
label-policy:
|
||||
if: github.event_name == 'pull_request_target'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Apply security backport label policy
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
pr_number="$(jq -r '.pull_request.number' "${GITHUB_EVENT_PATH}")"
|
||||
pr_json="$(gh pr view "${pr_number}" --json number,title,url,baseRefName,labels)"
|
||||
|
||||
PR_JSON="${pr_json}" \
|
||||
python3 - <<'PY'
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
pr = json.loads(os.environ["PR_JSON"])
|
||||
|
||||
pr_number = pr["number"]
|
||||
labels = [label["name"] for label in pr.get("labels", [])]
|
||||
|
||||
def has(label: str) -> bool:
|
||||
return label in labels
|
||||
|
||||
def ensure_label(label: str) -> None:
|
||||
if not has(label):
|
||||
subprocess.run(
|
||||
["gh", "pr", "edit", str(pr_number), "--add-label", label],
|
||||
check=False,
|
||||
)
|
||||
|
||||
def remove_label(label: str) -> None:
|
||||
if has(label):
|
||||
subprocess.run(
|
||||
["gh", "pr", "edit", str(pr_number), "--remove-label", label],
|
||||
check=False,
|
||||
)
|
||||
|
||||
def comment(body: str) -> None:
|
||||
subprocess.run(
|
||||
["gh", "pr", "comment", str(pr_number), "--body", body],
|
||||
check=True,
|
||||
)
|
||||
|
||||
if not has("security:patch"):
|
||||
remove_label("status:needs-severity")
|
||||
sys.exit(0)
|
||||
|
||||
severity_labels = [
|
||||
label
|
||||
for label in ("severity:medium", "severity:high", "severity:critical")
|
||||
if has(label)
|
||||
]
|
||||
if len(severity_labels) == 0:
|
||||
ensure_label("status:needs-severity")
|
||||
comment(
|
||||
"This PR is labeled `security:patch` but is missing a severity "
|
||||
"label. Add one of `severity:medium`, `severity:high`, or "
|
||||
"`severity:critical` before backport automation can proceed."
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
if len(severity_labels) > 1:
|
||||
comment(
|
||||
"This PR has multiple severity labels. Keep exactly one of "
|
||||
"`severity:medium`, `severity:high`, or `severity:critical`."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
remove_label("status:needs-severity")
|
||||
|
||||
target_labels = [
|
||||
label
|
||||
for label in ("backport:stable", "backport:stable-1")
|
||||
if has(label)
|
||||
]
|
||||
has_none = has("backport:none")
|
||||
if has_none and target_labels:
|
||||
comment(
|
||||
"`backport:none` cannot be combined with other backport labels. "
|
||||
"Remove `backport:none` or remove the explicit backport targets."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
if not has_none and not target_labels:
|
||||
ensure_label("backport:stable")
|
||||
ensure_label("backport:stable-1")
|
||||
comment(
|
||||
"Applied default backport labels `backport:stable` and "
|
||||
"`backport:stable-1` for a qualifying security patch."
|
||||
)
|
||||
PY
|
||||
|
||||
backport:
|
||||
if: >
|
||||
github.event_name == 'workflow_dispatch' ||
|
||||
(
|
||||
github.event_name == 'pull_request_target' &&
|
||||
github.event.pull_request.merged == true
|
||||
)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Resolve PR metadata
|
||||
id: metadata
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
INPUT_PR_NUMBER: ${{ inputs.pull_request }}
|
||||
LATEST_BRANCH: ${{ env.LATEST_BRANCH }}
|
||||
STABLE_BRANCH: ${{ env.STABLE_BRANCH }}
|
||||
STABLE_1_BRANCH: ${{ env.STABLE_1_BRANCH }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [ "${GITHUB_EVENT_NAME}" = "workflow_dispatch" ]; then
|
||||
pr_number="${INPUT_PR_NUMBER}"
|
||||
else
|
||||
pr_number="$(jq -r '.pull_request.number' "${GITHUB_EVENT_PATH}")"
|
||||
fi
|
||||
|
||||
case "${pr_number}" in
|
||||
''|*[!0-9]*)
|
||||
echo "A valid pull request number is required."
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
pr_json="$(gh pr view "${pr_number}" --json number,title,url,mergeCommit,baseRefName,labels,mergedAt,author)"
|
||||
|
||||
PR_JSON="${pr_json}" \
|
||||
python3 - <<'PY'
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
pr = json.loads(os.environ["PR_JSON"])
|
||||
github_output = os.environ["GITHUB_OUTPUT"]
|
||||
|
||||
labels = [label["name"] for label in pr.get("labels", [])]
|
||||
if "security:patch" not in labels:
|
||||
print("Not a security patch PR; skipping.")
|
||||
sys.exit(0)
|
||||
|
||||
severity_labels = [
|
||||
label
|
||||
for label in ("severity:medium", "severity:high", "severity:critical")
|
||||
if label in labels
|
||||
]
|
||||
if len(severity_labels) != 1:
|
||||
raise SystemExit(
|
||||
"Merged security patch PR must have exactly one severity label."
|
||||
)
|
||||
|
||||
if not pr.get("mergedAt"):
|
||||
raise SystemExit(f"PR #{pr['number']} is not merged.")
|
||||
|
||||
if "backport:none" in labels:
|
||||
target_pairs = []
|
||||
else:
|
||||
mapping = {
|
||||
"backport:stable": os.environ["STABLE_BRANCH"],
|
||||
"backport:stable-1": os.environ["STABLE_1_BRANCH"],
|
||||
}
|
||||
target_pairs = []
|
||||
for label_name, branch in mapping.items():
|
||||
if label_name in labels and branch and branch != pr["baseRefName"]:
|
||||
target_pairs.append({"label": label_name, "branch": branch})
|
||||
|
||||
with open(github_output, "a", encoding="utf-8") as f:
|
||||
f.write(f"pr_number={pr['number']}\n")
|
||||
f.write(f"merge_sha={pr['mergeCommit']['oid']}\n")
|
||||
f.write(f"title={pr['title']}\n")
|
||||
f.write(f"url={pr['url']}\n")
|
||||
f.write(f"author={pr['author']['login']}\n")
|
||||
f.write(f"severity_label={severity_labels[0]}\n")
|
||||
f.write(f"target_pairs={json.dumps(target_pairs)}\n")
|
||||
PY
|
||||
|
||||
- name: Backport to release branches
|
||||
if: ${{ steps.metadata.outputs.target_pairs != '[]' }}
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PR_NUMBER: ${{ steps.metadata.outputs.pr_number }}
|
||||
MERGE_SHA: ${{ steps.metadata.outputs.merge_sha }}
|
||||
PR_TITLE: ${{ steps.metadata.outputs.title }}
|
||||
PR_URL: ${{ steps.metadata.outputs.url }}
|
||||
PR_AUTHOR: ${{ steps.metadata.outputs.author }}
|
||||
SEVERITY_LABEL: ${{ steps.metadata.outputs.severity_label }}
|
||||
TARGET_PAIRS: ${{ steps.metadata.outputs.target_pairs }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
git remote set-url origin "https://x-access-token:${GH_TOKEN}@github.com/${GITHUB_REPOSITORY}.git"
|
||||
git fetch origin --prune
|
||||
|
||||
merge_parent_count="$(git rev-list --parents -n 1 "${MERGE_SHA}" | awk '{print NF-1}')"
|
||||
|
||||
failures=()
|
||||
successes=()
|
||||
|
||||
while IFS=$'\t' read -r backport_label target_branch; do
|
||||
[ -n "${target_branch}" ] || continue
|
||||
|
||||
safe_branch_name="${target_branch//\//-}"
|
||||
head_branch="backport/${safe_branch_name}/pr-${PR_NUMBER}"
|
||||
|
||||
existing_pr="$(gh pr list \
|
||||
--base "${target_branch}" \
|
||||
--head "${head_branch}" \
|
||||
--state all \
|
||||
--json number,url \
|
||||
--jq '.[0]')"
|
||||
if [ -n "${existing_pr}" ] && [ "${existing_pr}" != "null" ]; then
|
||||
pr_url="$(printf '%s' "${existing_pr}" | jq -r '.url')"
|
||||
successes+=("${target_branch}:existing:${pr_url}")
|
||||
continue
|
||||
fi
|
||||
|
||||
git checkout -B "${head_branch}" "origin/${target_branch}"
|
||||
|
||||
if [ "${merge_parent_count}" -gt 1 ]; then
|
||||
cherry_pick_args=(-m 1 "${MERGE_SHA}")
|
||||
else
|
||||
cherry_pick_args=("${MERGE_SHA}")
|
||||
fi
|
||||
|
||||
if ! git cherry-pick -x "${cherry_pick_args[@]}"; then
|
||||
git cherry-pick --abort || true
|
||||
gh pr edit "${PR_NUMBER}" --add-label "backport:conflict" || true
|
||||
gh pr comment "${PR_NUMBER}" --body \
|
||||
"Automatic backport to \`${target_branch}\` conflicted. The original author or release manager should resolve it manually."
|
||||
failures+=("${target_branch}:cherry-pick failed")
|
||||
continue
|
||||
fi
|
||||
|
||||
git push --force-with-lease origin "${head_branch}"
|
||||
|
||||
body_file="$(mktemp)"
|
||||
printf '%s\n' \
|
||||
"Automated backport of [#${PR_NUMBER}](${PR_URL})." \
|
||||
"" \
|
||||
"- Source PR: #${PR_NUMBER}" \
|
||||
"- Source commit: ${MERGE_SHA}" \
|
||||
"- Target branch: ${target_branch}" \
|
||||
"- Severity: ${SEVERITY_LABEL}" \
|
||||
> "${body_file}"
|
||||
|
||||
pr_url="$(gh pr create \
|
||||
--base "${target_branch}" \
|
||||
--head "${head_branch}" \
|
||||
--title "${PR_TITLE} (backport to ${target_branch})" \
|
||||
--body-file "${body_file}")"
|
||||
|
||||
backport_pr_number="$(gh pr list \
|
||||
--base "${target_branch}" \
|
||||
--head "${head_branch}" \
|
||||
--state open \
|
||||
--json number \
|
||||
--jq '.[0].number')"
|
||||
|
||||
gh pr edit "${backport_pr_number}" \
|
||||
--add-label "security:patch" \
|
||||
--add-label "${SEVERITY_LABEL}" \
|
||||
--add-label "${backport_label}" || true
|
||||
|
||||
successes+=("${target_branch}:created:${pr_url}")
|
||||
done < <(
|
||||
python3 - <<'PY'
|
||||
import json
|
||||
import os
|
||||
for pair in json.loads(os.environ["TARGET_PAIRS"]):
|
||||
print(f"{pair['label']}\t{pair['branch']}")
|
||||
PY
|
||||
)
|
||||
|
||||
summary_file="$(mktemp)"
|
||||
{
|
||||
echo "## Security backport summary"
|
||||
echo
|
||||
if [ "${#successes[@]}" -gt 0 ]; then
|
||||
echo "### Created or existing"
|
||||
for entry in "${successes[@]}"; do
|
||||
echo "- ${entry}"
|
||||
done
|
||||
echo
|
||||
fi
|
||||
if [ "${#failures[@]}" -gt 0 ]; then
|
||||
echo "### Failures"
|
||||
for entry in "${failures[@]}"; do
|
||||
echo "- ${entry}"
|
||||
done
|
||||
fi
|
||||
} | tee -a "${GITHUB_STEP_SUMMARY}" > "${summary_file}"
|
||||
|
||||
gh pr comment "${PR_NUMBER}" --body-file "${summary_file}"
|
||||
|
||||
if [ "${#failures[@]}" -gt 0 ]; then
|
||||
printf 'Backport failures:\n%s\n' "${failures[@]}" >&2
|
||||
exit 1
|
||||
fi
|
||||
@@ -1,214 +0,0 @@
|
||||
name: security-patch-prs
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: "0 3 * * 1-5"
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
patch:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
lane:
|
||||
- gomod
|
||||
- terraform
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Go
|
||||
uses: ./.github/actions/setup-go
|
||||
|
||||
- name: Patch Go dependencies
|
||||
if: matrix.lane == 'gomod'
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
go get -u=patch ./...
|
||||
go mod tidy
|
||||
|
||||
# Guardrail: do not auto-edit replace directives.
|
||||
if git diff --unified=0 -- go.mod | grep -E '^[+-]replace '; then
|
||||
echo "Refusing to auto-edit go.mod replace directives"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Guardrail: only go.mod / go.sum may change.
|
||||
extra="$(git diff --name-only | grep -Ev '^(go\.mod|go\.sum)$' || true)"
|
||||
test -z "$extra" || { echo "Unexpected files changed:"; echo "$extra"; exit 1; }
|
||||
|
||||
- name: Patch bundled Terraform
|
||||
if: matrix.lane == 'terraform'
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
current="$(
|
||||
grep -oE 'NewVersion\("[0-9]+\.[0-9]+\.[0-9]+"\)' \
|
||||
provisioner/terraform/install.go \
|
||||
| head -1 \
|
||||
| grep -oE '[0-9]+\.[0-9]+\.[0-9]+'
|
||||
)"
|
||||
|
||||
series="$(echo "$current" | cut -d. -f1,2)"
|
||||
|
||||
latest="$(
|
||||
curl -fsSL https://releases.hashicorp.com/terraform/index.json \
|
||||
| jq -r --arg series "$series" '
|
||||
.versions
|
||||
| keys[]
|
||||
| select(startswith($series + "."))
|
||||
' \
|
||||
| sort -V \
|
||||
| tail -1
|
||||
)"
|
||||
|
||||
test -n "$latest"
|
||||
[ "$latest" != "$current" ] || exit 0
|
||||
|
||||
CURRENT_TERRAFORM_VERSION="$current" \
|
||||
LATEST_TERRAFORM_VERSION="$latest" \
|
||||
python3 - <<'PY'
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
current = os.environ["CURRENT_TERRAFORM_VERSION"]
|
||||
latest = os.environ["LATEST_TERRAFORM_VERSION"]
|
||||
|
||||
updates = {
|
||||
"scripts/Dockerfile.base": (
|
||||
f"terraform/{current}/",
|
||||
f"terraform/{latest}/",
|
||||
),
|
||||
"provisioner/terraform/install.go": (
|
||||
f'NewVersion("{current}")',
|
||||
f'NewVersion("{latest}")',
|
||||
),
|
||||
"install.sh": (
|
||||
f'TERRAFORM_VERSION="{current}"',
|
||||
f'TERRAFORM_VERSION="{latest}"',
|
||||
),
|
||||
}
|
||||
|
||||
for path_str, (before, after) in updates.items():
|
||||
path = Path(path_str)
|
||||
content = path.read_text()
|
||||
if before not in content:
|
||||
raise SystemExit(f"did not find expected text in {path_str}: {before}")
|
||||
path.write_text(content.replace(before, after))
|
||||
PY
|
||||
|
||||
# Guardrail: only the Terraform-version files may change.
|
||||
extra="$(git diff --name-only | grep -Ev '^(scripts/Dockerfile.base|provisioner/terraform/install.go|install.sh)$' || true)"
|
||||
test -z "$extra" || { echo "Unexpected files changed:"; echo "$extra"; exit 1; }
|
||||
|
||||
- name: Validate Go dependency patch
|
||||
if: matrix.lane == 'gomod'
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
go test ./...
|
||||
|
||||
- name: Validate Terraform patch
|
||||
if: matrix.lane == 'terraform'
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
go test ./provisioner/terraform/...
|
||||
docker build -f scripts/Dockerfile.base .
|
||||
|
||||
- name: Skip PR creation when there are no changes
|
||||
id: changes
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if git diff --quiet; then
|
||||
echo "has_changes=false" >> "${GITHUB_OUTPUT}"
|
||||
else
|
||||
echo "has_changes=true" >> "${GITHUB_OUTPUT}"
|
||||
fi
|
||||
|
||||
- name: Commit changes
|
||||
if: steps.changes.outputs.has_changes == 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
git checkout -B "secpatch/${{ matrix.lane }}"
|
||||
git add -A
|
||||
git commit -m "security: patch ${{ matrix.lane }}"
|
||||
|
||||
- name: Push branch
|
||||
if: steps.changes.outputs.has_changes == 'true'
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
git push --force-with-lease \
|
||||
"https://x-access-token:${GH_TOKEN}@github.com/${{ github.repository }}.git" \
|
||||
"HEAD:refs/heads/secpatch/${{ matrix.lane }}"
|
||||
|
||||
- name: Create or update PR
|
||||
if: steps.changes.outputs.has_changes == 'true'
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
branch="secpatch/${{ matrix.lane }}"
|
||||
title="security: patch ${{ matrix.lane }}"
|
||||
body="$(cat <<'EOF'
|
||||
Automated security patch PR for `${{ matrix.lane }}`.
|
||||
|
||||
Scope:
|
||||
- gomod: patch-level Go dependency updates only
|
||||
- terraform: bundled Terraform patch updates only
|
||||
|
||||
Guardrails:
|
||||
- no application-code edits
|
||||
- no auto-editing of go.mod replace directives
|
||||
- CI must pass
|
||||
EOF
|
||||
)"
|
||||
|
||||
existing_pr="$(gh pr list --head "${branch}" --base main --json number --jq '.[0].number // empty')"
|
||||
if [[ -n "${existing_pr}" ]]; then
|
||||
gh pr edit "${existing_pr}" \
|
||||
--title "${title}" \
|
||||
--body "${body}"
|
||||
pr_number="${existing_pr}"
|
||||
else
|
||||
gh pr create \
|
||||
--base main \
|
||||
--head "${branch}" \
|
||||
--title "${title}" \
|
||||
--body "${body}"
|
||||
pr_number="$(gh pr list --head "${branch}" --base main --json number --jq '.[0].number // empty')"
|
||||
fi
|
||||
|
||||
for label in security dependencies automated-pr; do
|
||||
if gh label list --json name --jq '.[].name' | grep -Fxq "${label}"; then
|
||||
gh pr edit "${pr_number}" --add-label "${label}"
|
||||
fi
|
||||
done
|
||||
@@ -2155,17 +2155,12 @@ func (q *querier) DeleteUserChatProviderKey(ctx context.Context, arg database.De
|
||||
return q.db.DeleteUserChatProviderKey(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
// First get the secret to check ownership
|
||||
secret, err := q.GetUserSecret(ctx, id)
|
||||
if err != nil {
|
||||
func (q *querier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, obj); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, secret); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteUserSecret(ctx, id)
|
||||
return q.db.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg database.DeleteWebpushSubscriptionByUserIDAndEndpointParams) error {
|
||||
@@ -4128,19 +4123,6 @@ func (q *querier) GetUserNotificationPreferences(ctx context.Context, userID uui
|
||||
return q.db.GetUserNotificationPreferences(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
|
||||
// First get the secret to check ownership
|
||||
secret, err := q.db.GetUserSecret(ctx, id)
|
||||
if err != nil {
|
||||
return database.UserSecret{}, err
|
||||
}
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, secret); err != nil {
|
||||
return database.UserSecret{}, err
|
||||
}
|
||||
return secret, nil
|
||||
}
|
||||
|
||||
func (q *querier) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, obj); err != nil {
|
||||
@@ -5524,7 +5506,7 @@ func (q *querier) ListUserChatCompactionThresholds(ctx context.Context, userID u
|
||||
return q.db.ListUserChatCompactionThresholds(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(userID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, obj); err != nil {
|
||||
return nil, err
|
||||
@@ -5532,6 +5514,16 @@ func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]data
|
||||
return q.db.ListUserSecrets(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
// This query returns decrypted secret values and must only be called
|
||||
// from system contexts (provisioner, agent manifest). REST API
|
||||
// handlers should use ListUserSecrets (metadata only).
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.ListUserSecretsWithValues(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
|
||||
workspace, err := q.db.GetWorkspaceByID(ctx, workspaceID)
|
||||
if err != nil {
|
||||
@@ -6632,17 +6624,12 @@ func (q *querier) UpdateUserRoles(ctx context.Context, arg database.UpdateUserRo
|
||||
return q.db.UpdateUserRoles(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
|
||||
// First get the secret to check ownership
|
||||
secret, err := q.db.GetUserSecret(ctx, arg.ID)
|
||||
if err != nil {
|
||||
func (q *querier) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, obj); err != nil {
|
||||
return database.UserSecret{}, err
|
||||
}
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, secret); err != nil {
|
||||
return database.UserSecret{}, err
|
||||
}
|
||||
return q.db.UpdateUserSecret(ctx, arg)
|
||||
return q.db.UpdateUserSecretByUserIDAndName(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateUserStatus(ctx context.Context, arg database.UpdateUserStatusParams) (database.User, error) {
|
||||
|
||||
@@ -5346,19 +5346,20 @@ func (s *MethodTestSuite) TestUserSecrets() {
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionRead).
|
||||
Returns(secret)
|
||||
}))
|
||||
s.Run("GetUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
|
||||
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
|
||||
check.Args(secret.ID).
|
||||
Asserts(secret, policy.ActionRead).
|
||||
Returns(secret)
|
||||
}))
|
||||
s.Run("ListUserSecrets", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
|
||||
dbm.EXPECT().ListUserSecrets(gomock.Any(), user.ID).Return([]database.UserSecret{secret}, nil).AnyTimes()
|
||||
row := testutil.Fake(s.T(), faker, database.ListUserSecretsRow{UserID: user.ID})
|
||||
dbm.EXPECT().ListUserSecrets(gomock.Any(), user.ID).Return([]database.ListUserSecretsRow{row}, nil).AnyTimes()
|
||||
check.Args(user.ID).
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionRead).
|
||||
Returns([]database.ListUserSecretsRow{row})
|
||||
}))
|
||||
s.Run("ListUserSecretsWithValues", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
|
||||
dbm.EXPECT().ListUserSecretsWithValues(gomock.Any(), user.ID).Return([]database.UserSecret{secret}, nil).AnyTimes()
|
||||
check.Args(user.ID).
|
||||
Asserts(rbac.ResourceSystem, policy.ActionRead).
|
||||
Returns([]database.UserSecret{secret})
|
||||
}))
|
||||
s.Run("CreateUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
@@ -5370,22 +5371,21 @@ func (s *MethodTestSuite) TestUserSecrets() {
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionCreate).
|
||||
Returns(ret)
|
||||
}))
|
||||
s.Run("UpdateUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
|
||||
updated := testutil.Fake(s.T(), faker, database.UserSecret{ID: secret.ID})
|
||||
arg := database.UpdateUserSecretParams{ID: secret.ID}
|
||||
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateUserSecret(gomock.Any(), arg).Return(updated, nil).AnyTimes()
|
||||
s.Run("UpdateUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
updated := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
|
||||
arg := database.UpdateUserSecretByUserIDAndNameParams{UserID: user.ID, Name: "test"}
|
||||
dbm.EXPECT().UpdateUserSecretByUserIDAndName(gomock.Any(), arg).Return(updated, nil).AnyTimes()
|
||||
check.Args(arg).
|
||||
Asserts(secret, policy.ActionUpdate).
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionUpdate).
|
||||
Returns(updated)
|
||||
}))
|
||||
s.Run("DeleteUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
|
||||
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteUserSecret(gomock.Any(), secret.ID).Return(nil).AnyTimes()
|
||||
check.Args(secret.ID).
|
||||
Asserts(secret, policy.ActionRead, secret, policy.ActionDelete).
|
||||
s.Run("DeleteUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.DeleteUserSecretByUserIDAndNameParams{UserID: user.ID, Name: "test"}
|
||||
dbm.EXPECT().DeleteUserSecretByUserIDAndName(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionDelete).
|
||||
Returns()
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -1597,6 +1597,7 @@ func UserSecret(t testing.TB, db database.Store, seed database.UserSecret) datab
|
||||
Name: takeFirst(seed.Name, "secret-name"),
|
||||
Description: takeFirst(seed.Description, "secret description"),
|
||||
Value: takeFirst(seed.Value, "secret value"),
|
||||
ValueKeyID: seed.ValueKeyID,
|
||||
EnvName: takeFirst(seed.EnvName, "SECRET_ENV_NAME"),
|
||||
FilePath: takeFirst(seed.FilePath, "~/secret/file/path"),
|
||||
})
|
||||
|
||||
@@ -712,11 +712,11 @@ func (m queryMetricsStore) DeleteUserChatProviderKey(ctx context.Context, arg da
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
func (m queryMetricsStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteUserSecret(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("DeleteUserSecret").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserSecret").Inc()
|
||||
r0 := m.s.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteUserSecretByUserIDAndName").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserSecretByUserIDAndName").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
@@ -2624,14 +2624,6 @@ func (m queryMetricsStore) GetUserNotificationPreferences(ctx context.Context, u
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserSecret(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetUserSecret").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserSecret").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserSecretByUserIDAndName(ctx, arg)
|
||||
@@ -3920,7 +3912,7 @@ func (m queryMetricsStore) ListUserChatCompactionThresholds(ctx context.Context,
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListUserSecrets(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("ListUserSecrets").Observe(time.Since(start).Seconds())
|
||||
@@ -3928,6 +3920,14 @@ func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListUserSecretsWithValues(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("ListUserSecretsWithValues").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListUserSecretsWithValues").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListWorkspaceAgentPortShares(ctx, workspaceID)
|
||||
@@ -4696,11 +4696,11 @@ func (m queryMetricsStore) UpdateUserRoles(ctx context.Context, arg database.Upd
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
|
||||
func (m queryMetricsStore) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateUserSecret(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateUserSecret").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserSecret").Inc()
|
||||
r0, r1 := m.s.UpdateUserSecretByUserIDAndName(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateUserSecretByUserIDAndName").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserSecretByUserIDAndName").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
|
||||
@@ -1199,18 +1199,18 @@ func (mr *MockStoreMockRecorder) DeleteUserChatProviderKey(ctx, arg any) *gomock
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).DeleteUserChatProviderKey), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteUserSecret mocks base method.
|
||||
func (m *MockStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
// DeleteUserSecretByUserIDAndName mocks base method.
|
||||
func (m *MockStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteUserSecret", ctx, id)
|
||||
ret := m.ctrl.Call(m, "DeleteUserSecretByUserIDAndName", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteUserSecret indicates an expected call of DeleteUserSecret.
|
||||
func (mr *MockStoreMockRecorder) DeleteUserSecret(ctx, id any) *gomock.Call {
|
||||
// DeleteUserSecretByUserIDAndName indicates an expected call of DeleteUserSecretByUserIDAndName.
|
||||
func (mr *MockStoreMockRecorder) DeleteUserSecretByUserIDAndName(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserSecret", reflect.TypeOf((*MockStore)(nil).DeleteUserSecret), ctx, id)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserSecretByUserIDAndName", reflect.TypeOf((*MockStore)(nil).DeleteUserSecretByUserIDAndName), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteWebpushSubscriptionByUserIDAndEndpoint mocks base method.
|
||||
@@ -4907,21 +4907,6 @@ func (mr *MockStoreMockRecorder) GetUserNotificationPreferences(ctx, userID any)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserNotificationPreferences", reflect.TypeOf((*MockStore)(nil).GetUserNotificationPreferences), ctx, userID)
|
||||
}
|
||||
|
||||
// GetUserSecret mocks base method.
|
||||
func (m *MockStore) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUserSecret", ctx, id)
|
||||
ret0, _ := ret[0].(database.UserSecret)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetUserSecret indicates an expected call of GetUserSecret.
|
||||
func (mr *MockStoreMockRecorder) GetUserSecret(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserSecret", reflect.TypeOf((*MockStore)(nil).GetUserSecret), ctx, id)
|
||||
}
|
||||
|
||||
// GetUserSecretByUserIDAndName mocks base method.
|
||||
func (m *MockStore) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7412,10 +7397,10 @@ func (mr *MockStoreMockRecorder) ListUserChatCompactionThresholds(ctx, userID an
|
||||
}
|
||||
|
||||
// ListUserSecrets mocks base method.
|
||||
func (m *MockStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
func (m *MockStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListUserSecrets", ctx, userID)
|
||||
ret0, _ := ret[0].([]database.UserSecret)
|
||||
ret0, _ := ret[0].([]database.ListUserSecretsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -7426,6 +7411,21 @@ func (mr *MockStoreMockRecorder) ListUserSecrets(ctx, userID any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUserSecrets", reflect.TypeOf((*MockStore)(nil).ListUserSecrets), ctx, userID)
|
||||
}
|
||||
|
||||
// ListUserSecretsWithValues mocks base method.
|
||||
func (m *MockStore) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListUserSecretsWithValues", ctx, userID)
|
||||
ret0, _ := ret[0].([]database.UserSecret)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListUserSecretsWithValues indicates an expected call of ListUserSecretsWithValues.
|
||||
func (mr *MockStoreMockRecorder) ListUserSecretsWithValues(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUserSecretsWithValues", reflect.TypeOf((*MockStore)(nil).ListUserSecretsWithValues), ctx, userID)
|
||||
}
|
||||
|
||||
// ListWorkspaceAgentPortShares mocks base method.
|
||||
func (m *MockStore) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -8854,19 +8854,19 @@ func (mr *MockStoreMockRecorder) UpdateUserRoles(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserRoles", reflect.TypeOf((*MockStore)(nil).UpdateUserRoles), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateUserSecret mocks base method.
|
||||
func (m *MockStore) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
|
||||
// UpdateUserSecretByUserIDAndName mocks base method.
|
||||
func (m *MockStore) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateUserSecret", ctx, arg)
|
||||
ret := m.ctrl.Call(m, "UpdateUserSecretByUserIDAndName", ctx, arg)
|
||||
ret0, _ := ret[0].(database.UserSecret)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateUserSecret indicates an expected call of UpdateUserSecret.
|
||||
func (mr *MockStoreMockRecorder) UpdateUserSecret(ctx, arg any) *gomock.Call {
|
||||
// UpdateUserSecretByUserIDAndName indicates an expected call of UpdateUserSecretByUserIDAndName.
|
||||
func (mr *MockStoreMockRecorder) UpdateUserSecretByUserIDAndName(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserSecret", reflect.TypeOf((*MockStore)(nil).UpdateUserSecret), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserSecretByUserIDAndName", reflect.TypeOf((*MockStore)(nil).UpdateUserSecretByUserIDAndName), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateUserStatus mocks base method.
|
||||
|
||||
@@ -81,8 +81,8 @@ func newMsgQueue(ctx context.Context, l Listener, le ListenerWithErr) *msgQueue
|
||||
}
|
||||
|
||||
func (q *msgQueue) run() {
|
||||
var batch [maxDrainBatch]msgOrErr
|
||||
for {
|
||||
// wait until there is something on the queue or we are closed
|
||||
q.cond.L.Lock()
|
||||
for q.size == 0 && !q.closed {
|
||||
q.cond.Wait()
|
||||
@@ -91,28 +91,32 @@ func (q *msgQueue) run() {
|
||||
q.cond.L.Unlock()
|
||||
return
|
||||
}
|
||||
item := q.q[q.front]
|
||||
q.front = (q.front + 1) % BufferSize
|
||||
q.size--
|
||||
// Drain up to maxDrainBatch items while holding the lock.
|
||||
n := min(q.size, maxDrainBatch)
|
||||
for i := range n {
|
||||
batch[i] = q.q[q.front]
|
||||
q.front = (q.front + 1) % BufferSize
|
||||
}
|
||||
q.size -= n
|
||||
q.cond.L.Unlock()
|
||||
|
||||
// process item without holding lock
|
||||
if item.err == nil {
|
||||
// real message
|
||||
if q.l != nil {
|
||||
q.l(q.ctx, item.msg)
|
||||
// Dispatch each message individually without holding the lock.
|
||||
for i := range n {
|
||||
item := batch[i]
|
||||
if item.err == nil {
|
||||
if q.l != nil {
|
||||
q.l(q.ctx, item.msg)
|
||||
continue
|
||||
}
|
||||
if q.le != nil {
|
||||
q.le(q.ctx, item.msg, nil)
|
||||
continue
|
||||
}
|
||||
continue
|
||||
}
|
||||
if q.le != nil {
|
||||
q.le(q.ctx, item.msg, nil)
|
||||
continue
|
||||
q.le(q.ctx, nil, item.err)
|
||||
}
|
||||
// unhittable
|
||||
continue
|
||||
}
|
||||
// if the listener wants errors, send it.
|
||||
if q.le != nil {
|
||||
q.le(q.ctx, nil, item.err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -233,6 +237,12 @@ type PGPubsub struct {
|
||||
// for a subscriber before dropping messages.
|
||||
const BufferSize = 2048
|
||||
|
||||
// maxDrainBatch is the maximum number of messages to drain from the ring
|
||||
// buffer per iteration. Batching amortizes the cost of mutex
|
||||
// acquire/release and cond.Wait across many messages, improving drain
|
||||
// throughput during bursts.
|
||||
const maxDrainBatch = 256
|
||||
|
||||
// Subscribe calls the listener when an event matching the name is received.
|
||||
func (p *PGPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
|
||||
return p.subscribeQueue(event, newMsgQueue(context.Background(), listener, nil))
|
||||
|
||||
@@ -152,7 +152,7 @@ type sqlcQuerier interface {
|
||||
DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error)
|
||||
DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error
|
||||
DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error
|
||||
DeleteUserSecret(ctx context.Context, id uuid.UUID) error
|
||||
DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) error
|
||||
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
|
||||
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
|
||||
DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) error
|
||||
@@ -598,7 +598,6 @@ type sqlcQuerier interface {
|
||||
GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUserLinkByUserIDLoginTypeParams) (UserLink, error)
|
||||
GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]UserLink, error)
|
||||
GetUserNotificationPreferences(ctx context.Context, userID uuid.UUID) ([]NotificationPreference, error)
|
||||
GetUserSecret(ctx context.Context, id uuid.UUID) (UserSecret, error)
|
||||
GetUserSecretByUserIDAndName(ctx context.Context, arg GetUserSecretByUserIDAndNameParams) (UserSecret, error)
|
||||
// GetUserStatusCounts returns the count of users in each status over time.
|
||||
// The time range is inclusively defined by the start_time and end_time parameters.
|
||||
@@ -818,7 +817,13 @@ type sqlcQuerier interface {
|
||||
ListProvisionerKeysByOrganizationExcludeReserved(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
|
||||
ListTasks(ctx context.Context, arg ListTasksParams) ([]Task, error)
|
||||
ListUserChatCompactionThresholds(ctx context.Context, userID uuid.UUID) ([]UserConfig, error)
|
||||
ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]UserSecret, error)
|
||||
// Returns metadata only (no value or value_key_id) for the
|
||||
// REST API list and get endpoints.
|
||||
ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]ListUserSecretsRow, error)
|
||||
// Returns all columns including the secret value. Used by the
|
||||
// provisioner (build-time injection) and the agent manifest
|
||||
// (runtime injection).
|
||||
ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]UserSecret, error)
|
||||
ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceAgentPortShare, error)
|
||||
MarkAllInboxNotificationsAsRead(ctx context.Context, arg MarkAllInboxNotificationsAsReadParams) error
|
||||
OIDCClaimFieldValues(ctx context.Context, arg OIDCClaimFieldValuesParams) ([]string, error)
|
||||
@@ -957,7 +962,7 @@ type sqlcQuerier interface {
|
||||
UpdateUserProfile(ctx context.Context, arg UpdateUserProfileParams) (User, error)
|
||||
UpdateUserQuietHoursSchedule(ctx context.Context, arg UpdateUserQuietHoursScheduleParams) (User, error)
|
||||
UpdateUserRoles(ctx context.Context, arg UpdateUserRolesParams) (User, error)
|
||||
UpdateUserSecret(ctx context.Context, arg UpdateUserSecretParams) (UserSecret, error)
|
||||
UpdateUserSecretByUserIDAndName(ctx context.Context, arg UpdateUserSecretByUserIDAndNameParams) (UserSecret, error)
|
||||
UpdateUserStatus(ctx context.Context, arg UpdateUserStatusParams) (User, error)
|
||||
UpdateUserTaskNotificationAlertDismissed(ctx context.Context, arg UpdateUserTaskNotificationAlertDismissedParams) (bool, error)
|
||||
UpdateUserTerminalFont(ctx context.Context, arg UpdateUserTerminalFontParams) (UserConfig, error)
|
||||
|
||||
@@ -7339,13 +7339,7 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, secretID, createdSecret.ID)
|
||||
|
||||
// 2. READ by ID
|
||||
readSecret, err := db.GetUserSecret(ctx, createdSecret.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, createdSecret.ID, readSecret.ID)
|
||||
assert.Equal(t, "workflow-secret", readSecret.Name)
|
||||
|
||||
// 3. READ by UserID and Name
|
||||
// 2. READ by UserID and Name
|
||||
readByNameParams := database.GetUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID,
|
||||
Name: "workflow-secret",
|
||||
@@ -7353,33 +7347,43 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
|
||||
readByNameSecret, err := db.GetUserSecretByUserIDAndName(ctx, readByNameParams)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, createdSecret.ID, readByNameSecret.ID)
|
||||
assert.Equal(t, "workflow-secret", readByNameSecret.Name)
|
||||
|
||||
// 4. LIST
|
||||
// 3. LIST (metadata only)
|
||||
secrets, err := db.ListUserSecrets(ctx, testUser.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secrets, 1)
|
||||
assert.Equal(t, createdSecret.ID, secrets[0].ID)
|
||||
|
||||
// 5. UPDATE
|
||||
updateParams := database.UpdateUserSecretParams{
|
||||
ID: createdSecret.ID,
|
||||
Description: "Updated workflow description",
|
||||
Value: "updated-workflow-value",
|
||||
EnvName: "UPDATED_WORKFLOW_ENV",
|
||||
FilePath: "/updated/workflow/path",
|
||||
// 4. LIST with values
|
||||
secretsWithValues, err := db.ListUserSecretsWithValues(ctx, testUser.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secretsWithValues, 1)
|
||||
assert.Equal(t, "workflow-value", secretsWithValues[0].Value)
|
||||
|
||||
// 5. UPDATE (partial - only description)
|
||||
updateParams := database.UpdateUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID,
|
||||
Name: "workflow-secret",
|
||||
UpdateDescription: true,
|
||||
Description: "Updated workflow description",
|
||||
}
|
||||
|
||||
updatedSecret, err := db.UpdateUserSecret(ctx, updateParams)
|
||||
updatedSecret, err := db.UpdateUserSecretByUserIDAndName(ctx, updateParams)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated workflow description", updatedSecret.Description)
|
||||
assert.Equal(t, "updated-workflow-value", updatedSecret.Value)
|
||||
assert.Equal(t, "workflow-value", updatedSecret.Value) // Value unchanged
|
||||
assert.Equal(t, "WORKFLOW_ENV", updatedSecret.EnvName) // EnvName unchanged
|
||||
|
||||
// 6. DELETE
|
||||
err = db.DeleteUserSecret(ctx, createdSecret.ID)
|
||||
err = db.DeleteUserSecretByUserIDAndName(ctx, database.DeleteUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID,
|
||||
Name: "workflow-secret",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify deletion
|
||||
_, err = db.GetUserSecret(ctx, createdSecret.ID)
|
||||
_, err = db.GetUserSecretByUserIDAndName(ctx, readByNameParams)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no rows in result set")
|
||||
|
||||
@@ -7449,9 +7453,13 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
|
||||
})
|
||||
|
||||
// Verify both secrets exist
|
||||
_, err = db.GetUserSecret(ctx, secret1.ID)
|
||||
_, err = db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID, Name: secret1.Name,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.GetUserSecret(ctx, secret2.ID)
|
||||
_, err = db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID, Name: secret2.Name,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
@@ -7474,14 +7482,14 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
|
||||
// Create secrets for users
|
||||
user1Secret := dbgen.UserSecret(t, db, database.UserSecret{
|
||||
_ = dbgen.UserSecret(t, db, database.UserSecret{
|
||||
UserID: user1.ID,
|
||||
Name: "user1-secret",
|
||||
Description: "User 1's secret",
|
||||
Value: "user1-value",
|
||||
})
|
||||
|
||||
user2Secret := dbgen.UserSecret(t, db, database.UserSecret{
|
||||
_ = dbgen.UserSecret(t, db, database.UserSecret{
|
||||
UserID: user2.ID,
|
||||
Name: "user2-secret",
|
||||
Description: "User 2's secret",
|
||||
@@ -7491,7 +7499,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
subject rbac.Subject
|
||||
secretID uuid.UUID
|
||||
lookupUserID uuid.UUID
|
||||
lookupName string
|
||||
expectedAccess bool
|
||||
}{
|
||||
{
|
||||
@@ -7501,7 +7510,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
Roles: rbac.RoleIdentifiers{rbac.RoleMember()},
|
||||
Scope: rbac.ScopeAll,
|
||||
},
|
||||
secretID: user1Secret.ID,
|
||||
lookupUserID: user1.ID,
|
||||
lookupName: "user1-secret",
|
||||
expectedAccess: true,
|
||||
},
|
||||
{
|
||||
@@ -7511,7 +7521,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
Roles: rbac.RoleIdentifiers{rbac.RoleMember()},
|
||||
Scope: rbac.ScopeAll,
|
||||
},
|
||||
secretID: user2Secret.ID,
|
||||
lookupUserID: user2.ID,
|
||||
lookupName: "user2-secret",
|
||||
expectedAccess: false,
|
||||
},
|
||||
{
|
||||
@@ -7521,7 +7532,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
Roles: rbac.RoleIdentifiers{rbac.RoleOwner()},
|
||||
Scope: rbac.ScopeAll,
|
||||
},
|
||||
secretID: user1Secret.ID,
|
||||
lookupUserID: user1.ID,
|
||||
lookupName: "user1-secret",
|
||||
expectedAccess: false,
|
||||
},
|
||||
{
|
||||
@@ -7531,7 +7543,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
Roles: rbac.RoleIdentifiers{rbac.ScopedRoleOrgAdmin(org.ID)},
|
||||
Scope: rbac.ScopeAll,
|
||||
},
|
||||
secretID: user1Secret.ID,
|
||||
lookupUserID: user1.ID,
|
||||
lookupName: "user1-secret",
|
||||
expectedAccess: false,
|
||||
},
|
||||
}
|
||||
@@ -7543,8 +7556,10 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
|
||||
authCtx := dbauthz.As(ctx, tc.subject)
|
||||
|
||||
// Test GetUserSecret
|
||||
_, err := authDB.GetUserSecret(authCtx, tc.secretID)
|
||||
_, err := authDB.GetUserSecretByUserIDAndName(authCtx, database.GetUserSecretByUserIDAndNameParams{
|
||||
UserID: tc.lookupUserID,
|
||||
Name: tc.lookupName,
|
||||
})
|
||||
|
||||
if tc.expectedAccess {
|
||||
require.NoError(t, err, "expected access to be granted")
|
||||
|
||||
+120
-55
@@ -22639,21 +22639,30 @@ INSERT INTO user_secrets (
|
||||
name,
|
||||
description,
|
||||
value,
|
||||
value_key_id,
|
||||
env_name,
|
||||
file_path
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7
|
||||
$1,
|
||||
$2,
|
||||
$3,
|
||||
$4,
|
||||
$5,
|
||||
$6,
|
||||
$7,
|
||||
$8
|
||||
) RETURNING id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id
|
||||
`
|
||||
|
||||
type CreateUserSecretParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Description string `db:"description" json:"description"`
|
||||
Value string `db:"value" json:"value"`
|
||||
EnvName string `db:"env_name" json:"env_name"`
|
||||
FilePath string `db:"file_path" json:"file_path"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Description string `db:"description" json:"description"`
|
||||
Value string `db:"value" json:"value"`
|
||||
ValueKeyID sql.NullString `db:"value_key_id" json:"value_key_id"`
|
||||
EnvName string `db:"env_name" json:"env_name"`
|
||||
FilePath string `db:"file_path" json:"file_path"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) CreateUserSecret(ctx context.Context, arg CreateUserSecretParams) (UserSecret, error) {
|
||||
@@ -22663,6 +22672,7 @@ func (q *sqlQuerier) CreateUserSecret(ctx context.Context, arg CreateUserSecretP
|
||||
arg.Name,
|
||||
arg.Description,
|
||||
arg.Value,
|
||||
arg.ValueKeyID,
|
||||
arg.EnvName,
|
||||
arg.FilePath,
|
||||
)
|
||||
@@ -22682,41 +22692,24 @@ func (q *sqlQuerier) CreateUserSecret(ctx context.Context, arg CreateUserSecretP
|
||||
return i, err
|
||||
}
|
||||
|
||||
const deleteUserSecret = `-- name: DeleteUserSecret :exec
|
||||
const deleteUserSecretByUserIDAndName = `-- name: DeleteUserSecretByUserIDAndName :exec
|
||||
DELETE FROM user_secrets
|
||||
WHERE id = $1
|
||||
WHERE user_id = $1 AND name = $2
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteUserSecret, id)
|
||||
type DeleteUserSecretByUserIDAndNameParams struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteUserSecretByUserIDAndName, arg.UserID, arg.Name)
|
||||
return err
|
||||
}
|
||||
|
||||
const getUserSecret = `-- name: GetUserSecret :one
|
||||
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id FROM user_secrets
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetUserSecret(ctx context.Context, id uuid.UUID) (UserSecret, error) {
|
||||
row := q.db.QueryRowContext(ctx, getUserSecret, id)
|
||||
var i UserSecret
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.Name,
|
||||
&i.Description,
|
||||
&i.Value,
|
||||
&i.EnvName,
|
||||
&i.FilePath,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.ValueKeyID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserSecretByUserIDAndName = `-- name: GetUserSecretByUserIDAndName :one
|
||||
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id FROM user_secrets
|
||||
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id
|
||||
FROM user_secrets
|
||||
WHERE user_id = $1 AND name = $2
|
||||
`
|
||||
|
||||
@@ -22744,17 +22737,76 @@ func (q *sqlQuerier) GetUserSecretByUserIDAndName(ctx context.Context, arg GetUs
|
||||
}
|
||||
|
||||
const listUserSecrets = `-- name: ListUserSecrets :many
|
||||
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id FROM user_secrets
|
||||
SELECT
|
||||
id, user_id, name, description,
|
||||
env_name, file_path,
|
||||
created_at, updated_at
|
||||
FROM user_secrets
|
||||
WHERE user_id = $1
|
||||
ORDER BY name ASC
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]UserSecret, error) {
|
||||
type ListUserSecretsRow struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Description string `db:"description" json:"description"`
|
||||
EnvName string `db:"env_name" json:"env_name"`
|
||||
FilePath string `db:"file_path" json:"file_path"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
// Returns metadata only (no value or value_key_id) for the
|
||||
// REST API list and get endpoints.
|
||||
func (q *sqlQuerier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]ListUserSecretsRow, error) {
|
||||
rows, err := q.db.QueryContext(ctx, listUserSecrets, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []ListUserSecretsRow
|
||||
for rows.Next() {
|
||||
var i ListUserSecretsRow
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.Name,
|
||||
&i.Description,
|
||||
&i.EnvName,
|
||||
&i.FilePath,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listUserSecretsWithValues = `-- name: ListUserSecretsWithValues :many
|
||||
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id
|
||||
FROM user_secrets
|
||||
WHERE user_id = $1
|
||||
ORDER BY name ASC
|
||||
`
|
||||
|
||||
// Returns all columns including the secret value. Used by the
|
||||
// provisioner (build-time injection) and the agent manifest
|
||||
// (runtime injection).
|
||||
func (q *sqlQuerier) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]UserSecret, error) {
|
||||
rows, err := q.db.QueryContext(ctx, listUserSecretsWithValues, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []UserSecret
|
||||
for rows.Next() {
|
||||
var i UserSecret
|
||||
@@ -22783,33 +22835,46 @@ func (q *sqlQuerier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]U
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateUserSecret = `-- name: UpdateUserSecret :one
|
||||
const updateUserSecretByUserIDAndName = `-- name: UpdateUserSecretByUserIDAndName :one
|
||||
UPDATE user_secrets
|
||||
SET
|
||||
description = $2,
|
||||
value = $3,
|
||||
env_name = $4,
|
||||
file_path = $5,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = $1
|
||||
value = CASE WHEN $1::bool THEN $2 ELSE value END,
|
||||
value_key_id = CASE WHEN $1::bool THEN $3 ELSE value_key_id END,
|
||||
description = CASE WHEN $4::bool THEN $5 ELSE description END,
|
||||
env_name = CASE WHEN $6::bool THEN $7 ELSE env_name END,
|
||||
file_path = CASE WHEN $8::bool THEN $9 ELSE file_path END,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE user_id = $10 AND name = $11
|
||||
RETURNING id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id
|
||||
`
|
||||
|
||||
type UpdateUserSecretParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
Description string `db:"description" json:"description"`
|
||||
Value string `db:"value" json:"value"`
|
||||
EnvName string `db:"env_name" json:"env_name"`
|
||||
FilePath string `db:"file_path" json:"file_path"`
|
||||
type UpdateUserSecretByUserIDAndNameParams struct {
|
||||
UpdateValue bool `db:"update_value" json:"update_value"`
|
||||
Value string `db:"value" json:"value"`
|
||||
ValueKeyID sql.NullString `db:"value_key_id" json:"value_key_id"`
|
||||
UpdateDescription bool `db:"update_description" json:"update_description"`
|
||||
Description string `db:"description" json:"description"`
|
||||
UpdateEnvName bool `db:"update_env_name" json:"update_env_name"`
|
||||
EnvName string `db:"env_name" json:"env_name"`
|
||||
UpdateFilePath bool `db:"update_file_path" json:"update_file_path"`
|
||||
FilePath string `db:"file_path" json:"file_path"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateUserSecret(ctx context.Context, arg UpdateUserSecretParams) (UserSecret, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateUserSecret,
|
||||
arg.ID,
|
||||
arg.Description,
|
||||
func (q *sqlQuerier) UpdateUserSecretByUserIDAndName(ctx context.Context, arg UpdateUserSecretByUserIDAndNameParams) (UserSecret, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateUserSecretByUserIDAndName,
|
||||
arg.UpdateValue,
|
||||
arg.Value,
|
||||
arg.ValueKeyID,
|
||||
arg.UpdateDescription,
|
||||
arg.Description,
|
||||
arg.UpdateEnvName,
|
||||
arg.EnvName,
|
||||
arg.UpdateFilePath,
|
||||
arg.FilePath,
|
||||
arg.UserID,
|
||||
arg.Name,
|
||||
)
|
||||
var i UserSecret
|
||||
err := row.Scan(
|
||||
|
||||
@@ -1,14 +1,26 @@
|
||||
-- name: GetUserSecretByUserIDAndName :one
|
||||
SELECT * FROM user_secrets
|
||||
WHERE user_id = $1 AND name = $2;
|
||||
|
||||
-- name: GetUserSecret :one
|
||||
SELECT * FROM user_secrets
|
||||
WHERE id = $1;
|
||||
SELECT *
|
||||
FROM user_secrets
|
||||
WHERE user_id = @user_id AND name = @name;
|
||||
|
||||
-- name: ListUserSecrets :many
|
||||
SELECT * FROM user_secrets
|
||||
WHERE user_id = $1
|
||||
-- Returns metadata only (no value or value_key_id) for the
|
||||
-- REST API list and get endpoints.
|
||||
SELECT
|
||||
id, user_id, name, description,
|
||||
env_name, file_path,
|
||||
created_at, updated_at
|
||||
FROM user_secrets
|
||||
WHERE user_id = @user_id
|
||||
ORDER BY name ASC;
|
||||
|
||||
-- name: ListUserSecretsWithValues :many
|
||||
-- Returns all columns including the secret value. Used by the
|
||||
-- provisioner (build-time injection) and the agent manifest
|
||||
-- (runtime injection).
|
||||
SELECT *
|
||||
FROM user_secrets
|
||||
WHERE user_id = @user_id
|
||||
ORDER BY name ASC;
|
||||
|
||||
-- name: CreateUserSecret :one
|
||||
@@ -18,23 +30,32 @@ INSERT INTO user_secrets (
|
||||
name,
|
||||
description,
|
||||
value,
|
||||
value_key_id,
|
||||
env_name,
|
||||
file_path
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7
|
||||
@id,
|
||||
@user_id,
|
||||
@name,
|
||||
@description,
|
||||
@value,
|
||||
@value_key_id,
|
||||
@env_name,
|
||||
@file_path
|
||||
) RETURNING *;
|
||||
|
||||
-- name: UpdateUserSecret :one
|
||||
-- name: UpdateUserSecretByUserIDAndName :one
|
||||
UPDATE user_secrets
|
||||
SET
|
||||
description = $2,
|
||||
value = $3,
|
||||
env_name = $4,
|
||||
file_path = $5,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = $1
|
||||
value = CASE WHEN @update_value::bool THEN @value ELSE value END,
|
||||
value_key_id = CASE WHEN @update_value::bool THEN @value_key_id ELSE value_key_id END,
|
||||
description = CASE WHEN @update_description::bool THEN @description ELSE description END,
|
||||
env_name = CASE WHEN @update_env_name::bool THEN @env_name ELSE env_name END,
|
||||
file_path = CASE WHEN @update_file_path::bool THEN @file_path ELSE file_path END,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE user_id = @user_id AND name = @name
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteUserSecret :exec
|
||||
-- name: DeleteUserSecretByUserIDAndName :exec
|
||||
DELETE FROM user_secrets
|
||||
WHERE id = $1;
|
||||
WHERE user_id = @user_id AND name = @name;
|
||||
|
||||
+37
-19
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/sync/singleflight"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/derp"
|
||||
"tailscale.com/tailcfg"
|
||||
@@ -389,6 +390,7 @@ type MultiAgentController struct {
|
||||
// connections to the destination
|
||||
tickets map[uuid.UUID]map[uuid.UUID]struct{}
|
||||
coordination *tailnet.BasicCoordination
|
||||
sendGroup singleflight.Group
|
||||
|
||||
cancel context.CancelFunc
|
||||
expireOldAgentsDone chan struct{}
|
||||
@@ -418,28 +420,44 @@ func (m *MultiAgentController) New(client tailnet.CoordinatorClient) tailnet.Clo
|
||||
|
||||
func (m *MultiAgentController) ensureAgent(agentID uuid.UUID) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
_, ok := m.connectionTimes[agentID]
|
||||
// If we don't have the agent, subscribe.
|
||||
if !ok {
|
||||
m.logger.Debug(context.Background(),
|
||||
"subscribing to agent", slog.F("agent_id", agentID))
|
||||
if m.coordination != nil {
|
||||
err := m.coordination.Client.Send(&proto.CoordinateRequest{
|
||||
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
|
||||
})
|
||||
if err != nil {
|
||||
err = xerrors.Errorf("subscribe agent: %w", err)
|
||||
m.coordination.SendErr(err)
|
||||
_ = m.coordination.Client.Close()
|
||||
m.coordination = nil
|
||||
return err
|
||||
}
|
||||
}
|
||||
m.tickets[agentID] = map[uuid.UUID]struct{}{}
|
||||
if ok {
|
||||
m.connectionTimes[agentID] = time.Now()
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
m.logger.Debug(context.Background(),
|
||||
"subscribing to agent", slog.F("agent_id", agentID))
|
||||
|
||||
_, err, _ := m.sendGroup.Do(agentID.String(), func() (interface{}, error) {
|
||||
m.mu.Lock()
|
||||
coord := m.coordination
|
||||
m.mu.Unlock()
|
||||
if coord == nil {
|
||||
return nil, xerrors.New("no active coordination")
|
||||
}
|
||||
err := coord.Client.Send(&proto.CoordinateRequest{
|
||||
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.tickets[agentID] = map[uuid.UUID]struct{}{}
|
||||
m.mu.Unlock()
|
||||
return nil, nil
|
||||
})
|
||||
if err != nil {
|
||||
m.logger.Error(context.Background(), "ensureAgent send failed",
|
||||
slog.F("agent_id", agentID), slog.Error(err))
|
||||
return xerrors.Errorf("send AddTunnel: %w", err)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.connectionTimes[agentID] = time.Now()
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -730,7 +730,10 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log := s.Logger.With(slog.F("agent_id", appToken.AgentID))
|
||||
log := s.Logger.With(
|
||||
slog.F("agent_id", appToken.AgentID),
|
||||
slog.F("workspace_id", appToken.WorkspaceID),
|
||||
)
|
||||
log.Debug(ctx, "resolved PTY request")
|
||||
|
||||
values := r.URL.Query()
|
||||
@@ -765,19 +768,21 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
go httpapi.HeartbeatClose(ctx, s.Logger, cancel, conn)
|
||||
|
||||
ctx, wsNetConn := WebsocketNetConn(ctx, conn, websocket.MessageBinary)
|
||||
defer wsNetConn.Close() // Also closes conn.
|
||||
|
||||
go httpapi.HeartbeatClose(ctx, log, cancel, conn)
|
||||
|
||||
dialStart := time.Now()
|
||||
|
||||
agentConn, release, err := s.AgentProvider.AgentConn(ctx, appToken.AgentID)
|
||||
if err != nil {
|
||||
log.Debug(ctx, "dial workspace agent", slog.Error(err))
|
||||
log.Debug(ctx, "dial workspace agent", slog.Error(err), slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
|
||||
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err))
|
||||
return
|
||||
}
|
||||
defer release()
|
||||
log.Debug(ctx, "dialed workspace agent")
|
||||
log.Debug(ctx, "dialed workspace agent", slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
|
||||
// #nosec G115 - Safe conversion for terminal height/width which are expected to be within uint16 range (0-65535)
|
||||
ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect, uint16(height), uint16(width), r.URL.Query().Get("command"), func(arp *workspacesdk.AgentReconnectingPTYInit) {
|
||||
arp.Container = container
|
||||
@@ -785,12 +790,12 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
|
||||
arp.BackendType = backendType
|
||||
})
|
||||
if err != nil {
|
||||
log.Debug(ctx, "dial reconnecting pty server in workspace agent", slog.Error(err))
|
||||
log.Debug(ctx, "dial reconnecting pty server in workspace agent", slog.Error(err), slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
|
||||
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial: %s", err))
|
||||
return
|
||||
}
|
||||
defer ptNetConn.Close()
|
||||
log.Debug(ctx, "obtained PTY")
|
||||
log.Debug(ctx, "obtained PTY", slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
|
||||
|
||||
report := newStatsReportFromSignedToken(*appToken)
|
||||
s.collectStats(report)
|
||||
@@ -800,7 +805,7 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
|
||||
}()
|
||||
|
||||
agentssh.Bicopy(ctx, wsNetConn, ptNetConn)
|
||||
log.Debug(ctx, "pty Bicopy finished")
|
||||
log.Debug(ctx, "pty Bicopy finished", slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
|
||||
}
|
||||
|
||||
func (s *Server) collectStats(stats StatsReport) {
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/sync/singleflight"
|
||||
"golang.org/x/xerrors"
|
||||
gProto "google.golang.org/protobuf/proto"
|
||||
|
||||
@@ -33,9 +34,9 @@ const (
|
||||
eventReadyForHandshake = "tailnet_ready_for_handshake"
|
||||
HeartbeatPeriod = time.Second * 2
|
||||
MissedHeartbeats = 3
|
||||
numQuerierWorkers = 10
|
||||
numQuerierWorkers = 40
|
||||
numBinderWorkers = 10
|
||||
numTunnelerWorkers = 10
|
||||
numTunnelerWorkers = 20
|
||||
numHandshakerWorkers = 5
|
||||
dbMaxBackoff = 10 * time.Second
|
||||
cleanupPeriod = time.Hour
|
||||
@@ -770,6 +771,9 @@ func (m *mapper) bestToUpdate(best map[uuid.UUID]mapping) *proto.CoordinateRespo
|
||||
|
||||
for k := range m.sent {
|
||||
if _, ok := best[k]; !ok {
|
||||
m.logger.Debug(m.ctx, "peer no longer in best mappings, sending DISCONNECTED",
|
||||
slog.F("peer_id", k),
|
||||
)
|
||||
resp.PeerUpdates = append(resp.PeerUpdates, &proto.CoordinateResponse_PeerUpdate{
|
||||
Id: agpl.UUIDToByteSlice(k),
|
||||
Kind: proto.CoordinateResponse_PeerUpdate_DISCONNECTED,
|
||||
@@ -820,6 +824,8 @@ type querier struct {
|
||||
mu sync.Mutex
|
||||
mappers map[mKey]*mapper
|
||||
healthy bool
|
||||
|
||||
resyncGroup singleflight.Group
|
||||
}
|
||||
|
||||
func newQuerier(ctx context.Context,
|
||||
@@ -958,7 +964,7 @@ func (q *querier) cleanupConn(c *connIO) {
|
||||
|
||||
// maxBatchSize is the maximum number of keys to process in a single batch
|
||||
// query.
|
||||
const maxBatchSize = 50
|
||||
const maxBatchSize = 200
|
||||
|
||||
func (q *querier) peerUpdateWorker() {
|
||||
defer q.wg.Done()
|
||||
@@ -1207,8 +1213,13 @@ func (q *querier) subscribe() {
|
||||
func (q *querier) listenPeer(_ context.Context, msg []byte, err error) {
|
||||
if xerrors.Is(err, pubsub.ErrDroppedMessages) {
|
||||
q.logger.Warn(q.ctx, "pubsub may have dropped peer updates")
|
||||
// we need to schedule a full resync of peer mappings
|
||||
q.resyncPeerMappings()
|
||||
// Schedule a full resync asynchronously so we don't block the
|
||||
// pubsub drain goroutine. Singleflight coalesces concurrent
|
||||
// resync requests.
|
||||
go q.resyncGroup.Do("resync", func() (any, error) {
|
||||
q.resyncPeerMappings()
|
||||
return nil, nil
|
||||
})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
@@ -1234,8 +1245,13 @@ func (q *querier) listenPeer(_ context.Context, msg []byte, err error) {
|
||||
func (q *querier) listenTunnel(_ context.Context, msg []byte, err error) {
|
||||
if xerrors.Is(err, pubsub.ErrDroppedMessages) {
|
||||
q.logger.Warn(q.ctx, "pubsub may have dropped tunnel updates")
|
||||
// we need to schedule a full resync of peer mappings
|
||||
q.resyncPeerMappings()
|
||||
// Schedule a full resync asynchronously so we don't block the
|
||||
// pubsub drain goroutine. Singleflight coalesces concurrent
|
||||
// resync requests.
|
||||
go q.resyncGroup.Do("resync", func() (any, error) {
|
||||
q.resyncPeerMappings()
|
||||
return nil, nil
|
||||
})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
@@ -1601,6 +1617,10 @@ func (h *heartbeats) filter(mappings []mapping) []mapping {
|
||||
// the only mapping available for it. Newer mappings will take
|
||||
// precedence.
|
||||
m.kind = proto.CoordinateResponse_PeerUpdate_LOST
|
||||
h.logger.Debug(h.ctx, "mapping rewritten to LOST due to missed heartbeats",
|
||||
slog.F("peer_id", m.peer),
|
||||
slog.F("coordinator_id", m.coordinator),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -76,11 +76,11 @@ replace github.com/aquasecurity/trivy => github.com/coder/trivy v0.0.0-202603091
|
||||
// https://github.com/spf13/afero/pull/487
|
||||
replace github.com/spf13/afero => github.com/aslilac/afero v0.0.0-20250403163713-f06e86036696
|
||||
|
||||
// Forked from kylecarbs/fantasy (cj/go1.25 branch) which adds:
|
||||
// Forked from coder/fantasy (cj/go1.25 branch) which adds:
|
||||
// 1) Anthropic computer use + thinking effort
|
||||
// 2) Go 1.25 downgrade for Windows CI compat
|
||||
// 3) ibetitsmike/fantasy#4 — skip ephemeral replay items when store=false
|
||||
replace charm.land/fantasy => github.com/kylecarbs/fantasy v0.0.0-20260325145725-112927d9b6d8
|
||||
replace charm.land/fantasy => github.com/coder/fantasy v0.0.0-20260325145725-112927d9b6d8
|
||||
|
||||
replace github.com/charmbracelet/anthropic-sdk-go => github.com/kylecarbs/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab
|
||||
|
||||
|
||||
@@ -322,6 +322,8 @@ github.com/coder/bubbletea v1.2.2-0.20241212190825-007a1cdb2c41 h1:SBN/DA63+ZHwu
|
||||
github.com/coder/bubbletea v1.2.2-0.20241212190825-007a1cdb2c41/go.mod h1:I9ULxr64UaOSUv7hcb3nX4kowodJCVS7vt7VVJk/kW4=
|
||||
github.com/coder/clistat v1.2.1 h1:P9/10njXMyj5cWzIU5wkRsSy5LVQH49+tcGMsAgWX0w=
|
||||
github.com/coder/clistat v1.2.1/go.mod h1:m7SC0uj88eEERgvF8Kn6+w6XF21BeSr+15f7GoLAw0A=
|
||||
github.com/coder/fantasy v0.0.0-20260325145725-112927d9b6d8 h1:n+6v+yT1B6V4oSGPmXFh7mul1E+RzG9rnqp50Vb7M/w=
|
||||
github.com/coder/fantasy v0.0.0-20260325145725-112927d9b6d8/go.mod h1:ktfNX0xDpIKeggZbP/j5IYJci6pyMOR3WmZSfz9XLYw=
|
||||
github.com/coder/flog v1.1.0 h1:kbAes1ai8fIS5OeV+QAnKBQE22ty1jRF/mcAwHpLBa4=
|
||||
github.com/coder/flog v1.1.0/go.mod h1:UQlQvrkJBvnRGo69Le8E24Tcl5SJleAAR7gYEHzAmdQ=
|
||||
github.com/coder/go-httpstat v0.0.0-20230801153223-321c88088322 h1:m0lPZjlQ7vdVpRBPKfYIFlmgevoTkBxB10wv6l2gOaU=
|
||||
@@ -813,8 +815,6 @@ github.com/kylecarbs/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab h1:5UMY
|
||||
github.com/kylecarbs/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab/go.mod h1:hqlYqR7uPKOKfnNeicUbZp0Ps0GeYFlKYtwh5HGDCx8=
|
||||
github.com/kylecarbs/chroma/v2 v2.0.0-20240401211003-9e036e0631f3 h1:Z9/bo5PSeMutpdiKYNt/TTSfGM1Ll0naj3QzYX9VxTc=
|
||||
github.com/kylecarbs/chroma/v2 v2.0.0-20240401211003-9e036e0631f3/go.mod h1:BUGjjsD+ndS6eX37YgTchSEG+Jg9Jv1GiZs9sqPqztk=
|
||||
github.com/kylecarbs/fantasy v0.0.0-20260325145725-112927d9b6d8 h1:fZ0208U3B438fDSHCc/GNioPIyaFqn6eBsQTO61QtrI=
|
||||
github.com/kylecarbs/fantasy v0.0.0-20260325145725-112927d9b6d8/go.mod h1:ktfNX0xDpIKeggZbP/j5IYJci6pyMOR3WmZSfz9XLYw=
|
||||
github.com/kylecarbs/openai-go/v3 v3.0.0-20260319113850-9477dcaedcae h1:xlFZNX4nnxpj9Cf6mTwD3pirXGNtBJ/6COsf9iZmsL0=
|
||||
github.com/kylecarbs/openai-go/v3 v3.0.0-20260319113850-9477dcaedcae/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
|
||||
github.com/kylecarbs/spinner v1.18.2-0.20220329160715-20702b5af89e h1:OP0ZMFeZkUnOzTFRfpuK3m7Kp4fNvC6qN+exwj7aI4M=
|
||||
|
||||
+44
-35
@@ -68,17 +68,17 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
|
||||
return xerrors.Errorf("detecting branch: %w", err)
|
||||
}
|
||||
|
||||
// Match standard release branches (release/2.32) and RC
|
||||
// branches (release/2.32-rc.0).
|
||||
branchRe := regexp.MustCompile(`^release/(\d+)\.(\d+)(?:-rc\.(\d+))?$`)
|
||||
// Match release branches (release/X.Y). RCs are tagged
|
||||
// from main, not from release branches.
|
||||
branchRe := regexp.MustCompile(`^release/(\d+)\.(\d+)$`)
|
||||
m := branchRe.FindStringSubmatch(currentBranch)
|
||||
if m == nil {
|
||||
warnf(w, "Current branch %q is not a release branch (release/X.Y or release/X.Y-rc.N).", currentBranch)
|
||||
warnf(w, "Current branch %q is not a release branch (release/X.Y).", currentBranch)
|
||||
branchInput, err := cliui.Prompt(inv, cliui.PromptOptions{
|
||||
Text: "Enter the release branch to use (e.g. release/2.21 or release/2.21-rc.0)",
|
||||
Text: "Enter the release branch to use (e.g. release/2.21)",
|
||||
Validate: func(s string) error {
|
||||
if !branchRe.MatchString(s) {
|
||||
return xerrors.New("must be in format release/X.Y or release/X.Y-rc.N (e.g. release/2.21 or release/2.21-rc.0)")
|
||||
return xerrors.New("must be in format release/X.Y (e.g. release/2.21)")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
@@ -91,10 +91,6 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
|
||||
}
|
||||
branchMajor, _ := strconv.Atoi(m[1])
|
||||
branchMinor, _ := strconv.Atoi(m[2])
|
||||
branchRC := -1 // -1 means not an RC branch.
|
||||
if m[3] != "" {
|
||||
branchRC, _ = strconv.Atoi(m[3])
|
||||
}
|
||||
successf(w, "Using release branch: %s", currentBranch)
|
||||
|
||||
// --- Fetch & sync check ---
|
||||
@@ -138,31 +134,44 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
|
||||
}
|
||||
}
|
||||
|
||||
// changelogBaseRef is the git ref used as the starting point
|
||||
// for release notes generation. When a tag already exists in
|
||||
// this minor series we use it directly. For the first release
|
||||
// on a new minor no matching tag exists, so we compute the
|
||||
// merge-base with the previous minor's release branch instead.
|
||||
// This works even when that branch has no tags yet (it was
|
||||
// just cut and pushed). As a last resort we fall back to the
|
||||
// latest reachable tag from a previous minor.
|
||||
var changelogBaseRef string
|
||||
if prevVersion != nil {
|
||||
changelogBaseRef = prevVersion.String()
|
||||
} else {
|
||||
prevReleaseBranch := fmt.Sprintf("release/%d.%d", branchMajor, branchMinor-1)
|
||||
if err := gitRun("fetch", "--quiet", "origin", prevReleaseBranch); err != nil {
|
||||
warnf(w, "Could not fetch %s: %v", prevReleaseBranch, err)
|
||||
}
|
||||
if mb, mbErr := gitOutput("merge-base", "HEAD", "origin/"+prevReleaseBranch); mbErr == nil && mb != "" {
|
||||
changelogBaseRef = mb
|
||||
infof(w, "Using merge-base with %s as changelog base: %s", prevReleaseBranch, mb[:12])
|
||||
} else {
|
||||
// No previous release branch found; fall back to
|
||||
// the latest reachable tag from a previous minor.
|
||||
for _, t := range mergedTags {
|
||||
if t.Major == branchMajor && t.Minor < branchMinor {
|
||||
changelogBaseRef = t.String()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var suggested version
|
||||
if prevVersion == nil {
|
||||
infof(w, "No previous release tag found on this branch.")
|
||||
suggested = version{Major: branchMajor, Minor: branchMinor, Patch: 0}
|
||||
if branchRC >= 0 {
|
||||
suggested.Pre = fmt.Sprintf("rc.%d", branchRC)
|
||||
}
|
||||
} else {
|
||||
infof(w, "Previous release tag: %s", prevVersion.String())
|
||||
if branchRC >= 0 {
|
||||
// On an RC branch, suggest the next RC for
|
||||
// the same base version.
|
||||
nextRC := 0
|
||||
if prevVersion.IsRC() {
|
||||
nextRC = prevVersion.rcNumber() + 1
|
||||
}
|
||||
suggested = version{
|
||||
Major: prevVersion.Major,
|
||||
Minor: prevVersion.Minor,
|
||||
Patch: prevVersion.Patch,
|
||||
Pre: fmt.Sprintf("rc.%d", nextRC),
|
||||
}
|
||||
} else {
|
||||
suggested = version{Major: prevVersion.Major, Minor: prevVersion.Minor, Patch: prevVersion.Patch + 1}
|
||||
}
|
||||
suggested = version{Major: prevVersion.Major, Minor: prevVersion.Minor, Patch: prevVersion.Patch + 1}
|
||||
}
|
||||
|
||||
fmt.Fprintln(w)
|
||||
@@ -366,8 +375,8 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
|
||||
infof(w, "Generating release notes...")
|
||||
|
||||
commitRange := "HEAD"
|
||||
if prevVersion != nil {
|
||||
commitRange = prevVersion.String() + "..HEAD"
|
||||
if changelogBaseRef != "" {
|
||||
commitRange = changelogBaseRef + "..HEAD"
|
||||
}
|
||||
|
||||
commits, err := commitLog(commitRange)
|
||||
@@ -473,16 +482,16 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
|
||||
}
|
||||
if !hasContent {
|
||||
prevStr := "the beginning of time"
|
||||
if prevVersion != nil {
|
||||
prevStr = prevVersion.String()
|
||||
if changelogBaseRef != "" {
|
||||
prevStr = changelogBaseRef
|
||||
}
|
||||
fmt.Fprintf(¬es, "\n_No changes since %s._\n", prevStr)
|
||||
}
|
||||
|
||||
// Compare link.
|
||||
if prevVersion != nil {
|
||||
if changelogBaseRef != "" {
|
||||
fmt.Fprintf(¬es, "\nCompare: [`%s...%s`](https://github.com/%s/%s/compare/%s...%s)\n",
|
||||
prevVersion, newVersion, owner, repo, prevVersion, newVersion)
|
||||
changelogBaseRef, newVersion, owner, repo, changelogBaseRef, newVersion)
|
||||
}
|
||||
|
||||
// Container image.
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import type { Interpolation, Theme } from "@emotion/react";
|
||||
import type { FC, HTMLAttributes } from "react";
|
||||
import type { LogLevel } from "#/api/typesGenerated";
|
||||
import { MONOSPACE_FONT_FAMILY } from "#/theme/constants";
|
||||
import { cn } from "#/utils/cn";
|
||||
|
||||
const DEFAULT_LOG_LINE_SIDE_PADDING = 24;
|
||||
|
||||
@@ -17,65 +16,40 @@ type LogLineProps = {
|
||||
level: LogLevel;
|
||||
} & HTMLAttributes<HTMLPreElement>;
|
||||
|
||||
export const LogLine: FC<LogLineProps> = ({ level, ...divProps }) => {
|
||||
export const LogLine: FC<LogLineProps> = ({ level, className, ...props }) => {
|
||||
return (
|
||||
<pre
|
||||
css={styles.line}
|
||||
className={`${level} ${divProps.className} logs-line`}
|
||||
{...divProps}
|
||||
{...props}
|
||||
className={cn(
|
||||
"logs-line",
|
||||
"m-0 break-all flex items-center h-auto",
|
||||
"text-[13px] text-content-primary font-mono",
|
||||
level === "error" &&
|
||||
"bg-surface-error text-content-error [&_.dashed-line]:bg-border-error",
|
||||
level === "debug" &&
|
||||
"bg-surface-sky text-content-sky [&_.dashed-line]:bg-border-sky",
|
||||
level === "warn" &&
|
||||
"bg-surface-warning text-content-warning [&_.dashed-line]:bg-border-warning",
|
||||
className,
|
||||
)}
|
||||
style={{
|
||||
padding: `0 var(--log-line-side-padding, ${DEFAULT_LOG_LINE_SIDE_PADDING}px)`,
|
||||
}}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export const LogLinePrefix: FC<HTMLAttributes<HTMLSpanElement>> = (props) => {
|
||||
return <pre css={styles.prefix} {...props} />;
|
||||
export const LogLinePrefix: FC<HTMLAttributes<HTMLSpanElement>> = ({
|
||||
className,
|
||||
...props
|
||||
}) => {
|
||||
return (
|
||||
<pre
|
||||
className={cn(
|
||||
"select-none m-0 inline-block text-content-secondary mr-6",
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
const styles = {
|
||||
line: (theme) => ({
|
||||
margin: 0,
|
||||
wordBreak: "break-all",
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
fontSize: 13,
|
||||
color: theme.palette.text.primary,
|
||||
fontFamily: MONOSPACE_FONT_FAMILY,
|
||||
height: "auto",
|
||||
padding: `0 var(--log-line-side-padding, ${DEFAULT_LOG_LINE_SIDE_PADDING}px)`,
|
||||
|
||||
"&.error": {
|
||||
backgroundColor: theme.roles.error.background,
|
||||
color: theme.roles.error.text,
|
||||
|
||||
"& .dashed-line": {
|
||||
backgroundColor: theme.roles.error.outline,
|
||||
},
|
||||
},
|
||||
|
||||
"&.debug": {
|
||||
backgroundColor: theme.roles.notice.background,
|
||||
color: theme.roles.notice.text,
|
||||
|
||||
"& .dashed-line": {
|
||||
backgroundColor: theme.roles.notice.outline,
|
||||
},
|
||||
},
|
||||
|
||||
"&.warn": {
|
||||
backgroundColor: theme.roles.warning.background,
|
||||
color: theme.roles.warning.text,
|
||||
|
||||
"& .dashed-line": {
|
||||
backgroundColor: theme.roles.warning.outline,
|
||||
},
|
||||
},
|
||||
}),
|
||||
|
||||
prefix: (theme) => ({
|
||||
userSelect: "none",
|
||||
margin: 0,
|
||||
display: "inline-block",
|
||||
color: theme.palette.text.secondary,
|
||||
marginRight: 24,
|
||||
}),
|
||||
} satisfies Record<string, Interpolation<Theme>>;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import type { Interpolation, Theme } from "@emotion/react";
|
||||
import dayjs from "dayjs";
|
||||
import type { FC } from "react";
|
||||
import { cn } from "#/utils/cn";
|
||||
import { type Line, LogLine, LogLinePrefix } from "./LogLine";
|
||||
|
||||
export const DEFAULT_LOG_LINE_SIDE_PADDING = 24;
|
||||
@@ -17,7 +17,17 @@ export const Logs: FC<LogsProps> = ({
|
||||
className = "",
|
||||
}) => {
|
||||
return (
|
||||
<div css={styles.root} className={`${className} logs-container`}>
|
||||
<div
|
||||
className={cn(
|
||||
"logs-container",
|
||||
"min-h-40 py-2 rounded-lg overflow-x-auto bg-surface-primary",
|
||||
"[&:not(:last-child)]:border-0",
|
||||
"[&:not(:last-child)]:border-solid",
|
||||
"[&:not(:last-child)]:border-b-border",
|
||||
"[&:not(:last-child)]:rounded-none",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="min-w-fit">
|
||||
{lines.map((line) => (
|
||||
<LogLine key={line.id} level={line.level}>
|
||||
@@ -33,18 +43,3 @@ export const Logs: FC<LogsProps> = ({
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const styles = {
|
||||
root: (theme) => ({
|
||||
minHeight: 156,
|
||||
padding: "8px 0",
|
||||
borderRadius: 8,
|
||||
overflowX: "auto",
|
||||
background: theme.palette.background.default,
|
||||
|
||||
"&:not(:last-child)": {
|
||||
borderBottom: `1px solid ${theme.palette.divider}`,
|
||||
borderRadius: 0,
|
||||
},
|
||||
}),
|
||||
} satisfies Record<string, Interpolation<Theme>>;
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
export { useTabOverflowKebabMenu } from "./useTabOverflowKebabMenu";
|
||||
@@ -0,0 +1,274 @@
|
||||
import {
|
||||
type RefObject,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useRef,
|
||||
useState,
|
||||
} from "react";
|
||||
|
||||
type TabValue = {
|
||||
value: string;
|
||||
};
|
||||
|
||||
type UseKebabMenuOptions<T extends TabValue> = {
|
||||
tabs: readonly T[];
|
||||
enabled: boolean;
|
||||
isActive: boolean;
|
||||
overflowTriggerWidth?: number;
|
||||
};
|
||||
|
||||
type UseKebabMenuResult<T extends TabValue> = {
|
||||
containerRef: RefObject<HTMLDivElement | null>;
|
||||
visibleTabs: T[];
|
||||
overflowTabs: T[];
|
||||
getTabMeasureProps: (tabValue: string) => Record<string, string>;
|
||||
};
|
||||
|
||||
const ALWAYS_VISIBLE_TABS_COUNT = 1;
|
||||
const DATA_ATTR_TAB_VALUE = "data-tab-overflow-item-value";
|
||||
|
||||
/**
|
||||
* Splits tabs into visible and overflow groups based on container width.
|
||||
*
|
||||
* Tabs must render with `getTabMeasureProps()` so this hook can measure
|
||||
* trigger widths from the DOM.
|
||||
*/
|
||||
export const useKebabMenu = <T extends TabValue>({
|
||||
tabs,
|
||||
enabled,
|
||||
isActive,
|
||||
overflowTriggerWidth = 44,
|
||||
}: UseKebabMenuOptions<T>): UseKebabMenuResult<T> => {
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const tabsRef = useRef<readonly T[]>(tabs);
|
||||
tabsRef.current = tabs;
|
||||
const previousTabsRef = useRef<readonly T[]>(tabs);
|
||||
const availableWidthRef = useRef<number | null>(null);
|
||||
// Width cache prevents oscillation when overflow tabs are not mounted.
|
||||
const tabWidthByValueRef = useRef<Record<string, number>>({});
|
||||
const [overflowTabValues, setTabValues] = useState<string[]>([]);
|
||||
|
||||
const recalculateOverflow = useCallback(
|
||||
(availableWidth: number) => {
|
||||
if (!enabled || !isActive) {
|
||||
// Keep this update idempotent to avoid render loops.
|
||||
setTabValues((currentValues) => {
|
||||
if (currentValues.length === 0) {
|
||||
return currentValues;
|
||||
}
|
||||
return [];
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const container = containerRef.current;
|
||||
if (!container) {
|
||||
return;
|
||||
}
|
||||
const currentTabs = tabsRef.current;
|
||||
|
||||
const tabWidthByValue = measureTabWidths({
|
||||
tabs: currentTabs,
|
||||
container,
|
||||
previousTabWidthByValue: tabWidthByValueRef.current,
|
||||
});
|
||||
tabWidthByValueRef.current = tabWidthByValue;
|
||||
|
||||
const nextOverflowValues = calculateTabValues({
|
||||
tabs: currentTabs,
|
||||
availableWidth,
|
||||
tabWidthByValue,
|
||||
overflowTriggerWidth,
|
||||
});
|
||||
|
||||
setTabValues((currentValues) => {
|
||||
// Avoid state updates when the computed overflow did not change.
|
||||
if (areStringArraysEqual(currentValues, nextOverflowValues)) {
|
||||
return currentValues;
|
||||
}
|
||||
return nextOverflowValues;
|
||||
});
|
||||
},
|
||||
[enabled, isActive, overflowTriggerWidth],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (previousTabsRef.current === tabs) {
|
||||
// No change in tabs, no need to recalculate.
|
||||
return;
|
||||
}
|
||||
previousTabsRef.current = tabs;
|
||||
if (availableWidthRef.current === null) {
|
||||
// First mount, no width available yet.
|
||||
return;
|
||||
}
|
||||
recalculateOverflow(availableWidthRef.current);
|
||||
}, [recalculateOverflow, tabs]);
|
||||
|
||||
useEffect(() => {
|
||||
const container = containerRef.current;
|
||||
if (!container) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Recompute whenever ResizeObserver reports a container width change.
|
||||
const observer = new ResizeObserver(([entry]) => {
|
||||
if (!entry) {
|
||||
return;
|
||||
}
|
||||
availableWidthRef.current = entry.contentRect.width;
|
||||
recalculateOverflow(entry.contentRect.width);
|
||||
});
|
||||
observer.observe(container);
|
||||
return () => observer.disconnect();
|
||||
}, [recalculateOverflow]);
|
||||
|
||||
const overflowTabValuesSet = new Set(overflowTabValues);
|
||||
const { visibleTabs, overflowTabs } = tabs.reduce<{
|
||||
visibleTabs: T[];
|
||||
overflowTabs: T[];
|
||||
}>(
|
||||
(tabGroups, tab) => {
|
||||
if (overflowTabValuesSet.has(tab.value)) {
|
||||
tabGroups.overflowTabs.push(tab);
|
||||
} else {
|
||||
tabGroups.visibleTabs.push(tab);
|
||||
}
|
||||
return tabGroups;
|
||||
},
|
||||
{ visibleTabs: [], overflowTabs: [] },
|
||||
);
|
||||
|
||||
const getTabMeasureProps = (tabValue: string) => {
|
||||
return { [DATA_ATTR_TAB_VALUE]: tabValue };
|
||||
};
|
||||
|
||||
return {
|
||||
containerRef,
|
||||
visibleTabs,
|
||||
overflowTabs,
|
||||
getTabMeasureProps,
|
||||
};
|
||||
};
|
||||
|
||||
const calculateTabValues = <T extends TabValue>({
|
||||
tabs,
|
||||
availableWidth,
|
||||
tabWidthByValue,
|
||||
overflowTriggerWidth,
|
||||
}: {
|
||||
tabs: readonly T[];
|
||||
availableWidth: number;
|
||||
tabWidthByValue: Readonly<Record<string, number>>;
|
||||
overflowTriggerWidth: number;
|
||||
}): string[] => {
|
||||
const tabWidthByValueMap = new Map<string, number>();
|
||||
for (const tab of tabs) {
|
||||
tabWidthByValueMap.set(tab.value, tabWidthByValue[tab.value] ?? 0);
|
||||
}
|
||||
|
||||
const firstOptionalTabIndex = Math.min(
|
||||
ALWAYS_VISIBLE_TABS_COUNT,
|
||||
tabs.length,
|
||||
);
|
||||
if (firstOptionalTabIndex >= tabs.length) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const alwaysVisibleTabs = tabs.slice(0, firstOptionalTabIndex);
|
||||
const optionalTabs = tabs.slice(firstOptionalTabIndex);
|
||||
const alwaysVisibleWidth = alwaysVisibleTabs.reduce((total, tab) => {
|
||||
return total + (tabWidthByValueMap.get(tab.value) ?? 0);
|
||||
}, 0);
|
||||
const firstTabIndex = findFirstTabIndex({
|
||||
optionalTabs,
|
||||
optionalTabWidths: optionalTabs.map((tab) => {
|
||||
return tabWidthByValueMap.get(tab.value) ?? 0;
|
||||
}),
|
||||
startingUsedWidth: alwaysVisibleWidth,
|
||||
availableWidth,
|
||||
overflowTriggerWidth,
|
||||
});
|
||||
|
||||
if (firstTabIndex === -1) {
|
||||
return [];
|
||||
}
|
||||
|
||||
return optionalTabs
|
||||
.slice(firstTabIndex)
|
||||
.map((overflowTab) => overflowTab.value);
|
||||
};
|
||||
|
||||
const measureTabWidths = <T extends TabValue>({
|
||||
tabs,
|
||||
container,
|
||||
previousTabWidthByValue,
|
||||
}: {
|
||||
tabs: readonly T[];
|
||||
container: HTMLDivElement;
|
||||
previousTabWidthByValue: Readonly<Record<string, number>>;
|
||||
}): Record<string, number> => {
|
||||
const nextTabWidthByValue = { ...previousTabWidthByValue };
|
||||
for (const tab of tabs) {
|
||||
const tabElement = container.querySelector<HTMLElement>(
|
||||
`[${DATA_ATTR_TAB_VALUE}="${tab.value}"]`,
|
||||
);
|
||||
if (tabElement) {
|
||||
nextTabWidthByValue[tab.value] = tabElement.offsetWidth;
|
||||
}
|
||||
}
|
||||
return nextTabWidthByValue;
|
||||
};
|
||||
|
||||
const findFirstTabIndex = ({
|
||||
optionalTabs,
|
||||
optionalTabWidths,
|
||||
startingUsedWidth,
|
||||
availableWidth,
|
||||
overflowTriggerWidth,
|
||||
}: {
|
||||
optionalTabs: readonly TabValue[];
|
||||
optionalTabWidths: readonly number[];
|
||||
startingUsedWidth: number;
|
||||
availableWidth: number;
|
||||
overflowTriggerWidth: number;
|
||||
}): number => {
|
||||
const result = optionalTabs.reduce(
|
||||
(acc, _tab, index) => {
|
||||
if (acc.firstTabIndex !== -1) {
|
||||
return acc;
|
||||
}
|
||||
|
||||
const tabWidth = optionalTabWidths[index] ?? 0;
|
||||
const hasMoreTabs = index < optionalTabs.length - 1;
|
||||
// Reserve kebab trigger width whenever additional tabs remain.
|
||||
const widthNeeded =
|
||||
acc.usedWidth + tabWidth + (hasMoreTabs ? overflowTriggerWidth : 0);
|
||||
|
||||
if (widthNeeded <= availableWidth) {
|
||||
return {
|
||||
usedWidth: acc.usedWidth + tabWidth,
|
||||
firstTabIndex: -1,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
usedWidth: acc.usedWidth,
|
||||
firstTabIndex: index,
|
||||
};
|
||||
},
|
||||
{ usedWidth: startingUsedWidth, firstTabIndex: -1 },
|
||||
);
|
||||
|
||||
return result.firstTabIndex;
|
||||
};
|
||||
|
||||
const areStringArraysEqual = (
|
||||
left: readonly string[],
|
||||
right: readonly string[],
|
||||
): boolean => {
|
||||
return (
|
||||
left.length === right.length &&
|
||||
left.every((value, index) => value === right[index])
|
||||
);
|
||||
};
|
||||
@@ -1,157 +0,0 @@
|
||||
import {
|
||||
type RefObject,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useLayoutEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
} from "react";
|
||||
|
||||
type TabLike = {
|
||||
value: string;
|
||||
};
|
||||
|
||||
type UseTabOverflowKebabMenuOptions<TTab extends TabLike> = {
|
||||
tabs: readonly TTab[];
|
||||
enabled: boolean;
|
||||
isActive: boolean;
|
||||
alwaysVisibleTabsCount?: number;
|
||||
overflowTriggerWidthPx?: number;
|
||||
};
|
||||
|
||||
type UseTabOverflowKebabMenuResult<TTab extends TabLike> = {
|
||||
containerRef: RefObject<HTMLDivElement | null>;
|
||||
visibleTabs: TTab[];
|
||||
overflowTabs: TTab[];
|
||||
getTabMeasureProps: (tabValue: string) => Record<string, string>;
|
||||
};
|
||||
|
||||
const DATA_ATTR_TAB_VALUE = "data-tab-overflow-item-value";
|
||||
|
||||
export const useTabOverflowKebabMenu = <TTab extends TabLike>({
|
||||
tabs,
|
||||
enabled,
|
||||
isActive,
|
||||
alwaysVisibleTabsCount = 1,
|
||||
overflowTriggerWidthPx = 44,
|
||||
}: UseTabOverflowKebabMenuOptions<TTab>): UseTabOverflowKebabMenuResult<TTab> => {
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const tabWidthByValueRef = useRef<Record<string, number>>({});
|
||||
const [overflowTabValues, setOverflowTabValues] = useState<string[]>([]);
|
||||
|
||||
const recalculateOverflow = useCallback(() => {
|
||||
if (!enabled) {
|
||||
setOverflowTabValues([]);
|
||||
return;
|
||||
}
|
||||
|
||||
const container = containerRef.current;
|
||||
if (!container) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const tab of tabs) {
|
||||
const tabElement = container.querySelector<HTMLElement>(
|
||||
`[${DATA_ATTR_TAB_VALUE}="${tab.value}"]`,
|
||||
);
|
||||
if (tabElement) {
|
||||
tabWidthByValueRef.current[tab.value] = tabElement.offsetWidth;
|
||||
}
|
||||
}
|
||||
|
||||
const alwaysVisibleTabs = tabs.slice(0, alwaysVisibleTabsCount);
|
||||
const optionalTabs = tabs.slice(alwaysVisibleTabsCount);
|
||||
if (optionalTabs.length === 0) {
|
||||
setOverflowTabValues([]);
|
||||
return;
|
||||
}
|
||||
|
||||
const alwaysVisibleWidth = alwaysVisibleTabs.reduce((total, tab) => {
|
||||
return total + (tabWidthByValueRef.current[tab.value] ?? 0);
|
||||
}, 0);
|
||||
|
||||
const availableWidth = container.clientWidth;
|
||||
let usedWidth = alwaysVisibleWidth;
|
||||
const nextOverflowValues: string[] = [];
|
||||
|
||||
for (let i = 0; i < optionalTabs.length; i++) {
|
||||
const tab = optionalTabs[i];
|
||||
const tabWidth = tabWidthByValueRef.current[tab.value] ?? 0;
|
||||
const hasMoreTabsAfterCurrent = i < optionalTabs.length - 1;
|
||||
const widthNeeded =
|
||||
usedWidth +
|
||||
tabWidth +
|
||||
(hasMoreTabsAfterCurrent ? overflowTriggerWidthPx : 0);
|
||||
|
||||
if (widthNeeded <= availableWidth) {
|
||||
usedWidth += tabWidth;
|
||||
continue;
|
||||
}
|
||||
|
||||
nextOverflowValues.push(
|
||||
...optionalTabs.slice(i).map((overflowTab) => overflowTab.value),
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
setOverflowTabValues((currentValues) => {
|
||||
if (
|
||||
currentValues.length === nextOverflowValues.length &&
|
||||
currentValues.every(
|
||||
(value, index) => value === nextOverflowValues[index],
|
||||
)
|
||||
) {
|
||||
return currentValues;
|
||||
}
|
||||
return nextOverflowValues;
|
||||
});
|
||||
}, [alwaysVisibleTabsCount, enabled, overflowTriggerWidthPx, tabs]);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
if (!isActive) {
|
||||
return;
|
||||
}
|
||||
recalculateOverflow();
|
||||
}, [isActive, recalculateOverflow]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isActive) {
|
||||
return;
|
||||
}
|
||||
const container = containerRef.current;
|
||||
if (!container) {
|
||||
return;
|
||||
}
|
||||
const observer = new ResizeObserver(() => {
|
||||
recalculateOverflow();
|
||||
});
|
||||
observer.observe(container);
|
||||
return () => observer.disconnect();
|
||||
}, [isActive, recalculateOverflow]);
|
||||
|
||||
const overflowTabValuesSet = useMemo(
|
||||
() => new Set(overflowTabValues),
|
||||
[overflowTabValues],
|
||||
);
|
||||
|
||||
const visibleTabs = useMemo(
|
||||
() => tabs.filter((tab) => !overflowTabValuesSet.has(tab.value)),
|
||||
[tabs, overflowTabValuesSet],
|
||||
);
|
||||
const overflowTabs = useMemo(
|
||||
() => tabs.filter((tab) => overflowTabValuesSet.has(tab.value)),
|
||||
[tabs, overflowTabValuesSet],
|
||||
);
|
||||
|
||||
const getTabMeasureProps = useCallback((tabValue: string) => {
|
||||
return { [DATA_ATTR_TAB_VALUE]: tabValue };
|
||||
}, []);
|
||||
|
||||
return {
|
||||
containerRef,
|
||||
visibleTabs,
|
||||
overflowTabs,
|
||||
getTabMeasureProps,
|
||||
};
|
||||
};
|
||||
@@ -2,7 +2,7 @@ import type { Meta, StoryObj } from "@storybook/react-vite";
|
||||
import { expect, spyOn, userEvent, waitFor, within } from "storybook/test";
|
||||
import { API } from "#/api/api";
|
||||
import { workspaceAgentContainersKey } from "#/api/queries/workspaces";
|
||||
import type * as TypesGen from "#/api/typesGenerated";
|
||||
import type { WorkspaceAgentLogSource } from "#/api/typesGenerated";
|
||||
import { getPreferredProxy } from "#/contexts/ProxyContext";
|
||||
import { chromatic } from "#/testHelpers/chromatic";
|
||||
import * as M from "#/testHelpers/entities";
|
||||
@@ -92,7 +92,7 @@ const logs = [
|
||||
created_at: fixedLogTimestamp,
|
||||
}));
|
||||
|
||||
const installScriptLogSource: TypesGen.WorkspaceAgentLogSource = {
|
||||
const installScriptLogSource: WorkspaceAgentLogSource = {
|
||||
...M.MockWorkspaceAgentLogSource,
|
||||
id: "f2ee4b8d-b09d-4f4e-a1f1-5e4adf7d53bb",
|
||||
display_name: "Install Script",
|
||||
@@ -122,42 +122,6 @@ const tabbedLogs = [
|
||||
},
|
||||
];
|
||||
|
||||
const overflowLogSources: TypesGen.WorkspaceAgentLogSource[] = [
|
||||
M.MockWorkspaceAgentLogSource,
|
||||
{
|
||||
...M.MockWorkspaceAgentLogSource,
|
||||
id: "58f5db69-5f78-496f-bce1-0686f5525aa1",
|
||||
display_name: "code-server",
|
||||
icon: "/icon/code.svg",
|
||||
},
|
||||
{
|
||||
...M.MockWorkspaceAgentLogSource,
|
||||
id: "f39d758c-bce2-4f41-8d70-58fdb1f0f729",
|
||||
display_name: "Install and start AgentAPI",
|
||||
icon: "/icon/claude.svg",
|
||||
},
|
||||
{
|
||||
...M.MockWorkspaceAgentLogSource,
|
||||
id: "bf7529b8-1787-4a20-b54f-eb894680e48f",
|
||||
display_name: "Mux",
|
||||
icon: "/icon/mux.svg",
|
||||
},
|
||||
{
|
||||
...M.MockWorkspaceAgentLogSource,
|
||||
id: "0d6ebde6-c534-4551-9f91-bfd98bfb04f4",
|
||||
display_name: "Portable Desktop",
|
||||
icon: "/icon/portable-desktop.svg",
|
||||
},
|
||||
];
|
||||
|
||||
const overflowLogs = overflowLogSources.map((source, index) => ({
|
||||
id: 200 + index,
|
||||
level: "info",
|
||||
output: `${source.display_name}: line`,
|
||||
source_id: source.id,
|
||||
created_at: fixedLogTimestamp,
|
||||
}));
|
||||
|
||||
const meta: Meta<typeof AgentRow> = {
|
||||
title: "components/AgentRow",
|
||||
component: AgentRow,
|
||||
@@ -440,44 +404,3 @@ export const LogsTabs: Story = {
|
||||
await expect(canvas.getByText("install: pnpm install")).toBeVisible();
|
||||
},
|
||||
};
|
||||
|
||||
export const LogsTabsOverflow: Story = {
|
||||
args: {
|
||||
agent: {
|
||||
...M.MockWorkspaceAgentReady,
|
||||
logs_length: overflowLogs.length,
|
||||
log_sources: overflowLogSources,
|
||||
},
|
||||
},
|
||||
parameters: {
|
||||
webSocket: [
|
||||
{
|
||||
event: "message",
|
||||
data: JSON.stringify(overflowLogs),
|
||||
},
|
||||
],
|
||||
},
|
||||
render: (args) => (
|
||||
<div className="max-w-[320px]">
|
||||
<AgentRow {...args} />
|
||||
</div>
|
||||
),
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
const page = within(canvasElement.ownerDocument.body);
|
||||
await userEvent.click(canvas.getByRole("button", { name: "Logs" }));
|
||||
await userEvent.click(
|
||||
canvas.getByRole("button", { name: "More log tabs" }),
|
||||
);
|
||||
const overflowItems = await page.findAllByRole("menuitemradio");
|
||||
const selectedItem = overflowItems[0];
|
||||
const selectedSource = selectedItem.textContent;
|
||||
if (!selectedSource) {
|
||||
throw new Error("Overflow menu item must have text content.");
|
||||
}
|
||||
await userEvent.click(selectedItem);
|
||||
await waitFor(() =>
|
||||
expect(canvas.getByText(`${selectedSource}: line`)).toBeVisible(),
|
||||
);
|
||||
},
|
||||
};
|
||||
|
||||
@@ -8,10 +8,9 @@ import {
|
||||
} from "lucide-react";
|
||||
import {
|
||||
type FC,
|
||||
useCallback,
|
||||
type ReactNode,
|
||||
useEffect,
|
||||
useLayoutEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
} from "react";
|
||||
@@ -42,7 +41,7 @@ import {
|
||||
TabsList,
|
||||
TabsTrigger,
|
||||
} from "#/components/Tabs/Tabs";
|
||||
import { useTabOverflowKebabMenu } from "#/components/Tabs/utils";
|
||||
import { useKebabMenu } from "#/components/Tabs/utils/useKebabMenu";
|
||||
import { useProxy } from "#/contexts/ProxyContext";
|
||||
import { useClipboard } from "#/hooks/useClipboard";
|
||||
import { useFeatureVisibility } from "#/modules/dashboard/useFeatureVisibility";
|
||||
@@ -162,7 +161,7 @@ export const AgentRow: FC<AgentRowProps> = ({
|
||||
// This is a bit of a hack on the react-window API to get the scroll position.
|
||||
// If we're scrolled to the bottom, we want to keep the list scrolled to the bottom.
|
||||
// This makes it feel similar to a terminal that auto-scrolls downwards!
|
||||
const handleLogScroll = useCallback((props: ListOnScrollProps) => {
|
||||
const handleLogScroll = (props: ListOnScrollProps) => {
|
||||
if (
|
||||
props.scrollOffset === 0 ||
|
||||
props.scrollUpdateWasRequested ||
|
||||
@@ -179,7 +178,7 @@ export const AgentRow: FC<AgentRowProps> = ({
|
||||
logListDivRef.current.scrollHeight -
|
||||
(props.scrollOffset + parent.clientHeight);
|
||||
setBottomOfLogs(distanceFromBottom < AGENT_LOG_LINE_HEIGHT);
|
||||
}, []);
|
||||
};
|
||||
|
||||
const devcontainers = useAgentContainers(agent);
|
||||
|
||||
@@ -211,59 +210,56 @@ export const AgentRow: FC<AgentRowProps> = ({
|
||||
);
|
||||
|
||||
const [selectedLogTab, setSelectedLogTab] = useState("all");
|
||||
const logTabs = useMemo(() => {
|
||||
const sourceLogTabs = agent.log_sources
|
||||
.filter((logSource) => {
|
||||
// Remove the logSources that have no entries.
|
||||
return agentLogs.some(
|
||||
(log) =>
|
||||
log.source_id === logSource.id && (log.output?.length ?? 0) > 0,
|
||||
);
|
||||
})
|
||||
.map((logSource) => ({
|
||||
// Show the icon for the log source if it has one.
|
||||
// In the startup script case, we show a bespoke play icon.
|
||||
startIcon: logSource.icon ? (
|
||||
<ExternalImage
|
||||
src={logSource.icon}
|
||||
alt=""
|
||||
className="size-icon-xs shrink-0"
|
||||
/>
|
||||
) : logSource.display_name === STARTUP_SCRIPT_DISPLAY_NAME ? (
|
||||
<PlayIcon className="size-icon-xs shrink-0" />
|
||||
) : null,
|
||||
title: logSource.display_name,
|
||||
value: logSource.id,
|
||||
}));
|
||||
const startupScriptLogTab = sourceLogTabs.find(
|
||||
(tab) => tab.title === STARTUP_SCRIPT_DISPLAY_NAME,
|
||||
);
|
||||
const sortedSourceLogTabs = sourceLogTabs
|
||||
.filter((tab) => tab !== startupScriptLogTab)
|
||||
.sort((a, b) => a.title.localeCompare(b.title));
|
||||
return [
|
||||
{
|
||||
title: "All Logs",
|
||||
value: "all",
|
||||
},
|
||||
...(startupScriptLogTab ? [startupScriptLogTab] : []),
|
||||
...sortedSourceLogTabs,
|
||||
] as {
|
||||
startIcon?: React.ReactNode;
|
||||
title: string;
|
||||
value: string;
|
||||
}[];
|
||||
}, [agent.log_sources, agentLogs]);
|
||||
const sourceLogTabs = agent.log_sources
|
||||
.filter((logSource) => {
|
||||
// Remove the logSources that have no entries.
|
||||
return agentLogs.some(
|
||||
(log) =>
|
||||
log.source_id === logSource.id && (log.output?.length ?? 0) > 0,
|
||||
);
|
||||
})
|
||||
.map((logSource) => ({
|
||||
// Show the icon for the log source if it has one.
|
||||
// In the startup script case, we show a bespoke play icon.
|
||||
startIcon: logSource.icon ? (
|
||||
<ExternalImage
|
||||
src={logSource.icon}
|
||||
alt=""
|
||||
className="size-icon-xs shrink-0"
|
||||
/>
|
||||
) : logSource.display_name === STARTUP_SCRIPT_DISPLAY_NAME ? (
|
||||
<PlayIcon className="size-icon-xs shrink-0" />
|
||||
) : null,
|
||||
title: logSource.display_name,
|
||||
value: logSource.id,
|
||||
}));
|
||||
const startupScriptLogTab = sourceLogTabs.find(
|
||||
(tab) => tab.title === STARTUP_SCRIPT_DISPLAY_NAME,
|
||||
);
|
||||
const sortedSourceLogTabs = sourceLogTabs
|
||||
.filter((tab) => tab !== startupScriptLogTab)
|
||||
.sort((a, b) => a.title.localeCompare(b.title));
|
||||
const logTabs: {
|
||||
startIcon?: ReactNode;
|
||||
title: string;
|
||||
value: string;
|
||||
}[] = [
|
||||
{
|
||||
title: "All Logs",
|
||||
value: "all",
|
||||
},
|
||||
...(startupScriptLogTab ? [startupScriptLogTab] : []),
|
||||
...sortedSourceLogTabs,
|
||||
];
|
||||
const {
|
||||
containerRef: logTabsListContainerRef,
|
||||
visibleTabs: visibleLogTabs,
|
||||
overflowTabs: overflowLogTabs,
|
||||
getTabMeasureProps,
|
||||
} = useTabOverflowKebabMenu({
|
||||
} = useKebabMenu({
|
||||
tabs: logTabs,
|
||||
enabled: true,
|
||||
isActive: showLogs,
|
||||
alwaysVisibleTabsCount: 1,
|
||||
});
|
||||
const overflowLogTabValuesSet = new Set(
|
||||
overflowLogTabs.map((tab) => tab.value),
|
||||
@@ -279,16 +275,29 @@ export const AgentRow: FC<AgentRowProps> = ({
|
||||
level: log.level,
|
||||
sourceId: log.source_id,
|
||||
}));
|
||||
const allLogsText = agentLogs.map((log) => log.output).join("\n");
|
||||
const selectedLogsText = selectedLogs.map((log) => log.output).join("\n");
|
||||
const hasSelectedLogs = selectedLogs.length > 0;
|
||||
const hasAnyLogs = agentLogs.length > 0;
|
||||
const { showCopiedSuccess, copyToClipboard } = useClipboard();
|
||||
const selectedLogTabTitle =
|
||||
logTabs.find((tab) => tab.value === selectedLogTab)?.title ?? "Logs";
|
||||
const sanitizedTabTitle = selectedLogTabTitle
|
||||
.toLowerCase()
|
||||
.replaceAll(/[^a-z0-9]+/g, "-")
|
||||
.replaceAll(/(^-|-$)/g, "");
|
||||
const logFilenameSuffix = sanitizedTabTitle || "logs";
|
||||
const downloadableLogSets = logTabs
|
||||
.filter((tab) => tab.value !== "all")
|
||||
.map((tab) => {
|
||||
const logsText = agentLogs
|
||||
.filter((log) => log.source_id === tab.value)
|
||||
.map((log) => log.output)
|
||||
.join("\n");
|
||||
const filenameSuffix = tab.title
|
||||
.toLowerCase()
|
||||
.replaceAll(/[^a-z0-9]+/g, "-")
|
||||
.replaceAll(/(^-|-$)/g, "");
|
||||
return {
|
||||
label: tab.title,
|
||||
filenameSuffix: filenameSuffix || tab.value,
|
||||
logsText,
|
||||
startIcon: tab.startIcon,
|
||||
};
|
||||
});
|
||||
|
||||
return (
|
||||
<div
|
||||
@@ -547,9 +556,9 @@ export const AgentRow: FC<AgentRowProps> = ({
|
||||
</Button>
|
||||
<DownloadSelectedAgentLogsButton
|
||||
agentName={agent.name}
|
||||
filenameSuffix={logFilenameSuffix}
|
||||
logsText={selectedLogsText}
|
||||
disabled={!hasSelectedLogs}
|
||||
logSets={downloadableLogSets}
|
||||
allLogsText={allLogsText}
|
||||
disabled={!hasAnyLogs}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -1,14 +1,27 @@
|
||||
import { saveAs } from "file-saver";
|
||||
import { DownloadIcon } from "lucide-react";
|
||||
import { type FC, useState } from "react";
|
||||
import { ChevronDownIcon, DownloadIcon, PackageIcon } from "lucide-react";
|
||||
import { type FC, type ReactNode, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { getErrorDetail } from "#/api/errors";
|
||||
import { Button } from "#/components/Button/Button";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "#/components/DropdownMenu/DropdownMenu";
|
||||
|
||||
type DownloadableLogSet = {
|
||||
label: string;
|
||||
filenameSuffix: string;
|
||||
logsText: string;
|
||||
startIcon?: ReactNode;
|
||||
};
|
||||
|
||||
type DownloadSelectedAgentLogsButtonProps = {
|
||||
agentName: string;
|
||||
filenameSuffix: string;
|
||||
logsText: string;
|
||||
logSets: readonly DownloadableLogSet[];
|
||||
allLogsText: string;
|
||||
disabled?: boolean;
|
||||
download?: (file: Blob, filename: string) => void | Promise<void>;
|
||||
};
|
||||
@@ -17,13 +30,13 @@ export const DownloadSelectedAgentLogsButton: FC<
|
||||
DownloadSelectedAgentLogsButtonProps
|
||||
> = ({
|
||||
agentName,
|
||||
filenameSuffix,
|
||||
logsText,
|
||||
logSets,
|
||||
allLogsText,
|
||||
disabled = false,
|
||||
download = saveAs,
|
||||
}) => {
|
||||
const [isDownloading, setIsDownloading] = useState(false);
|
||||
const handleDownload = async () => {
|
||||
const downloadLogs = async (logsText: string, filenameSuffix: string) => {
|
||||
try {
|
||||
setIsDownloading(true);
|
||||
const file = new Blob([logsText], { type: "text/plain" });
|
||||
@@ -37,15 +50,40 @@ export const DownloadSelectedAgentLogsButton: FC<
|
||||
}
|
||||
};
|
||||
|
||||
const hasAllLogs = allLogsText.length > 0;
|
||||
|
||||
return (
|
||||
<Button
|
||||
variant="subtle"
|
||||
size="sm"
|
||||
disabled={disabled || isDownloading}
|
||||
onClick={handleDownload}
|
||||
>
|
||||
<DownloadIcon />
|
||||
{isDownloading ? "Downloading..." : "Download logs"}
|
||||
</Button>
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button variant="subtle" size="sm" disabled={disabled || isDownloading}>
|
||||
<DownloadIcon />
|
||||
{isDownloading ? "Downloading..." : "Download logs"}
|
||||
<ChevronDownIcon className="size-icon-sm" />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem
|
||||
disabled={!hasAllLogs}
|
||||
onSelect={() => {
|
||||
downloadLogs(allLogsText, "all-logs");
|
||||
}}
|
||||
>
|
||||
<PackageIcon />
|
||||
Download all logs
|
||||
</DropdownMenuItem>
|
||||
{logSets.map((logSet) => (
|
||||
<DropdownMenuItem
|
||||
key={logSet.filenameSuffix}
|
||||
disabled={logSet.logsText.length === 0}
|
||||
onSelect={() => {
|
||||
downloadLogs(logSet.logsText, logSet.filenameSuffix);
|
||||
}}
|
||||
>
|
||||
{logSet.startIcon}
|
||||
<span>Download {logSet.label}</span>
|
||||
</DropdownMenuItem>
|
||||
))}
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
import { BookOpenIcon, LoaderIcon, TriangleAlertIcon } from "lucide-react";
|
||||
import type React from "react";
|
||||
import { ScrollArea } from "#/components/ScrollArea/ScrollArea";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "#/components/Tooltip/Tooltip";
|
||||
import { cn } from "#/utils/cn";
|
||||
import { Response } from "../Response";
|
||||
import { ToolCollapsible } from "./ToolCollapsible";
|
||||
import type { ToolStatus } from "./utils";
|
||||
|
||||
export const ReadSkillTool: React.FC<{
|
||||
label: string;
|
||||
body: string;
|
||||
status: ToolStatus;
|
||||
isError: boolean;
|
||||
errorMessage?: string;
|
||||
}> = ({ label, body, status, isError, errorMessage }) => {
|
||||
const hasContent = body.length > 0;
|
||||
const isRunning = status === "running";
|
||||
|
||||
return (
|
||||
<ToolCollapsible
|
||||
className="w-full"
|
||||
hasContent={hasContent}
|
||||
header={
|
||||
<>
|
||||
<BookOpenIcon className="h-4 w-4 shrink-0 text-content-secondary" />
|
||||
<span className={cn("text-sm", "text-content-secondary")}>
|
||||
{isRunning ? `Reading ${label}…` : `Read ${label}`}
|
||||
</span>
|
||||
{isError && (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<TriangleAlertIcon className="h-3.5 w-3.5 shrink-0 text-content-secondary" />
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
{errorMessage || "Failed to read skill"}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
)}
|
||||
{isRunning && (
|
||||
<LoaderIcon className="h-3.5 w-3.5 shrink-0 animate-spin motion-reduce:animate-none text-content-secondary" />
|
||||
)}
|
||||
</>
|
||||
}
|
||||
>
|
||||
{body && (
|
||||
<ScrollArea
|
||||
className="mt-1.5 rounded-md border border-solid border-border-default"
|
||||
viewportClassName="max-h-64"
|
||||
scrollBarClassName="w-1.5"
|
||||
>
|
||||
<div className="px-3 py-2">
|
||||
<Response>{body}</Response>
|
||||
</div>
|
||||
</ScrollArea>
|
||||
)}
|
||||
</ToolCollapsible>
|
||||
);
|
||||
};
|
||||
@@ -1342,6 +1342,12 @@ export const ReadSkillCompleted: Story = {
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
expect(canvas.getByText(/Read skill deep-review/)).toBeInTheDocument();
|
||||
// Expand the collapsible to verify markdown body renders.
|
||||
const toggle = canvas.getByRole("button");
|
||||
await userEvent.click(toggle);
|
||||
await waitFor(() => {
|
||||
expect(canvas.getByText("Deep Review Skill")).toBeInTheDocument();
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
@@ -1393,6 +1399,12 @@ export const ReadSkillFileCompleted: Story = {
|
||||
expect(
|
||||
canvas.getByText(/Read deep-review\/roles\/security-reviewer\.md/),
|
||||
).toBeInTheDocument();
|
||||
// Expand the collapsible to verify markdown content renders.
|
||||
const toggle = canvas.getByRole("button");
|
||||
await userEvent.click(toggle);
|
||||
await waitFor(() => {
|
||||
expect(canvas.getByText("Security Reviewer Role")).toBeInTheDocument();
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ import { ListTemplatesTool } from "./ListTemplatesTool";
|
||||
import { ProcessOutputTool } from "./ProcessOutputTool";
|
||||
import { ProposePlanTool } from "./ProposePlanTool";
|
||||
import { ReadFileTool } from "./ReadFileTool";
|
||||
import { ReadSkillTool } from "./ReadSkillTool";
|
||||
import { ReadTemplateTool } from "./ReadTemplateTool";
|
||||
import { SubagentTool } from "./SubagentTool";
|
||||
import { ToolCollapsible } from "./ToolCollapsible";
|
||||
@@ -210,6 +211,55 @@ const ReadFileRenderer: FC<ToolRendererProps> = ({
|
||||
);
|
||||
};
|
||||
|
||||
const ReadSkillRenderer: FC<ToolRendererProps> = ({
|
||||
status,
|
||||
args,
|
||||
result,
|
||||
isError,
|
||||
}) => {
|
||||
const parsedArgs = parseArgs(args);
|
||||
const skillName = parsedArgs ? asString(parsedArgs.name) : "";
|
||||
const rec = asRecord(result);
|
||||
const body = rec ? asString(rec.body) : "";
|
||||
|
||||
return (
|
||||
<ReadSkillTool
|
||||
label={skillName ? `skill ${skillName}` : "skill"}
|
||||
body={body}
|
||||
status={status}
|
||||
isError={isError}
|
||||
errorMessage={rec ? asString(rec.error || rec.message) : undefined}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
const ReadSkillFileRenderer: FC<ToolRendererProps> = ({
|
||||
status,
|
||||
args,
|
||||
result,
|
||||
isError,
|
||||
}) => {
|
||||
const parsedArgs = parseArgs(args);
|
||||
const skillName = parsedArgs ? asString(parsedArgs.name) : "";
|
||||
const filePath = parsedArgs ? asString(parsedArgs.path) : "";
|
||||
const label =
|
||||
skillName && filePath
|
||||
? `${skillName}/${filePath}`
|
||||
: skillName || filePath || "skill file";
|
||||
const rec = asRecord(result);
|
||||
const content = rec ? asString(rec.content) : "";
|
||||
|
||||
return (
|
||||
<ReadSkillTool
|
||||
label={label}
|
||||
body={content}
|
||||
status={status}
|
||||
isError={isError}
|
||||
errorMessage={rec ? asString(rec.error || rec.message) : undefined}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
const WriteFileRenderer: FC<ToolRendererProps> = ({
|
||||
status,
|
||||
args,
|
||||
@@ -667,6 +717,8 @@ const toolRenderers: Record<string, FC<ToolRendererProps>> = {
|
||||
create_workspace: CreateWorkspaceRenderer,
|
||||
list_templates: ListTemplatesRenderer,
|
||||
read_template: ReadTemplateRenderer,
|
||||
read_skill: ReadSkillRenderer,
|
||||
read_skill_file: ReadSkillFileRenderer,
|
||||
spawn_agent: SubagentRenderer,
|
||||
wait_agent: SubagentRenderer,
|
||||
message_agent: SubagentRenderer,
|
||||
|
||||
+3
@@ -462,6 +462,9 @@ export const CreateAndUpdateProvider: Story = {
|
||||
await waitFor(() => {
|
||||
expect(args.onCreateProvider).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
await waitFor(() => {
|
||||
expect(body.getByRole("button", { name: "Save changes" })).toBeDisabled();
|
||||
});
|
||||
expect(args.onCreateProvider).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
provider: "openai",
|
||||
|
||||
+50
-42
@@ -12,6 +12,8 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/tailscale/wireguard-go/tun"
|
||||
@@ -463,6 +465,8 @@ type Conn struct {
|
||||
|
||||
trafficStats *connstats.Statistics
|
||||
lastNetInfo *tailcfg.NetInfo
|
||||
|
||||
awaitReachableGroup singleflight.Group
|
||||
}
|
||||
|
||||
func (c *Conn) GetNetInfo() *tailcfg.NetInfo {
|
||||
@@ -599,56 +603,60 @@ func (c *Conn) DERPMap() *tailcfg.DERPMap {
|
||||
// address is reachable. It's the callers responsibility to provide
|
||||
// a timeout, otherwise this function will block forever.
|
||||
func (c *Conn) AwaitReachable(ctx context.Context, ip netip.Addr) bool {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel() // Cancel all pending pings on exit.
|
||||
result, _, _ := c.awaitReachableGroup.Do(ip.String(), func() (interface{}, error) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel() // Cancel all pending pings on exit.
|
||||
|
||||
completedCtx, completed := context.WithCancel(context.Background())
|
||||
defer completed()
|
||||
completedCtx, completed := context.WithCancel(context.Background())
|
||||
defer completed()
|
||||
|
||||
run := func() {
|
||||
// Safety timeout, initially we'll have around 10-20 goroutines
|
||||
// running in parallel. The exponential backoff will converge
|
||||
// around ~1 ping / 30s, this means we'll have around 10-20
|
||||
// goroutines pending towards the end as well.
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
defer cancel()
|
||||
run := func() {
|
||||
// Safety timeout, initially we'll have around 10-20 goroutines
|
||||
// running in parallel. The exponential backoff will converge
|
||||
// around ~1 ping / 30s, this means we'll have around 10-20
|
||||
// goroutines pending towards the end as well.
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// For reachability, we use TSMP ping, which pings at the IP layer, and
|
||||
// therefore requires that wireguard and the netstack are up. If we
|
||||
// don't wait for wireguard to be up, we could miss a handshake, and it
|
||||
// might take 5 seconds for the handshake to be retried. A 5s initial
|
||||
// round trip can set us up for poor TCP performance, since the initial
|
||||
// round-trip-time sets the initial retransmit timeout.
|
||||
_, _, _, err := c.pingWithType(ctx, ip, tailcfg.PingTSMP)
|
||||
if err == nil {
|
||||
completed()
|
||||
// For reachability, we use TSMP ping, which pings at the IP layer,
|
||||
// and therefore requires that wireguard and the netstack are up.
|
||||
// If we don't wait for wireguard to be up, we could miss a
|
||||
// handshake, and it might take 5 seconds for the handshake to be
|
||||
// retried. A 5s initial round trip can set us up for poor TCP
|
||||
// performance, since the initial round-trip-time sets the initial
|
||||
// retransmit timeout.
|
||||
_, _, _, err := c.pingWithType(ctx, ip, tailcfg.PingTSMP)
|
||||
if err == nil {
|
||||
completed()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eb := backoff.NewExponentialBackOff()
|
||||
eb.MaxElapsedTime = 0
|
||||
eb.InitialInterval = 50 * time.Millisecond
|
||||
eb.MaxInterval = 30 * time.Second
|
||||
// Consume the first interval since
|
||||
// we'll fire off a ping immediately.
|
||||
_ = eb.NextBackOff()
|
||||
eb := backoff.NewExponentialBackOff()
|
||||
eb.MaxElapsedTime = 0
|
||||
eb.InitialInterval = 50 * time.Millisecond
|
||||
eb.MaxInterval = 5 * time.Second
|
||||
// Consume the first interval since
|
||||
// we'll fire off a ping immediately.
|
||||
_ = eb.NextBackOff()
|
||||
|
||||
t := backoff.NewTicker(eb)
|
||||
defer t.Stop()
|
||||
t := backoff.NewTicker(eb)
|
||||
defer t.Stop()
|
||||
|
||||
go run()
|
||||
for {
|
||||
select {
|
||||
case <-completedCtx.Done():
|
||||
return true
|
||||
case <-t.C:
|
||||
// Pings can take a while, so we can run multiple
|
||||
// in parallel to return ASAP.
|
||||
go run()
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
go run()
|
||||
for {
|
||||
select {
|
||||
case <-completedCtx.Done():
|
||||
return true, nil
|
||||
case <-t.C:
|
||||
// Pings can take a while, so we can run multiple
|
||||
// in parallel to return ASAP.
|
||||
go run()
|
||||
case <-ctx.Done():
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
return result.(bool)
|
||||
}
|
||||
|
||||
// Closed is a channel that ends when the connection has
|
||||
|
||||
Reference in New Issue
Block a user