Compare commits

..

1 Commits

Author SHA1 Message Date
Kyle Carberry 603e68cc80 feat(site): render image diffs in the files changed panel
When binary image files (png, jpg, gif, svg, webp, etc.) are added,
deleted, or modified, the diff panel now renders them visually instead
of silently dropping them.

- Added ImageDiffView component with support for added (green border),
  deleted (red border), and modified (side-by-side Before/After) images
- Added parseBinaryImageDiffs() to extract binary image entries from
  raw unified diff text that @pierre/diffs skips
- Integrated into FilesChangedPanel rendering loop with a DiffEntry
  union type that branches between FileDiff and ImageDiffView
- File tree sidebar shows A/D/M badges for image files
- Checkerboard transparency background, error handling with fallback
- Images fetched from raw.githubusercontent.com using repo/branch info
- 7 unit tests for the parser covering all edge cases
- Storybook stories for dark/light with mixed text+image diffs
2026-03-06 00:10:57 +00:00
292 changed files with 3911 additions and 20829 deletions
+2 -2
View File
@@ -189,8 +189,8 @@ func (q *sqlQuerier) UpdateUser(ctx context.Context, arg UpdateUserParams) (User
### Common Debug Commands
```bash
# Run tests (starts Postgres automatically if needed)
make test
# Check database connection
make test-postgres
# Run specific database tests
go test ./coderd/database/... -run TestSpecificFunction
+1
View File
@@ -67,6 +67,7 @@ coderd/
| `make test` | Run all Go tests |
| `make test RUN=TestFunctionName` | Run specific test |
| `go test -v ./path/to/package -run TestFunctionName` | Run test with verbose output |
| `make test-postgres` | Run tests with Postgres database |
| `make test-race` | Run tests with Go race detector |
| `make test-e2e` | Run end-to-end tests |
+1
View File
@@ -109,6 +109,7 @@
- Run full test suite: `make test`
- Run specific test: `make test RUN=TestFunctionName`
- Run with Postgres: `make test-postgres`
- Run with race detector: `make test-race`
- Run end-to-end tests: `make test-e2e`
+1
View File
@@ -1,6 +1,7 @@
name: "🐞 Bug"
description: "File a bug report."
title: "bug: "
labels: ["needs-triage"]
type: "Bug"
body:
- type: checkboxes
+5 -1
View File
@@ -70,7 +70,11 @@ runs:
set -euo pipefail
if [[ ${RACE_DETECTION} == true ]]; then
make test-race
gotestsum --junitfile="gotests.xml" --packages="${TEST_PACKAGES}" -- \
-tags=testsmallbatch \
-race \
-parallel "${TEST_NUM_PARALLEL_TESTS}" \
-p "${TEST_NUM_PARALLEL_PACKAGES}"
else
make test
fi
+11 -9
View File
@@ -366,9 +366,9 @@ jobs:
needs: changes
if: needs.changes.outputs.go == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
# This timeout must be greater than the timeout set by `go test` in
# `make test` to ensure we receive a trace of running goroutines.
# Setting this to the timeout +5m should work quite well even if
# some of the preceding steps are slow.
# `make test-postgres` to ensure we receive a trace of running
# goroutines. Setting this to the timeout +5m should work quite well
# even if some of the preceding steps are slow.
timeout-minutes: 25
strategy:
fail-fast: false
@@ -475,6 +475,11 @@ jobs:
mkdir -p /tmp/tmpfs
sudo mount_tmpfs -o noowners -s 8g /tmp/tmpfs
# Install google-chrome for scaletests.
# As another concern, should we really have this kind of external dependency
# requirement on standard CI?
brew install google-chrome
# macOS will output "The default interactive shell is now zsh" intermittently in CI.
touch ~/.bash_profile && echo "export BASH_SILENCE_DEPRECATION_WARNING=1" >> ~/.bash_profile
@@ -569,9 +574,9 @@ jobs:
- changes
if: needs.changes.outputs.go == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
# This timeout must be greater than the timeout set by `go test` in
# `make test` to ensure we receive a trace of running goroutines.
# Setting this to the timeout +5m should work quite well even if
# some of the preceding steps are slow.
# `make test-postgres` to ensure we receive a trace of running
# goroutines. Setting this to the timeout +5m should work quite well
# even if some of the preceding steps are slow.
timeout-minutes: 25
steps:
- name: Harden Runner
@@ -981,9 +986,6 @@ jobs:
run: |
make build/coder_docs_"$(./scripts/version.sh)".tgz
- name: Check for unstaged files
run: ./scripts/check_unstaged.sh
required:
runs-on: ubuntu-latest
needs:
-65
View File
@@ -1,65 +0,0 @@
name: Linear Release
on:
push:
branches:
- main
# This event reads the workflow from the default branch (main), not the
# release branch. No cherry-pick needed.
# https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#release
release:
types: [published]
permissions:
contents: read
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
sync:
name: Sync issues to Linear release
if: github.event_name == 'push'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
- name: Sync issues
id: sync
uses: linear/linear-release-action@f64cdc603e6eb7a7ef934bc5492ae929f88c8d1a # v0
with:
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
command: sync
- name: Print release URL
if: steps.sync.outputs.release-url
run: echo "Synced to $RELEASE_URL"
env:
RELEASE_URL: ${{ steps.sync.outputs.release-url }}
complete:
name: Complete Linear release
if: github.event_name == 'release'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Complete release
id: complete
uses: linear/linear-release-action@f64cdc603e6eb7a7ef934bc5492ae929f88c8d1a # v0
with:
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
command: complete
version: ${{ github.event.release.tag_name }}
- name: Print release URL
if: steps.complete.outputs.release-url
run: echo "Completed $RELEASE_URL"
env:
RELEASE_URL: ${{ steps.complete.outputs.release-url }}
+3 -3
View File
@@ -16,9 +16,9 @@ jobs:
# when changing runner sizes
runs-on: ${{ matrix.os == 'macos-latest' && github.repository_owner == 'coder' && 'depot-macos-latest' || matrix.os == 'windows-2022' && github.repository_owner == 'coder' && 'depot-windows-2022-16' || matrix.os }}
# This timeout must be greater than the timeout set by `go test` in
# `make test` to ensure we receive a trace of running goroutines.
# Setting this to the timeout +5m should work quite well even if
# some of the preceding steps are slow.
# `make test-postgres` to ensure we receive a trace of running
# goroutines. Setting this to the timeout +5m should work quite well
# even if some of the preceding steps are slow.
timeout-minutes: 25
strategy:
fail-fast: false
-1
View File
@@ -38,7 +38,6 @@ site/.swc
# Make target for updating generated/golden files (any dir).
.gen
/_gen/
.gen-golden
# Build
+13 -45
View File
@@ -37,20 +37,19 @@ Only pause to ask for confirmation when:
## Essential Commands
| Task | Command | Notes |
|-----------------|--------------------------|-------------------------------------|
| **Development** | `./scripts/develop.sh` | ⚠️ Don't use manual build |
| **Build** | `make build` | Fat binaries (includes server) |
| **Build Slim** | `make build-slim` | Slim binaries |
| **Test** | `make test` | Full test suite |
| **Test Single** | `make test RUN=TestName` | Faster than full suite |
| **Test Race** | `make test-race` | Run tests with Go race detector |
| **Lint** | `make lint` | Always run after changes |
| **Generate** | `make gen` | After database changes |
| **Format** | `make fmt` | Auto-format code |
| **Clean** | `make clean` | Clean build artifacts |
| **Pre-commit** | `make pre-commit` | Fast CI checks (gen/fmt/lint/build) |
| **Pre-push** | `make pre-push` | All CI checks including tests |
| Task | Command | Notes |
|-------------------|--------------------------|----------------------------------|
| **Development** | `./scripts/develop.sh` | ⚠️ Don't use manual build |
| **Build** | `make build` | Fat binaries (includes server) |
| **Build Slim** | `make build-slim` | Slim binaries |
| **Test** | `make test` | Full test suite |
| **Test Single** | `make test RUN=TestName` | Faster than full suite |
| **Test Postgres** | `make test-postgres` | Run tests with Postgres database |
| **Test Race** | `make test-race` | Run tests with Go race detector |
| **Lint** | `make lint` | Always run after changes |
| **Generate** | `make gen` | After database changes |
| **Format** | `make fmt` | Auto-format code |
| **Clean** | `make clean` | Clean build artifacts |
### Documentation Commands
@@ -104,37 +103,6 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID)
### Full workflows available in imported WORKFLOWS.md
### Git Hooks (MANDATORY - DO NOT SKIP)
**You MUST install and use the git hooks. NEVER bypass them with
`--no-verify`. Skipping hooks wastes CI cycles and is unacceptable.**
The first run will be slow as caches warm up. Consecutive runs are
**significantly faster** (often 10x) thanks to Go build cache,
generated file timestamps, and warm node_modules. This is NOT a
reason to skip them. Wait for hooks to complete before proceeding,
no matter how long they take.
```sh
git config core.hooksPath scripts/githooks
```
Two hooks run automatically:
- **pre-commit**: `make pre-commit` (gen, fmt, lint, typos, build).
Fast checks that catch most CI failures. Allow at least 5 minutes.
- **pre-push**: `make pre-push` (full CI suite including tests).
Runs before pushing to catch everything CI would. Allow at least
15 minutes (race tests are slow without cache).
`git commit` and `git push` will appear to hang while hooks run.
This is normal. Do not interrupt, retry, or reduce the timeout.
NEVER run `git config core.hooksPath` to change or disable hooks.
If a hook fails, fix the issue and retry. Do not work around the
failure by skipping the hook.
### Git Workflow
When working on existing PRs, check out the branch first:
+110 -283
View File
@@ -19,16 +19,6 @@ SHELL := bash
.SHELLFLAGS := -ceu
.ONESHELL:
# When MAKE_TIMED=1, replace SHELL with a wrapper that prints
# elapsed wall-clock time for each recipe. pre-commit and pre-push
# set this on their sub-makes so every parallel job reports its
# duration. Ad-hoc usage: make MAKE_TIMED=1 test
ifdef MAKE_TIMED
SHELL := $(CURDIR)/scripts/lib/timed-shell.sh
.SHELLFLAGS = $@ -ceu
export MAKE_TIMED
endif
# This doesn't work on directories.
# See https://stackoverflow.com/questions/25752543/make-delete-on-error-for-directory-targets
.DELETE_ON_ERROR:
@@ -43,25 +33,6 @@ endif
coderd/database/unique_constraint.go \
coderd/database/dbmetrics/querymetrics.go \
coderd/database/dbauthz/dbauthz.go \
coderd/database/dbmock/dbmock.go \
coderd/database/pubsub/psmock/psmock.go \
agent/agentcontainers/acmock/acmock.go \
coderd/httpmw/loggermw/loggermock/loggermock.go \
codersdk/workspacesdk/agentconnmock/agentconnmock.go \
tailnet/tailnettest/coordinatormock.go \
tailnet/tailnettest/coordinateemock.go \
tailnet/tailnettest/workspaceupdatesprovidermock.go \
tailnet/tailnettest/subscriptionmock.go \
enterprise/aibridged/aibridgedmock/clientmock.go \
enterprise/aibridged/aibridgedmock/poolmock.go \
tailnet/proto/tailnet.pb.go \
agent/proto/agent.pb.go \
agent/agentsocket/proto/agentsocket.pb.go \
agent/boundarylogproxy/codec/boundary.pb.go \
provisionersdk/proto/provisioner.pb.go \
provisionerd/proto/provisionerd.pb.go \
vpn/vpn.pb.go \
enterprise/aibridged/proto/aibridged.pb.go \
site/src/api/typesGenerated.ts \
site/e2e/provisionerGenerated.ts \
site/src/api/chatModelOptionsGenerated.json \
@@ -79,23 +50,6 @@ endif
codersdk/rbacresources_gen.go \
codersdk/apikey_scopes_gen.go
# atomic_write runs a command, captures stdout into a temp file, and
# atomically replaces $@. An optional second argument is a formatting
# command that receives the temp file path as its argument.
# Usage: $(call atomic_write,GENERATE_CMD[,FORMAT_CMD])
define atomic_write
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && \
$(1) > "$$tmpfile" && \
$(if $(2),$(2) "$$tmpfile" &&) \
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
endef
# Shared temp directory for atomic writes. Lives at the project root
# so all targets share the same filesystem, and is gitignored.
# Order-only prerequisite: recipes that need it depend on | _gen
_gen:
mkdir -p _gen
# Don't print the commands in the file unless you specify VERBOSE. This is
# essentially the same as putting "@" at the start of each line.
ifndef VERBOSE
@@ -113,11 +67,6 @@ VERSION := $(shell ./scripts/version.sh)
POSTGRES_VERSION ?= 17
POSTGRES_IMAGE ?= us-docker.pkg.dev/coder-v2-images-public/public/postgres:$(POSTGRES_VERSION)
# Limit parallel Make jobs in pre-commit/pre-push. Defaults to
# nproc/4 (min 2) since test and lint targets have internal
# parallelism. Override: make pre-push PARALLEL_JOBS=8
PARALLEL_JOBS ?= $(shell n=$$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 8); echo $$(( n / 4 > 2 ? n / 4 : 2 )))
# Use the highest ZSTD compression level in CI.
ifdef CI
ZSTDFLAGS := -22 --ultra
@@ -131,7 +80,7 @@ endif
# Note, all find statements should be written with `.` or `./path` as
# the search path so that these exclusions match.
FIND_EXCLUSIONS= \
-not \( \( -path '*/.git/*' -o -path './build/*' -o -path './vendor/*' -o -path './.coderv2/*' -o -path '*/node_modules/*' -o -path '*/out/*' -o -path './coderd/apidoc/*' -o -path '*/.next/*' -o -path '*/.terraform/*' -o -path './_gen/*' \) -prune \)
-not \( \( -path '*/.git/*' -o -path './build/*' -o -path './vendor/*' -o -path './.coderv2/*' -o -path '*/node_modules/*' -o -path '*/out/*' -o -path './coderd/apidoc/*' -o -path '*/.next/*' -o -path '*/.terraform/*' \) -prune \)
# Source files used for make targets, evaluated on use.
GO_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.go' -not -name '*_test.go')
# Same as GO_SRC_FILES but excluding certain files that have problematic
@@ -621,7 +570,7 @@ endif
# GitHub Actions linters are run in a separate CI job (lint-actions) that only
# triggers when workflow files change, so we skip them here when CI=true.
LINT_ACTIONS_TARGETS := $(if $(CI),,lint/actions/actionlint)
lint: lint/shellcheck lint/go lint/ts lint/examples lint/helm lint/site-icons lint/markdown lint/check-scopes lint/migrations lint/bootstrap $(LINT_ACTIONS_TARGETS)
lint: lint/shellcheck lint/go lint/ts lint/examples lint/helm lint/site-icons lint/markdown lint/check-scopes lint/migrations $(LINT_ACTIONS_TARGETS)
.PHONY: lint
lint/site-icons:
@@ -636,7 +585,7 @@ lint/ts: site/node_modules/.installed
lint/go:
./scripts/check_enterprise_imports.sh
./scripts/check_codersdk_imports.sh
linter_ver=$$(grep -oE 'GOLANGCI_LINT_VERSION=\S+' dogfood/coder/Dockerfile | cut -d '=' -f 2)
linter_ver=$(shell egrep -o 'GOLANGCI_LINT_VERSION=\S+' dogfood/coder/Dockerfile | cut -d '=' -f 2)
go run github.com/golangci/golangci-lint/cmd/golangci-lint@v$$linter_ver run
go tool github.com/coder/paralleltestctx/cmd/paralleltestctx -custom-funcs="testutil.Context" ./...
.PHONY: lint/go
@@ -651,11 +600,6 @@ lint/shellcheck: $(SHELL_SRC_FILES)
shellcheck --external-sources $(SHELL_SRC_FILES)
.PHONY: lint/shellcheck
lint/bootstrap:
bash scripts/check_bootstrap_quotes.sh
.PHONY: lint/bootstrap
lint/helm:
cd helm/
make lint
@@ -690,118 +634,6 @@ lint/migrations:
./scripts/check_pg_schema.sh "Fixtures" $(FIXTURE_FILES)
.PHONY: lint/migrations
TYPOS_VERSION := $(shell grep -oP 'crate-ci/typos@\S+\s+\#\s+v\K[0-9.]+' .github/workflows/ci.yaml)
# Map uname values to typos release asset names.
TYPOS_ARCH := $(shell uname -m)
ifeq ($(shell uname -s),Darwin)
TYPOS_OS := apple-darwin
else
TYPOS_OS := unknown-linux-musl
endif
build/typos-$(TYPOS_VERSION):
mkdir -p build/
curl -sSfL "https://github.com/crate-ci/typos/releases/download/v$(TYPOS_VERSION)/typos-v$(TYPOS_VERSION)-$(TYPOS_ARCH)-$(TYPOS_OS).tar.gz" \
| tar -xzf - -C build/ ./typos
mv build/typos "$@"
lint/typos: build/typos-$(TYPOS_VERSION)
build/typos-$(TYPOS_VERSION) --config .github/workflows/typos.toml
.PHONY: lint/typos
# pre-commit and pre-push mirror CI "required" jobs locally.
# See the "required" job's needs list in .github/workflows/ci.yaml.
#
# pre-commit runs checks that don't need external services (Docker,
# Playwright). This is the git pre-commit hook default since test
# and Docker failures in the local environment would otherwise block
# all commits.
#
# pre-push runs the full CI suite including tests. This is the git
# pre-push hook default, catching everything CI would before pushing.
#
# pre-push uses two-phase execution: gen+fmt+test-postgres-docker
# first (writes files, starts Docker), then lint+build+test in
# parallel. pre-commit uses two phases: gen+fmt first, then
# lint+build. This avoids races where gen's `go run` creates
# temporary .go files that lint's find-based checks pick up.
# Within each phase, targets run in parallel via -j. Both fail if
# any tracked files have unstaged changes afterward.
#
# Both pre-commit and pre-push:
# gen, fmt, lint, lint/typos, slim binary (local arch)
#
# pre-push only (need external services or are slow):
# site/out/index.html (pnpm build)
# test-postgres-docker + test (needs Docker)
# test-js, test-e2e (needs Playwright)
# sqlc-vet (needs Docker)
# offlinedocs/check
#
# Omitted:
# test-go-pg-17 (same tests, different PG version)
define check-unstaged
unstaged="$$(git diff --name-only)"
if [[ -n $$unstaged ]]; then
echo "ERROR: unstaged changes in tracked files:"
echo "$$unstaged"
echo
echo "Review each change (git diff), verify correctness, then stage:"
echo " git add -u && git commit"
exit 1
fi
untracked=$$(git ls-files --other --exclude-standard)
if [[ -n $$untracked ]]; then
echo "WARNING: untracked files (not in this commit, won't be in CI):"
echo "$$untracked"
echo
fi
endef
pre-commit:
start=$$(date +%s)
echo "=== Phase 1/2: gen + fmt ==="
$(MAKE) -j$(PARALLEL_JOBS) --output-sync=target MAKE_TIMED=1 gen fmt
$(check-unstaged)
echo "=== Phase 2/2: lint + build ==="
$(MAKE) -j$(PARALLEL_JOBS) --output-sync=target MAKE_TIMED=1 \
lint \
lint/typos \
build/coder-slim_$(GOOS)_$(GOARCH)$(GOOS_BIN_EXT)
$(check-unstaged)
echo "$(BOLD)$(GREEN)=== pre-commit passed in $$(( $$(date +%s) - $$start ))s ===$(RESET)"
.PHONY: pre-commit
pre-push:
start=$$(date +%s)
echo "=== Phase 1/2: gen + fmt + postgres ==="
$(MAKE) -j$(PARALLEL_JOBS) --output-sync=target MAKE_TIMED=1 gen fmt test-postgres-docker
$(check-unstaged)
echo "=== Phase 2/2: lint + build + test ==="
$(MAKE) -j$(PARALLEL_JOBS) --output-sync=target MAKE_TIMED=1 \
lint \
lint/typos \
build/coder-slim_$(GOOS)_$(GOARCH)$(GOOS_BIN_EXT) \
site/out/index.html \
test \
test-js \
test-e2e \
test-race \
sqlc-vet \
offlinedocs/check
$(check-unstaged)
echo "$(BOLD)$(GREEN)=== pre-push passed in $$(( $$(date +%s) - $$start ))s ===$(RESET)"
.PHONY: pre-push
offlinedocs/check: offlinedocs/node_modules/.installed
cd offlinedocs/
pnpm format:check
pnpm lint
pnpm export
.PHONY: offlinedocs/check
# All files generated by the database should be added here, and this can be used
# as a target for jobs that need to run after the database is generated.
DB_GEN_FILES := \
@@ -990,7 +822,7 @@ $(TAILNETTEST_MOCKS): tailnet/coordinator.go tailnet/service.go
touch "$@"
tailnet/proto/tailnet.pb.go: tailnet/proto/tailnet.proto
./scripts/atomic_protoc.sh \
protoc \
--go_out=. \
--go_opt=paths=source_relative \
--go-drpc_out=. \
@@ -998,7 +830,7 @@ tailnet/proto/tailnet.pb.go: tailnet/proto/tailnet.proto
./tailnet/proto/tailnet.proto
agent/proto/agent.pb.go: agent/proto/agent.proto
./scripts/atomic_protoc.sh \
protoc \
--go_out=. \
--go_opt=paths=source_relative \
--go-drpc_out=. \
@@ -1006,7 +838,7 @@ agent/proto/agent.pb.go: agent/proto/agent.proto
./agent/proto/agent.proto
agent/agentsocket/proto/agentsocket.pb.go: agent/agentsocket/proto/agentsocket.proto agent/proto/agent.proto
./scripts/atomic_protoc.sh \
protoc \
--go_out=. \
--go_opt=paths=source_relative \
--go-drpc_out=. \
@@ -1014,7 +846,7 @@ agent/agentsocket/proto/agentsocket.pb.go: agent/agentsocket/proto/agentsocket.p
./agent/agentsocket/proto/agentsocket.proto
provisionersdk/proto/provisioner.pb.go: provisionersdk/proto/provisioner.proto
./scripts/atomic_protoc.sh \
protoc \
--go_out=. \
--go_opt=paths=source_relative \
--go-drpc_out=. \
@@ -1022,7 +854,7 @@ provisionersdk/proto/provisioner.pb.go: provisionersdk/proto/provisioner.proto
./provisionersdk/proto/provisioner.proto
provisionerd/proto/provisionerd.pb.go: provisionerd/proto/provisionerd.proto
./scripts/atomic_protoc.sh \
protoc \
--go_out=. \
--go_opt=paths=source_relative \
--go-drpc_out=. \
@@ -1030,110 +862,132 @@ provisionerd/proto/provisionerd.pb.go: provisionerd/proto/provisionerd.proto
./provisionerd/proto/provisionerd.proto
vpn/vpn.pb.go: vpn/vpn.proto
./scripts/atomic_protoc.sh \
protoc \
--go_out=. \
--go_opt=paths=source_relative \
./vpn/vpn.proto
agent/boundarylogproxy/codec/boundary.pb.go: agent/boundarylogproxy/codec/boundary.proto agent/proto/agent.proto
./scripts/atomic_protoc.sh \
protoc \
--go_out=. \
--go_opt=paths=source_relative \
./agent/boundarylogproxy/codec/boundary.proto
enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged.proto
./scripts/atomic_protoc.sh \
protoc \
--go_out=. \
--go_opt=paths=source_relative \
--go-drpc_out=. \
--go-drpc_opt=paths=source_relative \
./enterprise/aibridged/proto/aibridged.proto
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go') | _gen
$(call atomic_write,go run -C ./scripts/apitypings main.go,./scripts/biome_format.sh)
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go')
# Generate to a temp file, format it, then atomically move to
# the target so that an interrupt never leaves a partial or
# unformatted file in the working tree.
tmpfile=$$(mktemp -d)/$(notdir $@) && \
go run -C ./scripts/apitypings main.go > "$$tmpfile" && \
./scripts/biome_format.sh "$$tmpfile" && \
mv "$$tmpfile" "$@"
site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/proto/provisionerd.pb.go provisionersdk/proto/provisioner.pb.go
(cd site/ && pnpm run gen:provisioner)
touch "$@"
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*) | _gen
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && \
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*)
tmpfile=$$(mktemp -d)/$(notdir $@) && \
go run ./scripts/gensite/ -icons "$$tmpfile" && \
./scripts/biome_format.sh "$$tmpfile" && \
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
mv "$$tmpfile" "$@"
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates) | _gen
$(call atomic_write,go run ./scripts/examplegen/main.go)
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates)
go run ./scripts/examplegen/main.go > "$@.tmp" && mv "$@.tmp" "$@"
coderd/rbac/object_gen.go: scripts/typegen/rbacobject.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go | _gen
$(call atomic_write,go run ./scripts/typegen/main.go rbac object)
coderd/rbac/object_gen.go: scripts/typegen/rbacobject.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go
tempdir=$(shell mktemp -d /tmp/typegen_rbac_object.XXXXXX)
go run ./scripts/typegen/main.go rbac object > "$$tempdir/object_gen.go"
mv -v "$$tempdir/object_gen.go" coderd/rbac/object_gen.go
rmdir -v "$$tempdir"
touch "$@"
# NOTE: depends on object_gen.go because `go run` compiles
# coderd/rbac which includes it.
coderd/rbac/scopes_constants_gen.go: scripts/typegen/scopenames.gotmpl scripts/typegen/main.go coderd/rbac/policy/policy.go \
coderd/rbac/object_gen.go | _gen
# Write to a temp file first to avoid truncating the package
# during build since the generator imports the rbac package.
$(call atomic_write,go run ./scripts/typegen/main.go rbac scopenames)
coderd/rbac/object_gen.go
# Generate typed low-level ScopeName constants from RBACPermissions
# Write to a temp file first to avoid truncating the package during build
# since the generator imports the rbac package.
tempfile=$(shell mktemp /tmp/scopes_constants_gen.XXXXXX)
go run ./scripts/typegen/main.go rbac scopenames > "$$tempfile"
mv -v "$$tempfile" coderd/rbac/scopes_constants_gen.go
touch "$@"
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
# `go run` compiles coderd/rbac which includes both.
codersdk/rbacresources_gen.go: scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go \
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
# Write to a temp file to avoid truncating the target, which
# would break the codersdk package and any parallel build targets.
$(call atomic_write,go run scripts/typegen/main.go rbac codersdk)
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go
# Do no overwrite codersdk/rbacresources_gen.go directly, as it would make the file empty, breaking
# the `codersdk` package and any parallel build targets.
go run scripts/typegen/main.go rbac codersdk > /tmp/rbacresources_gen.go
mv /tmp/rbacresources_gen.go codersdk/rbacresources_gen.go
touch "$@"
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
# `go run` compiles coderd/rbac which includes both.
codersdk/apikey_scopes_gen.go: scripts/apikeyscopesgen/main.go coderd/rbac/scopes_catalog.go coderd/rbac/scopes.go \
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go
# Generate SDK constants for external API key scopes.
$(call atomic_write,go run ./scripts/apikeyscopesgen)
go run ./scripts/apikeyscopesgen > /tmp/apikey_scopes_gen.go
mv /tmp/apikey_scopes_gen.go codersdk/apikey_scopes_gen.go
touch "$@"
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
# `go run` compiles coderd/rbac which includes both.
site/src/api/rbacresourcesGenerated.ts: site/node_modules/.installed scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go \
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
$(call atomic_write,go run scripts/typegen/main.go rbac typescript,./scripts/biome_format.sh)
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go
tmpfile=$$(mktemp -d)/$(notdir $@) && \
go run scripts/typegen/main.go rbac typescript > "$$tmpfile" && \
./scripts/biome_format.sh "$$tmpfile" && \
mv "$$tmpfile" "$@"
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go | _gen
$(call atomic_write,go run scripts/typegen/main.go countries,./scripts/biome_format.sh)
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go
tmpfile=$$(mktemp -d)/$(notdir $@) && \
go run scripts/typegen/main.go countries > "$$tmpfile" && \
./scripts/biome_format.sh "$$tmpfile" && \
mv "$$tmpfile" "$@"
site/src/api/chatModelOptionsGenerated.json: scripts/modeloptionsgen/main.go codersdk/chats.go | _gen
$(call atomic_write,go run ./scripts/modeloptionsgen/main.go | tail -n +2,./scripts/biome_format.sh)
site/src/api/chatModelOptionsGenerated.json: scripts/modeloptionsgen/main.go codersdk/chats.go
tmpfile=$$(mktemp -d)/$(notdir $@) && \
go run ./scripts/modeloptionsgen/main.go | tail -n +2 > "$$tmpfile" && \
./scripts/biome_format.sh "$$tmpfile" && \
mv "$$tmpfile" "$@"
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES) | _gen
$(call atomic_write,go run ./scripts/metricsdocgen/scanner)
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES)
go run ./scripts/metricsdocgen/scanner > $@.tmp && mv $@.tmp $@
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics scripts/metricsdocgen/generated_metrics | _gen
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && cp "$@" "$$tmpfile" && \
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics scripts/metricsdocgen/generated_metrics
tmpfile=$$(mktemp -d)/$(notdir $@) && cp "$@" "$$tmpfile" && \
go run scripts/metricsdocgen/main.go --prometheus-doc-file="$$tmpfile" && \
pnpm exec markdownlint-cli2 --fix "$$tmpfile" && \
pnpm exec markdown-table-formatter "$$tmpfile" && \
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
mv "$$tmpfile" "$@"
docs/reference/cli/index.md: node_modules/.installed scripts/clidocgen/main.go examples/examples.gen.json $(GO_SRC_FILES) | _gen
tmpdir=$$(mktemp -d -p _gen) && \
tmpdir=$$(realpath "$$tmpdir") && \
docs/reference/cli/index.md: node_modules/.installed scripts/clidocgen/main.go examples/examples.gen.json $(GO_SRC_FILES)
tmpdir=$$(mktemp -d) && \
mkdir -p "$$tmpdir/docs/reference/cli" && \
cp docs/manifest.json "$$tmpdir/docs/manifest.json" && \
CI=true DOCS_DIR="$$tmpdir/docs" go run ./scripts/clidocgen && \
pnpm exec markdownlint-cli2 --fix "$$tmpdir/docs/reference/cli/*.md" && \
pnpm exec markdown-table-formatter "$$tmpdir/docs/reference/cli/*.md" && \
for f in "$$tmpdir/docs/reference/cli/"*.md; do mv "$$f" "docs/reference/cli/$$(basename "$$f")"; done && \
cp "$$tmpdir/docs/reference/cli/"*.md docs/reference/cli/ && \
rm -rf "$$tmpdir"
docs/admin/security/audit-logs.md: node_modules/.installed coderd/database/querier.go scripts/auditdocgen/main.go enterprise/audit/table.go coderd/rbac/object_gen.go | _gen
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && cp "$@" "$$tmpfile" && \
docs/admin/security/audit-logs.md: node_modules/.installed coderd/database/querier.go scripts/auditdocgen/main.go enterprise/audit/table.go coderd/rbac/object_gen.go
tmpfile=$$(mktemp -d)/$(notdir $@) && cp "$@" "$$tmpfile" && \
go run scripts/auditdocgen/main.go --audit-doc-file="$$tmpfile" && \
pnpm exec markdownlint-cli2 --fix "$$tmpfile" && \
pnpm exec markdown-table-formatter "$$tmpfile" && \
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
mv "$$tmpfile" "$@"
coderd/apidoc/.gen: \
node_modules/.installed \
@@ -1148,27 +1002,25 @@ coderd/apidoc/.gen: \
scripts/apidocgen/generate.sh \
scripts/apidocgen/swaginit/main.go \
$(wildcard scripts/apidocgen/postprocess/*) \
$(wildcard scripts/apidocgen/markdown-template/*) | _gen
tmpdir=$$(mktemp -d -p _gen) && swagtmp=$$(mktemp -d -p _gen) && \
tmpdir=$$(realpath "$$tmpdir") && swagtmp=$$(realpath "$$swagtmp") && \
$(wildcard scripts/apidocgen/markdown-template/*)
tmpdir=$$(mktemp -d) && swagtmp=$$(mktemp -d) && \
mkdir -p "$$tmpdir/reference/api" && \
cp docs/manifest.json "$$tmpdir/manifest.json" && \
SWAG_OUTPUT_DIR="$$swagtmp" APIDOCGEN_DOCS_DIR="$$tmpdir" ./scripts/apidocgen/generate.sh && \
pnpm exec markdownlint-cli2 --fix "$$tmpdir/reference/api/*.md" && \
pnpm exec markdown-table-formatter "$$tmpdir/reference/api/*.md" && \
./scripts/biome_format.sh "$$swagtmp/swagger.json" && \
for f in "$$tmpdir/reference/api/"*.md; do mv "$$f" "docs/reference/api/$$(basename "$$f")"; done && \
mv "$$tmpdir/manifest.json" _gen/manifest-staging.json && \
mv "$$swagtmp/docs.go" coderd/apidoc/docs.go && \
mv "$$swagtmp/swagger.json" coderd/apidoc/swagger.json && \
cp "$$tmpdir/reference/api/"*.md docs/reference/api/ && \
cp "$$tmpdir/manifest.json" docs/manifest.json && \
cp "$$swagtmp/docs.go" coderd/apidoc/docs.go && \
cp "$$swagtmp/swagger.json" coderd/apidoc/swagger.json && \
rm -rf "$$tmpdir" "$$swagtmp"
touch "$@"
docs/manifest.json: site/node_modules/.installed coderd/apidoc/.gen docs/reference/cli/index.md | _gen
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && \
cp _gen/manifest-staging.json "$$tmpfile" && \
docs/manifest.json: site/node_modules/.installed coderd/apidoc/.gen docs/reference/cli/index.md
tmpfile=$$(mktemp -d)/$(notdir $@) && cp "$@" "$$tmpfile" && \
./scripts/biome_format.sh "$$tmpfile" && \
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
mv "$$tmpfile" "$@"
coderd/apidoc/swagger.json: site/node_modules/.installed coderd/apidoc/.gen
touch "$@"
@@ -1255,22 +1107,10 @@ else
GOTESTSUM_RETRY_FLAGS :=
endif
# Default to 8x8 parallelism to avoid overwhelming our workspaces.
# Race detection defaults to 4x4 because the detector adds significant
# CPU overhead. Override via TEST_NUM_PARALLEL_PACKAGES /
# TEST_NUM_PARALLEL_TESTS.
TEST_PARALLEL_PACKAGES := $(or $(TEST_NUM_PARALLEL_PACKAGES),8)
TEST_PARALLEL_TESTS := $(or $(TEST_NUM_PARALLEL_TESTS),8)
RACE_PARALLEL_PACKAGES := $(or $(TEST_NUM_PARALLEL_PACKAGES),4)
RACE_PARALLEL_TESTS := $(or $(TEST_NUM_PARALLEL_TESTS),4)
# Use testsmallbatch tag to reduce wireguard memory allocation in tests
# (from ~18GB to negligible). Recursively expanded so target-specific
# overrides of TEST_PARALLEL_* take effect (e.g. test-race lowers
# parallelism). CI job timeout is 25m (see test-go-pg in ci.yaml),
# keep the Go timeout 5m shorter so tests produce goroutine dumps
# instead of the CI runner killing the process with no output.
GOTEST_FLAGS = -tags=testsmallbatch -v -timeout 20m -p $(TEST_PARALLEL_PACKAGES) -parallel=$(TEST_PARALLEL_TESTS)
# default to 8x8 parallelism to avoid overwhelming our workspaces. Hopefully we can remove these defaults
# when we get our test suite's resource utilization under control.
# Use testsmallbatch tag to reduce wireguard memory allocation in tests (from ~18GB to negligible).
GOTEST_FLAGS := -tags=testsmallbatch -v -p $(or $(TEST_NUM_PARALLEL_PACKAGES),"8") -parallel=$(or $(TEST_NUM_PARALLEL_TESTS),"8")
# The most common use is to set TEST_COUNT=1 to avoid Go's test cache.
ifdef TEST_COUNT
@@ -1296,34 +1136,13 @@ endif
TEST_PACKAGES ?= ./...
test:
$(GIT_FLAGS) gotestsum --format standard-quiet \
$(GOTESTSUM_RETRY_FLAGS) \
--packages="$(TEST_PACKAGES)" \
-- \
$(GOTEST_FLAGS)
$(GIT_FLAGS) gotestsum --format standard-quiet $(GOTESTSUM_RETRY_FLAGS) --packages="$(TEST_PACKAGES)" -- $(GOTEST_FLAGS)
.PHONY: test
test-race: TEST_PARALLEL_PACKAGES := $(RACE_PARALLEL_PACKAGES)
test-race: TEST_PARALLEL_TESTS := $(RACE_PARALLEL_TESTS)
test-race:
$(GIT_FLAGS) gotestsum --format standard-quiet \
--junitfile="gotests.xml" \
$(GOTESTSUM_RETRY_FLAGS) \
--packages="$(TEST_PACKAGES)" \
-- \
-race \
$(GOTEST_FLAGS)
.PHONY: test-race
test-cli:
$(MAKE) test TEST_PACKAGES="./cli..."
.PHONY: test-cli
test-js: site/node_modules/.installed
cd site/
pnpm test:ci
.PHONY: test-js
# sqlc-cloud-is-setup will fail if no SQLc auth token is set. Use this as a
# dependency for any sqlc-cloud related targets.
sqlc-cloud-is-setup:
@@ -1335,22 +1154,37 @@ sqlc-cloud-is-setup:
sqlc-push: sqlc-cloud-is-setup test-postgres-docker
echo "--- sqlc push"
SQLC_DATABASE_URL="postgresql://postgres:postgres@localhost:5432/$$(go run scripts/migrate-ci/main.go)" \
SQLC_DATABASE_URL="postgresql://postgres:postgres@localhost:5432/$(shell go run scripts/migrate-ci/main.go)" \
sqlc push -f coderd/database/sqlc.yaml && echo "Passed sqlc push"
.PHONY: sqlc-push
sqlc-verify: sqlc-cloud-is-setup test-postgres-docker
echo "--- sqlc verify"
SQLC_DATABASE_URL="postgresql://postgres:postgres@localhost:5432/$$(go run scripts/migrate-ci/main.go)" \
SQLC_DATABASE_URL="postgresql://postgres:postgres@localhost:5432/$(shell go run scripts/migrate-ci/main.go)" \
sqlc verify -f coderd/database/sqlc.yaml && echo "Passed sqlc verify"
.PHONY: sqlc-verify
sqlc-vet: test-postgres-docker
echo "--- sqlc vet"
SQLC_DATABASE_URL="postgresql://postgres:postgres@localhost:5432/$$(go run scripts/migrate-ci/main.go)" \
SQLC_DATABASE_URL="postgresql://postgres:postgres@localhost:5432/$(shell go run scripts/migrate-ci/main.go)" \
sqlc vet -f coderd/database/sqlc.yaml && echo "Passed sqlc vet"
.PHONY: sqlc-vet
# When updating -timeout for this test, keep in sync with
# test-go-postgres (.github/workflows/coder.yaml).
# Do add coverage flags so that test caching works.
test-postgres: test-postgres-docker
# The postgres test is prone to failure, so we limit parallelism for
# more consistent execution.
$(GIT_FLAGS) gotestsum \
--junitfile="gotests.xml" \
--jsonfile="gotests.json" \
$(GOTESTSUM_RETRY_FLAGS) \
--packages="./..." -- \
-tags=testsmallbatch \
-timeout=20m \
-count=1
.PHONY: test-postgres
test-migrations: test-postgres-docker
echo "--- test migrations"
@@ -1366,24 +1200,13 @@ test-migrations: test-postgres-docker
# NOTE: we set --memory to the same size as a GitHub runner.
test-postgres-docker:
# If our container is already running, nothing to do.
if docker ps --filter "name=test-postgres-docker-${POSTGRES_VERSION}" --format '{{.Names}}' | grep -q .; then \
echo "test-postgres-docker-${POSTGRES_VERSION} is already running."; \
exit 0; \
fi
# If something else is on 5432, warn but don't fail.
if pg_isready -h 127.0.0.1 -q 2>/dev/null; then \
echo "WARNING: PostgreSQL is already running on 127.0.0.1:5432 (not our container)."; \
echo "Tests will use this instance. To use the Makefile's container, stop it first."; \
exit 0; \
fi
docker rm -f test-postgres-docker-${POSTGRES_VERSION} || true
# Try pulling up to three times to avoid CI flakes.
docker pull ${POSTGRES_IMAGE} || {
retries=2
for try in $$(seq 1 $${retries}); do
echo "Failed to pull image, retrying ($${try}/$${retries})..."
for try in $(seq 1 ${retries}); do
echo "Failed to pull image, retrying (${try}/${retries})..."
sleep 1
if docker pull ${POSTGRES_IMAGE}; then
break
@@ -1424,11 +1247,16 @@ test-postgres-docker:
-c log_statement=all
while ! pg_isready -h 127.0.0.1
do
echo "$$(date) - waiting for database to start"
echo "$(date) - waiting for database to start"
sleep 0.5
done
.PHONY: test-postgres-docker
# Make sure to keep this in sync with test-go-race from .github/workflows/ci.yaml.
test-race:
$(GIT_FLAGS) gotestsum --junitfile="gotests.xml" -- -tags=testsmallbatch -race -count=1 -parallel 4 -p 4 ./...
.PHONY: test-race
test-tailnet-integration:
env \
CODER_TAILNET_TESTS=true \
@@ -1457,7 +1285,6 @@ site/e2e/bin/coder: go.mod go.sum $(GO_SRC_FILES)
test-e2e: site/e2e/bin/coder site/node_modules/.installed site/out/index.html
cd site/
pnpm playwright:install
ifdef CI
DEBUG=pw:api pnpm playwright:test --forbid-only --workers 1
else
+2 -10
View File
@@ -41,7 +41,6 @@ import (
"github.com/coder/coder/v2/agent/agentcontainers"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentfiles"
"github.com/coder/coder/v2/agent/agentgit"
"github.com/coder/coder/v2/agent/agentproc"
"github.com/coder/coder/v2/agent/agentscripts"
"github.com/coder/coder/v2/agent/agentsocket"
@@ -103,7 +102,6 @@ type Options struct {
Execer agentexec.Execer
Devcontainers bool
DevcontainerAPIOptions []agentcontainers.Option // Enable Devcontainers for these to be effective.
GitAPIOptions []agentgit.Option
Clock quartz.Clock
SocketServerEnabled bool
SocketPath string // Path for the agent socket server socket
@@ -219,7 +217,6 @@ func New(options Options) Agent {
devcontainers: options.Devcontainers,
containerAPIOptions: options.DevcontainerAPIOptions,
gitAPIOptions: options.GitAPIOptions,
socketPath: options.SocketPath,
socketServerEnabled: options.SocketServerEnabled,
boundaryLogProxySocketPath: options.BoundaryLogProxySocketPath,
@@ -305,10 +302,8 @@ type agent struct {
devcontainers bool
containerAPIOptions []agentcontainers.Option
containerAPI *agentcontainers.API
gitAPIOptions []agentgit.Option
filesAPI *agentfiles.API
gitAPI *agentgit.API
processAPI *agentproc.API
socketServerEnabled bool
@@ -381,11 +376,8 @@ func (a *agent) init() {
a.containerAPI = agentcontainers.NewAPI(a.logger.Named("containers"), containerAPIOpts...)
pathStore := agentgit.NewPathStore()
a.filesAPI = agentfiles.NewAPI(a.logger.Named("files"), a.filesystem, pathStore)
a.processAPI = agentproc.NewAPI(a.logger.Named("processes"), a.execer, a.updateCommandEnv, pathStore)
gitOpts := append([]agentgit.Option{agentgit.WithClock(a.clock)}, a.gitAPIOptions...)
a.gitAPI = agentgit.NewAPI(a.logger.Named("git"), pathStore, gitOpts...)
a.filesAPI = agentfiles.NewAPI(a.logger.Named("files"), a.filesystem)
a.processAPI = agentproc.NewAPI(a.logger.Named("processes"), a.execer, a.updateCommandEnv)
a.reconnectingPTYServer = reconnectingpty.NewServer(
a.logger.Named("reconnecting-pty"),
+1 -4
View File
@@ -7,21 +7,18 @@ import (
"github.com/spf13/afero"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentgit"
)
// API exposes file-related operations performed through the agent.
type API struct {
logger slog.Logger
filesystem afero.Fs
pathStore *agentgit.PathStore
}
func NewAPI(logger slog.Logger, filesystem afero.Fs, pathStore *agentgit.PathStore) *API {
func NewAPI(logger slog.Logger, filesystem afero.Fs) *API {
api := &API{
logger: logger,
filesystem: filesystem,
pathStore: pathStore,
}
return api
}
-20
View File
@@ -13,12 +13,10 @@ import (
"strings"
"syscall"
"github.com/google/uuid"
"github.com/spf13/afero"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentgit"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
@@ -303,13 +301,6 @@ func (api *API) HandleWriteFile(rw http.ResponseWriter, r *http.Request) {
return
}
// Track edited path for git watch.
if api.pathStore != nil {
if chatID, ancestorIDs, ok := agentgit.ExtractChatContext(r); ok {
api.pathStore.AddPaths(append([]uuid.UUID{chatID}, ancestorIDs...), []string{path})
}
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{
Message: fmt.Sprintf("Successfully wrote to %q", path),
})
@@ -389,17 +380,6 @@ func (api *API) HandleEditFiles(rw http.ResponseWriter, r *http.Request) {
return
}
// Track edited paths for git watch.
if api.pathStore != nil {
if chatID, ancestorIDs, ok := agentgit.ExtractChatContext(r); ok {
filePaths := make([]string, 0, len(req.Files))
for _, f := range req.Files {
filePaths = append(filePaths, f.Path)
}
api.pathStore.AddPaths(append([]uuid.UUID{chatID}, ancestorIDs...), filePaths)
}
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{
Message: "Successfully edited file(s)",
})
+4 -171
View File
@@ -11,12 +11,9 @@ import (
"os"
"path/filepath"
"runtime"
"strings"
"syscall"
"testing"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/spf13/afero"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
@@ -24,7 +21,6 @@ import (
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentfiles"
"github.com/coder/coder/v2/agent/agentgit"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/testutil"
@@ -120,7 +116,7 @@ func TestReadFile(t *testing.T) {
}
return nil
})
api := agentfiles.NewAPI(logger, fs, nil)
api := agentfiles.NewAPI(logger, fs)
dirPath := filepath.Join(tmpdir, "a-directory")
err := fs.MkdirAll(dirPath, 0o755)
@@ -300,7 +296,7 @@ func TestWriteFile(t *testing.T) {
}
return nil
})
api := agentfiles.NewAPI(logger, fs, nil)
api := agentfiles.NewAPI(logger, fs)
dirPath := filepath.Join(tmpdir, "directory")
err := fs.MkdirAll(dirPath, 0o755)
@@ -418,7 +414,7 @@ func TestEditFiles(t *testing.T) {
}
return nil
})
api := agentfiles.NewAPI(logger, fs, nil)
api := agentfiles.NewAPI(logger, fs)
dirPath := filepath.Join(tmpdir, "directory")
err := fs.MkdirAll(dirPath, 0o755)
@@ -842,169 +838,6 @@ func TestEditFiles(t *testing.T) {
}
}
func TestHandleWriteFile_ChatHeaders_UpdatesPathStore(t *testing.T) {
t.Parallel()
pathStore := agentgit.NewPathStore()
logger := slogtest.Make(t, nil)
fs := afero.NewMemMapFs()
api := agentfiles.NewAPI(logger, fs, pathStore)
testPath := filepath.Join(os.TempDir(), "test.txt")
chatID := uuid.New()
ancestorID := uuid.New()
ancestorJSON, _ := json.Marshal([]string{ancestorID.String()})
body := strings.NewReader("hello world")
req := httptest.NewRequest(http.MethodPost, "/write-file?path="+testPath, body)
req.Header.Set(workspacesdk.CoderChatIDHeader, chatID.String())
req.Header.Set(workspacesdk.CoderAncestorChatIDsHeader, string(ancestorJSON))
rr := httptest.NewRecorder()
r := chi.NewRouter()
r.Post("/write-file", api.HandleWriteFile)
r.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
// Verify PathStore was updated for both chat and ancestor.
paths := pathStore.GetPaths(chatID)
require.Equal(t, []string{testPath}, paths)
ancestorPaths := pathStore.GetPaths(ancestorID)
require.Equal(t, []string{testPath}, ancestorPaths)
}
func TestHandleWriteFile_NoChatHeaders_NoPathStoreUpdate(t *testing.T) {
t.Parallel()
pathStore := agentgit.NewPathStore()
logger := slogtest.Make(t, nil)
fs := afero.NewMemMapFs()
api := agentfiles.NewAPI(logger, fs, pathStore)
testPath := filepath.Join(os.TempDir(), "test.txt")
body := strings.NewReader("hello world")
req := httptest.NewRequest(http.MethodPost, "/write-file?path="+testPath, body)
rr := httptest.NewRecorder()
r := chi.NewRouter()
r.Post("/write-file", api.HandleWriteFile)
r.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
// PathStore should be globally empty since no chat headers were set.
require.Equal(t, 0, pathStore.Len())
}
func TestHandleWriteFile_Failure_NoPathStoreUpdate(t *testing.T) {
t.Parallel()
pathStore := agentgit.NewPathStore()
logger := slogtest.Make(t, nil)
fs := afero.NewMemMapFs()
api := agentfiles.NewAPI(logger, fs, pathStore)
chatID := uuid.New()
// Write to a relative path (should fail with 400).
body := strings.NewReader("hello world")
req := httptest.NewRequest(http.MethodPost, "/write-file?path=relative/path.txt", body)
req.Header.Set(workspacesdk.CoderChatIDHeader, chatID.String())
rr := httptest.NewRecorder()
r := chi.NewRouter()
r.Post("/write-file", api.HandleWriteFile)
r.ServeHTTP(rr, req)
require.Equal(t, http.StatusBadRequest, rr.Code)
// PathStore should NOT be updated on failure.
paths := pathStore.GetPaths(chatID)
require.Empty(t, paths)
}
func TestHandleEditFiles_ChatHeaders_UpdatesPathStore(t *testing.T) {
t.Parallel()
pathStore := agentgit.NewPathStore()
logger := slogtest.Make(t, nil)
fs := afero.NewMemMapFs()
api := agentfiles.NewAPI(logger, fs, pathStore)
testPath := filepath.Join(os.TempDir(), "test.txt")
// Create the file first.
require.NoError(t, afero.WriteFile(fs, testPath, []byte("hello"), 0o644))
chatID := uuid.New()
editReq := workspacesdk.FileEditRequest{
Files: []workspacesdk.FileEdits{
{
Path: testPath,
Edits: []workspacesdk.FileEdit{
{Search: "hello", Replace: "world"},
},
},
},
}
body, _ := json.Marshal(editReq)
req := httptest.NewRequest(http.MethodPost, "/edit-files", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set(workspacesdk.CoderChatIDHeader, chatID.String())
rr := httptest.NewRecorder()
r := chi.NewRouter()
r.Post("/edit-files", api.HandleEditFiles)
r.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
paths := pathStore.GetPaths(chatID)
require.Equal(t, []string{testPath}, paths)
}
func TestHandleEditFiles_Failure_NoPathStoreUpdate(t *testing.T) {
t.Parallel()
pathStore := agentgit.NewPathStore()
logger := slogtest.Make(t, nil)
fs := afero.NewMemMapFs()
api := agentfiles.NewAPI(logger, fs, pathStore)
chatID := uuid.New()
// Edit a non-existent file (should fail with 404).
editReq := workspacesdk.FileEditRequest{
Files: []workspacesdk.FileEdits{
{
Path: "/nonexistent/file.txt",
Edits: []workspacesdk.FileEdit{
{Search: "hello", Replace: "world"},
},
},
},
}
body, _ := json.Marshal(editReq)
req := httptest.NewRequest(http.MethodPost, "/edit-files", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set(workspacesdk.CoderChatIDHeader, chatID.String())
rr := httptest.NewRecorder()
r := chi.NewRouter()
r.Post("/edit-files", api.HandleEditFiles)
r.ServeHTTP(rr, req)
require.NotEqual(t, http.StatusOK, rr.Code)
// PathStore should NOT be updated on failure.
paths := pathStore.GetPaths(chatID)
require.Empty(t, paths)
}
func TestReadFileLines(t *testing.T) {
t.Parallel()
@@ -1018,7 +851,7 @@ func TestReadFileLines(t *testing.T) {
}
return nil
})
api := agentfiles.NewAPI(logger, fs, nil)
api := agentfiles.NewAPI(logger, fs)
dirPath := filepath.Join(tmpdir, "a-directory-lines")
err := fs.MkdirAll(dirPath, 0o755)
-441
View File
@@ -1,441 +0,0 @@
// Package agentgit provides a WebSocket-based service for watching git
// repository changes on the agent. It is mounted at /api/v0/git/watch
// and allows clients to subscribe to file paths, triggering scans of
// the corresponding git repositories.
package agentgit
import (
"bytes"
"context"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"time"
"github.com/dustin/go-humanize"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/quartz"
)
// Option configures the git watch service.
type Option func(*Handler)
// WithClock sets a controllable clock for testing. Defaults to
// quartz.NewReal().
func WithClock(c quartz.Clock) Option {
return func(h *Handler) {
h.clock = c
}
}
// WithGitBinary overrides the git binary path (for testing).
func WithGitBinary(path string) Option {
return func(h *Handler) {
h.gitBin = path
}
}
const (
// scanCooldown is the minimum interval between successive scans.
scanCooldown = 1 * time.Second
// fallbackPollInterval is the safety-net poll period used when no
// filesystem events arrive.
fallbackPollInterval = 30 * time.Second
// maxTotalDiffSize is the maximum size of the combined
// unified diff for an entire repository sent over the wire.
// This must stay under the WebSocket message size limit.
maxTotalDiffSize = 3 * 1024 * 1024 // 3 MiB
)
// Handler manages per-connection git watch state.
type Handler struct {
logger slog.Logger
clock quartz.Clock
gitBin string // path to git binary; empty means "git" (from PATH)
mu sync.Mutex
repoRoots map[string]struct{} // watched repo roots
lastSnapshots map[string]repoSnapshot // last emitted snapshot per repo
lastScanAt time.Time // when the last scan completed
scanTrigger chan struct{} // buffered(1), poked by triggers
}
// repoSnapshot captures the last emitted state for delta comparison.
type repoSnapshot struct {
branch string
remoteOrigin string
unifiedDiff string
}
// NewHandler creates a new git watch handler.
func NewHandler(logger slog.Logger, opts ...Option) *Handler {
h := &Handler{
logger: logger,
clock: quartz.NewReal(),
gitBin: "git",
repoRoots: make(map[string]struct{}),
lastSnapshots: make(map[string]repoSnapshot),
scanTrigger: make(chan struct{}, 1),
}
for _, opt := range opts {
opt(h)
}
// Check if git is available.
if _, err := exec.LookPath(h.gitBin); err != nil {
h.logger.Warn(context.Background(), "git binary not found, git scanning disabled")
}
return h
}
// gitAvailable returns true if the configured git binary can be found
// in PATH.
func (h *Handler) gitAvailable() bool {
_, err := exec.LookPath(h.gitBin)
return err == nil
}
// Subscribe processes a subscribe message, resolving paths to git repo
// roots and adding new repos to the watch set. Returns true if any new
// repo roots were added.
func (h *Handler) Subscribe(paths []string) bool {
if !h.gitAvailable() {
return false
}
h.mu.Lock()
defer h.mu.Unlock()
added := false
for _, p := range paths {
if !filepath.IsAbs(p) {
continue
}
p = filepath.Clean(p)
root, err := findRepoRoot(h.gitBin, p)
if err != nil {
// Not a git path — silently ignore.
continue
}
if _, ok := h.repoRoots[root]; ok {
continue
}
h.repoRoots[root] = struct{}{}
added = true
}
return added
}
// RequestScan pokes the scan trigger so the run loop performs a scan.
func (h *Handler) RequestScan() {
select {
case h.scanTrigger <- struct{}{}:
default:
// Already pending.
}
}
// Scan performs a scan of all subscribed repos and computes deltas
// against the previously emitted snapshots.
func (h *Handler) Scan(ctx context.Context) *codersdk.WorkspaceAgentGitServerMessage {
if !h.gitAvailable() {
return nil
}
h.mu.Lock()
roots := make([]string, 0, len(h.repoRoots))
for r := range h.repoRoots {
roots = append(roots, r)
}
h.mu.Unlock()
if len(roots) == 0 {
return nil
}
now := h.clock.Now().UTC()
var repos []codersdk.WorkspaceAgentRepoChanges
// Perform all I/O outside the lock to avoid blocking
// AddPaths/GetPaths/Subscribe callers during disk-heavy scans.
type scanResult struct {
root string
changes codersdk.WorkspaceAgentRepoChanges
err error
}
results := make([]scanResult, 0, len(roots))
for _, root := range roots {
changes, err := getRepoChanges(ctx, h.logger, h.gitBin, root)
results = append(results, scanResult{root: root, changes: changes, err: err})
}
// Re-acquire the lock only to commit snapshot updates.
h.mu.Lock()
defer h.mu.Unlock()
for _, res := range results {
if res.err != nil {
if isRepoDeleted(h.gitBin, res.root) {
// Repo root or .git directory was removed.
// Emit a removal entry, then evict from watch set.
removal := codersdk.WorkspaceAgentRepoChanges{
RepoRoot: res.root,
Removed: true,
}
delete(h.repoRoots, res.root)
delete(h.lastSnapshots, res.root)
repos = append(repos, removal)
} else {
// Transient error — log and skip without
// removing the repo from the watch set.
h.logger.Warn(ctx, "scan repo failed",
slog.F("root", res.root),
slog.Error(res.err),
)
}
continue
}
prev, hasPrev := h.lastSnapshots[res.root]
if hasPrev &&
prev.branch == res.changes.Branch &&
prev.remoteOrigin == res.changes.RemoteOrigin &&
prev.unifiedDiff == res.changes.UnifiedDiff {
// No change in this repo since last emit.
continue
}
// Update snapshot.
h.lastSnapshots[res.root] = repoSnapshot{
branch: res.changes.Branch,
remoteOrigin: res.changes.RemoteOrigin,
unifiedDiff: res.changes.UnifiedDiff,
}
repos = append(repos, res.changes)
}
h.lastScanAt = now
if len(repos) == 0 {
return nil
}
return &codersdk.WorkspaceAgentGitServerMessage{
Type: codersdk.WorkspaceAgentGitServerMessageTypeChanges,
ScannedAt: &now,
Repositories: repos,
}
}
// RunLoop runs the main event loop that listens for refresh requests
// and fallback poll ticks. It calls scanFn whenever a scan should
// happen (rate-limited to scanCooldown). It blocks until ctx is
// canceled.
func (h *Handler) RunLoop(ctx context.Context, scanFn func()) {
fallbackTicker := h.clock.NewTicker(fallbackPollInterval)
defer fallbackTicker.Stop()
for {
select {
case <-ctx.Done():
return
case <-h.scanTrigger:
h.rateLimitedScan(ctx, scanFn)
case <-fallbackTicker.C:
h.rateLimitedScan(ctx, scanFn)
}
}
}
func (h *Handler) rateLimitedScan(ctx context.Context, scanFn func()) {
h.mu.Lock()
elapsed := h.clock.Since(h.lastScanAt)
if elapsed < scanCooldown {
h.mu.Unlock()
// Wait for cooldown then scan.
remaining := scanCooldown - elapsed
timer := h.clock.NewTimer(remaining)
defer timer.Stop()
select {
case <-ctx.Done():
return
case <-timer.C:
}
scanFn()
return
}
h.mu.Unlock()
scanFn()
}
// isRepoDeleted returns true when the repo root directory or its .git
// entry no longer represents a valid git repository. This
// distinguishes a genuine repo deletion from a transient scan error
// (e.g. lock contention).
//
// It handles three deletion cases:
// 1. The repo root directory itself was removed.
// 2. The .git entry (directory or file) was removed.
// 3. The .git entry is a file (worktree/submodule) whose target
// gitdir was removed. In this case .git exists on disk but
// `git rev-parse --git-dir` fails because the referenced
// directory is gone.
func isRepoDeleted(gitBin string, repoRoot string) bool {
if _, err := os.Stat(repoRoot); os.IsNotExist(err) {
return true
}
gitPath := filepath.Join(repoRoot, ".git")
fi, err := os.Stat(gitPath)
if os.IsNotExist(err) {
return true
}
// If .git is a regular file (worktree or submodule), the actual
// git object store lives elsewhere. Validate that the target is
// still reachable by running git rev-parse.
if err == nil && !fi.IsDir() {
cmd := exec.CommandContext(context.Background(), gitBin, "-C", repoRoot, "rev-parse", "--git-dir")
if err := cmd.Run(); err != nil {
return true
}
}
return false
}
// findRepoRoot uses `git rev-parse --show-toplevel` to find the
// repository root for the given path.
func findRepoRoot(gitBin string, p string) (string, error) {
// If p is a file, start from its parent directory.
dir := p
if info, err := os.Stat(dir); err != nil || !info.IsDir() {
dir = filepath.Dir(dir)
}
cmd := exec.CommandContext(context.Background(), gitBin, "rev-parse", "--show-toplevel")
cmd.Dir = dir
out, err := cmd.Output()
if err != nil {
return "", xerrors.Errorf("no git repo found for %s", p)
}
root := filepath.FromSlash(strings.TrimSpace(string(out)))
// Resolve symlinks and short (8.3) names on Windows so the
// returned root matches paths produced by Go's filepath APIs.
if resolved, evalErr := filepath.EvalSymlinks(root); evalErr == nil {
root = resolved
}
return root, nil
}
// getRepoChanges reads the current state of a git repository using
// the git CLI. It returns branch, remote origin, and a unified diff.
func getRepoChanges(ctx context.Context, logger slog.Logger, gitBin string, repoRoot string) (codersdk.WorkspaceAgentRepoChanges, error) {
result := codersdk.WorkspaceAgentRepoChanges{
RepoRoot: repoRoot,
}
// Verify this is still a valid git repository before doing
// anything else. This catches deleted repos early.
verifyCmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "rev-parse", "--git-dir")
if err := verifyCmd.Run(); err != nil {
return result, xerrors.Errorf("not a git repository: %w", err)
}
// Read branch name.
branchCmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "symbolic-ref", "--short", "HEAD")
if out, err := branchCmd.Output(); err == nil {
result.Branch = strings.TrimSpace(string(out))
} else {
logger.Debug(ctx, "failed to read HEAD", slog.F("root", repoRoot), slog.Error(err))
}
// Read remote origin URL.
remoteCmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "config", "--get", "remote.origin.url")
if out, err := remoteCmd.Output(); err == nil {
result.RemoteOrigin = strings.TrimSpace(string(out))
}
// Compute unified diff.
// `git diff HEAD` shows both staged and unstaged changes vs HEAD.
// For repos with no commits yet, fall back to showing untracked
// files only.
diff, err := computeGitDiff(ctx, logger, gitBin, repoRoot)
if err != nil {
return result, xerrors.Errorf("compute diff: %w", err)
}
result.UnifiedDiff = diff
if len(result.UnifiedDiff) > maxTotalDiffSize {
result.UnifiedDiff = "Total diff too large to show. Size: " + humanize.IBytes(uint64(len(result.UnifiedDiff))) + ". Showing branch and remote only."
}
return result, nil
}
// computeGitDiff produces a unified diff string for the repository by
// combining `git diff HEAD` (staged + unstaged changes) with diffs
// for untracked files.
func computeGitDiff(ctx context.Context, logger slog.Logger, gitBin string, repoRoot string) (string, error) {
var diffParts []string
// Check if the repo has any commits.
hasCommits := true
checkCmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "rev-parse", "HEAD")
if err := checkCmd.Run(); err != nil {
hasCommits = false
}
if hasCommits {
// `git diff HEAD` captures both staged and unstaged changes
// relative to HEAD in a single unified diff.
cmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "diff", "HEAD")
out, err := cmd.Output()
if err != nil {
return "", xerrors.Errorf("git diff HEAD: %w", err)
}
if len(out) > 0 {
diffParts = append(diffParts, string(out))
}
}
// Show untracked files as diffs too.
// `git ls-files --others --exclude-standard` lists untracked,
// non-ignored files.
lsCmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "ls-files", "--others", "--exclude-standard")
lsOut, err := lsCmd.Output()
if err != nil {
logger.Debug(ctx, "failed to list untracked files", slog.F("root", repoRoot), slog.Error(err))
return strings.Join(diffParts, ""), nil
}
untrackedFiles := strings.Split(strings.TrimSpace(string(lsOut)), "\n")
for _, f := range untrackedFiles {
f = strings.TrimSpace(f)
if f == "" {
continue
}
// Use `git diff --no-index /dev/null <file>` to generate
// a unified diff for untracked files.
var stdout bytes.Buffer
untrackedCmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "diff", "--no-index", "--", "/dev/null", f)
untrackedCmd.Stdout = &stdout
// git diff --no-index exits with 1 when files differ,
// which is expected. We ignore the error and check for
// output instead.
_ = untrackedCmd.Run()
if stdout.Len() > 0 {
diffParts = append(diffParts, stdout.String())
}
}
return strings.Join(diffParts, ""), nil
}
File diff suppressed because it is too large Load Diff
-147
View File
@@ -1,147 +0,0 @@
package agentgit
import (
"context"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/wsjson"
"github.com/coder/websocket"
)
// API exposes the git watch HTTP routes for the agent.
type API struct {
logger slog.Logger
opts []Option
pathStore *PathStore
}
// NewAPI creates a new git watch API.
func NewAPI(logger slog.Logger, pathStore *PathStore, opts ...Option) *API {
return &API{
logger: logger,
pathStore: pathStore,
opts: opts,
}
}
// Routes returns the chi router for mounting at /api/v0/git.
func (a *API) Routes() http.Handler {
r := chi.NewRouter()
r.Get("/watch", a.handleWatch)
return r
}
func (a *API) handleWatch(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
CompressionMode: websocket.CompressionNoContextTakeover,
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to accept WebSocket.",
Detail: err.Error(),
})
return
}
// 4 MiB read limit — subscribe messages with many paths can exceed the
// default 32 KB limit. Matches the SDK/proxy side.
conn.SetReadLimit(1 << 22)
stream := wsjson.NewStream[
codersdk.WorkspaceAgentGitClientMessage,
codersdk.WorkspaceAgentGitServerMessage,
](conn, websocket.MessageText, websocket.MessageText, a.logger)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
go httpapi.HeartbeatClose(ctx, a.logger, cancel, conn)
handler := NewHandler(a.logger, a.opts...)
// scanAndSend performs a scan and sends results if there are
// changes.
scanAndSend := func() {
msg := handler.Scan(ctx)
if msg != nil {
if err := stream.Send(*msg); err != nil {
a.logger.Debug(ctx, "failed to send changes", slog.Error(err))
cancel()
}
}
}
// If a chat_id query parameter is provided and the PathStore is
// available, subscribe to path updates for this chat.
chatIDStr := r.URL.Query().Get("chat_id")
if chatIDStr != "" && a.pathStore != nil {
chatID, parseErr := uuid.Parse(chatIDStr)
if parseErr == nil {
// Subscribe to future path updates BEFORE reading
// existing paths. This ordering guarantees no
// notification from AddPaths is lost: any call that
// lands before Subscribe is picked up by GetPaths
// below, and any call after Subscribe delivers a
// notification on the channel.
notifyCh, unsubscribe := a.pathStore.Subscribe(chatID)
defer unsubscribe()
// Load any paths that are already tracked for this chat.
existingPaths := a.pathStore.GetPaths(chatID)
if len(existingPaths) > 0 {
handler.Subscribe(existingPaths)
handler.RequestScan()
}
go func() {
for {
select {
case <-ctx.Done():
return
case <-notifyCh:
paths := a.pathStore.GetPaths(chatID)
handler.Subscribe(paths)
handler.RequestScan()
}
}
}()
}
}
// Start the main run loop in a goroutine.
go handler.RunLoop(ctx, scanAndSend)
// Read client messages.
updates := stream.Chan()
for {
select {
case <-ctx.Done():
_ = stream.Close(websocket.StatusGoingAway)
return
case msg, ok := <-updates:
if !ok {
return
}
switch msg.Type {
case codersdk.WorkspaceAgentGitClientMessageTypeRefresh:
handler.RequestScan()
default:
if err := stream.Send(codersdk.WorkspaceAgentGitServerMessage{
Type: codersdk.WorkspaceAgentGitServerMessageTypeError,
Message: "unknown message type",
}); err != nil {
return
}
}
}
}
}
-35
View File
@@ -1,35 +0,0 @@
package agentgit
import (
"encoding/json"
"net/http"
"github.com/google/uuid"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
// ExtractChatContext reads chat identity headers from the request.
// Returns zero values if headers are absent (non-chat request).
func ExtractChatContext(r *http.Request) (chatID uuid.UUID, ancestorIDs []uuid.UUID, ok bool) {
raw := r.Header.Get(workspacesdk.CoderChatIDHeader)
if raw == "" {
return uuid.Nil, nil, false
}
chatID, err := uuid.Parse(raw)
if err != nil {
return uuid.Nil, nil, false
}
rawAncestors := r.Header.Get(workspacesdk.CoderAncestorChatIDsHeader)
if rawAncestors != "" {
var ids []string
if err := json.Unmarshal([]byte(rawAncestors), &ids); err == nil {
for _, s := range ids {
if id, err := uuid.Parse(s); err == nil {
ancestorIDs = append(ancestorIDs, id)
}
}
}
}
return chatID, ancestorIDs, true
}
-148
View File
@@ -1,148 +0,0 @@
package agentgit_test
import (
"encoding/json"
"net/http/httptest"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/agentgit"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
func TestExtractChatContext(t *testing.T) {
t.Parallel()
validID := uuid.MustParse("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee")
ancestor1 := uuid.MustParse("11111111-2222-3333-4444-555555555555")
ancestor2 := uuid.MustParse("66666666-7777-8888-9999-aaaaaaaaaaaa")
tests := []struct {
name string
chatID string // empty means header not set
setChatID bool // whether to set the chat ID header at all
ancestors string // empty means header not set
setAncestors bool // whether to set the ancestor header at all
wantChatID uuid.UUID
wantAncestorIDs []uuid.UUID
wantOK bool
}{
{
name: "NoHeadersPresent",
setChatID: false,
setAncestors: false,
wantChatID: uuid.Nil,
wantAncestorIDs: nil,
wantOK: false,
},
{
name: "ValidChatID_NoAncestors",
chatID: validID.String(),
setChatID: true,
setAncestors: false,
wantChatID: validID,
wantAncestorIDs: nil,
wantOK: true,
},
{
name: "ValidChatID_ValidAncestors",
chatID: validID.String(),
setChatID: true,
ancestors: mustMarshalJSON(t, []string{
ancestor1.String(),
ancestor2.String(),
}),
setAncestors: true,
wantChatID: validID,
wantAncestorIDs: []uuid.UUID{ancestor1, ancestor2},
wantOK: true,
},
{
name: "MalformedChatID",
chatID: "not-a-uuid",
setChatID: true,
setAncestors: false,
wantChatID: uuid.Nil,
wantAncestorIDs: nil,
wantOK: false,
},
{
name: "ValidChatID_MalformedAncestorJSON",
chatID: validID.String(),
setChatID: true,
ancestors: `{this is not json}`,
setAncestors: true,
wantChatID: validID,
wantAncestorIDs: nil,
wantOK: true,
},
{
// Only valid UUIDs in the array are returned; invalid
// entries are silently skipped.
name: "ValidChatID_PartialValidAncestorUUIDs",
chatID: validID.String(),
setChatID: true,
ancestors: mustMarshalJSON(t, []string{
ancestor1.String(),
"bad-uuid",
ancestor2.String(),
}),
setAncestors: true,
wantChatID: validID,
wantAncestorIDs: []uuid.UUID{ancestor1, ancestor2},
wantOK: true,
},
{
// Header is explicitly set to an empty string, which
// Header.Get returns as "".
name: "EmptyChatIDHeader",
chatID: "",
setChatID: true,
setAncestors: false,
wantChatID: uuid.Nil,
wantAncestorIDs: nil,
wantOK: false,
},
{
name: "ValidChatID_EmptyAncestorHeader",
chatID: validID.String(),
setChatID: true,
ancestors: "",
setAncestors: true,
wantChatID: validID,
wantAncestorIDs: nil,
wantOK: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
r := httptest.NewRequest("GET", "/", nil)
if tt.setChatID {
r.Header.Set(workspacesdk.CoderChatIDHeader, tt.chatID)
}
if tt.setAncestors {
r.Header.Set(workspacesdk.CoderAncestorChatIDsHeader, tt.ancestors)
}
chatID, ancestorIDs, ok := agentgit.ExtractChatContext(r)
require.Equal(t, tt.wantOK, ok, "ok mismatch")
require.Equal(t, tt.wantChatID, chatID, "chatID mismatch")
require.Equal(t, tt.wantAncestorIDs, ancestorIDs, "ancestorIDs mismatch")
})
}
}
// mustMarshalJSON marshals v to a JSON string, failing the test on error.
func mustMarshalJSON(t *testing.T, v any) string {
t.Helper()
b, err := json.Marshal(v)
require.NoError(t, err)
return string(b)
}
-136
View File
@@ -1,136 +0,0 @@
package agentgit
import (
"sort"
"sync"
"github.com/google/uuid"
)
// PathStore tracks which file paths each chat has touched.
// It is safe for concurrent use.
type PathStore struct {
mu sync.RWMutex
chatPaths map[uuid.UUID]map[string]struct{}
subscribers map[uuid.UUID][]chan<- struct{}
}
// NewPathStore creates a new PathStore.
func NewPathStore() *PathStore {
return &PathStore{
chatPaths: make(map[uuid.UUID]map[string]struct{}),
subscribers: make(map[uuid.UUID][]chan<- struct{}),
}
}
// AddPaths adds paths to every chat in chatIDs and notifies
// their subscribers. Zero-value UUIDs are silently skipped.
func (ps *PathStore) AddPaths(chatIDs []uuid.UUID, paths []string) {
affected := make([]uuid.UUID, 0, len(chatIDs))
for _, id := range chatIDs {
if id != uuid.Nil {
affected = append(affected, id)
}
}
if len(affected) == 0 {
return
}
ps.mu.Lock()
for _, id := range affected {
m, ok := ps.chatPaths[id]
if !ok {
m = make(map[string]struct{})
ps.chatPaths[id] = m
}
for _, p := range paths {
m[p] = struct{}{}
}
}
ps.mu.Unlock()
ps.notifySubscribers(affected)
}
// Notify sends a signal to all subscribers of the given chat IDs
// without adding any paths. Zero-value UUIDs are silently skipped.
func (ps *PathStore) Notify(chatIDs []uuid.UUID) {
affected := make([]uuid.UUID, 0, len(chatIDs))
for _, id := range chatIDs {
if id != uuid.Nil {
affected = append(affected, id)
}
}
if len(affected) == 0 {
return
}
ps.notifySubscribers(affected)
}
// notifySubscribers sends a non-blocking signal to all subscriber
// channels for the given chat IDs.
func (ps *PathStore) notifySubscribers(chatIDs []uuid.UUID) {
ps.mu.RLock()
toNotify := make([]chan<- struct{}, 0)
for _, id := range chatIDs {
toNotify = append(toNotify, ps.subscribers[id]...)
}
ps.mu.RUnlock()
for _, ch := range toNotify {
select {
case ch <- struct{}{}:
default:
}
}
}
// GetPaths returns all paths tracked for a chat, deduplicated
// and sorted lexicographically.
func (ps *PathStore) GetPaths(chatID uuid.UUID) []string {
ps.mu.RLock()
defer ps.mu.RUnlock()
m := ps.chatPaths[chatID]
if len(m) == 0 {
return nil
}
out := make([]string, 0, len(m))
for p := range m {
out = append(out, p)
}
sort.Strings(out)
return out
}
// Len returns the number of chat IDs that have tracked paths.
func (ps *PathStore) Len() int {
ps.mu.RLock()
defer ps.mu.RUnlock()
return len(ps.chatPaths)
}
// Subscribe returns a channel that receives a signal whenever
// paths change for chatID, along with an unsubscribe function
// that removes the channel.
func (ps *PathStore) Subscribe(chatID uuid.UUID) (<-chan struct{}, func()) {
ch := make(chan struct{}, 1)
ps.mu.Lock()
ps.subscribers[chatID] = append(ps.subscribers[chatID], ch)
ps.mu.Unlock()
unsub := func() {
ps.mu.Lock()
defer ps.mu.Unlock()
subs := ps.subscribers[chatID]
for i, s := range subs {
if s == ch {
ps.subscribers[chatID] = append(subs[:i], subs[i+1:]...)
break
}
}
}
return ch, unsub
}
-268
View File
@@ -1,268 +0,0 @@
package agentgit_test
import (
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/agentgit"
"github.com/coder/coder/v2/testutil"
)
func TestPathStore_AddPaths_StoresForChatAndAncestors(t *testing.T) {
t.Parallel()
ps := agentgit.NewPathStore()
chatID := uuid.New()
ancestor1 := uuid.New()
ancestor2 := uuid.New()
ps.AddPaths([]uuid.UUID{chatID, ancestor1, ancestor2}, []string{"/a", "/b"})
// All three IDs should see the paths.
require.Equal(t, []string{"/a", "/b"}, ps.GetPaths(chatID))
require.Equal(t, []string{"/a", "/b"}, ps.GetPaths(ancestor1))
require.Equal(t, []string{"/a", "/b"}, ps.GetPaths(ancestor2))
// An unrelated chat should see nothing.
require.Nil(t, ps.GetPaths(uuid.New()))
}
func TestPathStore_AddPaths_SkipsNilUUIDs(t *testing.T) {
t.Parallel()
ps := agentgit.NewPathStore()
// A nil chatID should be a no-op.
ps.AddPaths([]uuid.UUID{uuid.Nil}, []string{"/x"})
require.Nil(t, ps.GetPaths(uuid.Nil))
// A nil ancestor should be silently skipped.
chatID := uuid.New()
ps.AddPaths([]uuid.UUID{chatID, uuid.Nil}, []string{"/y"})
require.Equal(t, []string{"/y"}, ps.GetPaths(chatID))
require.Nil(t, ps.GetPaths(uuid.Nil))
}
func TestPathStore_GetPaths_DeduplicatedSorted(t *testing.T) {
t.Parallel()
ps := agentgit.NewPathStore()
chatID := uuid.New()
ps.AddPaths([]uuid.UUID{chatID}, []string{"/z", "/a", "/m", "/a", "/z"})
ps.AddPaths([]uuid.UUID{chatID}, []string{"/a", "/b"})
got := ps.GetPaths(chatID)
require.Equal(t, []string{"/a", "/b", "/m", "/z"}, got)
}
func TestPathStore_Subscribe_ReceivesNotification(t *testing.T) {
t.Parallel()
ps := agentgit.NewPathStore()
chatID := uuid.New()
ch, unsub := ps.Subscribe(chatID)
defer unsub()
ps.AddPaths([]uuid.UUID{chatID}, []string{"/file"})
ctx := testutil.Context(t, testutil.WaitShort)
select {
case <-ch:
// Success.
case <-ctx.Done():
t.Fatal("timed out waiting for notification")
}
}
func TestPathStore_Subscribe_MultipleSubscribers(t *testing.T) {
t.Parallel()
ps := agentgit.NewPathStore()
chatID := uuid.New()
ch1, unsub1 := ps.Subscribe(chatID)
defer unsub1()
ch2, unsub2 := ps.Subscribe(chatID)
defer unsub2()
ps.AddPaths([]uuid.UUID{chatID}, []string{"/file"})
ctx := testutil.Context(t, testutil.WaitShort)
for i, ch := range []<-chan struct{}{ch1, ch2} {
select {
case <-ch:
// OK
case <-ctx.Done():
t.Fatalf("subscriber %d did not receive notification", i)
}
}
}
func TestPathStore_Unsubscribe_StopsNotifications(t *testing.T) {
t.Parallel()
ps := agentgit.NewPathStore()
chatID := uuid.New()
ch, unsub := ps.Subscribe(chatID)
unsub()
ps.AddPaths([]uuid.UUID{chatID}, []string{"/file"})
// AddPaths sends synchronously via a non-blocking send to the
// buffered channel, so if a notification were going to arrive
// it would already be in the channel by now.
select {
case <-ch:
t.Fatal("received notification after unsubscribe")
default:
// Expected: no notification.
}
}
func TestPathStore_Subscribe_AncestorNotification(t *testing.T) {
t.Parallel()
ps := agentgit.NewPathStore()
chatID := uuid.New()
ancestor := uuid.New()
// Subscribe to the ancestor, then add paths via the child.
ch, unsub := ps.Subscribe(ancestor)
defer unsub()
ps.AddPaths([]uuid.UUID{chatID, ancestor}, []string{"/file"})
ctx := testutil.Context(t, testutil.WaitShort)
select {
case <-ch:
// Success.
case <-ctx.Done():
t.Fatal("ancestor subscriber did not receive notification")
}
}
func TestPathStore_Notify_NotifiesWithoutAddingPaths(t *testing.T) {
t.Parallel()
ps := agentgit.NewPathStore()
chatID := uuid.New()
ch, unsub := ps.Subscribe(chatID)
defer unsub()
ps.Notify([]uuid.UUID{chatID})
ctx := testutil.Context(t, testutil.WaitShort)
select {
case <-ch:
// Success.
case <-ctx.Done():
t.Fatal("timed out waiting for notification")
}
require.Nil(t, ps.GetPaths(chatID))
}
func TestPathStore_Notify_SkipsNilUUIDs(t *testing.T) {
t.Parallel()
ps := agentgit.NewPathStore()
chatID := uuid.New()
ch, unsub := ps.Subscribe(chatID)
defer unsub()
ps.Notify([]uuid.UUID{uuid.Nil})
// Notify sends synchronously via a non-blocking send to the
// buffered channel, so if a notification were going to arrive
// it would already be in the channel by now.
select {
case <-ch:
t.Fatal("received notification for nil UUID")
default:
// Expected: no notification.
}
require.Nil(t, ps.GetPaths(chatID))
}
func TestPathStore_Notify_AncestorNotification(t *testing.T) {
t.Parallel()
ps := agentgit.NewPathStore()
chatID := uuid.New()
ancestorID := uuid.New()
// Subscribe to the ancestor, then notify via the child.
ch, unsub := ps.Subscribe(ancestorID)
defer unsub()
ps.Notify([]uuid.UUID{chatID, ancestorID})
ctx := testutil.Context(t, testutil.WaitShort)
select {
case <-ch:
// Success.
case <-ctx.Done():
t.Fatal("ancestor subscriber did not receive notification")
}
require.Nil(t, ps.GetPaths(ancestorID))
}
func TestPathStore_ConcurrentSafety(t *testing.T) {
t.Parallel()
ps := agentgit.NewPathStore()
const goroutines = 20
const iterations = 50
chatIDs := make([]uuid.UUID, goroutines)
for i := range chatIDs {
chatIDs[i] = uuid.New()
}
var wg sync.WaitGroup
wg.Add(goroutines * 2) // writers + readers
// Writers.
for i := range goroutines {
go func(idx int) {
defer wg.Done()
for j := range iterations {
ancestors := []uuid.UUID{chatIDs[(idx+1)%goroutines]}
path := []string{
"/file-" + chatIDs[idx].String() + "-" + time.Now().Format(time.RFC3339Nano),
"/iter-" + string(rune('0'+j%10)),
}
ps.AddPaths(append([]uuid.UUID{chatIDs[idx]}, ancestors...), path)
}
}(i)
}
// Readers.
for i := range goroutines {
go func(idx int) {
defer wg.Done()
for range iterations {
_ = ps.GetPaths(chatIDs[idx])
}
}(i)
}
wg.Wait()
// Verify every chat has at least the paths it wrote.
for _, id := range chatIDs {
paths := ps.GetPaths(id)
require.NotEmpty(t, paths, "chat %s should have paths", id)
}
}
+5 -26
View File
@@ -7,11 +7,9 @@ import (
"net/http"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentgit"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
@@ -19,17 +17,15 @@ import (
// API exposes process-related operations through the agent.
type API struct {
logger slog.Logger
manager *manager
pathStore *agentgit.PathStore
logger slog.Logger
manager *manager
}
// NewAPI creates a new process API handler.
func NewAPI(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error), pathStore *agentgit.PathStore) *API {
func NewAPI(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error)) *API {
return &API{
logger: logger,
manager: newManager(logger, execer, updateEnv),
pathStore: pathStore,
logger: logger,
manager: newManager(logger, execer, updateEnv),
}
}
@@ -78,23 +74,6 @@ func (api *API) handleStartProcess(rw http.ResponseWriter, r *http.Request) {
return
}
// Notify git watchers after the process finishes so that
// file changes made by the command are visible in the scan.
// If a workdir is provided, track it as a path as well.
if api.pathStore != nil {
if chatID, ancestorIDs, ok := agentgit.ExtractChatContext(r); ok {
allIDs := append([]uuid.UUID{chatID}, ancestorIDs...)
go func() {
<-proc.done
if req.WorkDir != "" {
api.pathStore.AddPaths(allIDs, []string{req.WorkDir})
} else {
api.pathStore.Notify(allIDs)
}
}()
}
}
httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.StartProcessResponse{
ID: proc.id,
Started: true,
+1 -43
View File
@@ -12,14 +12,12 @@ import (
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentgit"
"github.com/coder/coder/v2/agent/agentproc"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
@@ -101,7 +99,7 @@ func newTestAPIWithUpdateEnv(t *testing.T, updateEnv func([]string) ([]string, e
logger := slogtest.Make(t, &slogtest.Options{
IgnoreErrors: true,
}).Leveled(slog.LevelDebug)
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, updateEnv, nil)
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, updateEnv)
t.Cleanup(func() {
_ = api.Close()
})
@@ -572,46 +570,6 @@ func TestSignalProcess(t *testing.T) {
})
}
func TestHandleStartProcess_ChatHeaders_EmptyWorkDir_StillNotifies(t *testing.T) {
t.Parallel()
pathStore := agentgit.NewPathStore()
chatID := uuid.New()
ch, unsub := pathStore.Subscribe(chatID)
defer unsub()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, func(current []string) ([]string, error) {
return current, nil
}, pathStore)
defer api.Close()
routes := api.Routes()
body, err := json.Marshal(workspacesdk.StartProcessRequest{
Command: "echo hello",
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/start", bytes.NewReader(body))
req.Header.Set(workspacesdk.CoderChatIDHeader, chatID.String())
rw := httptest.NewRecorder()
routes.ServeHTTP(rw, req)
require.Equal(t, http.StatusOK, rw.Code)
// The subscriber should be notified even though no paths
// were added.
select {
case <-ch:
case <-time.After(testutil.WaitShort):
t.Fatal("timed out waiting for path store notification")
}
// No paths should have been stored for this chat.
require.Nil(t, pathStore.GetPaths(chatID))
}
func TestProcessLifecycle(t *testing.T) {
t.Parallel()
-10
View File
@@ -110,11 +110,6 @@ type Config struct {
// X11DisplayOffset is the offset to add to the X11 display number.
// Default is 10.
X11DisplayOffset *int
// X11MaxPort overrides the highest port used for X11 forwarding
// listeners. Defaults to X11MaxPort (6200). Useful in tests
// to shrink the port range and reduce the number of sessions
// required.
X11MaxPort *int
// BlockFileTransfer restricts use of file transfer applications.
BlockFileTransfer bool
// ReportConnection.
@@ -163,10 +158,6 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
offset := X11DefaultDisplayOffset
config.X11DisplayOffset = &offset
}
if config.X11MaxPort == nil {
maxPort := X11MaxPort
config.X11MaxPort = &maxPort
}
if config.UpdateEnv == nil {
config.UpdateEnv = func(current []string) ([]string, error) { return current, nil }
}
@@ -210,7 +201,6 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
x11HandlerErrors: metrics.x11HandlerErrors,
fs: fs,
displayOffset: *config.X11DisplayOffset,
maxPort: *config.X11MaxPort,
sessions: make(map[*x11Session]struct{}),
connections: make(map[net.Conn]struct{}),
network: func() X11Network {
+1 -2
View File
@@ -57,7 +57,6 @@ type x11Forwarder struct {
x11HandlerErrors *prometheus.CounterVec
fs afero.Fs
displayOffset int
maxPort int
// network creates X11 listener sockets. Defaults to osNet{}.
network X11Network
@@ -315,7 +314,7 @@ func (x *x11Forwarder) evictLeastRecentlyUsedSession() {
// the next available port starting from X11StartPort and displayOffset.
func (x *x11Forwarder) createX11Listener(ctx context.Context) (ln net.Listener, display int, err error) {
// Look for an open port to listen on.
for port := X11StartPort + x.displayOffset; port <= x.maxPort; port++ {
for port := X11StartPort + x.displayOffset; port <= X11MaxPort; port++ {
if ctx.Err() != nil {
return nil, -1, ctx.Err()
}
+2 -7
View File
@@ -142,13 +142,8 @@ func TestServer_X11_EvictionLRU(t *testing.T) {
// Use in-process networking for X11 forwarding.
inproc := testutil.NewInProcNet()
// Limit port range so we only need a handful of sessions to fill it
// (the default 190 ports may easily timeout or conflict with other
// ports on the system).
maxPort := agentssh.X11StartPort + agentssh.X11DefaultDisplayOffset + 5
cfg := &agentssh.Config{
X11Net: inproc,
X11MaxPort: &maxPort,
X11Net: inproc,
}
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, cfg)
@@ -177,7 +172,7 @@ func TestServer_X11_EvictionLRU(t *testing.T) {
// configured port range.
startPort := agentssh.X11StartPort + agentssh.X11DefaultDisplayOffset
maxSessions := maxPort - startPort + 1 - 1 // -1 for the blocked port
maxSessions := agentssh.X11MaxPort - startPort + 1 - 1 // -1 for the blocked port
require.Greater(t, maxSessions, 0, "expected a positive maxSessions value")
// shellSession holds references to the session and its standard streams so
-1
View File
@@ -28,7 +28,6 @@ func (a *agent) apiHandler() http.Handler {
})
r.Mount("/api/v0", a.filesAPI.Routes())
r.Mount("/api/v0/git", a.gitAPI.Routes())
r.Mount("/api/v0/processes", a.processAPI.Routes())
if a.devcontainers {
-11
View File
@@ -42,20 +42,9 @@ func WithLogger(logger slog.Logger) Option {
}
}
// WithDone sets a channel that, when closed, stops the reaper
// goroutine. Callers that invoke ForkReap more than once in the
// same process (e.g. tests) should use this to prevent goroutine
// accumulation.
func WithDone(ch chan struct{}) Option {
return func(o *options) {
o.Done = ch
}
}
type options struct {
ExecArgs []string
PIDs reap.PidCh
CatchSignals []os.Signal
Logger slog.Logger
Done chan struct{}
}
-23
View File
@@ -18,15 +18,6 @@ import (
"github.com/coder/coder/v2/testutil"
)
// withDone returns an option that stops the reaper goroutine when t
// completes, preventing goroutine accumulation across subtests.
func withDone(t *testing.T) reaper.Option {
t.Helper()
done := make(chan struct{})
t.Cleanup(func() { close(done) })
return reaper.WithDone(done)
}
// TestReap checks that's the reaper is successfully reaping
// exited processes and passing the PIDs through the shared
// channel.
@@ -45,7 +36,6 @@ func TestReap(t *testing.T) {
reaper.WithPIDCallback(pids),
// Provide some argument that immediately exits.
reaper.WithExecArgs("/bin/sh", "-c", "exit 0"),
withDone(t),
)
require.NoError(t, err)
require.Equal(t, 0, exitCode)
@@ -99,7 +89,6 @@ func TestForkReapExitCodes(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
exitCode, err := reaper.ForkReap(
reaper.WithExecArgs("/bin/sh", "-c", tt.command),
withDone(t),
)
require.NoError(t, err)
require.Equal(t, tt.expectedCode, exitCode, "exit code mismatch for %q", tt.command)
@@ -129,7 +118,6 @@ func TestReapInterrupt(t *testing.T) {
exitCode, err := reaper.ForkReap(
reaper.WithPIDCallback(pids),
reaper.WithCatchSignals(os.Interrupt),
withDone(t),
// Signal propagation does not extend to children of children, so
// we create a little bash script to ensure sleep is interrupted.
reaper.WithExecArgs("/bin/sh", "-c", fmt.Sprintf("pid=0; trap 'kill -USR2 %d; kill -TERM $pid' INT; sleep 10 &\npid=$!; kill -USR1 %d; wait", os.Getpid(), os.Getpid())),
@@ -141,17 +129,6 @@ func TestReapInterrupt(t *testing.T) {
}()
require.Equal(t, <-usrSig, syscall.SIGUSR1)
// Prevent SIGINT from terminating the test process. Under the
// race detector, the catchSignals goroutine in ForkReap may not
// have called signal.Notify yet, so the default Go handler
// could kill us. Registering our own Notify disables the
// default behavior. Both this channel and the one inside
// catchSignals receive independent copies of the signal.
intC := make(chan os.Signal, 1)
signal.Notify(intC, os.Interrupt)
defer signal.Stop(intC)
err := syscall.Kill(os.Getpid(), syscall.SIGINT)
require.NoError(t, err)
require.Equal(t, <-usrSig, syscall.SIGUSR2)
+16 -21
View File
@@ -6,7 +6,6 @@ import (
"context"
"os"
"os/signal"
"sync"
"syscall"
"github.com/hashicorp/go-reap"
@@ -20,7 +19,20 @@ func IsInitProcess() bool {
return os.Getpid() == 1
}
func catchSignals(logger slog.Logger, pid int, sc <-chan os.Signal) {
func catchSignals(logger slog.Logger, pid int, sigs []os.Signal) {
if len(sigs) == 0 {
return
}
sc := make(chan os.Signal, 1)
signal.Notify(sc, sigs...)
defer signal.Stop(sc)
logger.Info(context.Background(), "reaper catching signals",
slog.F("signals", sigs),
slog.F("child_pid", pid),
)
for {
s := <-sc
sig, ok := s.(syscall.Signal)
@@ -52,17 +64,10 @@ func ForkReap(opt ...Option) (int, error) {
o(opts)
}
// Use the reapLock to prevent the reaper's Wait4(-1) from
// stealing the direct child's exit status. The reaper takes
// a write lock; we hold a read lock during our own Wait4.
var reapLock sync.RWMutex
reapLock.RLock()
go reap.ReapChildren(opts.PIDs, nil, opts.Done, &reapLock)
go reap.ReapChildren(opts.PIDs, nil, nil, nil)
pwd, err := os.Getwd()
if err != nil {
reapLock.RUnlock()
return 1, xerrors.Errorf("get wd: %w", err)
}
@@ -82,26 +87,16 @@ func ForkReap(opt ...Option) (int, error) {
//#nosec G204
pid, err := syscall.ForkExec(opts.ExecArgs[0], opts.ExecArgs, pattrs)
if err != nil {
reapLock.RUnlock()
return 1, xerrors.Errorf("fork exec: %w", err)
}
// Register the signal handler before spawning the goroutine
// so it is active by the time the child process starts. This
// avoids a race where a signal arrives before the goroutine
// has called signal.Notify.
if len(opts.CatchSignals) > 0 {
sc := make(chan os.Signal, 1)
signal.Notify(sc, opts.CatchSignals...)
go catchSignals(opts.Logger, pid, sc)
}
go catchSignals(opts.Logger, pid, opts.CatchSignals)
var wstatus syscall.WaitStatus
_, err = syscall.Wait4(pid, &wstatus, 0, nil)
for xerrors.Is(err, syscall.EINTR) {
_, err = syscall.Wait4(pid, &wstatus, 0, nil)
}
reapLock.RUnlock()
// Convert wait status to exit code using standard Unix conventions:
// - Normal exit: use the exit code
-4
View File
@@ -123,10 +123,6 @@ func Select(inv *serpent.Invocation, opts SelectOptions) (string, error) {
initialModel.height = defaultSelectModelHeight
}
if idx := slices.Index(opts.Options, opts.Default); idx >= 0 {
initialModel.cursor = idx
}
initialModel.search.Prompt = ""
initialModel.search.Focus()
+3 -3
View File
@@ -109,13 +109,13 @@ func (RootCmd) promptExample() *serpent.Command {
Options: []string{
"Blue", "Green", "Yellow", "Red", "Something else",
},
Default: "Green",
Default: "",
Message: "Select your favorite color:",
Size: 5,
HideSearch: !useSearch,
})
if value == "Something else" {
_, _ = fmt.Fprint(inv.Stdout, "I would have picked green.\n")
_, _ = fmt.Fprint(inv.Stdout, "I would have picked blue.\n")
} else {
_, _ = fmt.Fprintf(inv.Stdout, "%s is a nice color.\n", value)
}
@@ -128,7 +128,7 @@ func (RootCmd) promptExample() *serpent.Command {
Options: []string{
"Car", "Bike", "Plane", "Boat", "Train",
},
Default: "Bike",
Default: "Car",
})
if err != nil {
return err
+12 -12
View File
@@ -41,11 +41,11 @@ func Test_TaskLogs_Golden(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
inv, root := clitest.New(t, "task", "logs", setup.task.Name, "--output", "json")
inv, root := clitest.New(t, "task", "logs", task.Name, "--output", "json")
output := clitest.Capture(inv)
clitest.SetupConfig(t, setup.userClient, root)
clitest.SetupConfig(t, userClient, root)
ctx := testutil.Context(t, testutil.WaitLong)
err := inv.WithContext(ctx).Run()
@@ -64,11 +64,11 @@ func Test_TaskLogs_Golden(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
inv, root := clitest.New(t, "task", "logs", setup.task.ID.String(), "--output", "json")
inv, root := clitest.New(t, "task", "logs", task.ID.String(), "--output", "json")
output := clitest.Capture(inv)
clitest.SetupConfig(t, setup.userClient, root)
clitest.SetupConfig(t, userClient, root)
ctx := testutil.Context(t, testutil.WaitLong)
err := inv.WithContext(ctx).Run()
@@ -87,11 +87,11 @@ func Test_TaskLogs_Golden(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
inv, root := clitest.New(t, "task", "logs", setup.task.ID.String())
inv, root := clitest.New(t, "task", "logs", task.ID.String())
output := clitest.Capture(inv)
clitest.SetupConfig(t, setup.userClient, root)
clitest.SetupConfig(t, userClient, root)
ctx := testutil.Context(t, testutil.WaitLong)
err := inv.WithContext(ctx).Run()
@@ -141,10 +141,10 @@ func Test_TaskLogs_Golden(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsErr(assert.AnError))
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsErr(assert.AnError))
inv, root := clitest.New(t, "task", "logs", setup.task.ID.String())
clitest.SetupConfig(t, setup.userClient, root)
inv, root := clitest.New(t, "task", "logs", task.ID.String())
clitest.SetupConfig(t, userClient, root)
ctx := testutil.Context(t, testutil.WaitLong)
err := inv.WithContext(ctx).Run()
+26 -22
View File
@@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
@@ -20,12 +21,12 @@ func TestExpTaskPause(t *testing.T) {
// Given: A running task
setupCtx := testutil.Context(t, testutil.WaitLong)
setup := setupCLITaskTest(setupCtx, t, nil)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
// When: We attempt to pause the task
inv, root := clitest.New(t, "task", "pause", setup.task.Name, "--yes")
inv, root := clitest.New(t, "task", "pause", task.Name, "--yes")
output := clitest.Capture(inv)
clitest.SetupConfig(t, setup.userClient, root)
clitest.SetupConfig(t, userClient, root)
// Then: Expect the task to be paused
ctx := testutil.Context(t, testutil.WaitMedium)
@@ -33,7 +34,7 @@ func TestExpTaskPause(t *testing.T) {
require.NoError(t, err)
require.Contains(t, output.Stdout(), "has been paused")
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
})
@@ -45,13 +46,13 @@ func TestExpTaskPause(t *testing.T) {
// Given: A different user's running task
setupCtx := testutil.Context(t, testutil.WaitLong)
setup := setupCLITaskTest(setupCtx, t, nil)
adminClient, _, task := setupCLITaskTest(setupCtx, t, nil)
// When: We attempt to pause their task
identifier := fmt.Sprintf("%s/%s", setup.task.OwnerName, setup.task.Name)
identifier := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
inv, root := clitest.New(t, "task", "pause", identifier, "--yes")
output := clitest.Capture(inv)
clitest.SetupConfig(t, setup.ownerClient, root)
clitest.SetupConfig(t, adminClient, root)
// Then: We expect the task to be paused
ctx := testutil.Context(t, testutil.WaitMedium)
@@ -59,7 +60,7 @@ func TestExpTaskPause(t *testing.T) {
require.NoError(t, err)
require.Contains(t, output.Stdout(), "has been paused")
updated, err := setup.ownerClient.TaskByIdentifier(ctx, identifier)
updated, err := adminClient.TaskByIdentifier(ctx, identifier)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
})
@@ -69,11 +70,11 @@ func TestExpTaskPause(t *testing.T) {
// Given: A running task
setupCtx := testutil.Context(t, testutil.WaitLong)
setup := setupCLITaskTest(setupCtx, t, nil)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
// When: We attempt to pause the task
inv, root := clitest.New(t, "task", "pause", setup.task.Name)
clitest.SetupConfig(t, setup.userClient, root)
inv, root := clitest.New(t, "task", "pause", task.Name)
clitest.SetupConfig(t, userClient, root)
// And: We confirm we want to pause the task
ctx := testutil.Context(t, testutil.WaitMedium)
@@ -87,7 +88,7 @@ func TestExpTaskPause(t *testing.T) {
pty.ExpectMatchContext(ctx, "has been paused")
require.NoError(t, w.Wait())
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
})
@@ -97,11 +98,11 @@ func TestExpTaskPause(t *testing.T) {
// Given: A running task
setupCtx := testutil.Context(t, testutil.WaitLong)
setup := setupCLITaskTest(setupCtx, t, nil)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
// When: We attempt to pause the task
inv, root := clitest.New(t, "task", "pause", setup.task.Name)
clitest.SetupConfig(t, setup.userClient, root)
inv, root := clitest.New(t, "task", "pause", task.Name)
clitest.SetupConfig(t, userClient, root)
// But: We say no at the confirmation screen
ctx := testutil.Context(t, testutil.WaitMedium)
@@ -113,7 +114,7 @@ func TestExpTaskPause(t *testing.T) {
require.Error(t, w.Wait())
// Then: We expect the task to not be paused
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
require.NoError(t, err)
require.NotEqual(t, codersdk.TaskStatusPaused, updated.Status)
})
@@ -123,18 +124,21 @@ func TestExpTaskPause(t *testing.T) {
// Given: A running task
setupCtx := testutil.Context(t, testutil.WaitLong)
setup := setupCLITaskTest(setupCtx, t, nil)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
// And: We paused the running task
pauseTask(setupCtx, t, setup.userClient, setup.task)
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := userClient.PauseTask(ctx, task.OwnerName, task.ID)
require.NoError(t, err)
require.NotNil(t, resp.WorkspaceBuild)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, resp.WorkspaceBuild.ID)
// When: We attempt to pause the task again
inv, root := clitest.New(t, "task", "pause", setup.task.Name, "--yes")
clitest.SetupConfig(t, setup.userClient, root)
inv, root := clitest.New(t, "task", "pause", task.Name, "--yes")
clitest.SetupConfig(t, userClient, root)
// Then: We expect to get an error that the task is already paused
ctx := testutil.Context(t, testutil.WaitMedium)
err := inv.WithContext(ctx).Run()
err = inv.WithContext(ctx).Run()
require.ErrorContains(t, err, "is already paused")
})
}
+43 -31
View File
@@ -1,6 +1,7 @@
package cli_test
import (
"context"
"fmt"
"testing"
@@ -16,18 +17,29 @@ import (
func TestExpTaskResume(t *testing.T) {
t.Parallel()
// pauseTask is a helper that pauses a task and waits for the stop
// build to complete.
pauseTask := func(ctx context.Context, t *testing.T, client *codersdk.Client, task codersdk.Task) {
t.Helper()
pauseResp, err := client.PauseTask(ctx, task.OwnerName, task.ID)
require.NoError(t, err)
require.NotNil(t, pauseResp.WorkspaceBuild)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, pauseResp.WorkspaceBuild.ID)
}
t.Run("WithYesFlag", func(t *testing.T) {
t.Parallel()
// Given: A paused task
setupCtx := testutil.Context(t, testutil.WaitLong)
setup := setupCLITaskTest(setupCtx, t, nil)
pauseTask(setupCtx, t, setup.userClient, setup.task)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
pauseTask(setupCtx, t, userClient, task)
// When: We attempt to resume the task
inv, root := clitest.New(t, "task", "resume", setup.task.Name, "--yes")
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes")
output := clitest.Capture(inv)
clitest.SetupConfig(t, setup.userClient, root)
clitest.SetupConfig(t, userClient, root)
// Then: We expect the task to be resumed
ctx := testutil.Context(t, testutil.WaitMedium)
@@ -35,7 +47,7 @@ func TestExpTaskResume(t *testing.T) {
require.NoError(t, err)
require.Contains(t, output.Stdout(), "has been resumed")
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
})
@@ -47,14 +59,14 @@ func TestExpTaskResume(t *testing.T) {
// Given: A different user's paused task
setupCtx := testutil.Context(t, testutil.WaitLong)
setup := setupCLITaskTest(setupCtx, t, nil)
pauseTask(setupCtx, t, setup.userClient, setup.task)
adminClient, userClient, task := setupCLITaskTest(setupCtx, t, nil)
pauseTask(setupCtx, t, userClient, task)
// When: We attempt to resume their task
identifier := fmt.Sprintf("%s/%s", setup.task.OwnerName, setup.task.Name)
identifier := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
inv, root := clitest.New(t, "task", "resume", identifier, "--yes")
output := clitest.Capture(inv)
clitest.SetupConfig(t, setup.ownerClient, root)
clitest.SetupConfig(t, adminClient, root)
// Then: We expect the task to be resumed
ctx := testutil.Context(t, testutil.WaitMedium)
@@ -62,7 +74,7 @@ func TestExpTaskResume(t *testing.T) {
require.NoError(t, err)
require.Contains(t, output.Stdout(), "has been resumed")
updated, err := setup.ownerClient.TaskByIdentifier(ctx, identifier)
updated, err := adminClient.TaskByIdentifier(ctx, identifier)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
})
@@ -72,13 +84,13 @@ func TestExpTaskResume(t *testing.T) {
// Given: A paused task
setupCtx := testutil.Context(t, testutil.WaitLong)
setup := setupCLITaskTest(setupCtx, t, nil)
pauseTask(setupCtx, t, setup.userClient, setup.task)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
pauseTask(setupCtx, t, userClient, task)
// When: We attempt to resume the task (and specify no wait)
inv, root := clitest.New(t, "task", "resume", setup.task.Name, "--yes", "--no-wait")
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes", "--no-wait")
output := clitest.Capture(inv)
clitest.SetupConfig(t, setup.userClient, root)
clitest.SetupConfig(t, userClient, root)
// Then: We expect the task to be resumed in the background
ctx := testutil.Context(t, testutil.WaitMedium)
@@ -87,11 +99,11 @@ func TestExpTaskResume(t *testing.T) {
require.Contains(t, output.Stdout(), "in the background")
// And: The task to eventually be resumed
require.True(t, setup.task.WorkspaceID.Valid, "task should have a workspace ID")
ws := coderdtest.MustWorkspace(t, setup.userClient, setup.task.WorkspaceID.UUID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, setup.userClient, ws.LatestBuild.ID)
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
ws := coderdtest.MustWorkspace(t, userClient, task.WorkspaceID.UUID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, ws.LatestBuild.ID)
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
})
@@ -101,12 +113,12 @@ func TestExpTaskResume(t *testing.T) {
// Given: A paused task
setupCtx := testutil.Context(t, testutil.WaitLong)
setup := setupCLITaskTest(setupCtx, t, nil)
pauseTask(setupCtx, t, setup.userClient, setup.task)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
pauseTask(setupCtx, t, userClient, task)
// When: We attempt to resume the task
inv, root := clitest.New(t, "task", "resume", setup.task.Name)
clitest.SetupConfig(t, setup.userClient, root)
inv, root := clitest.New(t, "task", "resume", task.Name)
clitest.SetupConfig(t, userClient, root)
// And: We confirm we want to resume the task
ctx := testutil.Context(t, testutil.WaitMedium)
@@ -120,7 +132,7 @@ func TestExpTaskResume(t *testing.T) {
pty.ExpectMatchContext(ctx, "has been resumed")
require.NoError(t, w.Wait())
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
})
@@ -130,12 +142,12 @@ func TestExpTaskResume(t *testing.T) {
// Given: A paused task
setupCtx := testutil.Context(t, testutil.WaitLong)
setup := setupCLITaskTest(setupCtx, t, nil)
pauseTask(setupCtx, t, setup.userClient, setup.task)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
pauseTask(setupCtx, t, userClient, task)
// When: We attempt to resume the task
inv, root := clitest.New(t, "task", "resume", setup.task.Name)
clitest.SetupConfig(t, setup.userClient, root)
inv, root := clitest.New(t, "task", "resume", task.Name)
clitest.SetupConfig(t, userClient, root)
// But: Say no at the confirmation screen
ctx := testutil.Context(t, testutil.WaitMedium)
@@ -147,7 +159,7 @@ func TestExpTaskResume(t *testing.T) {
require.Error(t, w.Wait())
// Then: We expect the task to still be paused
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
})
@@ -157,11 +169,11 @@ func TestExpTaskResume(t *testing.T) {
// Given: A running task
setupCtx := testutil.Context(t, testutil.WaitLong)
setup := setupCLITaskTest(setupCtx, t, nil)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
// When: We attempt to resume the task that is not paused
inv, root := clitest.New(t, "task", "resume", setup.task.Name, "--yes")
clitest.SetupConfig(t, setup.userClient, root)
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes")
clitest.SetupConfig(t, userClient, root)
// Then: We expect to get an error that the task is not paused
ctx := testutil.Context(t, testutil.WaitMedium)
+9 -154
View File
@@ -1,15 +1,10 @@
package cli
import (
"context"
"fmt"
"io"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/serpent"
)
@@ -20,15 +15,13 @@ func (r *RootCmd) taskSend() *serpent.Command {
cmd := &serpent.Command{
Use: "send <task> [<input> | --stdin]",
Short: "Send input to a task",
Long: `Send input to a task. If the task is paused, it will be automatically resumed before input is sent. If the task is initializing, it will wait for the task to become ready.
` +
FormatExamples(Example{
Description: "Send direct input to a task",
Command: `coder task send task1 "Please also add unit tests"`,
}, Example{
Description: "Send input from stdin to a task",
Command: `echo "Please also add unit tests" | coder task send task1 --stdin`,
}),
Long: FormatExamples(Example{
Description: "Send direct input to a task.",
Command: "coder task send task1 \"Please also add unit tests\"",
}, Example{
Description: "Send input from stdin to a task.",
Command: "echo \"Please also add unit tests\" | coder task send task1 --stdin",
}),
Middleware: serpent.RequireRangeArgs(1, 2),
Options: serpent.OptionSet{
{
@@ -71,48 +64,8 @@ func (r *RootCmd) taskSend() *serpent.Command {
return xerrors.Errorf("resolve task: %w", err)
}
display := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
// Before attempting to send, check the task status and
// handle non-active states.
var workspaceBuildID uuid.UUID
switch task.Status {
case codersdk.TaskStatusActive:
// Already active, no build to watch.
case codersdk.TaskStatusPaused:
resp, err := client.ResumeTask(ctx, task.OwnerName, task.ID)
if err != nil {
return xerrors.Errorf("resume task %q: %w", display, err)
} else if resp.WorkspaceBuild == nil {
return xerrors.Errorf("resume task %q", display)
}
workspaceBuildID = resp.WorkspaceBuild.ID
case codersdk.TaskStatusInitializing:
if !task.WorkspaceID.Valid {
return xerrors.Errorf("send input to task %q: task has no backing workspace", display)
}
workspace, err := client.Workspace(ctx, task.WorkspaceID.UUID)
if err != nil {
return xerrors.Errorf("get workspace for task %q: %w", display, err)
}
workspaceBuildID = workspace.LatestBuild.ID
default:
return xerrors.Errorf("task %q has status %s and cannot be sent input", display, task.Status)
}
if err := waitForTaskIdle(ctx, inv, client, task, workspaceBuildID); err != nil {
return xerrors.Errorf("wait for task %q to be idle: %w", display, err)
}
if err := client.TaskSend(ctx, codersdk.Me, task.ID, codersdk.TaskSendRequest{Input: taskInput}); err != nil {
return xerrors.Errorf("send input to task %q: %w", display, err)
if err = client.TaskSend(ctx, codersdk.Me, task.ID, codersdk.TaskSendRequest{Input: taskInput}); err != nil {
return xerrors.Errorf("send input to task: %w", err)
}
return nil
@@ -121,101 +74,3 @@ func (r *RootCmd) taskSend() *serpent.Command {
return cmd
}
// waitForTaskIdle optionally watches a workspace build to completion,
// then polls until the task becomes active and its app state is idle.
// This merges build-watching and idle-polling into a single loop so
// that status changes (e.g. paused) are never missed between phases.
func waitForTaskIdle(ctx context.Context, inv *serpent.Invocation, client *codersdk.Client, task codersdk.Task, workspaceBuildID uuid.UUID) error {
if workspaceBuildID != uuid.Nil {
if err := cliui.WorkspaceBuild(ctx, inv.Stdout, client, workspaceBuildID); err != nil {
return xerrors.Errorf("watch workspace build: %w", err)
}
}
cliui.Infof(inv.Stdout, "Waiting for task to become idle...")
// NOTE(DanielleMaywood):
// It has been observed that the `TaskStatusError` state has
// appeared during a typical healthy startup [^0]. To combat
// this, we allow a 5 minute grace period where we allow
// `TaskStatusError` to surface without immediately failing.
//
// TODO(DanielleMaywood):
// Remove this grace period once the upstream agentapi health
// check no longer reports transient error states during normal
// startup.
//
// [0]: https://github.com/coder/coder/pull/22203#discussion_r2858002569
const errorGracePeriod = 5 * time.Minute
gracePeriodDeadline := time.Now().Add(errorGracePeriod)
// NOTE(DanielleMaywood):
// On resume the MCP may not report an initial app status,
// leaving CurrentState nil indefinitely. To avoid hanging
// forever we treat Active with nil CurrentState as idle
// after a grace period, giving the MCP time to report
// during normal startup.
const nilStateGracePeriod = 30 * time.Second
var nilStateDeadline time.Time
// TODO(DanielleMaywood):
// When we have a streaming Task API, this should be converted
// away from polling.
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
task, err := client.TaskByID(ctx, task.ID)
if err != nil {
return xerrors.Errorf("get task by id: %w", err)
}
switch task.Status {
case codersdk.TaskStatusInitializing,
codersdk.TaskStatusPending:
// Not yet active, keep polling.
continue
case codersdk.TaskStatusActive:
// Task is active; check app state.
if task.CurrentState == nil {
// The MCP may not have reported state yet.
// Start a grace period on first observation
// and treat as idle once it expires.
if nilStateDeadline.IsZero() {
nilStateDeadline = time.Now().Add(nilStateGracePeriod)
}
if time.Now().After(nilStateDeadline) {
return nil
}
continue
}
// Reset nil-state deadline since we got a real
// state report.
nilStateDeadline = time.Time{}
switch task.CurrentState.State {
case codersdk.TaskStateIdle,
codersdk.TaskStateComplete,
codersdk.TaskStateFailed:
return nil
default:
// Still working, keep polling.
continue
}
case codersdk.TaskStatusError:
if time.Now().After(gracePeriodDeadline) {
return xerrors.Errorf("task entered %s state while waiting for it to become idle", task.Status)
}
case codersdk.TaskStatusPaused:
return xerrors.Errorf("task was paused while waiting for it to become idle")
case codersdk.TaskStatusUnknown:
return xerrors.Errorf("task entered %s state while waiting for it to become idle", task.Status)
default:
return xerrors.Errorf("task entered unexpected state (%s) while waiting for it to become idle", task.Status)
}
}
}
}
+13 -224
View File
@@ -12,14 +12,9 @@ import (
"github.com/stretchr/testify/require"
agentapisdk "github.com/coder/agentapi-sdk-go"
"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
)
@@ -30,12 +25,12 @@ func Test_TaskSend(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
var stdout strings.Builder
inv, root := clitest.New(t, "task", "send", setup.task.Name, "carry on with the task")
inv, root := clitest.New(t, "task", "send", task.Name, "carry on with the task")
inv.Stdout = &stdout
clitest.SetupConfig(t, setup.userClient, root)
clitest.SetupConfig(t, userClient, root)
ctx := testutil.Context(t, testutil.WaitLong)
err := inv.WithContext(ctx).Run()
@@ -46,12 +41,12 @@ func Test_TaskSend(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
var stdout strings.Builder
inv, root := clitest.New(t, "task", "send", setup.task.ID.String(), "carry on with the task")
inv, root := clitest.New(t, "task", "send", task.ID.String(), "carry on with the task")
inv.Stdout = &stdout
clitest.SetupConfig(t, setup.userClient, root)
clitest.SetupConfig(t, userClient, root)
ctx := testutil.Context(t, testutil.WaitLong)
err := inv.WithContext(ctx).Run()
@@ -62,13 +57,13 @@ func Test_TaskSend(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
var stdout strings.Builder
inv, root := clitest.New(t, "task", "send", setup.task.Name, "--stdin")
inv, root := clitest.New(t, "task", "send", task.Name, "--stdin")
inv.Stdout = &stdout
inv.Stdin = strings.NewReader("carry on with the task")
clitest.SetupConfig(t, setup.userClient, root)
clitest.SetupConfig(t, userClient, root)
ctx := testutil.Context(t, testutil.WaitLong)
err := inv.WithContext(ctx).Run()
@@ -115,223 +110,17 @@ func Test_TaskSend(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendErr(assert.AnError))
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendErr(t, assert.AnError))
var stdout strings.Builder
inv, root := clitest.New(t, "task", "send", setup.task.Name, "some task input")
inv, root := clitest.New(t, "task", "send", task.Name, "some task input")
inv.Stdout = &stdout
clitest.SetupConfig(t, setup.userClient, root)
clitest.SetupConfig(t, userClient, root)
ctx := testutil.Context(t, testutil.WaitLong)
err := inv.WithContext(ctx).Run()
require.ErrorContains(t, err, assert.AnError.Error())
})
t.Run("WaitsForInitializingTask", func(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "some task input", "some task response"))
// Close the first agent, pause, then resume the task so the
// workspace is started but no agent is connected.
// This puts the task in "initializing" state.
require.NoError(t, setup.agent.Close())
pauseTask(setupCtx, t, setup.userClient, setup.task)
resumeTask(setupCtx, t, setup.userClient, setup.task)
// When: We attempt to send input to the initializing task.
inv, root := clitest.New(t, "task", "send", setup.task.Name, "some task input")
clitest.SetupConfig(t, setup.userClient, root)
ctx := testutil.Context(t, testutil.WaitLong)
inv = inv.WithContext(ctx)
// Use a pty so we can wait for the command to produce build
// output, confirming it has entered the initializing code
// path before we connect the agent.
pty := ptytest.New(t).Attach(inv)
w := clitest.StartWithWaiter(t, inv)
// Wait for the command to observe the initializing state and
// start watching the workspace build. This ensures the command
// has entered the waiting code path.
pty.ExpectMatchContext(ctx, "Queued")
// Connect a new agent so the task can transition to active.
agentClient := agentsdk.New(setup.userClient.URL, agentsdk.WithFixedToken(setup.agentToken))
setup.agent = agenttest.New(t, setup.userClient.URL, setup.agentToken, func(o *agent.Options) {
o.Client = agentClient
})
coderdtest.NewWorkspaceAgentWaiter(t, setup.userClient, setup.task.WorkspaceID.UUID).
WaitFor(coderdtest.AgentsReady)
// Report the task app as idle so waitForTaskIdle can proceed.
require.NoError(t, agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
AppSlug: "task-sidebar",
State: codersdk.WorkspaceAppStatusStateIdle,
Message: "ready",
}))
// Then: The command should complete successfully.
require.NoError(t, w.Wait())
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusActive, updated.Status)
})
t.Run("ResumesPausedTask", func(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "some task input", "some task response"))
// Close the first agent before pausing so it does not conflict
// with the agent we reconnect after the workspace is resumed.
require.NoError(t, setup.agent.Close())
pauseTask(setupCtx, t, setup.userClient, setup.task)
// When: We attempt to send input to the paused task.
inv, root := clitest.New(t, "task", "send", setup.task.Name, "some task input")
clitest.SetupConfig(t, setup.userClient, root)
ctx := testutil.Context(t, testutil.WaitLong)
inv = inv.WithContext(ctx)
// Use a pty so we can wait for the command to produce build
// output, confirming it has entered the paused code path and
// triggered a resume before we connect the agent.
pty := ptytest.New(t).Attach(inv)
w := clitest.StartWithWaiter(t, inv)
// Wait for the command to observe the paused state, trigger
// a resume, and start watching the workspace build.
pty.ExpectMatchContext(ctx, "Queued")
// Connect a new agent so the task can transition to active.
agentClient := agentsdk.New(setup.userClient.URL, agentsdk.WithFixedToken(setup.agentToken))
setup.agent = agenttest.New(t, setup.userClient.URL, setup.agentToken, func(o *agent.Options) {
o.Client = agentClient
})
coderdtest.NewWorkspaceAgentWaiter(t, setup.userClient, setup.task.WorkspaceID.UUID).
WaitFor(coderdtest.AgentsReady)
// Report the task app as idle so waitForTaskIdle can proceed.
require.NoError(t, agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
AppSlug: "task-sidebar",
State: codersdk.WorkspaceAppStatusStateIdle,
Message: "ready",
}))
// Then: The command should complete successfully.
require.NoError(t, w.Wait())
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusActive, updated.Status)
})
t.Run("PausedDuringWaitForReady", func(t *testing.T) {
t.Parallel()
// Given: An initializing task (workspace running, no agent
// connected).
setupCtx := testutil.Context(t, testutil.WaitLong)
setup := setupCLITaskTest(setupCtx, t, nil)
require.NoError(t, setup.agent.Close())
pauseTask(setupCtx, t, setup.userClient, setup.task)
resumeTask(setupCtx, t, setup.userClient, setup.task)
// When: We attempt to send input to the initializing task.
inv, root := clitest.New(t, "task", "send", setup.task.Name, "some task input")
clitest.SetupConfig(t, setup.userClient, root)
ctx := testutil.Context(t, testutil.WaitLong)
inv = inv.WithContext(ctx)
pty := ptytest.New(t).Attach(inv)
w := clitest.StartWithWaiter(t, inv)
// Wait for the command to enter the build-watching phase
// of waitForTaskReady.
pty.ExpectMatchContext(ctx, "Queued")
// Pause the task while waitForTaskReady is polling. Since
// no agent is connected, the task stays initializing until
// we pause it, at which point the status becomes paused.
pauseTask(ctx, t, setup.userClient, setup.task)
// Then: The command should fail because the task was paused.
err := w.Wait()
require.Error(t, err)
require.ErrorContains(t, err, "was paused while waiting for it to become idle")
})
t.Run("WaitsForWorkingAppState", func(t *testing.T) {
t.Parallel()
// Given: An active task whose app is in "working" state.
setupCtx := testutil.Context(t, testutil.WaitLong)
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "some task input", "some task response"))
// Move the app into "working" state before running the command.
agentClient := agentsdk.New(setup.userClient.URL, agentsdk.WithFixedToken(setup.agentToken))
require.NoError(t, agentClient.PatchAppStatus(setupCtx, agentsdk.PatchAppStatus{
AppSlug: "task-sidebar",
State: codersdk.WorkspaceAppStatusStateWorking,
Message: "busy",
}))
// When: We send input while the app is working.
inv, root := clitest.New(t, "task", "send", setup.task.Name, "some task input")
clitest.SetupConfig(t, setup.userClient, root)
ctx := testutil.Context(t, testutil.WaitLong)
inv = inv.WithContext(ctx)
w := clitest.StartWithWaiter(t, inv)
// Transition the app back to idle so waitForTaskIdle proceeds.
require.NoError(t, agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
AppSlug: "task-sidebar",
State: codersdk.WorkspaceAppStatusStateIdle,
Message: "ready",
}))
// Then: The command should complete successfully.
require.NoError(t, w.Wait())
})
t.Run("SendToNonIdleAppState", func(t *testing.T) {
t.Parallel()
for _, appState := range []codersdk.WorkspaceAppStatusState{
codersdk.WorkspaceAppStatusStateComplete,
codersdk.WorkspaceAppStatusStateFailure,
} {
t.Run(string(appState), func(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "some input", "some response"))
agentClient := agentsdk.New(setup.userClient.URL, agentsdk.WithFixedToken(setup.agentToken))
require.NoError(t, agentClient.PatchAppStatus(setupCtx, agentsdk.PatchAppStatus{
AppSlug: "task-sidebar",
State: appState,
Message: "done",
}))
inv, root := clitest.New(t, "task", "send", setup.task.Name, "some input")
clitest.SetupConfig(t, setup.userClient, root)
ctx := testutil.Context(t, testutil.WaitLong)
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
})
}
})
}
func fakeAgentAPITaskSendOK(t *testing.T, expectMessage, returnMessage string) map[string]http.HandlerFunc {
@@ -362,7 +151,7 @@ func fakeAgentAPITaskSendOK(t *testing.T, expectMessage, returnMessage string) m
}
}
func fakeAgentAPITaskSendErr(returnErr error) map[string]http.HandlerFunc {
func fakeAgentAPITaskSendErr(t *testing.T, returnErr error) map[string]http.HandlerFunc {
return map[string]http.HandlerFunc{
"/status": func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
+5 -56
View File
@@ -88,13 +88,6 @@ func Test_Tasks(t *testing.T) {
o.Client = agentClient
})
coderdtest.NewWorkspaceAgentWaiter(t, userClient, tasks[0].WorkspaceID.UUID).WithContext(ctx).WaitFor(coderdtest.AgentsReady)
// Report the task app as idle so that waitForTaskIdle
// can proceed during the "send task message" step.
require.NoError(t, agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
AppSlug: "task-sidebar",
State: codersdk.WorkspaceAppStatusStateIdle,
Message: "ready",
}))
},
},
{
@@ -279,19 +272,10 @@ func fakeAgentAPIEcho(ctx context.Context, t testing.TB, initMsg agentapisdk.Mes
// setupCLITaskTest creates a test workspace with an AI task template and agent,
// with a fake agent API configured with the provided set of handlers.
// Returns the user client and workspace.
// setupCLITaskTestResult holds the return values from setupCLITaskTest.
type setupCLITaskTestResult struct {
ownerClient *codersdk.Client
userClient *codersdk.Client
task codersdk.Task
agentToken string
agent agent.Agent
}
func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[string]http.HandlerFunc) setupCLITaskTestResult {
func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[string]http.HandlerFunc) (ownerClient *codersdk.Client, memberClient *codersdk.Client, task codersdk.Task) {
t.Helper()
ownerClient := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
ownerClient = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, ownerClient)
userClient, _ := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID)
@@ -308,56 +292,21 @@ func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[st
})
require.NoError(t, err)
// Wait for the task's underlying workspace to be built.
// Wait for the task's underlying workspace to be built
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
workspace, err := userClient.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, workspace.LatestBuild.ID)
agentClient := agentsdk.New(userClient.URL, agentsdk.WithFixedToken(authToken))
agt := agenttest.New(t, userClient.URL, authToken, func(o *agent.Options) {
_ = agenttest.New(t, userClient.URL, authToken, func(o *agent.Options) {
o.Client = agentClient
})
coderdtest.NewWorkspaceAgentWaiter(t, userClient, workspace.ID).
WaitFor(coderdtest.AgentsReady)
// Report the task app as idle so that waitForTaskIdle can proceed.
err = agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
AppSlug: "task-sidebar",
State: codersdk.WorkspaceAppStatusStateIdle,
Message: "ready",
})
require.NoError(t, err)
return setupCLITaskTestResult{
ownerClient: ownerClient,
userClient: userClient,
task: task,
agentToken: authToken,
agent: agt,
}
}
// pauseTask pauses the task and waits for the stop build to complete.
func pauseTask(ctx context.Context, t *testing.T, client *codersdk.Client, task codersdk.Task) {
t.Helper()
pauseResp, err := client.PauseTask(ctx, task.OwnerName, task.ID)
require.NoError(t, err)
require.NotNil(t, pauseResp.WorkspaceBuild)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, pauseResp.WorkspaceBuild.ID)
}
// resumeTask resumes the task waits for the start build to complete. The task
// will be in "initializing" state after this returns because no agent is connected.
func resumeTask(ctx context.Context, t *testing.T, client *codersdk.Client, task codersdk.Task) {
t.Helper()
resumeResp, err := client.ResumeTask(ctx, task.OwnerName, task.ID)
require.NoError(t, err)
require.NotNil(t, resumeResp.WorkspaceBuild)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, resumeResp.WorkspaceBuild.ID)
return ownerClient, userClient, task
}
// setupCLITaskTestWithSnapshot creates a task in the specified status with a log snapshot.
+2 -5
View File
@@ -5,14 +5,11 @@ USAGE:
Send input to a task
Send input to a task. If the task is paused, it will be automatically resumed
before input is sent. If the task is initializing, it will wait for the task
to become ready.
- Send direct input to a task:
- Send direct input to a task.:
$ coder task send task1 "Please also add unit tests"
- Send input from stdin to a task:
- Send input from stdin to a task.:
$ echo "Please also add unit tests" | coder task send task1 --stdin
+2 -13
View File
@@ -16,7 +16,6 @@ import (
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
@@ -185,22 +184,12 @@ func TestTokens(t *testing.T) {
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
// Precondition: validate token is not expired before expiring
var expiredAtBefore time.Time
token, err := client.APIKeyByName(ctx, secondUser.ID.String(), "token-two")
require.NoError(t, err)
now := dbtime.Now()
require.True(t, token.ExpiresAt.After(now), "token should not be expired yet (expiresAt=%s, now=%s)", token.ExpiresAt.UTC(), now)
expiredAtBefore = token.ExpiresAt
err = inv.WithContext(ctx).Run()
require.NoError(t, err)
// Validate that token was expired
if token, err := client.APIKeyByName(ctx, secondUser.ID.String(), "token-two"); assert.NoError(t, err) {
now := dbtime.Now()
require.NotEqual(t, token.ExpiresAt, expiredAtBefore, "token expiresAt is the same as before expiring, but should have been updated")
require.False(t, token.ExpiresAt.After(now), "token expiresAt should not be in the future after expiring, but was %s (now=%s)", token.ExpiresAt.UTC(), now)
now := time.Now()
require.False(t, token.ExpiresAt.After(now), "token expiresAt should not be in the future, but was %s (now=%s)", token.ExpiresAt, now)
}
// Delete by ID (explicit delete flag)
@@ -387,9 +387,9 @@ func (b *Batcher) flush(ctx context.Context, reason string) {
b.Metrics.BatchSize.Observe(float64(count))
b.Metrics.MetadataTotal.Add(float64(count))
b.Metrics.BatchesTotal.WithLabelValues(reason).Inc()
elapsed = b.clock.Since(start)
b.Metrics.FlushDuration.WithLabelValues(reason).Observe(elapsed.Seconds())
b.Metrics.FlushDuration.WithLabelValues(reason).Observe(time.Since(start).Seconds())
elapsed = time.Since(start)
b.log.Debug(ctx, "flush complete",
slog.F("count", count),
slog.F("elapsed", elapsed),
+1 -13
View File
@@ -315,18 +315,6 @@ func taskFromDBTaskAndWorkspace(dbTask database.Task, ws codersdk.Workspace) cod
}
}
// appStatusStateToTaskState converts a WorkspaceAppStatusState to a
// TaskState. The two enums mostly share values but "failure" in the
// app status maps to "failed" in the public task API.
func appStatusStateToTaskState(s codersdk.WorkspaceAppStatusState) codersdk.TaskState {
switch s {
case codersdk.WorkspaceAppStatusStateFailure:
return codersdk.TaskStateFailed
default:
return codersdk.TaskState(s)
}
}
// deriveTaskCurrentState determines the current state of a task based on the
// workspace's latest app status and initialization phase.
// Returns nil if no valid state can be determined.
@@ -346,7 +334,7 @@ func deriveTaskCurrentState(
if ws.LatestBuild.Transition != codersdk.WorkspaceTransitionStart || ws.LatestAppStatus.CreatedAt.After(ws.LatestBuild.CreatedAt) {
currentState = &codersdk.TaskStateEntry{
Timestamp: ws.LatestAppStatus.CreatedAt,
State: appStatusStateToTaskState(ws.LatestAppStatus.State),
State: codersdk.TaskState(ws.LatestAppStatus.State),
Message: ws.LatestAppStatus.Message,
URI: ws.LatestAppStatus.URI,
}
-160
View File
@@ -481,128 +481,6 @@ const docTemplate = `{
}
}
},
"/chats/files": {
"post": {
"security": [
{
"CoderSessionToken": []
}
],
"consumes": [
"application/octet-stream"
],
"produces": [
"application/json"
],
"tags": [
"Chats"
],
"summary": "Upload a chat file",
"operationId": "upload-chat-file",
"parameters": [
{
"type": "string",
"description": "Content-Type must be an image type (image/png, image/jpeg, image/gif, image/webp)",
"name": "Content-Type",
"in": "header",
"required": true
},
{
"type": "string",
"format": "uuid",
"description": "Organization ID",
"name": "organization",
"in": "query",
"required": true
}
],
"responses": {
"201": {
"description": "Created",
"schema": {
"$ref": "#/definitions/codersdk.UploadChatFileResponse"
}
},
"400": {
"description": "Bad Request",
"schema": {
"$ref": "#/definitions/codersdk.Response"
}
},
"401": {
"description": "Unauthorized",
"schema": {
"$ref": "#/definitions/codersdk.Response"
}
},
"413": {
"description": "Request Entity Too Large",
"schema": {
"$ref": "#/definitions/codersdk.Response"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/codersdk.Response"
}
}
}
}
},
"/chats/files/{file}": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"tags": [
"Chats"
],
"summary": "Get a chat file",
"operationId": "get-chat-file",
"parameters": [
{
"type": "string",
"format": "uuid",
"description": "File ID",
"name": "file",
"in": "path",
"required": true
}
],
"responses": {
"200": {
"description": "OK"
},
"400": {
"description": "Bad Request",
"schema": {
"$ref": "#/definitions/codersdk.Response"
}
},
"401": {
"description": "Unauthorized",
"schema": {
"$ref": "#/definitions/codersdk.Response"
}
},
"404": {
"description": "Not Found",
"schema": {
"$ref": "#/definitions/codersdk.Response"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/codersdk.Response"
}
}
}
}
},
"/chats/{chat}/archive": {
"post": {
"tags": [
@@ -617,35 +495,6 @@ const docTemplate = `{
}
}
},
"/chats/{chat}/git/watch": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"tags": [
"Chats"
],
"summary": "Watch git changes for a chat.",
"operationId": "watch-chat-git",
"parameters": [
{
"type": "string",
"format": "uuid",
"description": "Chat ID",
"name": "chat",
"in": "path",
"required": true
}
],
"responses": {
"101": {
"description": "Switching Protocols"
}
}
}
},
"/chats/{chat}/unarchive": {
"post": {
"tags": [
@@ -20456,15 +20305,6 @@ const docTemplate = `{
}
}
},
"codersdk.UploadChatFileResponse": {
"type": "object",
"properties": {
"id": {
"type": "string",
"format": "uuid"
}
}
},
"codersdk.UploadResponse": {
"type": "object",
"properties": {
-150
View File
@@ -410,120 +410,6 @@
}
}
},
"/chats/files": {
"post": {
"security": [
{
"CoderSessionToken": []
}
],
"consumes": ["application/octet-stream"],
"produces": ["application/json"],
"tags": ["Chats"],
"summary": "Upload a chat file",
"operationId": "upload-chat-file",
"parameters": [
{
"type": "string",
"description": "Content-Type must be an image type (image/png, image/jpeg, image/gif, image/webp)",
"name": "Content-Type",
"in": "header",
"required": true
},
{
"type": "string",
"format": "uuid",
"description": "Organization ID",
"name": "organization",
"in": "query",
"required": true
}
],
"responses": {
"201": {
"description": "Created",
"schema": {
"$ref": "#/definitions/codersdk.UploadChatFileResponse"
}
},
"400": {
"description": "Bad Request",
"schema": {
"$ref": "#/definitions/codersdk.Response"
}
},
"401": {
"description": "Unauthorized",
"schema": {
"$ref": "#/definitions/codersdk.Response"
}
},
"413": {
"description": "Request Entity Too Large",
"schema": {
"$ref": "#/definitions/codersdk.Response"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/codersdk.Response"
}
}
}
}
},
"/chats/files/{file}": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"tags": ["Chats"],
"summary": "Get a chat file",
"operationId": "get-chat-file",
"parameters": [
{
"type": "string",
"format": "uuid",
"description": "File ID",
"name": "file",
"in": "path",
"required": true
}
],
"responses": {
"200": {
"description": "OK"
},
"400": {
"description": "Bad Request",
"schema": {
"$ref": "#/definitions/codersdk.Response"
}
},
"401": {
"description": "Unauthorized",
"schema": {
"$ref": "#/definitions/codersdk.Response"
}
},
"404": {
"description": "Not Found",
"schema": {
"$ref": "#/definitions/codersdk.Response"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/codersdk.Response"
}
}
}
}
},
"/chats/{chat}/archive": {
"post": {
"tags": ["Chats"],
@@ -536,33 +422,6 @@
}
}
},
"/chats/{chat}/git/watch": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"tags": ["Chats"],
"summary": "Watch git changes for a chat.",
"operationId": "watch-chat-git",
"parameters": [
{
"type": "string",
"format": "uuid",
"description": "Chat ID",
"name": "chat",
"in": "path",
"required": true
}
],
"responses": {
"101": {
"description": "Switching Protocols"
}
}
}
},
"/chats/{chat}/unarchive": {
"post": {
"tags": ["Chats"],
@@ -18764,15 +18623,6 @@
}
}
},
"codersdk.UploadChatFileResponse": {
"type": "object",
"properties": {
"id": {
"type": "string",
"format": "uuid"
}
}
},
"codersdk.UploadResponse": {
"type": "object",
"properties": {
+11 -11
View File
@@ -48,8 +48,8 @@ func TestTokenCRUD(t *testing.T) {
require.EqualValues(t, len(keys), 1)
require.Contains(t, res.Key, keys[0].ID)
// expires_at should default to 30 days
require.Greater(t, keys[0].ExpiresAt, dbtime.Now().Add(time.Hour*24*6))
require.Less(t, keys[0].ExpiresAt, dbtime.Now().Add(time.Hour*24*8))
require.Greater(t, keys[0].ExpiresAt, time.Now().Add(time.Hour*24*6))
require.Less(t, keys[0].ExpiresAt, time.Now().Add(time.Hour*24*8))
require.Equal(t, codersdk.APIKeyScopeAll, keys[0].Scope)
require.Len(t, keys[0].AllowList, 1)
require.Equal(t, "*:*", keys[0].AllowList[0].String())
@@ -194,8 +194,8 @@ func TestUserSetTokenDuration(t *testing.T) {
require.NoError(t, err)
keys, err := client.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{})
require.NoError(t, err)
require.Greater(t, keys[0].ExpiresAt, dbtime.Now().Add(time.Hour*6*24))
require.Less(t, keys[0].ExpiresAt, dbtime.Now().Add(time.Hour*8*24))
require.Greater(t, keys[0].ExpiresAt, time.Now().Add(time.Hour*6*24))
require.Less(t, keys[0].ExpiresAt, time.Now().Add(time.Hour*8*24))
}
func TestDefaultTokenDuration(t *testing.T) {
@@ -210,8 +210,8 @@ func TestDefaultTokenDuration(t *testing.T) {
require.NoError(t, err)
keys, err := client.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{})
require.NoError(t, err)
require.Greater(t, keys[0].ExpiresAt, dbtime.Now().Add(time.Hour*24*6))
require.Less(t, keys[0].ExpiresAt, dbtime.Now().Add(time.Hour*24*8))
require.Greater(t, keys[0].ExpiresAt, time.Now().Add(time.Hour*24*6))
require.Less(t, keys[0].ExpiresAt, time.Now().Add(time.Hour*24*8))
}
func TestTokenUserSetMaxLifetime(t *testing.T) {
@@ -518,7 +518,7 @@ func TestExpireAPIKey(t *testing.T) {
// Verify the token is not expired.
key, err := adminClient.APIKeyByID(ctx, codersdk.Me, keyID)
require.NoError(t, err)
require.True(t, key.ExpiresAt.After(dbtime.Now()))
require.True(t, key.ExpiresAt.After(time.Now()))
auditor.ResetLogs()
@@ -529,7 +529,7 @@ func TestExpireAPIKey(t *testing.T) {
// Verify the token is expired.
key, err = adminClient.APIKeyByID(ctx, codersdk.Me, keyID)
require.NoError(t, err)
require.True(t, key.ExpiresAt.Before(dbtime.Now()))
require.True(t, key.ExpiresAt.Before(time.Now()))
// Verify audit log.
als := auditor.AuditLogs()
@@ -556,7 +556,7 @@ func TestExpireAPIKey(t *testing.T) {
// Verify the token is expired.
key, err := memberClient.APIKeyByID(ctx, codersdk.Me, keyID)
require.NoError(t, err)
require.True(t, key.ExpiresAt.Before(dbtime.Now()))
require.True(t, key.ExpiresAt.Before(time.Now()))
})
t.Run("MemberCannotExpireOtherUsersToken", func(t *testing.T) {
@@ -607,7 +607,7 @@ func TestExpireAPIKey(t *testing.T) {
// Invariant: make sure it's actually expired
key, err := adminClient.APIKeyByID(ctx, codersdk.Me, keyID)
require.NoError(t, err)
require.LessOrEqual(t, key.ExpiresAt, dbtime.Now(), "key should be expired")
require.LessOrEqual(t, key.ExpiresAt, time.Now(), "key should be expired")
// Expire it again - should succeed (idempotent).
err = adminClient.ExpireAPIKey(ctx, codersdk.Me, keyID)
@@ -636,7 +636,7 @@ func TestExpireAPIKey(t *testing.T) {
// Verify it's expired.
key, err := adminClient.APIKeyByID(ctx, codersdk.Me, keyID)
require.NoError(t, err)
require.True(t, key.ExpiresAt.Before(dbtime.Now()))
require.True(t, key.ExpiresAt.Before(time.Now()))
// Delete the expired token - should succeed.
err = adminClient.DeleteAPIKey(ctx, codersdk.Me, keyID)
+1 -4
View File
@@ -599,11 +599,8 @@ func TestExecutorAutostopAIAgentActivity(t *testing.T) {
require.NoError(t, err)
// When: the autobuild executor ticks after the bumped deadline.
// Use time.Now() to account for elapsed time since the test's
// "now" variable, because the activity bump uses the database
// NOW() which advances with wall clock time.
go func() {
tickTime := time.Now().Add(time.Hour).Add(time.Minute)
tickTime := now.Add(time.Hour).Add(time.Minute)
coderdtest.UpdateProvisionerLastSeenAt(t, db, p.ID, tickTime)
tickCh <- tickTime
close(tickCh)
+253 -532
View File
File diff suppressed because it is too large Load Diff
+21 -106
View File
@@ -73,11 +73,10 @@ func TestInterruptChatBroadcastsStatusAcrossInstances(t *testing.T) {
require.Eventually(t, func() bool {
select {
case event := <-events:
if event.Type == codersdk.ChatStreamEventTypeStatus && event.Status != nil {
return event.Status.Status == codersdk.ChatStatusWaiting
if event.Type != codersdk.ChatStreamEventTypeStatus || event.Status == nil {
return false
}
t.Logf("skipping unexpected event: type=%s", event.Type)
return false
return event.Status.Status == codersdk.ChatStatusWaiting
default:
return false
}
@@ -367,7 +366,7 @@ func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
require.Len(t, messages, 1)
}
func TestSendMessageInterruptBehaviorQueuesAndInterruptsWhenBusy(t *testing.T) {
func TestSendMessageInterruptBehaviorSendsImmediatelyWhenBusy(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
@@ -399,31 +398,26 @@ func TestSendMessageInterruptBehaviorQueuesAndInterruptsWhenBusy(t *testing.T) {
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
})
require.NoError(t, err)
// The message should be queued, not inserted directly.
require.True(t, result.Queued)
require.NotNil(t, result.QueuedMessage)
// The chat should transition to waiting (interrupt signal),
// not pending.
require.Equal(t, database.ChatStatusWaiting, result.Chat.Status)
require.False(t, result.Queued)
require.Equal(t, database.ChatStatusPending, result.Chat.Status)
require.False(t, result.Chat.WorkerID.Valid)
fromDB, err := db.GetChatByID(ctx, chat.ID)
require.NoError(t, err)
require.Equal(t, database.ChatStatusWaiting, fromDB.Status)
require.Equal(t, database.ChatStatusPending, fromDB.Status)
require.False(t, fromDB.WorkerID.Valid)
// The message should be in the queue, not in chat_messages.
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
require.NoError(t, err)
require.Len(t, queued, 1)
require.Len(t, queued, 0)
// Only the initial user message should be in chat_messages.
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
require.NoError(t, err)
require.Len(t, messages, 1)
require.Len(t, messages, 2)
require.Equal(t, messages[len(messages)-1].ID, result.Message.ID)
}
func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) {
@@ -871,15 +865,15 @@ func TestSubscribeNoPubsubNoDuplicateMessageParts(t *testing.T) {
// events — the snapshot already contained everything. Before
// the fix, localSnapshot was replayed into the channel,
// causing duplicates.
require.Never(t, func() bool {
select {
case <-events:
return true
default:
return false
select {
case event, ok := <-events:
if ok {
t.Fatalf("unexpected event from channel (would be a duplicate): type=%s", event.Type)
}
}, 200*time.Millisecond, testutil.IntervalFast,
"expected no duplicate events after snapshot")
// Channel closed without events is fine.
case <-time.After(200 * time.Millisecond):
// No events — correct behavior.
}
}
func TestSubscribeAfterMessageID(t *testing.T) {
@@ -1468,17 +1462,10 @@ func TestInterruptChatDoesNotSendWebPushNotification(t *testing.T) {
// mockWebpushDispatcher implements webpush.Dispatcher and records Dispatch calls.
type mockWebpushDispatcher struct {
dispatchCount atomic.Int32
mu sync.Mutex
lastMessage codersdk.WebpushMessage
lastUserID uuid.UUID
}
func (m *mockWebpushDispatcher) Dispatch(_ context.Context, userID uuid.UUID, msg codersdk.WebpushMessage) error {
func (m *mockWebpushDispatcher) Dispatch(_ context.Context, _ uuid.UUID, _ codersdk.WebpushMessage) error {
m.dispatchCount.Add(1)
m.mu.Lock()
m.lastMessage = msg
m.lastUserID = userID
m.mu.Unlock()
return nil
}
@@ -1490,78 +1477,6 @@ func (*mockWebpushDispatcher) PublicKey() string {
return "test-vapid-public-key"
}
func TestSuccessfulChatSendsWebPushWithNavigationData(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
// Set up a mock OpenAI that returns a simple successful response.
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("title")
}
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("done")...,
)
})
// Mock webpush dispatcher that captures the dispatched message.
mockPush := &mockWebpushDispatcher{}
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := chatd.New(chatd.Config{
Logger: logger,
Database: db,
ReplicaID: uuid.New(),
Pubsub: ps,
PendingChatAcquireInterval: 10 * time.Millisecond,
InFlightChatStaleAfter: testutil.WaitSuperLong,
WebpushDispatcher: mockPush,
})
t.Cleanup(func() {
require.NoError(t, server.Close())
})
user, model := seedChatDependencies(ctx, t, db)
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "push-nav-test",
ModelConfigID: model.ID,
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
})
require.NoError(t, err)
// Wait for the chat to complete and return to waiting status.
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
if dbErr != nil {
return false
}
return fromDB.Status == database.ChatStatusWaiting && !fromDB.WorkerID.Valid && mockPush.dispatchCount.Load() == 1
}, testutil.IntervalFast)
// Verify a web push notification was dispatched exactly once.
require.Equal(t, int32(1), mockPush.dispatchCount.Load(),
"expected exactly one web push dispatch for a completed chat")
// Verify the notification was sent to the correct user.
mockPush.mu.Lock()
capturedMsg := mockPush.lastMessage
capturedUserID := mockPush.lastUserID
mockPush.mu.Unlock()
require.Equal(t, user.ID, capturedUserID,
"web push should be dispatched to the chat owner")
// Verify the Data field contains the correct navigation URL.
expectedURL := fmt.Sprintf("/agents/%s", chat.ID)
require.Equal(t, expectedURL, capturedMsg.Data["url"],
"web push Data should contain the chat navigation URL")
}
func TestCloseDuringShutdownContextCanceledShouldRetryOnNewReplica(t *testing.T) {
t.Parallel()
+11 -16
View File
@@ -73,9 +73,7 @@ type RunOptions struct {
// OnRetry is called before each retry attempt when the LLM
// stream fails with a retryable error. It provides the attempt
// number, error, and backoff delay so callers can publish status
// events to connected clients. Callers should also clear any
// buffered stream state from the failed attempt in this callback
// to avoid sending duplicated content.
// events to connected clients.
OnRetry chatretry.OnRetryFn
OnInterruptedPersistError func(error)
@@ -211,10 +209,6 @@ func Run(ctx context.Context, opts RunOptions) error {
var lastUsage fantasy.Usage
var lastProviderMetadata fantasy.ProviderMetadata
totalSteps := 0
// When totalSteps reaches MaxSteps the inner loop exits immediately
// (its condition is false), stoppedByModel stays false, and the
// post-loop guard breaks the outer compaction loop.
for compactionAttempt := 0; ; compactionAttempt++ {
alreadyCompacted := false
// stoppedByModel is true when the inner step loop
@@ -228,8 +222,7 @@ func Run(ctx context.Context, opts RunOptions) error {
// agent never had a chance to use the compacted context.
compactedOnFinalStep := false
for step := 0; totalSteps < opts.MaxSteps; step++ {
totalSteps++
for step := 0; step < opts.MaxSteps; step++ {
// Copy messages so that provider-specific caching
// mutations don't leak back to the caller's slice.
// copy copies Message structs by value, so field
@@ -328,12 +321,6 @@ func Run(ctx context.Context, opts RunOptions) error {
lastUsage = result.usage
lastProviderMetadata = result.providerMetadata
// Append the step's response messages so that both
// inline and post-loop compaction see the full
// conversation including the latest assistant reply.
stepMessages := result.toResponseMessages()
messages = append(messages, stepMessages...)
// Inline compaction.
if opts.Compaction != nil && opts.ReloadMessages != nil {
did, compactErr := tryCompact(
@@ -367,11 +354,17 @@ func Run(ctx context.Context, opts RunOptions) error {
// The agent is continuing with tool calls, so any
// prior compaction has already been consumed.
compactedOnFinalStep = false
// Build messages from the step for the next iteration.
// toResponseMessages produces assistant-role content
// (text, reasoning, tool calls) and tool-result content.
stepMessages := result.toResponseMessages()
messages = append(messages, stepMessages...)
}
// Post-run compaction safety net: if we never compacted
// during the loop, try once at the end.
if !alreadyCompacted && opts.Compaction != nil && opts.ReloadMessages != nil {
if !alreadyCompacted && opts.Compaction != nil {
did, err := tryCompact(
ctx,
opts.Model,
@@ -390,6 +383,7 @@ func Run(ctx context.Context, opts RunOptions) error {
compactedOnFinalStep = true
}
}
// Re-enter the step loop when compaction fired on the
// model's final step. This lets the agent continue
// working with fresh summarized context instead of
@@ -520,6 +514,7 @@ func processStepStream(
})
}
}
case fantasy.StreamPartTypeToolInputStart:
activeToolCalls[part.ID] = &fantasy.ToolCallContent{
ToolCallID: part.ID,
+1 -2
View File
@@ -123,8 +123,7 @@ func tryCompact(
config.SystemSummaryPrefix + "\n\n" + summary,
)
persistCtx := context.WithoutCancel(ctx)
err = config.Persist(persistCtx, CompactionResult{
err = config.Persist(ctx, CompactionResult{
SystemSummary: systemSummary,
SummaryReport: summary,
ThresholdPercent: config.ThresholdPercent,
+1 -142
View File
@@ -76,20 +76,9 @@ func TestRun_Compaction(t *testing.T) {
return nil
},
},
ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) {
return []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
}, nil
},
})
require.NoError(t, err)
// Compaction fires twice: once inline when the threshold is
// reached on step 0 (the only step, since MaxSteps=1), and
// once from the post-run safety net during the re-entry
// iteration (where totalSteps already equals MaxSteps so the
// inner loop doesn't execute, but lastUsage still exceeds
// the threshold).
require.Equal(t, 2, persistCompactionCalls)
require.Equal(t, 1, persistCompactionCalls)
require.Contains(t, persistedCompaction.SystemSummary, summaryText)
require.Equal(t, summaryText, persistedCompaction.SummaryReport)
require.Equal(t, int64(80), persistedCompaction.ContextTokens)
@@ -162,25 +151,13 @@ func TestRun_Compaction(t *testing.T) {
return nil
},
},
ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) {
return []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
}, nil
},
})
require.NoError(t, err)
// Compaction fires twice (see PersistsWhenThresholdReached
// for the full explanation). Each cycle follows the order:
// publish_tool_call → generate → persist → publish_tool_result.
require.Equal(t, []string{
"publish_tool_call",
"generate",
"persist",
"publish_tool_result",
"publish_tool_call",
"generate",
"persist",
"publish_tool_result",
}, callOrder)
})
@@ -480,11 +457,6 @@ func TestRun_Compaction(t *testing.T) {
compactionErr = err
},
},
ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) {
return []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
}, nil
},
})
require.NoError(t, err)
require.Error(t, compactionErr)
@@ -600,117 +572,4 @@ func TestRun_Compaction(t *testing.T) {
// Two stream calls: one before compaction, one after re-entry.
require.Equal(t, 2, streamCallCount)
})
t.Run("PostRunCompactionReEntryIncludesUserSummary", func(t *testing.T) {
t.Parallel()
// After compaction the summary is stored as a user-role
// message. When the loop re-enters, the reloaded prompt
// must contain this user message so the LLM provider
// receives a valid prompt (providers like Anthropic
// require at least one non-system message).
var mu sync.Mutex
var streamCallCount int
var reEntryPrompt []fantasy.Message
persistCompactionCalls := 0
const summaryText = "post-run compacted summary"
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
mu.Lock()
step := streamCallCount
streamCallCount++
mu.Unlock()
switch step {
case 0:
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "initial response"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
{
Type: fantasy.StreamPartTypeFinish,
FinishReason: fantasy.FinishReasonStop,
Usage: fantasy.Usage{
InputTokens: 80,
TotalTokens: 85,
},
},
}), nil
default:
mu.Lock()
reEntryPrompt = append([]fantasy.Message(nil), call.Prompt...)
mu.Unlock()
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-2"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-2", Delta: "continued"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-2"},
{
Type: fantasy.StreamPartTypeFinish,
FinishReason: fantasy.FinishReasonStop,
Usage: fantasy.Usage{
InputTokens: 20,
TotalTokens: 25,
},
},
}), nil
}
},
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
return &fantasy.Response{
Content: []fantasy.Content{
fantasy.TextContent{Text: summaryText},
},
}, nil
},
}
// Simulate real post-compaction DB state: the summary is
// a user-role message (the only non-system content).
compactedMessages := []fantasy.Message{
textMessage(fantasy.MessageRoleSystem, "system prompt"),
textMessage(fantasy.MessageRoleUser, "Summary of earlier chat context:\n\ncompacted summary"),
}
err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
},
MaxSteps: 5,
PersistStep: func(_ context.Context, _ PersistedStep) error {
return nil
},
ContextLimitFallback: 100,
Compaction: &CompactionOptions{
ThresholdPercent: 70,
SummaryPrompt: "summarize now",
Persist: func(_ context.Context, _ CompactionResult) error {
persistCompactionCalls++
return nil
},
},
ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) {
return compactedMessages, nil
},
})
require.NoError(t, err)
require.GreaterOrEqual(t, persistCompactionCalls, 1)
// Re-entry happened: stream was called at least twice.
require.Equal(t, 2, streamCallCount)
// The re-entry prompt must contain the user summary.
require.NotEmpty(t, reEntryPrompt)
hasUser := false
for _, msg := range reEntryPrompt {
if msg.Role == fantasy.MessageRoleUser {
hasUser = true
break
}
}
require.True(t, hasUser, "re-entry prompt must contain a user message (the compaction summary)")
})
}
+2 -185
View File
@@ -1,14 +1,12 @@
package chatprompt
import (
"context"
"encoding/json"
"regexp"
"strings"
"charm.land/fantasy"
fantasyopenai "charm.land/fantasy/providers/openai"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"golang.org/x/xerrors"
@@ -18,156 +16,12 @@ import (
var toolCallIDSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_-]`)
// FileData holds resolved file content for LLM prompt building.
type FileData struct {
Data []byte
MediaType string
}
// FileResolver fetches file content by ID for LLM prompt building.
type FileResolver func(ctx context.Context, ids []uuid.UUID) (map[uuid.UUID]FileData, error)
// ExtractFileID parses the file_id from a serialized file content
// block envelope. Returns uuid.Nil and an error when the block is
// not a file-type block or has no file_id.
func ExtractFileID(raw json.RawMessage) (uuid.UUID, error) {
var envelope struct {
Type string `json:"type"`
Data struct {
FileID string `json:"file_id"`
} `json:"data"`
}
if err := json.Unmarshal(raw, &envelope); err != nil {
return uuid.Nil, xerrors.Errorf("unmarshal content block: %w", err)
}
if !strings.EqualFold(envelope.Type, string(fantasy.ContentTypeFile)) {
return uuid.Nil, xerrors.Errorf("not a file content block: %s", envelope.Type)
}
if envelope.Data.FileID == "" {
return uuid.Nil, xerrors.New("no file_id")
}
return uuid.Parse(envelope.Data.FileID)
}
// extractFileIDs scans raw message content for file_id references.
// Returns a map of block index to file ID. Returns nil for
// non-array content or content with no file references.
func extractFileIDs(raw pqtype.NullRawMessage) map[int]uuid.UUID {
if !raw.Valid || len(raw.RawMessage) == 0 {
return nil
}
var rawBlocks []json.RawMessage
if err := json.Unmarshal(raw.RawMessage, &rawBlocks); err != nil {
return nil
}
var result map[int]uuid.UUID
for i, block := range rawBlocks {
fid, err := ExtractFileID(block)
if err == nil {
if result == nil {
result = make(map[int]uuid.UUID)
}
result[i] = fid
}
}
return result
}
// patchFileContent fills in empty Data on FileContent blocks from
// resolved file data. Blocks that already have inline data (backward
// compat) or have no resolved data are left unchanged.
func patchFileContent(
content []fantasy.Content,
fileIDs map[int]uuid.UUID,
resolved map[uuid.UUID]FileData,
) {
for blockIdx, fid := range fileIDs {
if blockIdx >= len(content) {
continue
}
switch fc := content[blockIdx].(type) {
case fantasy.FileContent:
if len(fc.Data) > 0 {
continue
}
if data, found := resolved[fid]; found {
fc.Data = data.Data
content[blockIdx] = fc
}
case *fantasy.FileContent:
if len(fc.Data) > 0 {
continue
}
if data, found := resolved[fid]; found {
fc.Data = data.Data
}
}
}
}
// ConvertMessages converts persisted chat messages into LLM prompt
// messages without resolving file references from storage. Inline
// file data is preserved when present (backward compat).
func ConvertMessages(
messages []database.ChatMessage,
) ([]fantasy.Message, error) {
return ConvertMessagesWithFiles(context.Background(), messages, nil)
}
// ConvertMessagesWithFiles converts persisted chat messages into LLM
// prompt messages, resolving file references via the provided
// resolver. When resolver is nil, file blocks without inline data
// are passed through as-is (same behavior as ConvertMessages).
func ConvertMessagesWithFiles(
ctx context.Context,
messages []database.ChatMessage,
resolver FileResolver,
) ([]fantasy.Message, error) {
// Phase 1: Pre-scan user messages for file_id references.
var allFileIDs []uuid.UUID
seenFileIDs := make(map[uuid.UUID]struct{})
fileIDsByMsg := make(map[int]map[int]uuid.UUID)
if resolver != nil {
for i, msg := range messages {
visibility := msg.Visibility
if visibility == "" {
visibility = database.ChatMessageVisibilityBoth
}
if visibility != database.ChatMessageVisibilityModel &&
visibility != database.ChatMessageVisibilityBoth {
continue
}
if msg.Role != string(fantasy.MessageRoleUser) {
continue
}
fids := extractFileIDs(msg.Content)
if len(fids) > 0 {
fileIDsByMsg[i] = fids
for _, fid := range fids {
if _, seen := seenFileIDs[fid]; !seen {
seenFileIDs[fid] = struct{}{}
allFileIDs = append(allFileIDs, fid)
}
}
}
}
}
// Phase 2: Batch resolve file data.
var resolved map[uuid.UUID]FileData
if len(allFileIDs) > 0 {
var err error
resolved, err = resolver(ctx, allFileIDs)
if err != nil {
return nil, xerrors.Errorf("resolve chat files: %w", err)
}
}
// Phase 3: Convert messages, patching file content as needed.
prompt := make([]fantasy.Message, 0, len(messages))
toolNameByCallID := make(map[string]string)
for i, message := range messages {
for _, message := range messages {
visibility := message.Visibility
if visibility == "" {
visibility = database.ChatMessageVisibilityBoth
@@ -197,9 +51,6 @@ func ConvertMessagesWithFiles(
if err != nil {
return nil, err
}
if fids, ok := fileIDsByMsg[i]; ok {
patchFileContent(content, fids, resolved)
}
prompt = append(prompt, fantasy.Message{
Role: fantasy.MessageRoleUser,
Content: ToMessageParts(content),
@@ -549,10 +400,7 @@ func ExtractToolCalls(parts []fantasy.MessagePart) []fantasy.ToolCallContent {
}
// MarshalContent encodes message content blocks for persistence.
// fileIDs optionally maps block indices to chat_files IDs, which
// are injected into the JSON envelope for file-type blocks so
// the reference survives round-trips through storage.
func MarshalContent(blocks []fantasy.Content, fileIDs map[int]uuid.UUID) (pqtype.NullRawMessage, error) {
func MarshalContent(blocks []fantasy.Content) (pqtype.NullRawMessage, error) {
if len(blocks) == 0 {
return pqtype.NullRawMessage{}, nil
}
@@ -567,16 +415,6 @@ func MarshalContent(blocks []fantasy.Content, fileIDs map[int]uuid.UUID) (pqtype
err,
)
}
if fid, ok := fileIDs[i]; ok {
encoded, err = injectFileID(encoded, fid)
if err != nil {
return pqtype.NullRawMessage{}, xerrors.Errorf(
"inject file_id into content block %d: %w",
i,
err,
)
}
}
encodedBlocks = append(encodedBlocks, encoded)
}
@@ -587,27 +425,6 @@ func MarshalContent(blocks []fantasy.Content, fileIDs map[int]uuid.UUID) (pqtype
return pqtype.NullRawMessage{RawMessage: data, Valid: true}, nil
}
// injectFileID adds a file_id field into the data sub-object of a
// serialized content block envelope. This follows the same pattern
// as the reasoning title injection in marshalContentBlock.
func injectFileID(encoded json.RawMessage, fileID uuid.UUID) (json.RawMessage, error) {
var envelope struct {
Type string `json:"type"`
Data struct {
MediaType string `json:"media_type"`
Data json.RawMessage `json:"data,omitempty"`
FileID string `json:"file_id,omitempty"`
ProviderMetadata *json.RawMessage `json:"provider_metadata,omitempty"`
} `json:"data"`
}
if err := json.Unmarshal(encoded, &envelope); err != nil {
return encoded, err
}
envelope.Data.FileID = fileID.String()
envelope.Data.Data = nil // Strip inline data; resolved at LLM dispatch time.
return json.Marshal(envelope)
}
// MarshalToolResult encodes a single tool result for persistence as
// an opaque JSON blob. The stored shape is
// [{"tool_call_id":…,"tool_name":…,"result":…,"is_error":…}].
+1 -140
View File
@@ -1,13 +1,10 @@
package chatprompt_test
import (
"context"
"encoding/json"
"testing"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
@@ -55,7 +52,7 @@ func TestConvertMessages_NormalizesAssistantToolCallInput(t *testing.T) {
ToolName: "execute",
Input: tc.input,
},
}, nil)
})
require.NoError(t, err)
toolContent, err := chatprompt.MarshalToolResult(
@@ -92,139 +89,3 @@ func TestConvertMessages_NormalizesAssistantToolCallInput(t *testing.T) {
})
}
}
func TestConvertMessagesWithFiles_ResolvesFileData(t *testing.T) {
t.Parallel()
fileID := uuid.New()
fileData := []byte("fake-image-bytes")
// Build a user message with file_id but no inline data, as
// would be stored after injectFileID strips the data.
rawContent := mustJSON(t, []json.RawMessage{
mustJSON(t, map[string]any{
"type": "file",
"data": map[string]any{
"media_type": "image/png",
"file_id": fileID.String(),
},
}),
})
resolver := func(_ context.Context, ids []uuid.UUID) (map[uuid.UUID]chatprompt.FileData, error) {
result := make(map[uuid.UUID]chatprompt.FileData)
for _, id := range ids {
if id == fileID {
result[id] = chatprompt.FileData{
Data: fileData,
MediaType: "image/png",
}
}
}
return result, nil
}
prompt, err := chatprompt.ConvertMessagesWithFiles(
context.Background(),
[]database.ChatMessage{
{
Role: string(fantasy.MessageRoleUser),
Visibility: database.ChatMessageVisibilityBoth,
Content: pqtype.NullRawMessage{RawMessage: rawContent, Valid: true},
},
},
resolver,
)
require.NoError(t, err)
require.Len(t, prompt, 1)
require.Equal(t, fantasy.MessageRoleUser, prompt[0].Role)
require.Len(t, prompt[0].Content, 1)
filePart, ok := fantasy.AsMessagePart[fantasy.FilePart](prompt[0].Content[0])
require.True(t, ok, "expected FilePart")
require.Equal(t, fileData, filePart.Data)
require.Equal(t, "image/png", filePart.MediaType)
}
func TestConvertMessagesWithFiles_BackwardCompat(t *testing.T) {
t.Parallel()
// A message with inline data and a file_id should use the
// inline data even when the resolver returns nothing.
fileID := uuid.New()
inlineData := []byte("inline-image-data")
rawContent := mustJSON(t, []json.RawMessage{
mustJSON(t, map[string]any{
"type": "file",
"data": map[string]any{
"media_type": "image/png",
"data": inlineData,
"file_id": fileID.String(),
},
}),
})
prompt, err := chatprompt.ConvertMessagesWithFiles(
context.Background(),
[]database.ChatMessage{
{
Role: string(fantasy.MessageRoleUser),
Visibility: database.ChatMessageVisibilityBoth,
Content: pqtype.NullRawMessage{RawMessage: rawContent, Valid: true},
},
},
nil, // No resolver.
)
require.NoError(t, err)
require.Len(t, prompt, 1)
require.Len(t, prompt[0].Content, 1)
filePart, ok := fantasy.AsMessagePart[fantasy.FilePart](prompt[0].Content[0])
require.True(t, ok, "expected FilePart")
require.Equal(t, inlineData, filePart.Data)
}
func TestInjectFileID_StripsInlineData(t *testing.T) {
t.Parallel()
fileID := uuid.New()
imageData := []byte("raw-image-bytes")
// Marshal a file content block with inline data, then inject
// a file_id. The result should have file_id but no data.
content, err := chatprompt.MarshalContent([]fantasy.Content{
fantasy.FileContent{
MediaType: "image/png",
Data: imageData,
},
}, map[int]uuid.UUID{0: fileID})
require.NoError(t, err)
// Parse the stored content to verify shape.
var blocks []json.RawMessage
require.NoError(t, json.Unmarshal(content.RawMessage, &blocks))
require.Len(t, blocks, 1)
var envelope struct {
Type string `json:"type"`
Data struct {
MediaType string `json:"media_type"`
Data *json.RawMessage `json:"data,omitempty"`
FileID string `json:"file_id"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(blocks[0], &envelope))
require.Equal(t, "file", envelope.Type)
require.Equal(t, "image/png", envelope.Data.MediaType)
require.Equal(t, fileID.String(), envelope.Data.FileID)
// Data should be nil (omitted) since injectFileID strips it.
require.Nil(t, envelope.Data.Data, "inline data should be stripped")
}
func mustJSON(t *testing.T, v any) json.RawMessage {
t.Helper()
data, err := json.Marshal(v)
require.NoError(t, err)
return data
}
+7 -17
View File
@@ -8,8 +8,6 @@ import (
"errors"
"strings"
"time"
"golang.org/x/xerrors"
)
const (
@@ -20,12 +18,6 @@ const (
// MaxDelay is the upper bound for the exponential backoff
// duration. Matches the cap used in coder/mux.
MaxDelay = 60 * time.Second
// MaxAttempts is the upper bound on retry attempts before
// giving up. With a 60s max backoff this allows roughly
// 25 minutes of retries, which is reasonable for transient
// LLM provider issues.
MaxAttempts = 25
)
// nonRetryablePatterns are substrings that indicate a permanent error
@@ -139,8 +131,9 @@ type RetryFn func(ctx context.Context) error
type OnRetryFn func(attempt int, err error, delay time.Duration)
// Retry calls fn repeatedly until it succeeds, returns a
// non-retryable error, ctx is canceled, or MaxAttempts is reached.
// Retries use exponential backoff capped at MaxDelay.
// non-retryable error, or ctx is canceled. There is no max attempt
// limit — retries continue indefinitely with exponential backoff
// (capped at 60s), matching the behavior of coder/mux.
//
// The onRetry callback (if non-nil) is called before each retry
// attempt, giving the caller a chance to reset state, log, or
@@ -163,15 +156,10 @@ func Retry(ctx context.Context, fn RetryFn, onRetry OnRetryFn) error {
return ctx.Err()
}
attempt++
if attempt >= MaxAttempts {
return xerrors.Errorf("max retry attempts (%d) exceeded: %w", MaxAttempts, err)
}
delay := Delay(attempt - 1)
delay := Delay(attempt)
if onRetry != nil {
onRetry(attempt, err, delay)
onRetry(attempt+1, err, delay)
}
timer := time.NewTimer(delay)
@@ -181,5 +169,7 @@ func Retry(ctx context.Context, fn RetryFn, onRetry OnRetryFn) error {
return ctx.Err()
case <-timer.C:
}
attempt++
}
}
+1 -42
View File
@@ -11,7 +11,6 @@ import (
"time"
"github.com/google/uuid"
"github.com/openai/openai-go/v3/responses"
)
// OpenAIHandler handles OpenAI API requests and returns a response.
@@ -307,17 +306,6 @@ func writeChatCompletionsStreaming(w http.ResponseWriter, r *http.Request, chunk
}
}
// writeSSEEvent marshals v as JSON and writes it as an SSE data
// frame. Returns any write error.
func writeSSEEvent(w http.ResponseWriter, v interface{}) error {
data, err := json.Marshal(v)
if err != nil {
return err
}
_, err = fmt.Fprintf(w, "data: %s\n\n", data)
return err
}
func writeResponsesAPIStreaming(w http.ResponseWriter, r *http.Request, chunks <-chan OpenAIChunk) {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
@@ -341,23 +329,7 @@ func writeResponsesAPIStreaming(w http.ResponseWriter, r *http.Request, chunks <
return
case chunk, ok = <-chunks:
if !ok {
// Emit Responses API lifecycle events so
// the fantasy client closes open text
// blocks and persists the step content.
for outputIndex, itemID := range itemIDs {
_ = writeSSEEvent(w, responses.ResponseTextDoneEvent{
ItemID: itemID,
OutputIndex: int64(outputIndex),
})
_ = writeSSEEvent(w, responses.ResponseOutputItemDoneEvent{
OutputIndex: int64(outputIndex),
Item: responses.ResponseOutputItemUnion{
ID: itemID,
Type: "message",
},
})
}
_ = writeSSEEvent(w, responses.ResponseCompletedEvent{})
_, _ = fmt.Fprintf(w, "data: [DONE]\n\n")
flusher.Flush()
return
}
@@ -372,19 +344,6 @@ func writeResponsesAPIStreaming(w http.ResponseWriter, r *http.Request, chunks <
if !found {
itemID = fmt.Sprintf("msg_%s", uuid.New().String()[:8])
itemIDs[outputIndex] = itemID
// Emit response.output_item.added so the
// fantasy client triggers TextStart.
if err := writeSSEEvent(w, responses.ResponseOutputItemAddedEvent{
OutputIndex: int64(outputIndex),
Item: responses.ResponseOutputItemUnion{
ID: itemID,
Type: "message",
},
}); err != nil {
return
}
flusher.Flush()
}
chunkData := map[string]interface{}{
+229 -502
View File
@@ -1,22 +1,17 @@
package coderd
import (
"bufio"
"bytes"
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"io"
"mime"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"strconv"
"strings"
"sync"
"time"
"charm.land/fantasy"
@@ -40,8 +35,6 @@ import (
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/wsjson"
"github.com/coder/websocket"
)
const (
@@ -251,7 +244,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
return
}
contentBlocks, contentFileIDs, titleSource, inputError := createChatInputFromRequest(ctx, api.Database, req)
contentBlocks, titleSource, inputError := createChatInputFromRequest(req)
if inputError != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, *inputError)
return
@@ -286,7 +279,6 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
ModelConfigID: modelConfigID,
SystemPrompt: defaultChatSystemPrompt(),
InitialUserContent: contentBlocks,
ContentFileIDs: contentFileIDs,
})
if err != nil {
if database.IsForeignKeyViolation(
@@ -421,162 +413,6 @@ func (api *API) getChat(rw http.ResponseWriter, r *http.Request) {
})
}
// @Summary Watch git changes for a chat.
// @ID watch-chat-git
// @Security CoderSessionToken
// @Tags Chats
// @Param chat path string true "Chat ID" format(uuid)
// @Success 101
// @Router /chats/{chat}/git/watch [get]
//
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
//
//nolint:revive // HTTP handler writes to ResponseWriter.
func (api *API) watchChatGit(rw http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
chat = httpmw.ChatParam(r)
logger = api.Logger.Named("chat_git_watcher").With(slog.F("chat_id", chat.ID))
)
if !chat.WorkspaceID.Valid {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Chat has no workspace to watch.",
})
return
}
agents, err := api.Database.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, chat.WorkspaceID.UUID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace agents.",
Detail: err.Error(),
})
return
}
if len(agents) == 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Chat workspace has no agents.",
})
return
}
apiAgent, err := db2sdk.WorkspaceAgent(
api.DERPMap(),
*api.TailnetCoordinator.Load(),
agents[0],
nil,
nil,
nil,
api.AgentInactiveDisconnectTimeout,
api.DeploymentValues.AgentFallbackTroubleshootingURL.String(),
)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error reading workspace agent.",
Detail: err.Error(),
})
return
}
if apiAgent.Status != codersdk.WorkspaceAgentConnected {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Agent state is %q, it must be in the %q state.", apiAgent.Status, codersdk.WorkspaceAgentConnected),
})
return
}
dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second)
defer dialCancel()
agentConn, release, err := api.agentProvider.AgentConn(dialCtx, agents[0].ID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error dialing workspace agent.",
Detail: err.Error(),
})
return
}
defer release()
agentStream, err := agentConn.WatchGit(ctx, logger, chat.ID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error watching agent's git state.",
Detail: err.Error(),
})
return
}
defer agentStream.Close(websocket.StatusGoingAway)
clientConn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
CompressionMode: websocket.CompressionNoContextTakeover,
})
if err != nil {
logger.Error(ctx, "failed to accept websocket", slog.Error(err))
return
}
clientStream := wsjson.NewStream[
codersdk.WorkspaceAgentGitClientMessage,
codersdk.WorkspaceAgentGitServerMessage,
](clientConn, websocket.MessageText, websocket.MessageText, logger)
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
go httpapi.HeartbeatClose(ctx, logger, cancel, clientConn)
// Proxy agent → client.
agentCh := agentStream.Chan()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-api.ctx.Done():
return
case <-ctx.Done():
return
case msg, ok := <-agentCh:
if !ok {
cancel()
return
}
if err := clientStream.Send(msg); err != nil {
logger.Debug(ctx, "failed to forward agent message to client", slog.Error(err))
cancel()
return
}
}
}
}()
// Proxy client → agent.
clientCh := clientStream.Chan()
proxyLoop:
for {
select {
case <-api.ctx.Done():
break proxyLoop
case <-ctx.Done():
break proxyLoop
case msg, ok := <-clientCh:
if !ok {
break proxyLoop
}
if err := agentStream.Send(msg); err != nil {
logger.Debug(ctx, "failed to forward client message to agent", slog.Error(err))
break proxyLoop
}
}
}
cancel()
wg.Wait()
_ = clientStream.Close(websocket.StatusGoingAway)
}
// @Summary Archive a chat
// @ID archive-chat
// @Tags Chats
@@ -593,15 +429,7 @@ func (api *API) archiveChat(rw http.ResponseWriter, r *http.Request) {
return
}
var err error
// Use chatDaemon when available so it can notify
// active subscribers. Fall back to direct DB for the
// simple archive flag — no streaming state is involved.
if api.chatDaemon != nil {
err = api.chatDaemon.ArchiveChat(ctx, chat.ID)
} else {
err = api.Database.ArchiveChatByID(ctx, chat.ID)
}
err := api.Database.ArchiveChatByID(ctx, chat.ID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to archive chat.",
@@ -629,15 +457,7 @@ func (api *API) unarchiveChat(rw http.ResponseWriter, r *http.Request) {
return
}
var err error
// Use chatDaemon when available so it can notify
// active subscribers. Fall back to direct DB for the
// simple unarchive flag — no streaming state is involved.
if api.chatDaemon != nil {
err = api.chatDaemon.UnarchiveChat(ctx, chat.ID)
} else {
err = api.Database.UnarchiveChatByID(ctx, chat.ID)
}
err := api.Database.UnarchiveChatByID(ctx, chat.ID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to unarchive chat.",
@@ -668,7 +488,7 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
return
}
contentBlocks, contentFileIDs, _, inputError := createChatInputFromParts(ctx, api.Database, req.Content, "content")
contentBlocks, _, inputError := createChatInputFromParts(req.Content, "content")
if inputError != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: inputError.Message,
@@ -680,11 +500,10 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
sendResult, sendErr := api.chatDaemon.SendMessage(
ctx,
chatd.SendMessageOptions{
ChatID: chatID,
Content: contentBlocks,
ContentFileIDs: contentFileIDs,
ModelConfigID: req.ModelConfigID,
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
ChatID: chatID,
Content: contentBlocks,
ModelConfigID: req.ModelConfigID,
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
},
)
if sendErr != nil {
@@ -743,7 +562,7 @@ func (api *API) patchChatMessage(rw http.ResponseWriter, r *http.Request) {
return
}
contentBlocks, contentFileIDs, _, inputError := createChatInputFromParts(ctx, api.Database, req.Content, "content")
contentBlocks, _, inputError := createChatInputFromParts(req.Content, "content")
if inputError != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: inputError.Message,
@@ -756,7 +575,6 @@ func (api *API) patchChatMessage(rw http.ResponseWriter, r *http.Request) {
ChatID: chat.ID,
EditedMessageID: messageID,
Content: contentBlocks,
ContentFileIDs: contentFileIDs,
})
if editErr != nil {
switch {
@@ -871,6 +689,18 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
return
}
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to open chat stream.",
Detail: err.Error(),
})
return
}
defer func() {
<-senderClosed
}()
var afterMessageID int64
if v := r.URL.Query().Get("after_id"); v != "" {
var err error
@@ -884,31 +714,14 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
}
}
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to open chat stream.",
Detail: err.Error(),
})
return
}
snapshot, events, cancel, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
if !ok {
_ = sendEvent(codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypeError,
Data: codersdk.Response{
Message: "Chat streaming is not available.",
Detail: "Chat stream state is not configured.",
},
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Chat streaming is not available.",
Detail: "Chat stream state is not configured.",
})
// Ensure the WebSocket is closed so senderClosed
// completes and the handler can return.
<-senderClosed
return
}
defer func() {
<-senderClosed
}()
defer cancel()
sendChatStreamBatch := func(batch []codersdk.ChatStreamEvent) error {
@@ -1001,13 +814,9 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
if updateErr != nil {
api.Logger.Error(ctx, "failed to mark chat as waiting",
slog.F("chat_id", chatID), slog.Error(updateErr))
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to interrupt chat.",
Detail: updateErr.Error(),
})
return
} else {
chat = updatedChat
}
chat = updatedChat
}
httpapi.Write(ctx, rw, http.StatusOK, convertChat(chat, nil))
@@ -1052,6 +861,198 @@ func (api *API) getChatDiffContents(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusOK, diff)
}
// @Summary Get file content from a chat's linked GitHub repository
// @ID get-chat-diff-file-content
// @Security CoderSessionToken
// @Produce application/octet-stream
// @Tags Chats
// @Param chat path string true "Chat ID" format(uuid)
// @Param path query string true "Repo-relative file path"
// @Param ref query string true "Git ref (SHA or branch name)"
// @Success 200
// @Router /chats/{chat}/diff/file-content [get]
//
// getChatDiffFileContent proxies a single file's raw content from
// the chat's linked GitHub repository. The frontend uses this to
// render image diffs for binary files that cannot be shown as text.
// It resolves the repository owner/name from the chat's cached diff
// reference and streams the raw bytes back from the GitHub Contents
// API with a 10 MiB body limit.
//
//nolint:revive // HTTP handler writes to ResponseWriter.
func (api *API) getChatDiffFileContent(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
chat := httpmw.ChatParam(r)
filePath := r.URL.Query().Get("path")
if filePath == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Missing required query parameter: path.",
})
return
}
// Reject absolute paths and path traversal attempts. The path
// must be a clean relative path within the repository tree.
if strings.HasPrefix(filePath, "/") {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Query parameter 'path' must be a relative file path.",
})
return
}
if strings.Contains(filePath, "..") {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Query parameter 'path' must not contain path traversal segments.",
})
return
}
ref := r.URL.Query().Get("ref")
if ref == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Missing required query parameter: ref.",
})
return
}
owner, repo, token, err := api.resolveGitHubRepoForChat(ctx, chat)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to resolve repository for chat.",
})
return
}
if owner == "" || repo == "" {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: "Chat does not have a GitHub repository reference.",
})
return
}
contentType, body, err := api.fetchGitHubFileContent(ctx, owner, repo, filePath, ref, token)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadGateway, codersdk.Response{
Message: "Failed to fetch file content from GitHub.",
})
return
}
rw.Header().Set("Content-Type", contentType)
rw.Header().Set("Cache-Control", "private, max-age=300")
rw.WriteHeader(http.StatusOK)
_, _ = rw.Write(body)
}
// resolveGitHubRepoForChat resolves the GitHub owner, repo, and
// access token for a chat by looking up its cached diff reference.
// Returns empty owner/repo when the chat is not linked to a GitHub
// repository.
func (api *API) resolveGitHubRepoForChat(
ctx context.Context,
chat database.Chat,
) (owner, repo, token string, err error) {
status, found, err := api.getCachedChatDiffStatus(ctx, chat.ID)
if err != nil {
return "", "", "", xerrors.Errorf("get cached diff status: %w", err)
}
reference, err := api.resolveChatDiffReference(ctx, chat, found, status)
if err != nil {
return "", "", "", xerrors.Errorf("resolve diff reference: %w", err)
}
if reference.RepositoryRef == nil ||
!strings.EqualFold(reference.RepositoryRef.Provider, string(codersdk.EnhancedExternalAuthProviderGitHub)) {
return "", "", "", nil
}
token = api.resolveChatGitHubAccessToken(ctx, chat.OwnerID)
owner = reference.RepositoryRef.Owner
repo = reference.RepositoryRef.Repo
if owner == "" || repo == "" {
if reference.PullRequestURL != "" {
prRef, ok := parseGitHubPullRequestURL(reference.PullRequestURL)
if ok {
owner = prRef.Owner
repo = prRef.Repo
}
}
}
return owner, repo, token, nil
}
// fetchGitHubFileContent fetches raw file content from the GitHub
// Contents API. Each path segment is individually URL-escaped to
// prevent path traversal. The response body is limited to 10 MiB.
func (api *API) fetchGitHubFileContent(
ctx context.Context,
owner, repo, filePath, ref, token string,
) (contentType string, body []byte, err error) {
// Escape each path segment individually so directory separators
// are preserved but special characters cannot break out.
segments := strings.Split(filePath, "/")
for i, seg := range segments {
segments[i] = url.PathEscape(seg)
}
safePath := strings.Join(segments, "/")
requestURL := fmt.Sprintf(
"%s/repos/%s/%s/contents/%s?ref=%s",
githubAPIBaseURL, url.PathEscape(owner), url.PathEscape(repo), safePath, url.QueryEscape(ref),
)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
if err != nil {
return "", nil, xerrors.Errorf("create github file content request: %w", err)
}
req.Header.Set("Accept", "application/vnd.github.raw+json")
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
req.Header.Set("User-Agent", "coder-chat-diff")
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
httpClient := api.HTTPClient
if httpClient == nil {
httpClient = http.DefaultClient
}
resp, err := httpClient.Do(req)
if err != nil {
return "", nil, xerrors.Errorf("execute github file content request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
errBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192))
if readErr != nil {
return "", nil, xerrors.Errorf(
"github file content request failed with status %d",
resp.StatusCode,
)
}
return "", nil, xerrors.Errorf(
"github file content request failed with status %d: %s",
resp.StatusCode,
strings.TrimSpace(string(errBody)),
)
}
contentType = resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/octet-stream"
}
const maxFileSize = 10 << 20 // 10 MiB
body, err = io.ReadAll(io.LimitReader(resp.Body, maxFileSize))
if err != nil {
return "", nil, xerrors.Errorf("read github file content response: %w", err)
}
return contentType, body, nil
}
// chatCreateWorkspace provides workspace creation for the chat
// processor. RBAC authorization uses context-based checks via
// dbauthz.As rather than fake *http.Request objects.
@@ -2228,317 +2229,45 @@ func normalizeChatCompressionThreshold(
return threshold, nil
}
const (
// maxChatFileSize is the maximum size of a chat file upload (10 MB).
maxChatFileSize = 10 << 20
// maxChatFileName is the maximum length of an uploaded file name.
maxChatFileName = 255
)
// allowedChatFileMIMETypes lists the content types accepted for chat
// file uploads. SVG is explicitly excluded because it can contain scripts.
var allowedChatFileMIMETypes = map[string]bool{
"image/png": true,
"image/jpeg": true,
"image/gif": true,
"image/webp": true,
"image/svg+xml": false, // SVG can contain scripts.
}
var (
webpMagicRIFF = []byte("RIFF")
webpMagicWEBP = []byte("WEBP")
)
// detectChatFileType detects the MIME type of the given data.
// It extends http.DetectContentType with support for WebP, which
// Go's standard sniffer does not recognize.
func detectChatFileType(data []byte) string {
if len(data) >= 12 &&
bytes.Equal(data[0:4], webpMagicRIFF) &&
bytes.Equal(data[8:12], webpMagicWEBP) {
return "image/webp"
}
return http.DetectContentType(data)
}
func defaultChatSystemPrompt() string {
return chatd.DefaultSystemPrompt
}
// @Summary Upload a chat file
// @ID upload-chat-file
// @Security CoderSessionToken
// @Accept application/octet-stream
// @Produce json
// @Tags Chats
// @Param Content-Type header string true "Content-Type must be an image type (image/png, image/jpeg, image/gif, image/webp)"
// @Param organization query string true "Organization ID" format(uuid)
// @Success 201 {object} codersdk.UploadChatFileResponse
// @Failure 400 {object} codersdk.Response
// @Failure 401 {object} codersdk.Response
// @Failure 413 {object} codersdk.Response
// @Failure 500 {object} codersdk.Response
// @Router /chats/files [post]
func (api *API) postChatFile(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
apiKey := httpmw.APIKey(r)
if !api.Authorize(r, policy.ActionCreate, rbac.ResourceChat.WithOwner(apiKey.UserID.String())) {
httpapi.Forbidden(rw)
return
}
orgIDStr := r.URL.Query().Get("organization")
if orgIDStr == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Missing organization query parameter.",
})
return
}
orgID, err := uuid.Parse(orgIDStr)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid organization ID.",
})
return
}
contentType := r.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/octet-stream"
}
// Strip parameters (e.g. "image/png; charset=utf-8" → "image/png")
// so the allowlist check matches the base media type.
if mediaType, _, err := mime.ParseMediaType(contentType); err == nil {
contentType = mediaType
}
if allowed, ok := allowedChatFileMIMETypes[contentType]; !ok || !allowed {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Unsupported file type.",
Detail: "Allowed types: image/png, image/jpeg, image/gif, image/webp.",
})
return
}
r.Body = http.MaxBytesReader(rw, r.Body, maxChatFileSize)
br := bufio.NewReader(r.Body)
// Peek at the leading bytes to sniff the real content type
// before reading the entire body.
peek, peekErr := br.Peek(512)
if peekErr != nil && !errors.Is(peekErr, io.EOF) && !errors.Is(peekErr, bufio.ErrBufferFull) {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to read file from request.",
Detail: peekErr.Error(),
})
return
}
// Verify the actual content matches a safe image type so that
// a client cannot spoof Content-Type to serve active content.
detected := detectChatFileType(peek)
if allowed, ok := allowedChatFileMIMETypes[detected]; !ok || !allowed {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Unsupported file type.",
Detail: "Allowed types: image/png, image/jpeg, image/gif, image/webp.",
})
return
}
// Read the full body now that we know the type is valid.
data, err := io.ReadAll(br)
if err != nil {
var maxBytesErr *http.MaxBytesError
if errors.As(err, &maxBytesErr) {
httpapi.Write(ctx, rw, http.StatusRequestEntityTooLarge, codersdk.Response{
Message: "File too large.",
Detail: fmt.Sprintf("Maximum file size is %d bytes.", maxChatFileSize),
})
return
}
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to read file from request.",
Detail: err.Error(),
})
return
}
// Extract filename from Content-Disposition header if provided.
var filename string
if cd := r.Header.Get("Content-Disposition"); cd != "" {
if _, params, err := mime.ParseMediaType(cd); err == nil {
filename = params["filename"]
if len(filename) > maxChatFileName {
// Truncate at rune boundary to avoid splitting
// multi-byte UTF-8 characters.
var truncated []byte
for _, r := range filename {
encoded := []byte(string(r))
if len(truncated)+len(encoded) > maxChatFileName {
break
}
truncated = append(truncated, encoded...)
}
filename = string(truncated)
}
}
}
chatFile, err := api.Database.InsertChatFile(ctx, database.InsertChatFileParams{
OwnerID: apiKey.UserID,
OrganizationID: orgID,
Name: filename,
Mimetype: detected,
Data: data,
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to save chat file.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusCreated, codersdk.UploadChatFileResponse{
ID: chatFile.ID,
})
}
// @Summary Get a chat file
// @ID get-chat-file
// @Security CoderSessionToken
// @Tags Chats
// @Param file path string true "File ID" format(uuid)
// @Success 200
// @Failure 400 {object} codersdk.Response
// @Failure 401 {object} codersdk.Response
// @Failure 404 {object} codersdk.Response
// @Failure 500 {object} codersdk.Response
// @Router /chats/files/{file} [get]
func (api *API) chatFileByID(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
fileIDStr := chi.URLParam(r, "file")
fileID, err := uuid.Parse(fileIDStr)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid file ID.",
})
return
}
chatFile, err := api.Database.GetChatFileByID(ctx, fileID)
if err != nil {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to get chat file.",
Detail: err.Error(),
})
return
}
rw.Header().Set("Content-Type", chatFile.Mimetype)
if chatFile.Name != "" {
rw.Header().Set("Content-Disposition", mime.FormatMediaType("inline", map[string]string{"filename": chatFile.Name}))
} else {
rw.Header().Set("Content-Disposition", "inline")
}
rw.Header().Set("Cache-Control", "private, max-age=31536000, immutable")
rw.Header().Set("Content-Length", strconv.Itoa(len(chatFile.Data)))
rw.WriteHeader(http.StatusOK)
_, _ = rw.Write(chatFile.Data)
}
func createChatInputFromRequest(ctx context.Context, db database.Store, req codersdk.CreateChatRequest) (
func createChatInputFromRequest(req codersdk.CreateChatRequest) (
[]fantasy.Content,
map[int]uuid.UUID,
string,
*codersdk.Response,
) {
return createChatInputFromParts(ctx, db, req.Content, "content")
return createChatInputFromParts(req.Content, "content")
}
func createChatInputFromParts(
ctx context.Context,
db database.Store,
parts []codersdk.ChatInputPart,
fieldName string,
) ([]fantasy.Content, map[int]uuid.UUID, string, *codersdk.Response) {
) ([]fantasy.Content, string, *codersdk.Response) {
if len(parts) == 0 {
return nil, nil, "", &codersdk.Response{
return nil, "", &codersdk.Response{
Message: "Content is required.",
Detail: "Content cannot be empty.",
}
}
content := make([]fantasy.Content, 0, len(parts))
fileIDs := make(map[int]uuid.UUID)
textParts := make([]string, 0, len(parts))
for i, part := range parts {
switch strings.ToLower(strings.TrimSpace(string(part.Type))) {
case string(codersdk.ChatInputPartTypeText):
text := strings.TrimSpace(part.Text)
if text == "" {
return nil, nil, "", &codersdk.Response{
return nil, "", &codersdk.Response{
Message: "Invalid input part.",
Detail: fmt.Sprintf("%s[%d].text cannot be empty.", fieldName, i),
}
}
content = append(content, fantasy.TextContent{Text: text})
textParts = append(textParts, text)
case string(codersdk.ChatInputPartTypeFile):
if part.FileID == uuid.Nil {
return nil, nil, "", &codersdk.Response{
Message: "Invalid input part.",
Detail: fmt.Sprintf("%s[%d].file_id is required for file parts.", fieldName, i),
}
}
// Validate that the file exists and get its media type.
// File data is not loaded here; it's resolved at LLM
// dispatch time via chatFileResolver.
chatFile, err := db.GetChatFileByID(ctx, part.FileID)
if err != nil {
if httpapi.Is404Error(err) {
return nil, nil, "", &codersdk.Response{
Message: "Invalid input part.",
Detail: fmt.Sprintf("%s[%d].file_id references a file that does not exist.", fieldName, i),
}
}
return nil, nil, "", &codersdk.Response{
Message: "Internal error.",
Detail: fmt.Sprintf("Failed to retrieve file for %s[%d].", fieldName, i),
}
}
content = append(content, fantasy.FileContent{
MediaType: chatFile.Mimetype,
})
fileIDs[len(content)-1] = part.FileID
case string(codersdk.ChatInputPartTypeFileReference):
if part.FileName == "" {
return nil, nil, "", &codersdk.Response{
Message: "Invalid input part.",
Detail: fmt.Sprintf("%s[%d].file_name cannot be empty for file-reference.", fieldName, i),
}
}
lineRange := fmt.Sprintf("%d", part.StartLine)
if part.StartLine != part.EndLine {
lineRange = fmt.Sprintf("%d-%d", part.StartLine, part.EndLine)
}
var sb strings.Builder
_, _ = fmt.Fprintf(&sb, "[file-reference] %s:%s", part.FileName, lineRange)
if strings.TrimSpace(part.Content) != "" {
_, _ = fmt.Fprintf(&sb, "\n```%s\n%s\n```", part.FileName, strings.TrimSpace(part.Content))
}
text := sb.String()
content = append(content, fantasy.TextContent{Text: text})
textParts = append(textParts, text)
default:
return nil, nil, "", &codersdk.Response{
return nil, "", &codersdk.Response{
Message: "Invalid input part.",
Detail: fmt.Sprintf(
"%s[%d].type %q is not supported.",
@@ -2550,16 +2279,14 @@ func createChatInputFromParts(
}
}
// Allow file-only messages. The titleSource may be empty
// when only file parts are provided, callers handle this.
if len(content) == 0 {
return nil, nil, "", &codersdk.Response{
titleSource := strings.TrimSpace(strings.Join(textParts, " "))
if titleSource == "" {
return nil, "", &codersdk.Response{
Message: "Content is required.",
Detail: fmt.Sprintf("%s must include at least one text or file part.", fieldName),
Detail: "Content must include at least one text part.",
}
}
titleSource := strings.TrimSpace(strings.Join(textParts, " "))
return content, fileIDs, titleSource, nil
return content, titleSource, nil
}
func chatTitleFromMessage(message string) string {
+268 -883
View File
File diff suppressed because it is too large Load Diff
+8 -30
View File
@@ -99,7 +99,6 @@ import (
"github.com/coder/coder/v2/provisionersdk"
"github.com/coder/coder/v2/site"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/derpmetrics"
"github.com/coder/quartz"
"github.com/coder/serpent"
)
@@ -662,7 +661,6 @@ func New(options *Options) *API {
api.SiteHandler, err = site.New(&site.Options{
CacheDir: siteCacheDir,
Database: options.Database,
Authorizer: options.Authorizer,
SiteFS: site.FS(),
OAuth2Configs: oauthConfigs,
DocsURL: options.DeploymentValues.DocsURL.String(),
@@ -901,18 +899,17 @@ func New(options *Options) *API {
apiRateLimiter := httpmw.RateLimit(options.APIRateLimit, time.Minute)
// Register DERP on expvar HTTP handler, which we serve below in the router, c.f. expvar.Handler()
// These are the metrics the DERP server exposes.
// TODO: export via prometheus
expDERPOnce.Do(func() {
// We need to do this via a global Once because expvar registry is global and panics if we
// register multiple times. In production there is only one Coderd and one DERP server per
// process, but in testing, we create multiple of both, so the Once protects us from
// panicking.
if options.DERPServer != nil && expvar.Get("derp") == nil {
if options.DERPServer != nil {
expvar.Publish("derp", api.DERPServer.ExpVar())
}
})
if options.PrometheusRegistry != nil && options.DERPServer != nil {
options.PrometheusRegistry.MustRegister(derpmetrics.NewDERPExpvarCollector(options.DERPServer))
}
cors := httpmw.Cors(options.DeploymentValues.Dangerous.AllowAllCors.Value())
prometheusMW := httpmw.Prometheus(options.PrometheusRegistry)
@@ -927,16 +924,6 @@ func New(options *Options) *API {
loggermw.Logger(api.Logger),
singleSlashMW,
rolestore.CustomRoleMW,
// Validate API key on every request (if present) and store
// the result in context. The rate limiter reads this to key
// by user ID, and downstream ExtractAPIKeyMW reuses it to
// avoid redundant DB lookups. Never rejects requests.
httpmw.PrecheckAPIKey(httpmw.ValidateAPIKeyConfig{
DB: options.Database,
OAuth2Configs: oauthConfigs,
DisableSessionExpiryRefresh: options.DeploymentValues.Sessions.DisableExpiryRefresh.Value(),
Logger: options.Logger,
}),
httpmw.HTTPRoute, // NB: prometheusMW depends on this middleware.
prometheusMW,
// Build-Version is helpful for debugging.
@@ -1085,6 +1072,8 @@ func New(options *Options) *API {
r.NotFound(func(rw http.ResponseWriter, _ *http.Request) { httpapi.RouteNotFound(rw) })
r.Use(
// Specific routes can specify different limits, but every rate
// limit must be configurable by the admin.
apiRateLimiter,
httpmw.ReportCLITelemetry(api.Logger, options.Telemetry),
)
@@ -1122,11 +1111,6 @@ func New(options *Options) *API {
r.Post("/", api.postChats)
r.Get("/models", api.listChatModels)
r.Get("/watch", api.watchChats)
r.Route("/files", func(r chi.Router) {
r.Use(httpmw.RateLimit(options.FilesRateLimit, time.Minute))
r.Post("/", api.postChatFile)
r.Get("/{file}", api.chatFileByID)
})
r.Route("/providers", func(r chi.Router) {
r.Get("/", api.listChatProviders)
r.Post("/", api.createChatProvider)
@@ -1146,7 +1130,6 @@ func New(options *Options) *API {
r.Route("/{chat}", func(r chi.Router) {
r.Use(httpmw.ExtractChatParam(options.Database))
r.Get("/", api.getChat)
r.Get("/git/watch", api.watchChatGit)
r.Post("/archive", api.archiveChat)
r.Post("/unarchive", api.unarchiveChat)
r.Post("/messages", api.postChatMessages)
@@ -1155,6 +1138,7 @@ func New(options *Options) *API {
r.Post("/interrupt", api.interruptChat)
r.Get("/diff-status", api.getChatDiffStatus)
r.Get("/diff", api.getChatDiffContents)
r.Get("/diff/file-content", api.getChatDiffFileContent)
r.Route("/queue/{queuedMessage}", func(r chi.Router) {
r.Delete("/", api.deleteChatQueuedMessage)
r.Post("/promote", api.promoteChatQueuedMessage)
@@ -1177,6 +1161,8 @@ func New(options *Options) *API {
r.NotFound(func(rw http.ResponseWriter, _ *http.Request) { httpapi.RouteNotFound(rw) })
r.Use(
// Specific routes can specify different limits, but every rate
// limit must be configurable by the admin.
apiRateLimiter,
httpmw.ReportCLITelemetry(api.Logger, options.Telemetry),
)
@@ -1854,14 +1840,6 @@ func New(options *Options) *API {
"parsing additional CSP headers", slog.Error(cspParseErrors))
}
// Add blob: to img-src for chat file attachment previews when
// the agents experiment is enabled.
if api.Experiments.Enabled(codersdk.ExperimentAgents) {
additionalCSPHeaders[httpmw.CSPDirectiveImgSrc] = append(
additionalCSPHeaders[httpmw.CSPDirectiveImgSrc], "blob:",
)
}
// Add CSP headers to all static assets and pages. CSP headers only affect
// browsers, so these don't make sense on api routes.
cspMW := httpmw.CSPHeaders(
-114
View File
@@ -390,117 +390,3 @@ func TestCSRFExempt(t *testing.T) {
require.NotContains(t, string(data), "CSRF")
})
}
func TestDERPMetrics(t *testing.T) {
t.Parallel()
_, _, api := coderdtest.NewWithAPI(t, nil)
require.NotNil(t, api.Options.DERPServer, "DERP server should be configured")
require.NotNil(t, api.Options.PrometheusRegistry, "Prometheus registry should be configured")
// The registry is created internally by coderd. Gather from it
// to verify DERP metrics were registered during startup.
metrics, err := api.Options.PrometheusRegistry.Gather()
require.NoError(t, err)
names := make(map[string]struct{})
for _, m := range metrics {
names[m.GetName()] = struct{}{}
}
assert.Contains(t, names, "coder_derp_server_connections",
"expected coder_derp_server_connections to be registered")
assert.Contains(t, names, "coder_derp_server_bytes_received_total",
"expected coder_derp_server_bytes_received_total to be registered")
assert.Contains(t, names, "coder_derp_server_packets_dropped_reason_total",
"expected coder_derp_server_packets_dropped_reason_total to be registered")
}
// TestRateLimitByUser verifies that rate limiting keys by user ID when
// an authenticated session is present, rather than falling back to IP.
// This is a regression test for https://github.com/coder/coder/issues/20857
func TestRateLimitByUser(t *testing.T) {
t.Parallel()
const rateLimit = 5
ownerClient := coderdtest.New(t, &coderdtest.Options{
APIRateLimit: rateLimit,
})
firstUser := coderdtest.CreateFirstUser(t, ownerClient)
t.Run("HitsLimit", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
// Make rateLimit requests — they should all succeed.
for i := 0; i < rateLimit; i++ {
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
ownerClient.URL.String()+"/api/v2/buildinfo", nil)
require.NoError(t, err)
req.Header.Set(codersdk.SessionTokenHeader, ownerClient.SessionToken())
resp, err := ownerClient.HTTPClient.Do(req)
require.NoError(t, err)
resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode,
"request %d should succeed", i+1)
}
// The next request should be rate-limited.
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
ownerClient.URL.String()+"/api/v2/buildinfo", nil)
require.NoError(t, err)
req.Header.Set(codersdk.SessionTokenHeader, ownerClient.SessionToken())
resp, err := ownerClient.HTTPClient.Do(req)
require.NoError(t, err)
resp.Body.Close()
require.Equal(t, http.StatusTooManyRequests, resp.StatusCode,
"request should be rate limited")
})
t.Run("BypassOwner", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
// Owner with bypass header should not be rate-limited.
for i := 0; i < rateLimit+5; i++ {
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
ownerClient.URL.String()+"/api/v2/buildinfo", nil)
require.NoError(t, err)
req.Header.Set(codersdk.SessionTokenHeader, ownerClient.SessionToken())
req.Header.Set(codersdk.BypassRatelimitHeader, "true")
resp, err := ownerClient.HTTPClient.Do(req)
require.NoError(t, err)
resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode,
"owner bypass request %d should succeed", i+1)
}
})
t.Run("MemberCannotBypass", func(t *testing.T) {
t.Parallel()
memberClient, _ := coderdtest.CreateAnotherUser(t, ownerClient, firstUser.OrganizationID)
ctx := testutil.Context(t, testutil.WaitLong)
// A member requesting the bypass header should be rejected
// with 428 Precondition Required — only owners may bypass.
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
memberClient.URL.String()+"/api/v2/buildinfo", nil)
require.NoError(t, err)
req.Header.Set(codersdk.SessionTokenHeader, memberClient.SessionToken())
req.Header.Set(codersdk.BypassRatelimitHeader, "true")
resp, err := memberClient.HTTPClient.Do(req)
require.NoError(t, err)
resp.Body.Close()
require.Equal(t, http.StatusPreconditionRequired, resp.StatusCode,
"member should not be able to bypass rate limit")
})
}
+2 -2
View File
@@ -17,9 +17,9 @@ const (
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces
CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces
CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces
)
+6 -14
View File
@@ -1156,7 +1156,9 @@ func chatMessageParts(role string, raw pqtype.NullRawMessage) ([]codersdk.ChatMe
}
var rawBlocks []json.RawMessage
_ = json.Unmarshal(raw.RawMessage, &rawBlocks)
if role == string(fantasy.MessageRoleAssistant) {
_ = json.Unmarshal(raw.RawMessage, &rawBlocks)
}
parts := make([]codersdk.ChatMessagePart, 0, len(content))
for i, block := range content {
@@ -1164,20 +1166,10 @@ func chatMessageParts(role string, raw pqtype.NullRawMessage) ([]codersdk.ChatMe
if part.Type == "" {
continue
}
if i < len(rawBlocks) {
switch part.Type {
case codersdk.ChatMessagePartTypeReasoning:
if part.Type == codersdk.ChatMessagePartTypeReasoning {
part.Title = ""
if i < len(rawBlocks) {
part.Title = reasoningStoredTitle(rawBlocks[i])
case codersdk.ChatMessagePartTypeFile:
if fid, err := chatprompt.ExtractFileID(rawBlocks[i]); err == nil {
part.FileID = uuid.NullUUID{UUID: fid, Valid: true}
}
// When a file_id is present, omit inline data
// from the response. Clients fetch content via
// the GET /chats/files/{id} endpoint instead.
if part.FileID.Valid {
part.Data = nil
}
}
}
parts = append(parts, part)
+12 -31
View File
@@ -2457,30 +2457,6 @@ func (q *querier) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIDs []uu
return q.db.GetChatDiffStatusesByChatIDs(ctx, chatIDs)
}
func (q *querier) GetChatFileByID(ctx context.Context, id uuid.UUID) (database.ChatFile, error) {
file, err := q.db.GetChatFileByID(ctx, id)
if err != nil {
return database.ChatFile{}, err
}
if err := q.authorizeContext(ctx, policy.ActionRead, file); err != nil {
return database.ChatFile{}, err
}
return file, nil
}
func (q *querier) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) {
files, err := q.db.GetChatFilesByIDs(ctx, ids)
if err != nil {
return nil, err
}
for _, f := range files {
if err := q.authorizeContext(ctx, policy.ActionRead, f); err != nil {
return nil, err
}
}
return files, nil
}
func (q *querier) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) {
// ChatMessages are authorized through their parent Chat.
// We need to fetch the message first to get its chat_id.
@@ -3433,7 +3409,12 @@ func (q *querier) GetTaskSnapshot(ctx context.Context, taskID uuid.UUID) (databa
return database.TaskSnapshot{}, err
}
if err := q.authorizeContext(ctx, policy.ActionRead, task.RBACObject()); err != nil {
obj := rbac.ResourceTask.
WithID(task.ID).
WithOwner(task.OwnerID.String()).
InOrg(task.OrganizationID)
if err := q.authorizeContext(ctx, policy.ActionRead, obj); err != nil {
return database.TaskSnapshot{}, err
}
@@ -4515,11 +4496,6 @@ func (q *querier) InsertChat(ctx context.Context, arg database.InsertChatParams)
return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()), q.db.InsertChat)(ctx, arg)
}
func (q *querier) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) {
// Authorize create on chat resource scoped to the owner and org.
return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID), q.db.InsertChatFile)(ctx, arg)
}
func (q *querier) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
// Authorize create on the parent chat (using update permission).
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
@@ -6659,7 +6635,12 @@ func (q *querier) UpsertTaskSnapshot(ctx context.Context, arg database.UpsertTas
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, task.RBACObject()); err != nil {
obj := rbac.ResourceTask.
WithID(task.ID).
WithOwner(task.OwnerID.String()).
InOrg(task.OrganizationID)
if err := q.authorizeContext(ctx, policy.ActionUpdate, obj); err != nil {
return err
}
-16
View File
@@ -463,16 +463,6 @@ func (s *MethodTestSuite) TestChats() {
Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).
Returns([]database.ChatDiffStatus{diffStatusA, diffStatusB})
}))
s.Run("GetChatFileByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
file := testutil.Fake(s.T(), faker, database.ChatFile{})
dbm.EXPECT().GetChatFileByID(gomock.Any(), file.ID).Return(file, nil).AnyTimes()
check.Args(file.ID).Asserts(rbac.ResourceChat.WithOwner(file.OwnerID.String()).InOrg(file.OrganizationID).WithID(file.ID), policy.ActionRead).Returns(file)
}))
s.Run("GetChatFilesByIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
file := testutil.Fake(s.T(), faker, database.ChatFile{})
dbm.EXPECT().GetChatFilesByIDs(gomock.Any(), []uuid.UUID{file.ID}).Return([]database.ChatFile{file}, nil).AnyTimes()
check.Args([]uuid.UUID{file.ID}).Asserts(rbac.ResourceChat.WithOwner(file.OwnerID.String()).InOrg(file.OrganizationID).WithID(file.ID), policy.ActionRead).Returns([]database.ChatFile{file})
}))
s.Run("GetChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
@@ -589,12 +579,6 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().InsertChat(gomock.Any(), arg).Return(chat, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionCreate).Returns(chat)
}))
s.Run("InsertChatFile", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := testutil.Fake(s.T(), faker, database.InsertChatFileParams{})
file := testutil.Fake(s.T(), faker, database.InsertChatFileRow{OwnerID: arg.OwnerID, OrganizationID: arg.OrganizationID})
dbm.EXPECT().InsertChatFile(gomock.Any(), arg).Return(file, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID), policy.ActionCreate).Returns(file)
}))
s.Run("InsertChatMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := testutil.Fake(s.T(), faker, database.InsertChatMessageParams{ChatID: chat.ID})
-1
View File
@@ -1595,7 +1595,6 @@ func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertA
Client: seed.Client,
ThreadParentInterceptionID: seed.ThreadParentInterceptionID,
ThreadRootInterceptionID: seed.ThreadRootInterceptionID,
ClientSessionID: seed.ClientSessionID,
})
if endedAt != nil {
interception, err = db.UpdateAIBridgeInterceptionEnded(genCtx, database.UpdateAIBridgeInterceptionEndedParams{
-24
View File
@@ -1007,22 +1007,6 @@ func (m queryMetricsStore) GetChatDiffStatusesByChatIDs(ctx context.Context, cha
return r0, r1
}
func (m queryMetricsStore) GetChatFileByID(ctx context.Context, id uuid.UUID) (database.ChatFile, error) {
start := time.Now()
r0, r1 := m.s.GetChatFileByID(ctx, id)
m.queryLatencies.WithLabelValues("GetChatFileByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatFileByID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) {
start := time.Now()
r0, r1 := m.s.GetChatFilesByIDs(ctx, ids)
m.queryLatencies.WithLabelValues("GetChatFilesByIDs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatFilesByIDs").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) {
start := time.Now()
r0, r1 := m.s.GetChatMessageByID(ctx, id)
@@ -2959,14 +2943,6 @@ func (m queryMetricsStore) InsertChat(ctx context.Context, arg database.InsertCh
return r0, r1
}
func (m queryMetricsStore) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) {
start := time.Now()
r0, r1 := m.s.InsertChatFile(ctx, arg)
m.queryLatencies.WithLabelValues("InsertChatFile").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatFile").Inc()
return r0, r1
}
func (m queryMetricsStore) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
start := time.Now()
r0, r1 := m.s.InsertChatMessage(ctx, arg)
-45
View File
@@ -1837,36 +1837,6 @@ func (mr *MockStoreMockRecorder) GetChatDiffStatusesByChatIDs(ctx, chatIds any)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDiffStatusesByChatIDs", reflect.TypeOf((*MockStore)(nil).GetChatDiffStatusesByChatIDs), ctx, chatIds)
}
// GetChatFileByID mocks base method.
func (m *MockStore) GetChatFileByID(ctx context.Context, id uuid.UUID) (database.ChatFile, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatFileByID", ctx, id)
ret0, _ := ret[0].(database.ChatFile)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatFileByID indicates an expected call of GetChatFileByID.
func (mr *MockStoreMockRecorder) GetChatFileByID(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatFileByID", reflect.TypeOf((*MockStore)(nil).GetChatFileByID), ctx, id)
}
// GetChatFilesByIDs mocks base method.
func (m *MockStore) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatFilesByIDs", ctx, ids)
ret0, _ := ret[0].([]database.ChatFile)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatFilesByIDs indicates an expected call of GetChatFilesByIDs.
func (mr *MockStoreMockRecorder) GetChatFilesByIDs(ctx, ids any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatFilesByIDs", reflect.TypeOf((*MockStore)(nil).GetChatFilesByIDs), ctx, ids)
}
// GetChatMessageByID mocks base method.
func (m *MockStore) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) {
m.ctrl.T.Helper()
@@ -5541,21 +5511,6 @@ func (mr *MockStoreMockRecorder) InsertChat(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChat", reflect.TypeOf((*MockStore)(nil).InsertChat), ctx, arg)
}
// InsertChatFile mocks base method.
func (m *MockStore) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertChatFile", ctx, arg)
ret0, _ := ret[0].(database.InsertChatFileRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertChatFile indicates an expected call of InsertChatFile.
func (mr *MockStoreMockRecorder) InsertChatFile(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatFile", reflect.TypeOf((*MockStore)(nil).InsertChatFile), ctx, arg)
}
// InsertChatMessage mocks base method.
func (m *MockStore) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
m.ctrl.T.Helper()
+27 -58
View File
@@ -1046,8 +1046,7 @@ CREATE TABLE aibridge_interceptions (
api_key_id text,
client character varying(64) DEFAULT 'Unknown'::character varying,
thread_parent_id uuid,
thread_root_id uuid,
client_session_id character varying(256)
thread_root_id uuid
);
COMMENT ON TABLE aibridge_interceptions IS 'Audit log of requests intercepted by AI Bridge';
@@ -1058,8 +1057,6 @@ COMMENT ON COLUMN aibridge_interceptions.thread_parent_id IS 'The interception w
COMMENT ON COLUMN aibridge_interceptions.thread_root_id IS 'The root interception of the thread that this interception belongs to.';
COMMENT ON COLUMN aibridge_interceptions.client_session_id IS 'The session ID supplied by the client (optional and not universally supported).';
CREATE TABLE aibridge_token_usages (
id uuid NOT NULL,
interception_id uuid NOT NULL,
@@ -1190,16 +1187,6 @@ CREATE TABLE chat_diff_statuses (
git_remote_origin text DEFAULT ''::text NOT NULL
);
CREATE TABLE chat_files (
id uuid DEFAULT gen_random_uuid() NOT NULL,
owner_id uuid NOT NULL,
organization_id uuid NOT NULL,
created_at timestamp with time zone DEFAULT now() NOT NULL,
name text DEFAULT ''::text NOT NULL,
mimetype text NOT NULL,
data bytea NOT NULL
);
CREATE TABLE chat_messages (
id bigint NOT NULL,
chat_id uuid NOT NULL,
@@ -2107,31 +2094,6 @@ CREATE TABLE workspace_builds (
CONSTRAINT workspace_builds_deadline_below_max_deadline CHECK ((((deadline <> '0001-01-01 00:00:00+00'::timestamp with time zone) AND (deadline <= max_deadline)) OR (max_deadline = '0001-01-01 00:00:00+00'::timestamp with time zone)))
);
CREATE TABLE workspaces (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone NOT NULL,
owner_id uuid NOT NULL,
organization_id uuid NOT NULL,
template_id uuid NOT NULL,
deleted boolean DEFAULT false NOT NULL,
name character varying(64) NOT NULL,
autostart_schedule text,
ttl bigint,
last_used_at timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL,
dormant_at timestamp with time zone,
deleting_at timestamp with time zone,
automatic_updates automatic_updates DEFAULT 'never'::automatic_updates NOT NULL,
favorite boolean DEFAULT false NOT NULL,
next_start_at timestamp with time zone,
group_acl jsonb DEFAULT '{}'::jsonb NOT NULL,
user_acl jsonb DEFAULT '{}'::jsonb NOT NULL,
CONSTRAINT group_acl_is_object CHECK ((jsonb_typeof(group_acl) = 'object'::text)),
CONSTRAINT user_acl_is_object CHECK ((jsonb_typeof(user_acl) = 'object'::text))
);
COMMENT ON COLUMN workspaces.favorite IS 'Favorite is true if the workspace owner has favorited the workspace.';
CREATE VIEW tasks_with_status AS
SELECT tasks.id,
tasks.organization_id,
@@ -2144,8 +2106,6 @@ CREATE VIEW tasks_with_status AS
tasks.created_at,
tasks.deleted_at,
tasks.display_name,
COALESCE(workspaces.group_acl, '{}'::jsonb) AS workspace_group_acl,
COALESCE(workspaces.user_acl, '{}'::jsonb) AS workspace_user_acl,
CASE
WHEN (tasks.workspace_id IS NULL) THEN 'pending'::task_status
WHEN (build_status.status <> 'active'::task_status) THEN build_status.status
@@ -2161,8 +2121,7 @@ CREATE VIEW tasks_with_status AS
task_owner.owner_username,
task_owner.owner_name,
task_owner.owner_avatar_url
FROM (((((((((tasks
LEFT JOIN workspaces ON ((workspaces.id = tasks.workspace_id)))
FROM ((((((((tasks
CROSS JOIN LATERAL ( SELECT vu.username AS owner_username,
vu.name AS owner_name,
vu.avatar_url AS owner_avatar_url
@@ -2905,6 +2864,31 @@ CREATE VIEW workspace_build_with_user AS
COMMENT ON VIEW workspace_build_with_user IS 'Joins in the username + avatar url of the initiated by user.';
CREATE TABLE workspaces (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone NOT NULL,
owner_id uuid NOT NULL,
organization_id uuid NOT NULL,
template_id uuid NOT NULL,
deleted boolean DEFAULT false NOT NULL,
name character varying(64) NOT NULL,
autostart_schedule text,
ttl bigint,
last_used_at timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL,
dormant_at timestamp with time zone,
deleting_at timestamp with time zone,
automatic_updates automatic_updates DEFAULT 'never'::automatic_updates NOT NULL,
favorite boolean DEFAULT false NOT NULL,
next_start_at timestamp with time zone,
group_acl jsonb DEFAULT '{}'::jsonb NOT NULL,
user_acl jsonb DEFAULT '{}'::jsonb NOT NULL,
CONSTRAINT group_acl_is_object CHECK ((jsonb_typeof(group_acl) = 'object'::text)),
CONSTRAINT user_acl_is_object CHECK ((jsonb_typeof(user_acl) = 'object'::text))
);
COMMENT ON COLUMN workspaces.favorite IS 'Favorite is true if the workspace owner has favorited the workspace.';
CREATE VIEW workspace_latest_builds AS
SELECT latest_build.id,
latest_build.workspace_id,
@@ -3150,9 +3134,6 @@ ALTER TABLE ONLY boundary_usage_stats
ALTER TABLE ONLY chat_diff_statuses
ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
ALTER TABLE ONLY chat_files
ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id);
ALTER TABLE ONLY chat_messages
ADD CONSTRAINT chat_messages_pkey PRIMARY KEY (id);
@@ -3466,8 +3447,6 @@ CREATE INDEX idx_agent_stats_user_id ON workspace_agent_stats USING btree (user_
CREATE INDEX idx_aibridge_interceptions_client ON aibridge_interceptions USING btree (client);
CREATE INDEX idx_aibridge_interceptions_client_session_id ON aibridge_interceptions USING btree (client_session_id) WHERE (client_session_id IS NOT NULL);
CREATE INDEX idx_aibridge_interceptions_initiator_id ON aibridge_interceptions USING btree (initiator_id);
CREATE INDEX idx_aibridge_interceptions_model ON aibridge_interceptions USING btree (model);
@@ -3508,10 +3487,6 @@ CREATE INDEX idx_audit_logs_time_desc ON audit_logs USING btree ("time" DESC);
CREATE INDEX idx_chat_diff_statuses_stale_at ON chat_diff_statuses USING btree (stale_at);
CREATE INDEX idx_chat_files_org ON chat_files USING btree (organization_id);
CREATE INDEX idx_chat_files_owner ON chat_files USING btree (owner_id);
CREATE INDEX idx_chat_messages_chat ON chat_messages USING btree (chat_id);
CREATE INDEX idx_chat_messages_chat_created ON chat_messages USING btree (chat_id, created_at);
@@ -3791,12 +3766,6 @@ ALTER TABLE ONLY api_keys
ALTER TABLE ONLY chat_diff_statuses
ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ALTER TABLE ONLY chat_files
ADD CONSTRAINT chat_files_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
ALTER TABLE ONLY chat_files
ADD CONSTRAINT chat_files_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE;
ALTER TABLE ONLY chat_messages
ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
@@ -9,8 +9,6 @@ const (
ForeignKeyAibridgeInterceptionsInitiatorID ForeignKeyConstraint = "aibridge_interceptions_initiator_id_fkey" // ALTER TABLE ONLY aibridge_interceptions ADD CONSTRAINT aibridge_interceptions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id);
ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyChatDiffStatusesChatID ForeignKeyConstraint = "chat_diff_statuses_chat_id_fkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ForeignKeyChatFilesOrganizationID ForeignKeyConstraint = "chat_files_organization_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
ForeignKeyChatFilesOwnerID ForeignKeyConstraint = "chat_files_owner_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyChatMessagesChatID ForeignKeyConstraint = "chat_messages_chat_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ForeignKeyChatMessagesModelConfigID ForeignKeyConstraint = "chat_messages_model_config_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_model_config_id_fkey FOREIGN KEY (model_config_id) REFERENCES chat_model_configs(id);
ForeignKeyChatModelConfigsCreatedBy ForeignKeyConstraint = "chat_model_configs_created_by_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id);
+16 -22
View File
@@ -22,12 +22,8 @@ SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}")
# The logic below depends on the exact version being correct :(
sqlc generate
# Work directory for formatting before atomic replacement of
# generated files, ensuring the source tree is never left in a
# partially written state.
mkdir -p ../../_gen
workdir=$(mktemp -d ../../_gen/.dbgen.XXXXXX)
trap 'rm -rf "$workdir"' EXIT
tmpfile=$(mktemp "${TMPDIR:-/tmp}/queries.sql.go.XXXXXX")
trap 'rm -f "$tmpfile"' EXIT
first=true
files=$(find ./queries/ -type f -name "*.sql.go" | LC_ALL=C sort)
@@ -42,34 +38,32 @@ SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}")
# Copy the header from the first file only, ignoring the source comment.
if $first; then
head -n 6 <"$fi" | grep -v "source" >"$workdir/queries.sql.go"
head -n 6 <"$fi" | grep -v "source" >"$tmpfile"
first=false
fi
# Append the file past the imports section into queries.sql.go.
tail -n "+$cut" <"$fi" >>"$workdir/queries.sql.go"
tail -n "+$cut" <"$fi" >>"$tmpfile"
done
# Move sqlc outputs into workdir for formatting.
mv queries/querier.go "$workdir/querier.go"
mv queries/models.go "$workdir/models.go"
# Atomically replace the target file.
mv "$tmpfile" queries.sql.go
# Move the files we want.
mv queries/querier.go .
mv queries/models.go .
# Remove temporary go files.
rm -f queries/*.go
# Fix struct/interface names in the workdir (not the source tree).
gofmt -w -r 'Querier -> sqlcQuerier' -- "$workdir"/*.go
gofmt -w -r 'Queries -> sqlQuerier' -- "$workdir"/*.go
# Fix struct/interface names.
gofmt -w -r 'Querier -> sqlcQuerier' -- *.go
gofmt -w -r 'Queries -> sqlQuerier' -- *.go
# Ensure correct imports exist. Modules must all be downloaded so we
# get correct suggestions.
# Ensure correct imports exist. Modules must all be downloaded so we get correct
# suggestions.
go mod download
go tool golang.org/x/tools/cmd/goimports -w "$workdir/queries.sql.go"
# Atomically replace all three target files.
mv "$workdir/queries.sql.go" queries.sql.go
mv "$workdir/querier.go" querier.go
mv "$workdir/models.go" models.go
go tool golang.org/x/tools/cmd/goimports -w queries.sql.go
go run ../../scripts/dbgen
# This will error if a view is broken. This is in it's own package to avoid
@@ -1,145 +0,0 @@
-- Fix task status logic: pending provisioner job should give pending task status, not initializing.
-- A task is pending when the provisioner hasn't picked up the job yet.
-- A task is initializing when the provisioner is actively running the job.
DROP VIEW IF EXISTS tasks_with_status;
CREATE VIEW
tasks_with_status
AS
SELECT
tasks.*,
-- Combine component statuses with precedence: build -> agent -> app.
CASE
WHEN tasks.workspace_id IS NULL THEN 'pending'::task_status
WHEN build_status.status != 'active' THEN build_status.status::task_status
WHEN agent_status.status != 'active' THEN agent_status.status::task_status
ELSE app_status.status::task_status
END AS status,
-- Attach debug information for troubleshooting status.
jsonb_build_object(
'build', jsonb_build_object(
'transition', latest_build_raw.transition,
'job_status', latest_build_raw.job_status,
'computed', build_status.status
),
'agent', jsonb_build_object(
'lifecycle_state', agent_raw.lifecycle_state,
'computed', agent_status.status
),
'app', jsonb_build_object(
'health', app_raw.health,
'computed', app_status.status
)
) AS status_debug,
task_app.*,
agent_raw.lifecycle_state AS workspace_agent_lifecycle_state,
app_raw.health AS workspace_app_health,
task_owner.*
FROM
tasks
CROSS JOIN LATERAL (
SELECT
vu.username AS owner_username,
vu.name AS owner_name,
vu.avatar_url AS owner_avatar_url
FROM
visible_users vu
WHERE
vu.id = tasks.owner_id
) task_owner
LEFT JOIN LATERAL (
SELECT
task_app.workspace_build_number,
task_app.workspace_agent_id,
task_app.workspace_app_id
FROM
task_workspace_apps task_app
WHERE
task_id = tasks.id
ORDER BY
task_app.workspace_build_number DESC
LIMIT 1
) task_app ON TRUE
-- Join the raw data for computing task status.
LEFT JOIN LATERAL (
SELECT
workspace_build.transition,
provisioner_job.job_status,
workspace_build.job_id
FROM
workspace_builds workspace_build
JOIN
provisioner_jobs provisioner_job
ON provisioner_job.id = workspace_build.job_id
WHERE
workspace_build.workspace_id = tasks.workspace_id
AND workspace_build.build_number = task_app.workspace_build_number
) latest_build_raw ON TRUE
LEFT JOIN LATERAL (
SELECT
workspace_agent.lifecycle_state
FROM
workspace_agents workspace_agent
WHERE
workspace_agent.id = task_app.workspace_agent_id
) agent_raw ON TRUE
LEFT JOIN LATERAL (
SELECT
workspace_app.health
FROM
workspace_apps workspace_app
WHERE
workspace_app.id = task_app.workspace_app_id
) app_raw ON TRUE
-- Compute the status for each component.
CROSS JOIN LATERAL (
SELECT
CASE
WHEN latest_build_raw.job_status IS NULL THEN 'pending'::task_status
WHEN latest_build_raw.job_status IN ('failed', 'canceling', 'canceled') THEN 'error'::task_status
WHEN
latest_build_raw.transition IN ('stop', 'delete')
AND latest_build_raw.job_status = 'succeeded' THEN 'paused'::task_status
-- Job is pending (not picked up by provisioner yet).
WHEN
latest_build_raw.transition = 'start'
AND latest_build_raw.job_status = 'pending' THEN 'pending'::task_status
-- Job is running or done, defer to agent/app status.
WHEN
latest_build_raw.transition = 'start'
AND latest_build_raw.job_status IN ('running', 'succeeded') THEN 'active'::task_status
ELSE 'unknown'::task_status
END AS status
) build_status
CROSS JOIN LATERAL (
SELECT
CASE
-- No agent or connecting.
WHEN
agent_raw.lifecycle_state IS NULL
OR agent_raw.lifecycle_state IN ('created', 'starting') THEN 'initializing'::task_status
-- Agent is running, defer to app status.
-- NOTE(mafredri): The start_error/start_timeout states means connected, but some startup script failed.
-- This may or may not affect the task status but this has to be caught by app health check.
WHEN agent_raw.lifecycle_state IN ('ready', 'start_timeout', 'start_error') THEN 'active'::task_status
-- If the agent is shutting down or turned off, this is an unknown state because we would expect a stop
-- build to be running.
-- This is essentially equal to: `IN ('shutting_down', 'shutdown_timeout', 'shutdown_error', 'off')`,
-- but we cannot use them because the values were added in a migration.
WHEN agent_raw.lifecycle_state NOT IN ('created', 'starting', 'ready', 'start_timeout', 'start_error') THEN 'unknown'::task_status
ELSE 'unknown'::task_status
END AS status
) agent_status
CROSS JOIN LATERAL (
SELECT
CASE
WHEN app_raw.health = 'initializing' THEN 'initializing'::task_status
WHEN app_raw.health = 'unhealthy' THEN 'error'::task_status
WHEN app_raw.health IN ('healthy', 'disabled') THEN 'active'::task_status
ELSE 'unknown'::task_status
END AS status
) app_status
WHERE
tasks.deleted_at IS NULL;
@@ -1,151 +0,0 @@
-- Fix task status logic: pending provisioner job should give pending task status, not initializing.
-- A task is pending when the provisioner hasn't picked up the job yet.
-- A task is initializing when the provisioner is actively running the job.
DROP VIEW IF EXISTS tasks_with_status;
CREATE VIEW
tasks_with_status
AS
SELECT
tasks.*,
coalesce(workspaces.group_acl, '{}'::jsonb) as workspace_group_acl,
coalesce(workspaces.user_acl, '{}'::jsonb) as workspace_user_acl,
-- Combine component statuses with precedence: build -> agent -> app.
CASE
WHEN tasks.workspace_id IS NULL THEN 'pending'::task_status
WHEN build_status.status != 'active' THEN build_status.status::task_status
WHEN agent_status.status != 'active' THEN agent_status.status::task_status
ELSE app_status.status::task_status
END AS status,
-- Attach debug information for troubleshooting status.
jsonb_build_object(
'build', jsonb_build_object(
'transition', latest_build_raw.transition,
'job_status', latest_build_raw.job_status,
'computed', build_status.status
),
'agent', jsonb_build_object(
'lifecycle_state', agent_raw.lifecycle_state,
'computed', agent_status.status
),
'app', jsonb_build_object(
'health', app_raw.health,
'computed', app_status.status
)
) AS status_debug,
task_app.*,
agent_raw.lifecycle_state AS workspace_agent_lifecycle_state,
app_raw.health AS workspace_app_health,
task_owner.*
FROM
tasks
LEFT JOIN
workspaces ON workspaces.id = tasks.workspace_id
CROSS JOIN LATERAL (
SELECT
vu.username AS owner_username,
vu.name AS owner_name,
vu.avatar_url AS owner_avatar_url
FROM
visible_users vu
WHERE
vu.id = tasks.owner_id
) task_owner
LEFT JOIN LATERAL (
SELECT
task_app.workspace_build_number,
task_app.workspace_agent_id,
task_app.workspace_app_id
FROM
task_workspace_apps task_app
WHERE
task_id = tasks.id
ORDER BY
task_app.workspace_build_number DESC
LIMIT 1
) task_app ON TRUE
-- Join the raw data for computing task status.
LEFT JOIN LATERAL (
SELECT
workspace_build.transition,
provisioner_job.job_status,
workspace_build.job_id
FROM
workspace_builds workspace_build
JOIN
provisioner_jobs provisioner_job
ON provisioner_job.id = workspace_build.job_id
WHERE
workspace_build.workspace_id = tasks.workspace_id
AND workspace_build.build_number = task_app.workspace_build_number
) latest_build_raw ON TRUE
LEFT JOIN LATERAL (
SELECT
workspace_agent.lifecycle_state
FROM
workspace_agents workspace_agent
WHERE
workspace_agent.id = task_app.workspace_agent_id
) agent_raw ON TRUE
LEFT JOIN LATERAL (
SELECT
workspace_app.health
FROM
workspace_apps workspace_app
WHERE
workspace_app.id = task_app.workspace_app_id
) app_raw ON TRUE
-- Compute the status for each component.
CROSS JOIN LATERAL (
SELECT
CASE
WHEN latest_build_raw.job_status IS NULL THEN 'pending'::task_status
WHEN latest_build_raw.job_status IN ('failed', 'canceling', 'canceled') THEN 'error'::task_status
WHEN
latest_build_raw.transition IN ('stop', 'delete')
AND latest_build_raw.job_status = 'succeeded' THEN 'paused'::task_status
-- Job is pending (not picked up by provisioner yet).
WHEN
latest_build_raw.transition = 'start'
AND latest_build_raw.job_status = 'pending' THEN 'pending'::task_status
-- Job is running or done, defer to agent/app status.
WHEN
latest_build_raw.transition = 'start'
AND latest_build_raw.job_status IN ('running', 'succeeded') THEN 'active'::task_status
ELSE 'unknown'::task_status
END AS status
) build_status
CROSS JOIN LATERAL (
SELECT
CASE
-- No agent or connecting.
WHEN
agent_raw.lifecycle_state IS NULL
OR agent_raw.lifecycle_state IN ('created', 'starting') THEN 'initializing'::task_status
-- Agent is running, defer to app status.
-- NOTE(mafredri): The start_error/start_timeout states means connected, but some startup script failed.
-- This may or may not affect the task status but this has to be caught by app health check.
WHEN agent_raw.lifecycle_state IN ('ready', 'start_timeout', 'start_error') THEN 'active'::task_status
-- If the agent is shutting down or turned off, this is an unknown state because we would expect a stop
-- build to be running.
-- This is essentially equal to: `IN ('shutting_down', 'shutdown_timeout', 'shutdown_error', 'off')`,
-- but we cannot use them because the values were added in a migration.
WHEN agent_raw.lifecycle_state NOT IN ('created', 'starting', 'ready', 'start_timeout', 'start_error') THEN 'unknown'::task_status
ELSE 'unknown'::task_status
END AS status
) agent_status
CROSS JOIN LATERAL (
SELECT
CASE
WHEN app_raw.health = 'initializing' THEN 'initializing'::task_status
WHEN app_raw.health = 'unhealthy' THEN 'error'::task_status
WHEN app_raw.health IN ('healthy', 'disabled') THEN 'active'::task_status
ELSE 'unknown'::task_status
END AS status
) app_status
WHERE
tasks.deleted_at IS NULL;
@@ -1,4 +0,0 @@
DROP INDEX IF EXISTS idx_aibridge_interceptions_client_session_id;
ALTER TABLE aibridge_interceptions
DROP COLUMN client_session_id;
@@ -1,7 +0,0 @@
ALTER TABLE aibridge_interceptions
ADD COLUMN client_session_id VARCHAR(256) NULL;
COMMENT ON COLUMN aibridge_interceptions.client_session_id IS 'The session ID supplied by the client (optional and not universally supported).';
CREATE INDEX idx_aibridge_interceptions_client_session_id ON aibridge_interceptions (client_session_id)
WHERE client_session_id IS NOT NULL;
@@ -1,2 +0,0 @@
DROP INDEX IF EXISTS idx_chat_files_org;
DROP TABLE IF EXISTS chat_files;
@@ -1,12 +0,0 @@
CREATE TABLE chat_files (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
owner_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
organization_id UUID NOT NULL REFERENCES organizations(id) ON DELETE CASCADE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
name TEXT NOT NULL DEFAULT '',
mimetype TEXT NOT NULL,
data BYTEA NOT NULL
);
CREATE INDEX idx_chat_files_owner ON chat_files(owner_id);
CREATE INDEX idx_chat_files_org ON chat_files(organization_id);
@@ -1,13 +0,0 @@
INSERT INTO chat_files (id, owner_id, organization_id, created_at, name, mimetype, data)
SELECT
'00000000-0000-0000-0000-000000000099',
u.id,
om.organization_id,
'2024-01-01 00:00:00+00',
'test.png',
'image/png',
E'\\x89504E47'
FROM users u
JOIN organization_members om ON om.user_id = u.id
ORDER BY u.created_at, u.id
LIMIT 1;
+5 -18
View File
@@ -155,33 +155,20 @@ func (t Task) TaskTable() TaskTable {
}
func (t Task) RBACObject() rbac.Object {
obj := rbac.ResourceTask.
return t.TaskTable().RBACObject()
}
func (t TaskTable) RBACObject() rbac.Object {
return rbac.ResourceTask.
WithID(t.ID).
WithOwner(t.OwnerID.String()).
InOrg(t.OrganizationID)
if rbac.WorkspaceACLDisabled() {
return obj
}
if t.WorkspaceGroupACL != nil {
obj = obj.WithGroupACL(t.WorkspaceGroupACL.RBACACL())
}
if t.WorkspaceUserACL != nil {
obj = obj.WithACLUserList(t.WorkspaceUserACL.RBACACL())
}
return obj
}
func (c Chat) RBACObject() rbac.Object {
return rbac.ResourceChat.WithID(c.ID).WithOwner(c.OwnerID.String())
}
func (c ChatFile) RBACObject() rbac.Object {
return rbac.ResourceChat.WithID(c.ID).WithOwner(c.OwnerID.String()).InOrg(c.OrganizationID)
}
func (s APIKeyScope) ToRBAC() rbac.ScopeName {
switch s {
case ApiKeyScopeCoderAll:
-1
View File
@@ -815,7 +815,6 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, ar
&i.AIBridgeInterception.Client,
&i.AIBridgeInterception.ThreadParentID,
&i.AIBridgeInterception.ThreadRootID,
&i.AIBridgeInterception.ClientSessionID,
&i.VisibleUser.ID,
&i.VisibleUser.Username,
&i.VisibleUser.Name,
-14
View File
@@ -3793,8 +3793,6 @@ type AIBridgeInterception struct {
ThreadParentID uuid.NullUUID `db:"thread_parent_id" json:"thread_parent_id"`
// The root interception of the thread that this interception belongs to.
ThreadRootID uuid.NullUUID `db:"thread_root_id" json:"thread_root_id"`
// The session ID supplied by the client (optional and not universally supported).
ClientSessionID sql.NullString `db:"client_session_id" json:"client_session_id"`
}
// Audit log of tokens used by intercepted requests in AI Bridge
@@ -3926,16 +3924,6 @@ type ChatDiffStatus struct {
GitRemoteOrigin string `db:"git_remote_origin" json:"git_remote_origin"`
}
type ChatFile struct {
ID uuid.UUID `db:"id" json:"id"`
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
Name string `db:"name" json:"name"`
Mimetype string `db:"mimetype" json:"mimetype"`
Data []byte `db:"data" json:"data"`
}
type ChatMessage struct {
ID int64 `db:"id" json:"id"`
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
@@ -4506,8 +4494,6 @@ type Task struct {
CreatedAt time.Time `db:"created_at" json:"created_at"`
DeletedAt sql.NullTime `db:"deleted_at" json:"deleted_at"`
DisplayName string `db:"display_name" json:"display_name"`
WorkspaceGroupACL WorkspaceACL `db:"workspace_group_acl" json:"workspace_group_acl"`
WorkspaceUserACL WorkspaceACL `db:"workspace_user_acl" json:"workspace_user_acl"`
Status TaskStatus `db:"status" json:"status"`
StatusDebug json.RawMessage `db:"status_debug" json:"status_debug"`
WorkspaceBuildNumber sql.NullInt32 `db:"workspace_build_number" json:"workspace_build_number"`
-3
View File
@@ -218,8 +218,6 @@ type sqlcQuerier interface {
GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Chat, error)
GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error)
GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error)
GetChatFileByID(ctx context.Context, id uuid.UUID) (ChatFile, error)
GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]ChatFile, error)
GetChatMessageByID(ctx context.Context, id int64) (ChatMessage, error)
GetChatMessagesByChatID(ctx context.Context, arg GetChatMessagesByChatIDParams) ([]ChatMessage, error)
GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error)
@@ -603,7 +601,6 @@ type sqlcQuerier interface {
InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (Group, error)
InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error)
InsertChat(ctx context.Context, arg InsertChatParams) (Chat, error)
InsertChatFile(ctx context.Context, arg InsertChatFileParams) (InsertChatFileRow, error)
InsertChatMessage(ctx context.Context, arg InsertChatMessageParams) (ChatMessage, error)
InsertChatModelConfig(ctx context.Context, arg InsertChatModelConfigParams) (ChatModelConfig, error)
InsertChatProvider(ctx context.Context, arg InsertChatProviderParams) (ChatProvider, error)
+7 -282
View File
@@ -3867,37 +3867,6 @@ func TestGetProvisionerJobsByIDsWithQueuePosition(t *testing.T) {
queueSizes: nil, // TODO(yevhenii): should it be empty array instead?
queuePositions: nil,
},
// Many daemons with identical tags should produce same results as one.
{
name: "duplicate-daemons-same-tags",
jobTags: []database.StringMap{
{"a": "1"},
{"a": "1", "b": "2"},
},
daemonTags: []database.StringMap{
{"a": "1", "b": "2"},
{"a": "1", "b": "2"},
{"a": "1", "b": "2"},
},
queueSizes: []int64{2, 2},
queuePositions: []int64{1, 2},
},
// Jobs that don't match any queried job's daemon should still
// have correct queue positions.
{
name: "irrelevant-daemons-filtered",
jobTags: []database.StringMap{
{"a": "1"},
{"x": "9"},
},
daemonTags: []database.StringMap{
{"a": "1"},
{"x": "9"},
},
queueSizes: []int64{1},
queuePositions: []int64{1},
skipJobIDs: map[int]struct{}{1: {}},
},
}
for _, tc := range testCases {
@@ -4223,51 +4192,6 @@ func TestGetProvisionerJobsByIDsWithQueuePosition_OrderValidation(t *testing.T)
assert.EqualValues(t, []int64{1, 2, 3, 4, 5, 6}, queuePositions, "expected queue positions to be set correctly")
}
func TestGetProvisionerJobsByIDsWithQueuePosition_DuplicateDaemons(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
now := dbtime.Now()
ctx := testutil.Context(t, testutil.WaitShort)
// Create 3 pending jobs with the same tags.
jobs := make([]database.ProvisionerJob, 3)
for i := range jobs {
jobs[i] = dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
CreatedAt: now.Add(-time.Duration(3-i) * time.Minute),
Tags: database.StringMap{"scope": "organization", "owner": ""},
})
}
// Create 50 daemons with identical tags (simulates scale).
for i := range 50 {
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
Name: fmt.Sprintf("daemon_%d", i),
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
Tags: database.StringMap{"scope": "organization", "owner": ""},
})
}
jobIDs := make([]uuid.UUID, len(jobs))
for i, j := range jobs {
jobIDs[i] = j.ID
}
results, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx,
database.GetProvisionerJobsByIDsWithQueuePositionParams{
IDs: jobIDs,
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
})
require.NoError(t, err)
require.Len(t, results, 3)
// All daemons have identical tags, so queue should be same as
// if there were just one daemon.
for i, r := range results {
assert.Equal(t, int64(3), r.QueueSize, "job %d queue size", i)
assert.Equal(t, int64(i+1), r.QueuePosition, "job %d queue position", i)
}
}
func TestGroupRemovalTrigger(t *testing.T) {
t.Parallel()
@@ -8489,7 +8413,7 @@ func TestGetAuthenticatedWorkspaceAgentAndBuildByAuthToken_ShutdownScripts(t *te
})
// Agent should still authenticate during stop build execution.
row, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent.AuthToken)
row, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agent.AuthToken)
require.NoError(t, err, "agent should authenticate during stop build execution")
require.Equal(t, agent.ID, row.WorkspaceAgent.ID)
require.Equal(t, startBuild.ID, row.WorkspaceBuild.ID, "should return start build, not stop build")
@@ -8547,7 +8471,7 @@ func TestGetAuthenticatedWorkspaceAgentAndBuildByAuthToken_ShutdownScripts(t *te
})
// Agent should NOT authenticate after stop job completes.
_, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent.AuthToken)
_, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agent.AuthToken)
require.ErrorIs(t, err, sql.ErrNoRows, "agent should not authenticate after stop job completes")
})
@@ -8601,7 +8525,7 @@ func TestGetAuthenticatedWorkspaceAgentAndBuildByAuthToken_ShutdownScripts(t *te
})
// Agent should NOT authenticate (start build failed).
_, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent.AuthToken)
_, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agent.AuthToken)
require.ErrorIs(t, err, sql.ErrNoRows, "agent from failed start build should not authenticate")
})
@@ -8656,7 +8580,7 @@ func TestGetAuthenticatedWorkspaceAgentAndBuildByAuthToken_ShutdownScripts(t *te
})
// Agent should authenticate during pending stop build.
row, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent.AuthToken)
row, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agent.AuthToken)
require.NoError(t, err, "agent should authenticate during pending stop build")
require.Equal(t, agent.ID, row.WorkspaceAgent.ID)
require.Equal(t, startBuild.ID, row.WorkspaceBuild.ID, "should return start build")
@@ -8753,13 +8677,13 @@ func TestGetAuthenticatedWorkspaceAgentAndBuildByAuthToken_ShutdownScripts(t *te
})
// Agent from build 3 should authenticate.
row, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent2.AuthToken)
row, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agent2.AuthToken)
require.NoError(t, err, "agent from most recent start should authenticate during stop")
require.Equal(t, agent2.ID, row.WorkspaceAgent.ID)
require.Equal(t, startBuild2.ID, row.WorkspaceBuild.ID)
// Agent from build 1 should NOT authenticate.
_, err = db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent1.AuthToken)
_, err = db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agent1.AuthToken)
require.ErrorIs(t, err, sql.ErrNoRows, "agent from old cycle should not authenticate")
})
@@ -8813,7 +8737,7 @@ func TestGetAuthenticatedWorkspaceAgentAndBuildByAuthToken_ShutdownScripts(t *te
})
// Agent from build 1 should NOT authenticate (latest is not STOP).
_, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent1.AuthToken)
_, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agent1.AuthToken)
require.ErrorIs(t, err, sql.ErrNoRows, "agent should not authenticate when latest build is not STOP")
})
}
@@ -8917,202 +8841,3 @@ func TestInsertWorkspaceAgentDevcontainers(t *testing.T) {
})
}
}
func TestGetChatMessagesForPromptByChatID(t *testing.T) {
t.Parallel()
// This test exercises a complex CTE query for prompt
// reconstruction after compaction. It requires Postgres.
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
// Helper: create a chat model config (required FK for chats).
user := dbgen.User(t, db, database.User{})
// A chat_providers row is required as a FK for model configs.
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
})
require.NoError(t, err)
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: "openai",
Model: "test-model",
DisplayName: "Test Model",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
IsDefault: true,
ContextLimit: 128000,
CompressionThreshold: 80,
Options: json.RawMessage(`{}`),
})
require.NoError(t, err)
newChat := func(t *testing.T) database.Chat {
t.Helper()
chat, err := db.InsertChat(ctx, database.InsertChatParams{
OwnerID: user.ID,
LastModelConfigID: modelCfg.ID,
Title: "test-chat-" + uuid.NewString(),
})
require.NoError(t, err)
return chat
}
insertMsg := func(
t *testing.T,
chatID uuid.UUID,
role string,
vis database.ChatMessageVisibility,
compressed bool,
content string,
) database.ChatMessage {
t.Helper()
msg, err := db.InsertChatMessage(ctx, database.InsertChatMessageParams{
ChatID: chatID,
Role: role,
Visibility: vis,
Compressed: sql.NullBool{Bool: compressed, Valid: true},
Content: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`"` + content + `"`),
Valid: true,
},
})
require.NoError(t, err)
return msg
}
msgIDs := func(msgs []database.ChatMessage) []int64 {
ids := make([]int64, len(msgs))
for i, m := range msgs {
ids[i] = m.ID
}
return ids
}
t.Run("NoCompaction", func(t *testing.T) {
t.Parallel()
chat := newChat(t)
sys := insertMsg(t, chat.ID, "system", database.ChatMessageVisibilityModel, false, "system prompt")
usr := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityBoth, false, "hello")
ast := insertMsg(t, chat.ID, "assistant", database.ChatMessageVisibilityBoth, false, "hi there")
got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
require.NoError(t, err)
require.Equal(t, []int64{sys.ID, usr.ID, ast.ID}, msgIDs(got))
})
t.Run("UserOnlyVisibilityExcluded", func(t *testing.T) {
t.Parallel()
chat := newChat(t)
// Messages with visibility=user should NOT appear in the
// prompt (they are only for the UI).
insertMsg(t, chat.ID, "system", database.ChatMessageVisibilityModel, false, "system prompt")
insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityUser, false, "user-only msg")
usr := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityBoth, false, "hello")
got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
require.NoError(t, err)
for _, m := range got {
require.NotEqual(t, database.ChatMessageVisibilityUser, m.Visibility,
"visibility=user messages should not appear in the prompt")
}
require.Contains(t, msgIDs(got), usr.ID)
})
t.Run("AfterCompaction", func(t *testing.T) {
t.Parallel()
chat := newChat(t)
// Pre-compaction conversation.
sys := insertMsg(t, chat.ID, "system", database.ChatMessageVisibilityModel, false, "system prompt")
preUser := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityBoth, false, "old question")
preAsst := insertMsg(t, chat.ID, "assistant", database.ChatMessageVisibilityBoth, false, "old answer")
// Compaction messages:
// 1. Summary (role=user, visibility=model, compressed=true).
summary := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityModel, true, "compaction summary")
// 2. Compressed assistant tool-call (visibility=user).
insertMsg(t, chat.ID, "assistant", database.ChatMessageVisibilityUser, true, "tool call")
// 3. Compressed tool result (visibility=both).
insertMsg(t, chat.ID, "tool", database.ChatMessageVisibilityBoth, true, "tool result")
// Post-compaction messages.
postUser := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityBoth, false, "new question")
postAsst := insertMsg(t, chat.ID, "assistant", database.ChatMessageVisibilityBoth, false, "new answer")
got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
require.NoError(t, err)
gotIDs := msgIDs(got)
// Must include: system prompt, summary, post-compaction.
require.Contains(t, gotIDs, sys.ID, "system prompt must be included")
require.Contains(t, gotIDs, summary.ID, "compaction summary must be included")
require.Contains(t, gotIDs, postUser.ID, "post-compaction user msg must be included")
require.Contains(t, gotIDs, postAsst.ID, "post-compaction assistant msg must be included")
// Must exclude: pre-compaction non-system messages.
require.NotContains(t, gotIDs, preUser.ID, "pre-compaction user msg must be excluded")
require.NotContains(t, gotIDs, preAsst.ID, "pre-compaction assistant msg must be excluded")
// Verify ordering.
require.Equal(t, []int64{sys.ID, summary.ID, postUser.ID, postAsst.ID}, gotIDs)
})
t.Run("AfterCompactionSummaryIsUserRole", func(t *testing.T) {
t.Parallel()
chat := newChat(t)
// After compaction the summary must appear as role=user so
// that LLM APIs (e.g. Anthropic) see at least one
// non-system message in the prompt.
insertMsg(t, chat.ID, "system", database.ChatMessageVisibilityModel, false, "system prompt")
summary := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityModel, true, "summary text")
newUsr := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityBoth, false, "new question")
got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
require.NoError(t, err)
hasNonSystem := false
for _, m := range got {
if m.Role != "system" {
hasNonSystem = true
break
}
}
require.True(t, hasNonSystem,
"prompt must contain at least one non-system message after compaction")
require.Contains(t, msgIDs(got), summary.ID)
require.Contains(t, msgIDs(got), newUsr.ID)
})
t.Run("CompressedToolResultNotPickedAsSummary", func(t *testing.T) {
t.Parallel()
chat := newChat(t)
// The CTE uses visibility='model' (exact match). If it
// used IN ('model','both'), the compressed tool result
// (visibility=both) would be picked as the "summary"
// instead of the actual summary.
insertMsg(t, chat.ID, "system", database.ChatMessageVisibilityModel, false, "system prompt")
summary := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityModel, true, "real summary")
compressedTool := insertMsg(t, chat.ID, "tool", database.ChatMessageVisibilityBoth, true, "tool result")
postUser := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityBoth, false, "follow-up")
got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
require.NoError(t, err)
gotIDs := msgIDs(got)
require.Contains(t, gotIDs, summary.ID, "real summary must be included")
require.NotContains(t, gotIDs, compressedTool.ID,
"compressed tool result must not be included")
require.Contains(t, gotIDs, postUser.ID)
})
}
+21 -143
View File
@@ -378,7 +378,7 @@ func (q *sqlQuerier) DeleteOldAIBridgeRecords(ctx context.Context, beforeTime ti
const getAIBridgeInterceptionByID = `-- name: GetAIBridgeInterceptionByID :one
SELECT
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id
FROM
aibridge_interceptions
WHERE
@@ -400,7 +400,6 @@ func (q *sqlQuerier) GetAIBridgeInterceptionByID(ctx context.Context, id uuid.UU
&i.Client,
&i.ThreadParentID,
&i.ThreadRootID,
&i.ClientSessionID,
)
return i, err
}
@@ -435,7 +434,7 @@ func (q *sqlQuerier) GetAIBridgeInterceptionLineageByToolCallID(ctx context.Cont
const getAIBridgeInterceptions = `-- name: GetAIBridgeInterceptions :many
SELECT
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id
FROM
aibridge_interceptions
`
@@ -461,7 +460,6 @@ func (q *sqlQuerier) GetAIBridgeInterceptions(ctx context.Context) ([]AIBridgeIn
&i.Client,
&i.ThreadParentID,
&i.ThreadRootID,
&i.ClientSessionID,
); err != nil {
return nil, err
}
@@ -608,11 +606,11 @@ func (q *sqlQuerier) GetAIBridgeUserPromptsByInterceptionID(ctx context.Context,
const insertAIBridgeInterception = `-- name: InsertAIBridgeInterception :one
INSERT INTO aibridge_interceptions (
id, api_key_id, initiator_id, provider, model, metadata, started_at, client, client_session_id, thread_parent_id, thread_root_id
id, api_key_id, initiator_id, provider, model, metadata, started_at, client, thread_parent_id, thread_root_id
) VALUES (
$1, $2, $3, $4, $5, COALESCE($6::jsonb, '{}'::jsonb), $7, $8, $9, $10::uuid, $11::uuid
$1, $2, $3, $4, $5, COALESCE($6::jsonb, '{}'::jsonb), $7, $8, $9::uuid, $10::uuid
)
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id
`
type InsertAIBridgeInterceptionParams struct {
@@ -624,7 +622,6 @@ type InsertAIBridgeInterceptionParams struct {
Metadata json.RawMessage `db:"metadata" json:"metadata"`
StartedAt time.Time `db:"started_at" json:"started_at"`
Client sql.NullString `db:"client" json:"client"`
ClientSessionID sql.NullString `db:"client_session_id" json:"client_session_id"`
ThreadParentInterceptionID uuid.NullUUID `db:"thread_parent_interception_id" json:"thread_parent_interception_id"`
ThreadRootInterceptionID uuid.NullUUID `db:"thread_root_interception_id" json:"thread_root_interception_id"`
}
@@ -639,7 +636,6 @@ func (q *sqlQuerier) InsertAIBridgeInterception(ctx context.Context, arg InsertA
arg.Metadata,
arg.StartedAt,
arg.Client,
arg.ClientSessionID,
arg.ThreadParentInterceptionID,
arg.ThreadRootInterceptionID,
)
@@ -656,7 +652,6 @@ func (q *sqlQuerier) InsertAIBridgeInterception(ctx context.Context, arg InsertA
&i.Client,
&i.ThreadParentID,
&i.ThreadRootID,
&i.ClientSessionID,
)
return i, err
}
@@ -798,7 +793,7 @@ func (q *sqlQuerier) InsertAIBridgeUserPrompt(ctx context.Context, arg InsertAIB
const listAIBridgeInterceptions = `-- name: ListAIBridgeInterceptions :many
SELECT
aibridge_interceptions.id, aibridge_interceptions.initiator_id, aibridge_interceptions.provider, aibridge_interceptions.model, aibridge_interceptions.started_at, aibridge_interceptions.metadata, aibridge_interceptions.ended_at, aibridge_interceptions.api_key_id, aibridge_interceptions.client, aibridge_interceptions.thread_parent_id, aibridge_interceptions.thread_root_id, aibridge_interceptions.client_session_id,
aibridge_interceptions.id, aibridge_interceptions.initiator_id, aibridge_interceptions.provider, aibridge_interceptions.model, aibridge_interceptions.started_at, aibridge_interceptions.metadata, aibridge_interceptions.ended_at, aibridge_interceptions.api_key_id, aibridge_interceptions.client, aibridge_interceptions.thread_parent_id, aibridge_interceptions.thread_root_id,
visible_users.id, visible_users.username, visible_users.name, visible_users.avatar_url
FROM
aibridge_interceptions
@@ -909,7 +904,6 @@ func (q *sqlQuerier) ListAIBridgeInterceptions(ctx context.Context, arg ListAIBr
&i.AIBridgeInterception.Client,
&i.AIBridgeInterception.ThreadParentID,
&i.AIBridgeInterception.ThreadRootID,
&i.AIBridgeInterception.ClientSessionID,
&i.VisibleUser.ID,
&i.VisibleUser.Username,
&i.VisibleUser.Name,
@@ -1170,7 +1164,7 @@ UPDATE aibridge_interceptions
WHERE
id = $2::uuid
AND ended_at IS NULL
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id
`
type UpdateAIBridgeInterceptionEndedParams struct {
@@ -1193,7 +1187,6 @@ func (q *sqlQuerier) UpdateAIBridgeInterceptionEnded(ctx context.Context, arg Up
&i.Client,
&i.ThreadParentID,
&i.ThreadRootID,
&i.ClientSessionID,
)
return i, err
}
@@ -2214,103 +2207,6 @@ func (q *sqlQuerier) UpsertBoundaryUsageStats(ctx context.Context, arg UpsertBou
return new_period, err
}
const getChatFileByID = `-- name: GetChatFileByID :one
SELECT id, owner_id, organization_id, created_at, name, mimetype, data FROM chat_files WHERE id = $1::uuid
`
func (q *sqlQuerier) GetChatFileByID(ctx context.Context, id uuid.UUID) (ChatFile, error) {
row := q.db.QueryRowContext(ctx, getChatFileByID, id)
var i ChatFile
err := row.Scan(
&i.ID,
&i.OwnerID,
&i.OrganizationID,
&i.CreatedAt,
&i.Name,
&i.Mimetype,
&i.Data,
)
return i, err
}
const getChatFilesByIDs = `-- name: GetChatFilesByIDs :many
SELECT id, owner_id, organization_id, created_at, name, mimetype, data FROM chat_files WHERE id = ANY($1::uuid[])
`
func (q *sqlQuerier) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]ChatFile, error) {
rows, err := q.db.QueryContext(ctx, getChatFilesByIDs, pq.Array(ids))
if err != nil {
return nil, err
}
defer rows.Close()
var items []ChatFile
for rows.Next() {
var i ChatFile
if err := rows.Scan(
&i.ID,
&i.OwnerID,
&i.OrganizationID,
&i.CreatedAt,
&i.Name,
&i.Mimetype,
&i.Data,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertChatFile = `-- name: InsertChatFile :one
INSERT INTO chat_files (owner_id, organization_id, name, mimetype, data)
VALUES ($1::uuid, $2::uuid, $3::text, $4::text, $5::bytea)
RETURNING id, owner_id, organization_id, created_at, name, mimetype
`
type InsertChatFileParams struct {
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
Name string `db:"name" json:"name"`
Mimetype string `db:"mimetype" json:"mimetype"`
Data []byte `db:"data" json:"data"`
}
type InsertChatFileRow struct {
ID uuid.UUID `db:"id" json:"id"`
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
Name string `db:"name" json:"name"`
Mimetype string `db:"mimetype" json:"mimetype"`
}
func (q *sqlQuerier) InsertChatFile(ctx context.Context, arg InsertChatFileParams) (InsertChatFileRow, error) {
row := q.db.QueryRowContext(ctx, insertChatFile,
arg.OwnerID,
arg.OrganizationID,
arg.Name,
arg.Mimetype,
arg.Data,
)
var i InsertChatFileRow
err := row.Scan(
&i.ID,
&i.OwnerID,
&i.OrganizationID,
&i.CreatedAt,
&i.Name,
&i.Mimetype,
)
return i, err
}
const deleteChatModelConfigByID = `-- name: DeleteChatModelConfigByID :exec
UPDATE
chat_model_configs
@@ -3321,8 +3217,9 @@ WITH latest_compressed_summary AS (
chat_messages
WHERE
chat_id = $1::uuid
AND role = 'system'
AND visibility IN ('model', 'both')
AND compressed = TRUE
AND visibility = 'model'
ORDER BY
created_at DESC,
id DESC
@@ -12654,7 +12551,7 @@ const getProvisionerJobsByIDsWithQueuePosition = `-- name: GetProvisionerJobsByI
WITH filtered_provisioner_jobs AS (
-- Step 1: Filter provisioner_jobs
SELECT
id, created_at, tags
id, created_at
FROM
provisioner_jobs
WHERE
@@ -12669,32 +12566,21 @@ pending_jobs AS (
WHERE
job_status = 'pending'
),
unique_daemon_tags AS (
SELECT DISTINCT tags FROM provisioner_daemons pd
WHERE pd.last_seen_at IS NOT NULL
AND pd.last_seen_at >= (NOW() - ($2::bigint || ' ms')::interval)
),
relevant_daemon_tags AS (
SELECT udt.tags
FROM unique_daemon_tags udt
WHERE EXISTS (
SELECT 1 FROM filtered_provisioner_jobs fpj
WHERE provisioner_tagset_contains(udt.tags, fpj.tags)
)
online_provisioner_daemons AS (
SELECT id, tags FROM provisioner_daemons pd
WHERE pd.last_seen_at IS NOT NULL AND pd.last_seen_at >= (NOW() - ($2::bigint || ' ms')::interval)
),
ranked_jobs AS (
-- Step 3: Rank only pending jobs based on provisioner availability
SELECT
pj.id,
pj.created_at,
ROW_NUMBER() OVER (PARTITION BY rdt.tags ORDER BY pj.initiator_id = 'c42fdf75-3097-471c-8c33-fb52454d81c0'::uuid ASC, pj.created_at ASC) AS queue_position,
COUNT(*) OVER (PARTITION BY rdt.tags) AS queue_size
ROW_NUMBER() OVER (PARTITION BY opd.id ORDER BY pj.initiator_id = 'c42fdf75-3097-471c-8c33-fb52454d81c0'::uuid ASC, pj.created_at ASC) AS queue_position,
COUNT(*) OVER (PARTITION BY opd.id) AS queue_size
FROM
pending_jobs pj
INNER JOIN
relevant_daemon_tags rdt
ON
provisioner_tagset_contains(rdt.tags, pj.tags)
INNER JOIN online_provisioner_daemons opd
ON provisioner_tagset_contains(opd.tags, pj.tags) -- Join only on the small pending set
),
final_jobs AS (
-- Step 4: Compute best queue position and max queue size per job
@@ -15310,7 +15196,7 @@ func (q *sqlQuerier) DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid
}
const getTaskByID = `-- name: GetTaskByID :one
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, display_name, workspace_group_acl, workspace_user_acl, status, status_debug, workspace_build_number, workspace_agent_id, workspace_app_id, workspace_agent_lifecycle_state, workspace_app_health, owner_username, owner_name, owner_avatar_url FROM tasks_with_status WHERE id = $1::uuid
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, display_name, status, status_debug, workspace_build_number, workspace_agent_id, workspace_app_id, workspace_agent_lifecycle_state, workspace_app_health, owner_username, owner_name, owner_avatar_url FROM tasks_with_status WHERE id = $1::uuid
`
func (q *sqlQuerier) GetTaskByID(ctx context.Context, id uuid.UUID) (Task, error) {
@@ -15328,8 +15214,6 @@ func (q *sqlQuerier) GetTaskByID(ctx context.Context, id uuid.UUID) (Task, error
&i.CreatedAt,
&i.DeletedAt,
&i.DisplayName,
&i.WorkspaceGroupACL,
&i.WorkspaceUserACL,
&i.Status,
&i.StatusDebug,
&i.WorkspaceBuildNumber,
@@ -15345,7 +15229,7 @@ func (q *sqlQuerier) GetTaskByID(ctx context.Context, id uuid.UUID) (Task, error
}
const getTaskByOwnerIDAndName = `-- name: GetTaskByOwnerIDAndName :one
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, display_name, workspace_group_acl, workspace_user_acl, status, status_debug, workspace_build_number, workspace_agent_id, workspace_app_id, workspace_agent_lifecycle_state, workspace_app_health, owner_username, owner_name, owner_avatar_url FROM tasks_with_status
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, display_name, status, status_debug, workspace_build_number, workspace_agent_id, workspace_app_id, workspace_agent_lifecycle_state, workspace_app_health, owner_username, owner_name, owner_avatar_url FROM tasks_with_status
WHERE
owner_id = $1::uuid
AND deleted_at IS NULL
@@ -15372,8 +15256,6 @@ func (q *sqlQuerier) GetTaskByOwnerIDAndName(ctx context.Context, arg GetTaskByO
&i.CreatedAt,
&i.DeletedAt,
&i.DisplayName,
&i.WorkspaceGroupACL,
&i.WorkspaceUserACL,
&i.Status,
&i.StatusDebug,
&i.WorkspaceBuildNumber,
@@ -15389,7 +15271,7 @@ func (q *sqlQuerier) GetTaskByOwnerIDAndName(ctx context.Context, arg GetTaskByO
}
const getTaskByWorkspaceID = `-- name: GetTaskByWorkspaceID :one
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, display_name, workspace_group_acl, workspace_user_acl, status, status_debug, workspace_build_number, workspace_agent_id, workspace_app_id, workspace_agent_lifecycle_state, workspace_app_health, owner_username, owner_name, owner_avatar_url FROM tasks_with_status WHERE workspace_id = $1::uuid
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, display_name, status, status_debug, workspace_build_number, workspace_agent_id, workspace_app_id, workspace_agent_lifecycle_state, workspace_app_health, owner_username, owner_name, owner_avatar_url FROM tasks_with_status WHERE workspace_id = $1::uuid
`
func (q *sqlQuerier) GetTaskByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (Task, error) {
@@ -15407,8 +15289,6 @@ func (q *sqlQuerier) GetTaskByWorkspaceID(ctx context.Context, workspaceID uuid.
&i.CreatedAt,
&i.DeletedAt,
&i.DisplayName,
&i.WorkspaceGroupACL,
&i.WorkspaceUserACL,
&i.Status,
&i.StatusDebug,
&i.WorkspaceBuildNumber,
@@ -15688,7 +15568,7 @@ func (q *sqlQuerier) InsertTask(ctx context.Context, arg InsertTaskParams) (Task
}
const listTasks = `-- name: ListTasks :many
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, display_name, workspace_group_acl, workspace_user_acl, status, status_debug, workspace_build_number, workspace_agent_id, workspace_app_id, workspace_agent_lifecycle_state, workspace_app_health, owner_username, owner_name, owner_avatar_url FROM tasks_with_status tws
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, display_name, status, status_debug, workspace_build_number, workspace_agent_id, workspace_app_id, workspace_agent_lifecycle_state, workspace_app_health, owner_username, owner_name, owner_avatar_url FROM tasks_with_status tws
WHERE tws.deleted_at IS NULL
AND CASE WHEN $1::UUID != '00000000-0000-0000-0000-000000000000' THEN tws.owner_id = $1::UUID ELSE TRUE END
AND CASE WHEN $2::UUID != '00000000-0000-0000-0000-000000000000' THEN tws.organization_id = $2::UUID ELSE TRUE END
@@ -15723,8 +15603,6 @@ func (q *sqlQuerier) ListTasks(ctx context.Context, arg ListTasksParams) ([]Task
&i.CreatedAt,
&i.DeletedAt,
&i.DisplayName,
&i.WorkspaceGroupACL,
&i.WorkspaceUserACL,
&i.Status,
&i.StatusDebug,
&i.WorkspaceBuildNumber,
+2 -2
View File
@@ -1,8 +1,8 @@
-- name: InsertAIBridgeInterception :one
INSERT INTO aibridge_interceptions (
id, api_key_id, initiator_id, provider, model, metadata, started_at, client, client_session_id, thread_parent_id, thread_root_id
id, api_key_id, initiator_id, provider, model, metadata, started_at, client, thread_parent_id, thread_root_id
) VALUES (
@id, @api_key_id, @initiator_id, @provider, @model, COALESCE(@metadata::jsonb, '{}'::jsonb), @started_at, @client, sqlc.narg('client_session_id'), sqlc.narg('thread_parent_interception_id')::uuid, sqlc.narg('thread_root_interception_id')::uuid
@id, @api_key_id, @initiator_id, @provider, @model, COALESCE(@metadata::jsonb, '{}'::jsonb), @started_at, @client, sqlc.narg('thread_parent_interception_id')::uuid, sqlc.narg('thread_root_interception_id')::uuid
)
RETURNING *;
-10
View File
@@ -1,10 +0,0 @@
-- name: InsertChatFile :one
INSERT INTO chat_files (owner_id, organization_id, name, mimetype, data)
VALUES (@owner_id::uuid, @organization_id::uuid, @name::text, @mimetype::text, @data::bytea)
RETURNING id, owner_id, organization_id, created_at, name, mimetype;
-- name: GetChatFileByID :one
SELECT * FROM chat_files WHERE id = @id::uuid;
-- name: GetChatFilesByIDs :many
SELECT * FROM chat_files WHERE id = ANY(@ids::uuid[]);
+2 -1
View File
@@ -54,8 +54,9 @@ WITH latest_compressed_summary AS (
chat_messages
WHERE
chat_id = @chat_id::uuid
AND role = 'system'
AND visibility IN ('model', 'both')
AND compressed = TRUE
AND visibility = 'model'
ORDER BY
created_at DESC,
id DESC
+8 -19
View File
@@ -79,7 +79,7 @@ WHERE
WITH filtered_provisioner_jobs AS (
-- Step 1: Filter provisioner_jobs
SELECT
id, created_at, tags
id, created_at
FROM
provisioner_jobs
WHERE
@@ -94,32 +94,21 @@ pending_jobs AS (
WHERE
job_status = 'pending'
),
unique_daemon_tags AS (
SELECT DISTINCT tags FROM provisioner_daemons pd
WHERE pd.last_seen_at IS NOT NULL
AND pd.last_seen_at >= (NOW() - (@stale_interval_ms::bigint || ' ms')::interval)
),
relevant_daemon_tags AS (
SELECT udt.tags
FROM unique_daemon_tags udt
WHERE EXISTS (
SELECT 1 FROM filtered_provisioner_jobs fpj
WHERE provisioner_tagset_contains(udt.tags, fpj.tags)
)
online_provisioner_daemons AS (
SELECT id, tags FROM provisioner_daemons pd
WHERE pd.last_seen_at IS NOT NULL AND pd.last_seen_at >= (NOW() - (@stale_interval_ms::bigint || ' ms')::interval)
),
ranked_jobs AS (
-- Step 3: Rank only pending jobs based on provisioner availability
SELECT
pj.id,
pj.created_at,
ROW_NUMBER() OVER (PARTITION BY rdt.tags ORDER BY pj.initiator_id = 'c42fdf75-3097-471c-8c33-fb52454d81c0'::uuid ASC, pj.created_at ASC) AS queue_position,
COUNT(*) OVER (PARTITION BY rdt.tags) AS queue_size
ROW_NUMBER() OVER (PARTITION BY opd.id ORDER BY pj.initiator_id = 'c42fdf75-3097-471c-8c33-fb52454d81c0'::uuid ASC, pj.created_at ASC) AS queue_position,
COUNT(*) OVER (PARTITION BY opd.id) AS queue_size
FROM
pending_jobs pj
INNER JOIN
relevant_daemon_tags rdt
ON
provisioner_tagset_contains(rdt.tags, pj.tags)
INNER JOIN online_provisioner_daemons opd
ON provisioner_tagset_contains(opd.tags, pj.tags) -- Join only on the small pending set
),
final_jobs AS (
-- Step 4: Compute best queue position and max queue size per job
-8
View File
@@ -82,12 +82,6 @@ sql:
- column: "template_usage_stats.app_usage_mins"
go_type:
type: "StringMapOfInt"
- column: "tasks_with_status.workspace_user_acl"
go_type:
type: "WorkspaceACL"
- column: "tasks_with_status.workspace_group_acl"
go_type:
type: "WorkspaceACL"
- column: "workspaces.user_acl"
go_type:
type: "WorkspaceACL"
@@ -192,8 +186,6 @@ sql:
jwt: JWT
user_acl: UserACL
group_acl: GroupACL
workspace_user_acl: WorkspaceUserACL
workspace_group_acl: WorkspaceGroupACL
user_acl_display_info: UserACLDisplayInfo
group_acl_display_info: GroupACLDisplayInfo
troubleshooting_url: TroubleshootingURL
-1
View File
@@ -15,7 +15,6 @@ const (
UniqueAuditLogsPkey UniqueConstraint = "audit_logs_pkey" // ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id);
UniqueBoundaryUsageStatsPkey UniqueConstraint = "boundary_usage_stats_pkey" // ALTER TABLE ONLY boundary_usage_stats ADD CONSTRAINT boundary_usage_stats_pkey PRIMARY KEY (replica_id);
UniqueChatDiffStatusesPkey UniqueConstraint = "chat_diff_statuses_pkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
UniqueChatFilesPkey UniqueConstraint = "chat_files_pkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id);
UniqueChatMessagesPkey UniqueConstraint = "chat_messages_pkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_pkey PRIMARY KEY (id);
UniqueChatModelConfigsPkey UniqueConstraint = "chat_model_configs_pkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_pkey PRIMARY KEY (id);
UniqueChatProvidersPkey UniqueConstraint = "chat_providers_pkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_pkey PRIMARY KEY (id);
+2 -1
View File
@@ -24,6 +24,7 @@ import (
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/externalauth"
@@ -336,7 +337,7 @@ func TestRefreshToken(t *testing.T) {
require.Equal(t, 1, validateCalls, "token is validated")
require.Equal(t, 1, refreshCalls, "token is refreshed")
require.NotEqualf(t, link.OAuthAccessToken, updated.OAuthAccessToken, "token is updated")
dbLink, err := db.GetExternalAuthLink(context.Background(), database.GetExternalAuthLinkParams{
dbLink, err := db.GetExternalAuthLink(dbauthz.AsSystemRestricted(context.Background()), database.GetExternalAuthLinkParams{
ProviderID: link.ProviderID,
UserID: link.UserID,
})
+194 -391
View File
@@ -30,57 +30,7 @@ import (
"github.com/coder/coder/v2/codersdk"
)
type (
apiKeyContextKey struct{}
apiKeyPrecheckedContextKey struct{}
)
// ValidateAPIKeyConfig holds the settings needed for API key
// validation at the top of the request lifecycle. Unlike
// ExtractAPIKeyConfig it omits route-specific fields
// (RedirectToLogin, Optional, ActivateDormantUser, etc.).
type ValidateAPIKeyConfig struct {
DB database.Store
OAuth2Configs *OAuth2Configs
DisableSessionExpiryRefresh bool
// SessionTokenFunc overrides how the API token is extracted
// from the request. Nil uses the default (cookie/header).
SessionTokenFunc func(*http.Request) string
Logger slog.Logger
}
// ValidateAPIKeyResult is the outcome of successful validation.
type ValidateAPIKeyResult struct {
Key database.APIKey
Subject rbac.Subject
UserStatus database.UserStatus
}
// ValidateAPIKeyError represents a validation failure with enough
// context for downstream middlewares to decide how to respond.
type ValidateAPIKeyError struct {
Code int
Response codersdk.Response
// Hard is true for server errors and active failures (5xx,
// OAuth refresh failures) that must be surfaced even on
// optional-auth routes. Soft errors (missing/expired token)
// may be swallowed on optional routes.
Hard bool
}
func (e *ValidateAPIKeyError) Error() string {
return e.Response.Message
}
// APIKeyPrechecked stores the result of top-level API key
// validation performed by PrecheckAPIKey. It distinguishes
// two states:
// - Validation failed (including no token): Result == nil && Err != nil
// - Validation passed: Result != nil && Err == nil
type APIKeyPrechecked struct {
Result *ValidateAPIKeyResult
Err *ValidateAPIKeyError
}
type apiKeyContextKey struct{}
// APIKeyOptional may return an API key from the ExtractAPIKey handler.
func APIKeyOptional(r *http.Request) (database.APIKey, bool) {
@@ -199,298 +149,6 @@ func ExtractAPIKeyMW(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
}
}
// PrecheckAPIKey extracts and fully validates the API key on every
// request (if present) and stores the result in context. It never
// writes error responses and always calls next.
//
// The rate limiter reads the stored result to key by user ID and
// check the Owner bypass header. Downstream ExtractAPIKeyMW reads
// it to avoid redundant DB lookups and validation.
func PrecheckAPIKey(cfg ValidateAPIKeyConfig) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Already prechecked (shouldn't happen, but guard).
if _, ok := ctx.Value(apiKeyPrecheckedContextKey{}).(APIKeyPrechecked); ok {
next.ServeHTTP(rw, r)
return
}
result, valErr := ValidateAPIKey(ctx, cfg, r)
prechecked := APIKeyPrechecked{
Result: result,
Err: valErr,
}
ctx = context.WithValue(ctx, apiKeyPrecheckedContextKey{}, prechecked)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
}
// ValidateAPIKey extracts and validates the API key from the
// request. It performs all security-critical checks:
// - Token extraction and parsing
// - Database lookup + secret hash validation
// - Expiry check
// - OIDC/OAuth token refresh (if applicable)
// - API key LastUsed / ExpiresAt DB updates
// - User role lookup (UserRBACSubject)
//
// It does NOT:
// - Write HTTP error responses
// - Activate dormant users (route-specific)
// - Redirect to login (route-specific)
// - Check OAuth2 audience (route-specific, depends on AccessURL)
// - Set PostAuth headers (route-specific)
// - Check user active status (route-specific, depends on dormant activation)
//
// Returns (result, nil) on success or (nil, error) on failure.
func ValidateAPIKey(ctx context.Context, cfg ValidateAPIKeyConfig, r *http.Request) (*ValidateAPIKeyResult, *ValidateAPIKeyError) {
key, resp, ok := APIKeyFromRequest(ctx, cfg.DB, cfg.SessionTokenFunc, r)
if !ok {
return nil, &ValidateAPIKeyError{
Code: http.StatusUnauthorized,
Response: resp,
}
}
// Log the API key ID for all requests that have a valid key
// format and secret, regardless of whether subsequent validation
// (expiry, user status, etc.) succeeds.
if rl := loggermw.RequestLoggerFromContext(ctx); rl != nil {
rl.WithFields(slog.F("api_key_id", key.ID))
}
now := dbtime.Now()
if key.ExpiresAt.Before(now) {
return nil, &ValidateAPIKeyError{
Code: http.StatusUnauthorized,
Response: codersdk.Response{
Message: SignedOutErrorMessage,
Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
},
}
}
// Refresh OIDC/GitHub tokens if applicable.
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC {
//nolint:gocritic // System needs to fetch UserLink to check if it's valid.
link, err := cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
UserID: key.UserID,
LoginType: key.LoginType,
})
if errors.Is(err, sql.ErrNoRows) {
return nil, &ValidateAPIKeyError{
Code: http.StatusUnauthorized,
Response: codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "You must re-authenticate with the login provider.",
},
}
}
if err != nil {
return nil, &ValidateAPIKeyError{
Code: http.StatusInternalServerError,
Response: codersdk.Response{
Message: "A database error occurred",
Detail: fmt.Sprintf("get user link by user ID and login type: %s", err.Error()),
},
Hard: true,
}
}
// Check if the OAuth token is expired.
if !link.OAuthExpiry.IsZero() && link.OAuthExpiry.Before(now) {
if cfg.OAuth2Configs.IsZero() {
return nil, &ValidateAPIKeyError{
Code: http.StatusInternalServerError,
Response: codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("Unable to refresh OAuth token for login type %q. "+
"No OAuth2Configs provided. Contact an administrator to configure this login type.", key.LoginType),
},
Hard: true,
}
}
var friendlyName string
var oauthConfig promoauth.OAuth2Config
switch key.LoginType {
case database.LoginTypeGithub:
oauthConfig = cfg.OAuth2Configs.Github
friendlyName = "GitHub"
case database.LoginTypeOIDC:
oauthConfig = cfg.OAuth2Configs.OIDC
friendlyName = "OpenID Connect"
default:
return nil, &ValidateAPIKeyError{
Code: http.StatusInternalServerError,
Response: codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("Unexpected authentication type %q.", key.LoginType),
},
Hard: true,
}
}
if oauthConfig == nil {
return nil, &ValidateAPIKeyError{
Code: http.StatusInternalServerError,
Response: codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("Unable to refresh OAuth token for login type %q. "+
"OAuth2Config not provided. Contact an administrator to configure this login type.", key.LoginType),
},
Hard: true,
}
}
// Soft error: session expired naturally with no
// refresh token. Optional-auth routes treat this as
// unauthenticated.
if link.OAuthRefreshToken == "" {
return nil, &ValidateAPIKeyError{
Code: http.StatusUnauthorized,
Response: codersdk.Response{
Message: SignedOutErrorMessage,
Detail: fmt.Sprintf("%s session expired at %q. Try signing in again.", friendlyName, link.OAuthExpiry.String()),
},
}
}
// We have a refresh token, so let's try it.
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
AccessToken: link.OAuthAccessToken,
RefreshToken: link.OAuthRefreshToken,
Expiry: link.OAuthExpiry,
}).Token()
// Hard error: we actively tried to refresh and the
// provider rejected it — surface even on optional-auth
// routes.
if err != nil {
return nil, &ValidateAPIKeyError{
Code: http.StatusUnauthorized,
Response: codersdk.Response{
Message: fmt.Sprintf(
"Could not refresh expired %s token. Try re-authenticating to resolve this issue.",
friendlyName),
Detail: err.Error(),
},
Hard: true,
}
}
link.OAuthAccessToken = token.AccessToken
link.OAuthRefreshToken = token.RefreshToken
link.OAuthExpiry = token.Expiry
//nolint:gocritic // system needs to update user link
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
UserID: link.UserID,
LoginType: link.LoginType,
OAuthAccessToken: link.OAuthAccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthRefreshToken: link.OAuthRefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: link.OAuthExpiry,
// Refresh should keep the same debug context because we use
// the original claims for the group/role sync.
Claims: link.Claims,
})
if err != nil {
return nil, &ValidateAPIKeyError{
Code: http.StatusInternalServerError,
Response: codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
},
Hard: true,
}
}
}
}
// Update LastUsed and session expiry.
changed := false
if now.Sub(key.LastUsed) > time.Hour {
key.LastUsed = now
remoteIP := net.ParseIP(r.RemoteAddr)
if remoteIP == nil {
remoteIP = net.IPv4(0, 0, 0, 0)
}
bitlen := len(remoteIP) * 8
key.IPAddress = pqtype.Inet{
IPNet: net.IPNet{
IP: remoteIP,
Mask: net.CIDRMask(bitlen, bitlen),
},
Valid: true,
}
changed = true
}
if !cfg.DisableSessionExpiryRefresh {
apiKeyLifetime := time.Duration(key.LifetimeSeconds) * time.Second
if key.ExpiresAt.Sub(now) <= apiKeyLifetime-time.Hour {
key.ExpiresAt = now.Add(apiKeyLifetime)
changed = true
}
}
if changed {
//nolint:gocritic // System needs to update API Key LastUsed
err := cfg.DB.UpdateAPIKeyByID(dbauthz.AsSystemRestricted(ctx), database.UpdateAPIKeyByIDParams{
ID: key.ID,
LastUsed: key.LastUsed,
ExpiresAt: key.ExpiresAt,
IPAddress: key.IPAddress,
})
if err != nil {
return nil, &ValidateAPIKeyError{
Code: http.StatusInternalServerError,
Response: codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("API key couldn't update: %s.", err.Error()),
},
Hard: true,
}
}
//nolint:gocritic // system needs to update user last seen at
_, err = cfg.DB.UpdateUserLastSeenAt(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLastSeenAtParams{
ID: key.UserID,
LastSeenAt: dbtime.Now(),
UpdatedAt: dbtime.Now(),
})
if err != nil {
return nil, &ValidateAPIKeyError{
Code: http.StatusInternalServerError,
Response: codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("update user last_seen_at: %s", err.Error()),
},
Hard: true,
}
}
}
// Fetch user roles.
actor, userStatus, err := UserRBACSubject(ctx, cfg.DB, key.UserID, key.ScopeSet())
if err != nil {
return nil, &ValidateAPIKeyError{
Code: http.StatusUnauthorized,
Response: codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("Internal error fetching user's roles. %s", err.Error()),
},
Hard: true,
}
}
return &ValidateAPIKeyResult{
Key: *key,
Subject: actor,
UserStatus: userStatus,
}, nil
}
func APIKeyFromRequest(ctx context.Context, db database.Store, sessionTokenFunc func(r *http.Request) string, r *http.Request) (*database.APIKey, codersdk.Response, bool) {
tokenFunc := APITokenFromRequest
if sessionTokenFunc != nil {
@@ -582,60 +240,29 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
return nil, nil, false
}
// --- Consume prechecked result if available ---
// Skip prechecked data when cfg has a custom SessionTokenFunc,
// because the precheck used the default token extraction and may
// have validated a different token (e.g. workspace app token
// issuance in workspaceapps/db.go).
var key *database.APIKey
var actor rbac.Subject
var userStatus database.UserStatus
var skipValidation bool
if cfg.SessionTokenFunc == nil {
if pc, ok := ctx.Value(apiKeyPrecheckedContextKey{}).(APIKeyPrechecked); ok {
if pc.Err != nil {
// Validation failed at the top level (includes
// "no token provided").
if pc.Err.Hard {
return write(pc.Err.Code, pc.Err.Response)
}
return optionalWrite(pc.Err.Code, pc.Err.Response)
}
// Valid — use prechecked data, skip to route-specific logic.
key = &pc.Result.Key
actor = pc.Result.Subject
userStatus = pc.Result.UserStatus
skipValidation = true
}
key, resp, ok := APIKeyFromRequest(ctx, cfg.DB, cfg.SessionTokenFunc, r)
if !ok {
return optionalWrite(http.StatusUnauthorized, resp)
}
if !skipValidation {
// Full validation path (no prechecked result or custom token func).
result, valErr := ValidateAPIKey(ctx, ValidateAPIKeyConfig{
DB: cfg.DB,
OAuth2Configs: cfg.OAuth2Configs,
DisableSessionExpiryRefresh: cfg.DisableSessionExpiryRefresh,
SessionTokenFunc: cfg.SessionTokenFunc,
Logger: cfg.Logger,
}, r)
if valErr != nil {
if valErr.Hard {
return write(valErr.Code, valErr.Response)
}
return optionalWrite(valErr.Code, valErr.Response)
}
key = &result.Key
actor = result.Subject
userStatus = result.UserStatus
// Log the API key ID for all requests that have a valid key format and secret,
// regardless of whether subsequent validation (expiry, user status, etc.) succeeds.
if rl := loggermw.RequestLoggerFromContext(ctx); rl != nil {
rl.WithFields(slog.F("api_key_id", key.ID))
}
// --- Route-specific logic (always runs) ---
now := dbtime.Now()
if key.ExpiresAt.Before(now) {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
})
}
// Validate OAuth2 provider app token audience (RFC 8707) if applicable.
// Validate OAuth2 provider app token audience (RFC 8707) if applicable
if key.LoginType == database.LoginTypeOAuth2ProviderApp {
if err := validateOAuth2ProviderAppTokenAudience(ctx, cfg.DB, *key, cfg.AccessURL, r); err != nil {
// Log the detailed error for debugging but don't expose it to the client.
// Log the detailed error for debugging but don't expose it to the client
cfg.Logger.Debug(ctx, "oauth2 token audience validation failed", slog.Error(err))
return optionalWrite(http.StatusForbidden, codersdk.Response{
Message: "Token audience validation failed",
@@ -643,7 +270,183 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
}
}
// Dormant activation (config-dependent).
// We only check OIDC stuff if we have a valid APIKey. An expired key means we don't trust the requestor
// really is the user whose key they have, and so we shouldn't be doing anything on their behalf including possibly
// refreshing the OIDC token.
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC {
var err error
//nolint:gocritic // System needs to fetch UserLink to check if it's valid.
link, err := cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
UserID: key.UserID,
LoginType: key.LoginType,
})
if errors.Is(err, sql.ErrNoRows) {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "You must re-authenticate with the login provider.",
})
}
if err != nil {
return write(http.StatusInternalServerError, codersdk.Response{
Message: "A database error occurred",
Detail: fmt.Sprintf("get user link by user ID and login type: %s", err.Error()),
})
}
// Check if the OAuth token is expired
if !link.OAuthExpiry.IsZero() && link.OAuthExpiry.Before(now) {
if cfg.OAuth2Configs.IsZero() {
return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("Unable to refresh OAuth token for login type %q. "+
"No OAuth2Configs provided. Contact an administrator to configure this login type.", key.LoginType),
})
}
var friendlyName string
var oauthConfig promoauth.OAuth2Config
switch key.LoginType {
case database.LoginTypeGithub:
oauthConfig = cfg.OAuth2Configs.Github
friendlyName = "GitHub"
case database.LoginTypeOIDC:
oauthConfig = cfg.OAuth2Configs.OIDC
friendlyName = "OpenID Connect"
default:
return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("Unexpected authentication type %q.", key.LoginType),
})
}
// It's possible for cfg.OAuth2Configs to be non-nil, but still
// missing this type. For example, if a user logged in with GitHub,
// but the administrator later removed GitHub and replaced it with
// OIDC.
if oauthConfig == nil {
return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("Unable to refresh OAuth token for login type %q. "+
"OAuth2Config not provided. Contact an administrator to configure this login type.", key.LoginType),
})
}
if link.OAuthRefreshToken == "" {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: fmt.Sprintf("%s session expired at %q. Try signing in again.", friendlyName, link.OAuthExpiry.String()),
})
}
// We have a refresh token, so let's try it
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
AccessToken: link.OAuthAccessToken,
RefreshToken: link.OAuthRefreshToken,
Expiry: link.OAuthExpiry,
}).Token()
if err != nil {
return write(http.StatusUnauthorized, codersdk.Response{
Message: fmt.Sprintf(
"Could not refresh expired %s token. Try re-authenticating to resolve this issue.",
friendlyName),
Detail: err.Error(),
})
}
link.OAuthAccessToken = token.AccessToken
link.OAuthRefreshToken = token.RefreshToken
link.OAuthExpiry = token.Expiry
//nolint:gocritic // system needs to update user link
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
UserID: link.UserID,
LoginType: link.LoginType,
OAuthAccessToken: link.OAuthAccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthRefreshToken: link.OAuthRefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: link.OAuthExpiry,
// Refresh should keep the same debug context because we use
// the original claims for the group/role sync.
Claims: link.Claims,
})
if err != nil {
return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
})
}
}
}
// Tracks if the API key has properties updated
changed := false
// Only update LastUsed once an hour to prevent database spam.
if now.Sub(key.LastUsed) > time.Hour {
key.LastUsed = now
remoteIP := net.ParseIP(r.RemoteAddr)
if remoteIP == nil {
remoteIP = net.IPv4(0, 0, 0, 0)
}
bitlen := len(remoteIP) * 8
key.IPAddress = pqtype.Inet{
IPNet: net.IPNet{
IP: remoteIP,
Mask: net.CIDRMask(bitlen, bitlen),
},
Valid: true,
}
changed = true
}
// Only update the ExpiresAt once an hour to prevent database spam.
// We extend the ExpiresAt to reduce re-authentication.
if !cfg.DisableSessionExpiryRefresh {
apiKeyLifetime := time.Duration(key.LifetimeSeconds) * time.Second
if key.ExpiresAt.Sub(now) <= apiKeyLifetime-time.Hour {
key.ExpiresAt = now.Add(apiKeyLifetime)
changed = true
}
}
if changed {
//nolint:gocritic // System needs to update API Key LastUsed
err := cfg.DB.UpdateAPIKeyByID(dbauthz.AsSystemRestricted(ctx), database.UpdateAPIKeyByIDParams{
ID: key.ID,
LastUsed: key.LastUsed,
ExpiresAt: key.ExpiresAt,
IPAddress: key.IPAddress,
})
if err != nil {
return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("API key couldn't update: %s.", err.Error()),
})
}
// We only want to update this occasionally to reduce DB write
// load. We update alongside the UserLink and APIKey since it's
// easier on the DB to colocate writes.
//nolint:gocritic // system needs to update user last seen at
_, err = cfg.DB.UpdateUserLastSeenAt(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLastSeenAtParams{
ID: key.UserID,
LastSeenAt: dbtime.Now(),
UpdatedAt: dbtime.Now(),
})
if err != nil {
return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("update user last_seen_at: %s", err.Error()),
})
}
}
// If the key is valid, we also fetch the user roles and status.
// The roles are used for RBAC authorize checks, and the status
// is to block 'suspended' users from accessing the platform.
actor, userStatus, err := UserRBACSubject(ctx, cfg.DB, key.UserID, key.ScopeSet())
if err != nil {
return write(http.StatusUnauthorized, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("Internal error fetching user's roles. %s", err.Error()),
})
}
if userStatus == database.UserStatusDormant && cfg.ActivateDormantUser != nil {
id, _ := uuid.Parse(actor.ID)
user, err := cfg.ActivateDormantUser(ctx, database.User{
+15 -36
View File
@@ -32,56 +32,35 @@ func RateLimit(count int, window time.Duration) func(http.Handler) http.Handler
count,
window,
httprate.WithKeyFuncs(func(r *http.Request) (string, error) {
// Identify the caller. We check two sources:
//
// 1. apiKeyPrecheckedContextKey — set by PrecheckAPIKey
// at the root of the router. Only fully validated
// keys are used.
// 2. apiKeyContextKey — set by ExtractAPIKeyMW if it
// has already run (e.g. unit tests, workspace-app
// routes that don't go through PrecheckAPIKey).
//
// If neither is present, fall back to IP.
var userID string
var subject *rbac.Subject
if pc, ok := r.Context().Value(apiKeyPrecheckedContextKey{}).(APIKeyPrechecked); ok && pc.Result != nil {
userID = pc.Result.Key.UserID.String()
subject = &pc.Result.Subject
} else if ak, ok := r.Context().Value(apiKeyContextKey{}).(database.APIKey); ok {
userID = ak.UserID.String()
if auth, ok := UserAuthorizationOptional(r.Context()); ok {
subject = &auth
}
} else {
// Prioritize by user, but fallback to IP.
apiKey, ok := r.Context().Value(apiKeyContextKey{}).(database.APIKey)
if !ok {
return httprate.KeyByIP(r)
}
if ok, _ := strconv.ParseBool(r.Header.Get(codersdk.BypassRatelimitHeader)); !ok {
// No bypass attempt, just rate limit by user.
return userID, nil
// No bypass attempt, just ratelimit.
return apiKey.UserID.String(), nil
}
// Allow Owner to bypass rate limiting for load tests
// and automation. We avoid using rbac.Authorizer since
// rego is CPU-intensive and undermines the
// DoS-prevention goal of the rate limiter.
if subject == nil {
// Can't verify roles — rate limit normally.
return userID, nil
}
for _, role := range subject.SafeRoleNames() {
// and automation.
auth := UserAuthorization(r.Context())
// We avoid using rbac.Authorizer since rego is CPU-intensive
// and undermines the DoS-prevention goal of the rate limiter.
for _, role := range auth.SafeRoleNames() {
if role == rbac.RoleOwner() {
// HACK: use a random key each time to
// de facto disable rate limiting. The
// httprate package has no support for
// selectively changing the limit for
// particular keys.
// `httprate` package has no
// support for selectively changing the limit
// for particular keys.
return cryptorand.String(16)
}
}
return userID, xerrors.Errorf(
return apiKey.UserID.String(), xerrors.Errorf(
"%q provided but user is not %v",
codersdk.BypassRatelimitHeader, rbac.RoleOwner(),
)
+3 -1
View File
@@ -15,6 +15,7 @@ import (
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/idpsync"
@@ -356,7 +357,7 @@ func TestGroupSyncTable(t *testing.T) {
},
}
defOrg, err := db.GetDefaultOrganization(ctx)
defOrg, err := db.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx))
require.NoError(t, err)
SetupOrganization(t, s, db, user, defOrg.ID, def)
asserts = append(asserts, func(t *testing.T) {
@@ -554,6 +555,7 @@ func TestApplyGroupDifference(t *testing.T) {
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitMedium)
ctx = dbauthz.AsSystemRestricted(ctx)
org := dbgen.Organization(t, db, database.Organization{})
_, err := db.InsertAllUsersGroup(ctx, org.ID)
+2 -1
View File
@@ -13,6 +13,7 @@ import (
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
@@ -272,7 +273,7 @@ func TestRoleSyncTable(t *testing.T) {
}
// Also assert site wide roles
allRoles, err := db.GetAuthorizationUserRoles(ctx, user.ID)
allRoles, err := db.GetAuthorizationUserRoles(dbauthz.AsSystemRestricted(ctx), user.ID)
require.NoError(t, err)
allRoleIDs, err := allRoles.RoleNames()
+5 -4
View File
@@ -14,6 +14,7 @@ import (
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/notifications"
@@ -29,6 +30,7 @@ func TestBufferedUpdates(t *testing.T) {
// setup
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitSuperLong))
store, ps := dbtestutil.NewDB(t)
logger := testutil.Logger(t)
@@ -55,7 +57,6 @@ func TestBufferedUpdates(t *testing.T) {
user := dbgen.User(t, store, database.User{})
// WHEN: notifications are enqueued which should succeed and fail
ctx := testutil.Context(t, testutil.WaitSuperLong)
_, err = enq.Enqueue(ctx, user.ID, notifications.TemplateWorkspaceDeleted, map[string]string{"nice": "true", "i": "0"}, "") // Will succeed.
require.NoError(t, err)
_, err = enq.Enqueue(ctx, user.ID, notifications.TemplateWorkspaceDeleted, map[string]string{"nice": "true", "i": "1"}, "") // Will succeed.
@@ -105,6 +106,7 @@ func TestBuildPayload(t *testing.T) {
// SETUP
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitSuperLong))
store, _ := dbtestutil.NewDB(t)
logger := testutil.Logger(t)
@@ -144,7 +146,6 @@ func TestBuildPayload(t *testing.T) {
require.NoError(t, err)
// WHEN: a notification is enqueued
ctx := testutil.Context(t, testutil.WaitSuperLong)
_, err = enq.Enqueue(ctx, uuid.New(), notifications.TemplateWorkspaceDeleted, map[string]string{
"name": "my-workspace",
}, "test")
@@ -162,6 +163,7 @@ func TestStopBeforeRun(t *testing.T) {
// SETUP
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitSuperLong))
store, ps := dbtestutil.NewDB(t)
logger := testutil.Logger(t)
@@ -170,7 +172,6 @@ func TestStopBeforeRun(t *testing.T) {
require.NoError(t, err)
// THEN: validate that the manager can be stopped safely without Run() having been called yet
ctx := testutil.Context(t, testutil.WaitSuperLong)
require.Eventually(t, func() bool {
assert.NoError(t, mgr.Stop(ctx))
return true
@@ -182,6 +183,7 @@ func TestRunStopRace(t *testing.T) {
// SETUP
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitMedium))
store, ps := dbtestutil.NewDB(t)
logger := testutil.Logger(t)
@@ -192,7 +194,6 @@ func TestRunStopRace(t *testing.T) {
// Start Run and Stop after each other (run does "go loop()").
// This is to catch a (now fixed) race condition where the manager
// would be accessed/stopped while it was being created/starting up.
ctx := testutil.Context(t, testutil.WaitMedium)
mgr.Run(ctx)
err = mgr.Stop(ctx)
require.NoError(t, err)
+5 -5
View File
@@ -18,6 +18,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/notifications"
"github.com/coder/coder/v2/coderd/notifications/dispatch"
@@ -32,6 +33,7 @@ func TestMetrics(t *testing.T) {
// SETUP
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitSuperLong))
store, pubsub := dbtestutil.NewDB(t)
logger := testutil.Logger(t)
@@ -55,7 +57,6 @@ func TestMetrics(t *testing.T) {
mgr, err := notifications.NewManager(cfg, store, pubsub, defaultHelpers(), metrics, logger.Named("manager"))
require.NoError(t, err)
ctx := testutil.Context(t, testutil.WaitSuperLong)
t.Cleanup(func() {
assert.NoError(t, mgr.Stop(ctx))
})
@@ -220,6 +221,7 @@ func TestPendingUpdatesMetric(t *testing.T) {
t.Parallel()
// SETUP
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitSuperLong))
store, pubsub := dbtestutil.NewDB(t)
logger := testutil.Logger(t)
@@ -245,7 +247,6 @@ func TestPendingUpdatesMetric(t *testing.T) {
mgr, err := notifications.NewManager(cfg, interceptor, pubsub, defaultHelpers(), metrics, logger.Named("manager"),
notifications.WithTestClock(mClock))
require.NoError(t, err)
ctx := testutil.Context(t, testutil.WaitSuperLong)
t.Cleanup(func() {
assert.NoError(t, mgr.Stop(ctx))
})
@@ -313,6 +314,7 @@ func TestInflightDispatchesMetric(t *testing.T) {
t.Parallel()
// SETUP
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitSuperLong))
store, pubsub := dbtestutil.NewDB(t)
logger := testutil.Logger(t)
@@ -331,7 +333,6 @@ func TestInflightDispatchesMetric(t *testing.T) {
mgr, err := notifications.NewManager(cfg, store, pubsub, defaultHelpers(), metrics, logger.Named("manager"))
require.NoError(t, err)
ctx := testutil.Context(t, testutil.WaitSuperLong)
t.Cleanup(func() {
assert.NoError(t, mgr.Stop(ctx))
})
@@ -385,6 +386,7 @@ func TestInflightDispatchesMetric(t *testing.T) {
func TestCustomMethodMetricCollection(t *testing.T) {
t.Parallel()
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitSuperLong))
store, pubsub := dbtestutil.NewDB(t)
logger := testutil.Logger(t)
@@ -400,8 +402,6 @@ func TestCustomMethodMetricCollection(t *testing.T) {
defaultMethod = database.NotificationMethodSmtp
)
ctx := testutil.Context(t, testutil.WaitSuperLong)
// GIVEN: a template whose notification method differs from the default.
out, err := store.UpdateNotificationTemplateMethodByID(ctx, database.UpdateNotificationTemplateMethodByIDParams{
ID: tmpl,
+2 -2
View File
@@ -1472,12 +1472,12 @@ func TestNotificationTemplates_Golden(t *testing.T) {
// as appearance changes are enterprise features and we do not want to mix those
// can't use the api
if tc.appName != "" {
err = (*db).UpsertApplicationName(ctx, "Custom Application")
err = (*db).UpsertApplicationName(dbauthz.AsSystemRestricted(ctx), "Custom Application")
require.NoError(t, err)
}
if tc.logoURL != "" {
err = (*db).UpsertLogoURL(ctx, "https://custom.application/logo.png")
err = (*db).UpsertLogoURL(dbauthz.AsSystemRestricted(ctx), "https://custom.application/logo.png")
require.NoError(t, err)
}
+1 -1
View File
@@ -516,7 +516,7 @@ func TestOAuth2ProviderTokenExchange(t *testing.T) {
} else {
require.NoError(t, err)
require.NotEmpty(t, token.AccessToken)
require.True(t, dbtime.Now().Before(token.Expiry))
require.True(t, time.Now().Before(token.Expiry))
// Check that the token works.
newClient := codersdk.New(userClient.URL)
@@ -21,6 +21,7 @@ import (
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/prometheusmetrics/insights"
@@ -126,7 +127,7 @@ func TestCollectInsights(t *testing.T) {
AppStatBatchSize: workspaceapps.DefaultStatsDBReporterBatchSize,
})
refTime := time.Now().Add(-3 * time.Minute).Truncate(time.Minute)
err = reporter.ReportAppStats(context.Background(), []workspaceapps.StatsReport{
err = reporter.ReportAppStats(dbauthz.AsSystemRestricted(context.Background()), []workspaceapps.StatsReport{
{
UserID: user.ID,
WorkspaceID: workspace1.ID,
+10 -16
View File
@@ -564,7 +564,7 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
// The check `s.OIDCConfig != nil` is not as strict, since it can be an interface
// pointing to a typed nil.
if !reflect.ValueOf(s.OIDCConfig).IsNil() {
workspaceOwnerOIDCAccessToken, err = ObtainOIDCAccessToken(ctx, s.Logger, s.Database, s.OIDCConfig, owner.ID)
workspaceOwnerOIDCAccessToken, err = obtainOIDCAccessToken(ctx, s.Logger, s.Database, s.OIDCConfig, owner.ID)
if err != nil {
return nil, failJob(fmt.Sprintf("obtain OIDC access token: %s", err))
}
@@ -3075,15 +3075,15 @@ func deleteSessionTokenForUserAndWorkspace(ctx context.Context, db database.Stor
return nil
}
func shouldRefreshOIDCToken(link database.UserLink) (bool, time.Time) {
func shouldRefreshOIDCToken(link database.UserLink) bool {
if link.OAuthRefreshToken == "" {
// We cannot refresh even if we wanted to
return false, link.OAuthExpiry
return false
}
if link.OAuthExpiry.IsZero() {
// 0 expire means the token never expires, so we shouldn't refresh
return false, link.OAuthExpiry
return false
}
// This handles an edge case where the token is about to expire. A workspace
@@ -3093,19 +3093,15 @@ func shouldRefreshOIDCToken(link database.UserLink) (bool, time.Time) {
//
// If an OIDC provider issues short-lived tokens less than our defined period,
// the token will always be refreshed on every workspace build.
//
// By setting the expiration backwards, we are effectively shortening the
// time a token can be alive for by 10 minutes.
// Note: This is how it is done in the oauth2 package's own token refreshing logic.
expiresAt := link.OAuthExpiry.Add(-time.Minute * 10)
assumeExpiredAt := dbtime.Now().Add(-1 * time.Minute * 10)
// Return if the token is assumed to be expired.
return expiresAt.Before(dbtime.Now()), expiresAt
return link.OAuthExpiry.Before(assumeExpiredAt)
}
// ObtainOIDCAccessToken returns a valid OpenID Connect access token
// obtainOIDCAccessToken returns a valid OpenID Connect access token
// for the user if it's able to obtain one, otherwise it returns an empty string.
func ObtainOIDCAccessToken(ctx context.Context, logger slog.Logger, db database.Store, oidcConfig promoauth.OAuth2Config, userID uuid.UUID) (string, error) {
func obtainOIDCAccessToken(ctx context.Context, logger slog.Logger, db database.Store, oidcConfig promoauth.OAuth2Config, userID uuid.UUID) (string, error) {
link, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
UserID: userID,
LoginType: database.LoginTypeOIDC,
@@ -3117,13 +3113,11 @@ func ObtainOIDCAccessToken(ctx context.Context, logger slog.Logger, db database.
return "", xerrors.Errorf("get owner oidc link: %w", err)
}
if shouldRefresh, expiresAt := shouldRefreshOIDCToken(link); shouldRefresh {
if shouldRefreshOIDCToken(link) {
token, err := oidcConfig.TokenSource(ctx, &oauth2.Token{
AccessToken: link.OAuthAccessToken,
RefreshToken: link.OAuthRefreshToken,
// Use the expiresAt returned by shouldRefreshOIDCToken.
// It will force a refresh with an expired time.
Expiry: expiresAt,
Expiry: link.OAuthExpiry,
}).Token()
if err != nil {
// If OIDC fails to refresh, we return an empty string and don't fail.

Some files were not shown because too many files have changed in this diff Show More