Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 603e68cc80 |
@@ -189,8 +189,8 @@ func (q *sqlQuerier) UpdateUser(ctx context.Context, arg UpdateUserParams) (User
|
||||
### Common Debug Commands
|
||||
|
||||
```bash
|
||||
# Run tests (starts Postgres automatically if needed)
|
||||
make test
|
||||
# Check database connection
|
||||
make test-postgres
|
||||
|
||||
# Run specific database tests
|
||||
go test ./coderd/database/... -run TestSpecificFunction
|
||||
|
||||
@@ -67,6 +67,7 @@ coderd/
|
||||
| `make test` | Run all Go tests |
|
||||
| `make test RUN=TestFunctionName` | Run specific test |
|
||||
| `go test -v ./path/to/package -run TestFunctionName` | Run test with verbose output |
|
||||
| `make test-postgres` | Run tests with Postgres database |
|
||||
| `make test-race` | Run tests with Go race detector |
|
||||
| `make test-e2e` | Run end-to-end tests |
|
||||
|
||||
|
||||
@@ -109,6 +109,7 @@
|
||||
|
||||
- Run full test suite: `make test`
|
||||
- Run specific test: `make test RUN=TestFunctionName`
|
||||
- Run with Postgres: `make test-postgres`
|
||||
- Run with race detector: `make test-race`
|
||||
- Run end-to-end tests: `make test-e2e`
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
name: "🐞 Bug"
|
||||
description: "File a bug report."
|
||||
title: "bug: "
|
||||
labels: ["needs-triage"]
|
||||
type: "Bug"
|
||||
body:
|
||||
- type: checkboxes
|
||||
|
||||
@@ -70,7 +70,11 @@ runs:
|
||||
set -euo pipefail
|
||||
|
||||
if [[ ${RACE_DETECTION} == true ]]; then
|
||||
make test-race
|
||||
gotestsum --junitfile="gotests.xml" --packages="${TEST_PACKAGES}" -- \
|
||||
-tags=testsmallbatch \
|
||||
-race \
|
||||
-parallel "${TEST_NUM_PARALLEL_TESTS}" \
|
||||
-p "${TEST_NUM_PARALLEL_PACKAGES}"
|
||||
else
|
||||
make test
|
||||
fi
|
||||
|
||||
@@ -366,9 +366,9 @@ jobs:
|
||||
needs: changes
|
||||
if: needs.changes.outputs.go == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
|
||||
# This timeout must be greater than the timeout set by `go test` in
|
||||
# `make test` to ensure we receive a trace of running goroutines.
|
||||
# Setting this to the timeout +5m should work quite well even if
|
||||
# some of the preceding steps are slow.
|
||||
# `make test-postgres` to ensure we receive a trace of running
|
||||
# goroutines. Setting this to the timeout +5m should work quite well
|
||||
# even if some of the preceding steps are slow.
|
||||
timeout-minutes: 25
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -475,6 +475,11 @@ jobs:
|
||||
mkdir -p /tmp/tmpfs
|
||||
sudo mount_tmpfs -o noowners -s 8g /tmp/tmpfs
|
||||
|
||||
# Install google-chrome for scaletests.
|
||||
# As another concern, should we really have this kind of external dependency
|
||||
# requirement on standard CI?
|
||||
brew install google-chrome
|
||||
|
||||
# macOS will output "The default interactive shell is now zsh" intermittently in CI.
|
||||
touch ~/.bash_profile && echo "export BASH_SILENCE_DEPRECATION_WARNING=1" >> ~/.bash_profile
|
||||
|
||||
@@ -569,9 +574,9 @@ jobs:
|
||||
- changes
|
||||
if: needs.changes.outputs.go == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
|
||||
# This timeout must be greater than the timeout set by `go test` in
|
||||
# `make test` to ensure we receive a trace of running goroutines.
|
||||
# Setting this to the timeout +5m should work quite well even if
|
||||
# some of the preceding steps are slow.
|
||||
# `make test-postgres` to ensure we receive a trace of running
|
||||
# goroutines. Setting this to the timeout +5m should work quite well
|
||||
# even if some of the preceding steps are slow.
|
||||
timeout-minutes: 25
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
@@ -981,9 +986,6 @@ jobs:
|
||||
run: |
|
||||
make build/coder_docs_"$(./scripts/version.sh)".tgz
|
||||
|
||||
- name: Check for unstaged files
|
||||
run: ./scripts/check_unstaged.sh
|
||||
|
||||
required:
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
name: Linear Release
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
# This event reads the workflow from the default branch (main), not the
|
||||
# release branch. No cherry-pick needed.
|
||||
# https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#release
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
sync:
|
||||
name: Sync issues to Linear release
|
||||
if: github.event_name == 'push'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Sync issues
|
||||
id: sync
|
||||
uses: linear/linear-release-action@f64cdc603e6eb7a7ef934bc5492ae929f88c8d1a # v0
|
||||
with:
|
||||
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
|
||||
command: sync
|
||||
|
||||
- name: Print release URL
|
||||
if: steps.sync.outputs.release-url
|
||||
run: echo "Synced to $RELEASE_URL"
|
||||
env:
|
||||
RELEASE_URL: ${{ steps.sync.outputs.release-url }}
|
||||
|
||||
complete:
|
||||
name: Complete Linear release
|
||||
if: github.event_name == 'release'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Complete release
|
||||
id: complete
|
||||
uses: linear/linear-release-action@f64cdc603e6eb7a7ef934bc5492ae929f88c8d1a # v0
|
||||
with:
|
||||
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
|
||||
command: complete
|
||||
version: ${{ github.event.release.tag_name }}
|
||||
|
||||
- name: Print release URL
|
||||
if: steps.complete.outputs.release-url
|
||||
run: echo "Completed $RELEASE_URL"
|
||||
env:
|
||||
RELEASE_URL: ${{ steps.complete.outputs.release-url }}
|
||||
@@ -16,9 +16,9 @@ jobs:
|
||||
# when changing runner sizes
|
||||
runs-on: ${{ matrix.os == 'macos-latest' && github.repository_owner == 'coder' && 'depot-macos-latest' || matrix.os == 'windows-2022' && github.repository_owner == 'coder' && 'depot-windows-2022-16' || matrix.os }}
|
||||
# This timeout must be greater than the timeout set by `go test` in
|
||||
# `make test` to ensure we receive a trace of running goroutines.
|
||||
# Setting this to the timeout +5m should work quite well even if
|
||||
# some of the preceding steps are slow.
|
||||
# `make test-postgres` to ensure we receive a trace of running
|
||||
# goroutines. Setting this to the timeout +5m should work quite well
|
||||
# even if some of the preceding steps are slow.
|
||||
timeout-minutes: 25
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
@@ -38,7 +38,6 @@ site/.swc
|
||||
|
||||
# Make target for updating generated/golden files (any dir).
|
||||
.gen
|
||||
/_gen/
|
||||
.gen-golden
|
||||
|
||||
# Build
|
||||
|
||||
@@ -37,20 +37,19 @@ Only pause to ask for confirmation when:
|
||||
|
||||
## Essential Commands
|
||||
|
||||
| Task | Command | Notes |
|
||||
|-----------------|--------------------------|-------------------------------------|
|
||||
| **Development** | `./scripts/develop.sh` | ⚠️ Don't use manual build |
|
||||
| **Build** | `make build` | Fat binaries (includes server) |
|
||||
| **Build Slim** | `make build-slim` | Slim binaries |
|
||||
| **Test** | `make test` | Full test suite |
|
||||
| **Test Single** | `make test RUN=TestName` | Faster than full suite |
|
||||
| **Test Race** | `make test-race` | Run tests with Go race detector |
|
||||
| **Lint** | `make lint` | Always run after changes |
|
||||
| **Generate** | `make gen` | After database changes |
|
||||
| **Format** | `make fmt` | Auto-format code |
|
||||
| **Clean** | `make clean` | Clean build artifacts |
|
||||
| **Pre-commit** | `make pre-commit` | Fast CI checks (gen/fmt/lint/build) |
|
||||
| **Pre-push** | `make pre-push` | All CI checks including tests |
|
||||
| Task | Command | Notes |
|
||||
|-------------------|--------------------------|----------------------------------|
|
||||
| **Development** | `./scripts/develop.sh` | ⚠️ Don't use manual build |
|
||||
| **Build** | `make build` | Fat binaries (includes server) |
|
||||
| **Build Slim** | `make build-slim` | Slim binaries |
|
||||
| **Test** | `make test` | Full test suite |
|
||||
| **Test Single** | `make test RUN=TestName` | Faster than full suite |
|
||||
| **Test Postgres** | `make test-postgres` | Run tests with Postgres database |
|
||||
| **Test Race** | `make test-race` | Run tests with Go race detector |
|
||||
| **Lint** | `make lint` | Always run after changes |
|
||||
| **Generate** | `make gen` | After database changes |
|
||||
| **Format** | `make fmt` | Auto-format code |
|
||||
| **Clean** | `make clean` | Clean build artifacts |
|
||||
|
||||
### Documentation Commands
|
||||
|
||||
@@ -104,37 +103,6 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID)
|
||||
|
||||
### Full workflows available in imported WORKFLOWS.md
|
||||
|
||||
### Git Hooks (MANDATORY - DO NOT SKIP)
|
||||
|
||||
**You MUST install and use the git hooks. NEVER bypass them with
|
||||
`--no-verify`. Skipping hooks wastes CI cycles and is unacceptable.**
|
||||
|
||||
The first run will be slow as caches warm up. Consecutive runs are
|
||||
**significantly faster** (often 10x) thanks to Go build cache,
|
||||
generated file timestamps, and warm node_modules. This is NOT a
|
||||
reason to skip them. Wait for hooks to complete before proceeding,
|
||||
no matter how long they take.
|
||||
|
||||
```sh
|
||||
git config core.hooksPath scripts/githooks
|
||||
```
|
||||
|
||||
Two hooks run automatically:
|
||||
|
||||
- **pre-commit**: `make pre-commit` (gen, fmt, lint, typos, build).
|
||||
Fast checks that catch most CI failures. Allow at least 5 minutes.
|
||||
- **pre-push**: `make pre-push` (full CI suite including tests).
|
||||
Runs before pushing to catch everything CI would. Allow at least
|
||||
15 minutes (race tests are slow without cache).
|
||||
|
||||
`git commit` and `git push` will appear to hang while hooks run.
|
||||
This is normal. Do not interrupt, retry, or reduce the timeout.
|
||||
|
||||
NEVER run `git config core.hooksPath` to change or disable hooks.
|
||||
|
||||
If a hook fails, fix the issue and retry. Do not work around the
|
||||
failure by skipping the hook.
|
||||
|
||||
### Git Workflow
|
||||
|
||||
When working on existing PRs, check out the branch first:
|
||||
|
||||
@@ -19,16 +19,6 @@ SHELL := bash
|
||||
.SHELLFLAGS := -ceu
|
||||
.ONESHELL:
|
||||
|
||||
# When MAKE_TIMED=1, replace SHELL with a wrapper that prints
|
||||
# elapsed wall-clock time for each recipe. pre-commit and pre-push
|
||||
# set this on their sub-makes so every parallel job reports its
|
||||
# duration. Ad-hoc usage: make MAKE_TIMED=1 test
|
||||
ifdef MAKE_TIMED
|
||||
SHELL := $(CURDIR)/scripts/lib/timed-shell.sh
|
||||
.SHELLFLAGS = $@ -ceu
|
||||
export MAKE_TIMED
|
||||
endif
|
||||
|
||||
# This doesn't work on directories.
|
||||
# See https://stackoverflow.com/questions/25752543/make-delete-on-error-for-directory-targets
|
||||
.DELETE_ON_ERROR:
|
||||
@@ -43,25 +33,6 @@ endif
|
||||
coderd/database/unique_constraint.go \
|
||||
coderd/database/dbmetrics/querymetrics.go \
|
||||
coderd/database/dbauthz/dbauthz.go \
|
||||
coderd/database/dbmock/dbmock.go \
|
||||
coderd/database/pubsub/psmock/psmock.go \
|
||||
agent/agentcontainers/acmock/acmock.go \
|
||||
coderd/httpmw/loggermw/loggermock/loggermock.go \
|
||||
codersdk/workspacesdk/agentconnmock/agentconnmock.go \
|
||||
tailnet/tailnettest/coordinatormock.go \
|
||||
tailnet/tailnettest/coordinateemock.go \
|
||||
tailnet/tailnettest/workspaceupdatesprovidermock.go \
|
||||
tailnet/tailnettest/subscriptionmock.go \
|
||||
enterprise/aibridged/aibridgedmock/clientmock.go \
|
||||
enterprise/aibridged/aibridgedmock/poolmock.go \
|
||||
tailnet/proto/tailnet.pb.go \
|
||||
agent/proto/agent.pb.go \
|
||||
agent/agentsocket/proto/agentsocket.pb.go \
|
||||
agent/boundarylogproxy/codec/boundary.pb.go \
|
||||
provisionersdk/proto/provisioner.pb.go \
|
||||
provisionerd/proto/provisionerd.pb.go \
|
||||
vpn/vpn.pb.go \
|
||||
enterprise/aibridged/proto/aibridged.pb.go \
|
||||
site/src/api/typesGenerated.ts \
|
||||
site/e2e/provisionerGenerated.ts \
|
||||
site/src/api/chatModelOptionsGenerated.json \
|
||||
@@ -79,23 +50,6 @@ endif
|
||||
codersdk/rbacresources_gen.go \
|
||||
codersdk/apikey_scopes_gen.go
|
||||
|
||||
# atomic_write runs a command, captures stdout into a temp file, and
|
||||
# atomically replaces $@. An optional second argument is a formatting
|
||||
# command that receives the temp file path as its argument.
|
||||
# Usage: $(call atomic_write,GENERATE_CMD[,FORMAT_CMD])
|
||||
define atomic_write
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && \
|
||||
$(1) > "$$tmpfile" && \
|
||||
$(if $(2),$(2) "$$tmpfile" &&) \
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
endef
|
||||
|
||||
# Shared temp directory for atomic writes. Lives at the project root
|
||||
# so all targets share the same filesystem, and is gitignored.
|
||||
# Order-only prerequisite: recipes that need it depend on | _gen
|
||||
_gen:
|
||||
mkdir -p _gen
|
||||
|
||||
# Don't print the commands in the file unless you specify VERBOSE. This is
|
||||
# essentially the same as putting "@" at the start of each line.
|
||||
ifndef VERBOSE
|
||||
@@ -113,11 +67,6 @@ VERSION := $(shell ./scripts/version.sh)
|
||||
POSTGRES_VERSION ?= 17
|
||||
POSTGRES_IMAGE ?= us-docker.pkg.dev/coder-v2-images-public/public/postgres:$(POSTGRES_VERSION)
|
||||
|
||||
# Limit parallel Make jobs in pre-commit/pre-push. Defaults to
|
||||
# nproc/4 (min 2) since test and lint targets have internal
|
||||
# parallelism. Override: make pre-push PARALLEL_JOBS=8
|
||||
PARALLEL_JOBS ?= $(shell n=$$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 8); echo $$(( n / 4 > 2 ? n / 4 : 2 )))
|
||||
|
||||
# Use the highest ZSTD compression level in CI.
|
||||
ifdef CI
|
||||
ZSTDFLAGS := -22 --ultra
|
||||
@@ -131,7 +80,7 @@ endif
|
||||
# Note, all find statements should be written with `.` or `./path` as
|
||||
# the search path so that these exclusions match.
|
||||
FIND_EXCLUSIONS= \
|
||||
-not \( \( -path '*/.git/*' -o -path './build/*' -o -path './vendor/*' -o -path './.coderv2/*' -o -path '*/node_modules/*' -o -path '*/out/*' -o -path './coderd/apidoc/*' -o -path '*/.next/*' -o -path '*/.terraform/*' -o -path './_gen/*' \) -prune \)
|
||||
-not \( \( -path '*/.git/*' -o -path './build/*' -o -path './vendor/*' -o -path './.coderv2/*' -o -path '*/node_modules/*' -o -path '*/out/*' -o -path './coderd/apidoc/*' -o -path '*/.next/*' -o -path '*/.terraform/*' \) -prune \)
|
||||
# Source files used for make targets, evaluated on use.
|
||||
GO_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.go' -not -name '*_test.go')
|
||||
# Same as GO_SRC_FILES but excluding certain files that have problematic
|
||||
@@ -621,7 +570,7 @@ endif
|
||||
# GitHub Actions linters are run in a separate CI job (lint-actions) that only
|
||||
# triggers when workflow files change, so we skip them here when CI=true.
|
||||
LINT_ACTIONS_TARGETS := $(if $(CI),,lint/actions/actionlint)
|
||||
lint: lint/shellcheck lint/go lint/ts lint/examples lint/helm lint/site-icons lint/markdown lint/check-scopes lint/migrations lint/bootstrap $(LINT_ACTIONS_TARGETS)
|
||||
lint: lint/shellcheck lint/go lint/ts lint/examples lint/helm lint/site-icons lint/markdown lint/check-scopes lint/migrations $(LINT_ACTIONS_TARGETS)
|
||||
.PHONY: lint
|
||||
|
||||
lint/site-icons:
|
||||
@@ -636,7 +585,7 @@ lint/ts: site/node_modules/.installed
|
||||
lint/go:
|
||||
./scripts/check_enterprise_imports.sh
|
||||
./scripts/check_codersdk_imports.sh
|
||||
linter_ver=$$(grep -oE 'GOLANGCI_LINT_VERSION=\S+' dogfood/coder/Dockerfile | cut -d '=' -f 2)
|
||||
linter_ver=$(shell egrep -o 'GOLANGCI_LINT_VERSION=\S+' dogfood/coder/Dockerfile | cut -d '=' -f 2)
|
||||
go run github.com/golangci/golangci-lint/cmd/golangci-lint@v$$linter_ver run
|
||||
go tool github.com/coder/paralleltestctx/cmd/paralleltestctx -custom-funcs="testutil.Context" ./...
|
||||
.PHONY: lint/go
|
||||
@@ -651,11 +600,6 @@ lint/shellcheck: $(SHELL_SRC_FILES)
|
||||
shellcheck --external-sources $(SHELL_SRC_FILES)
|
||||
.PHONY: lint/shellcheck
|
||||
|
||||
lint/bootstrap:
|
||||
bash scripts/check_bootstrap_quotes.sh
|
||||
.PHONY: lint/bootstrap
|
||||
|
||||
|
||||
lint/helm:
|
||||
cd helm/
|
||||
make lint
|
||||
@@ -690,118 +634,6 @@ lint/migrations:
|
||||
./scripts/check_pg_schema.sh "Fixtures" $(FIXTURE_FILES)
|
||||
.PHONY: lint/migrations
|
||||
|
||||
TYPOS_VERSION := $(shell grep -oP 'crate-ci/typos@\S+\s+\#\s+v\K[0-9.]+' .github/workflows/ci.yaml)
|
||||
|
||||
# Map uname values to typos release asset names.
|
||||
TYPOS_ARCH := $(shell uname -m)
|
||||
ifeq ($(shell uname -s),Darwin)
|
||||
TYPOS_OS := apple-darwin
|
||||
else
|
||||
TYPOS_OS := unknown-linux-musl
|
||||
endif
|
||||
|
||||
build/typos-$(TYPOS_VERSION):
|
||||
mkdir -p build/
|
||||
curl -sSfL "https://github.com/crate-ci/typos/releases/download/v$(TYPOS_VERSION)/typos-v$(TYPOS_VERSION)-$(TYPOS_ARCH)-$(TYPOS_OS).tar.gz" \
|
||||
| tar -xzf - -C build/ ./typos
|
||||
mv build/typos "$@"
|
||||
|
||||
lint/typos: build/typos-$(TYPOS_VERSION)
|
||||
build/typos-$(TYPOS_VERSION) --config .github/workflows/typos.toml
|
||||
.PHONY: lint/typos
|
||||
|
||||
# pre-commit and pre-push mirror CI "required" jobs locally.
|
||||
# See the "required" job's needs list in .github/workflows/ci.yaml.
|
||||
#
|
||||
# pre-commit runs checks that don't need external services (Docker,
|
||||
# Playwright). This is the git pre-commit hook default since test
|
||||
# and Docker failures in the local environment would otherwise block
|
||||
# all commits.
|
||||
#
|
||||
# pre-push runs the full CI suite including tests. This is the git
|
||||
# pre-push hook default, catching everything CI would before pushing.
|
||||
#
|
||||
# pre-push uses two-phase execution: gen+fmt+test-postgres-docker
|
||||
# first (writes files, starts Docker), then lint+build+test in
|
||||
# parallel. pre-commit uses two phases: gen+fmt first, then
|
||||
# lint+build. This avoids races where gen's `go run` creates
|
||||
# temporary .go files that lint's find-based checks pick up.
|
||||
# Within each phase, targets run in parallel via -j. Both fail if
|
||||
# any tracked files have unstaged changes afterward.
|
||||
#
|
||||
# Both pre-commit and pre-push:
|
||||
# gen, fmt, lint, lint/typos, slim binary (local arch)
|
||||
#
|
||||
# pre-push only (need external services or are slow):
|
||||
# site/out/index.html (pnpm build)
|
||||
# test-postgres-docker + test (needs Docker)
|
||||
# test-js, test-e2e (needs Playwright)
|
||||
# sqlc-vet (needs Docker)
|
||||
# offlinedocs/check
|
||||
#
|
||||
# Omitted:
|
||||
# test-go-pg-17 (same tests, different PG version)
|
||||
|
||||
define check-unstaged
|
||||
unstaged="$$(git diff --name-only)"
|
||||
if [[ -n $$unstaged ]]; then
|
||||
echo "ERROR: unstaged changes in tracked files:"
|
||||
echo "$$unstaged"
|
||||
echo
|
||||
echo "Review each change (git diff), verify correctness, then stage:"
|
||||
echo " git add -u && git commit"
|
||||
exit 1
|
||||
fi
|
||||
untracked=$$(git ls-files --other --exclude-standard)
|
||||
if [[ -n $$untracked ]]; then
|
||||
echo "WARNING: untracked files (not in this commit, won't be in CI):"
|
||||
echo "$$untracked"
|
||||
echo
|
||||
fi
|
||||
endef
|
||||
|
||||
pre-commit:
|
||||
start=$$(date +%s)
|
||||
echo "=== Phase 1/2: gen + fmt ==="
|
||||
$(MAKE) -j$(PARALLEL_JOBS) --output-sync=target MAKE_TIMED=1 gen fmt
|
||||
$(check-unstaged)
|
||||
echo "=== Phase 2/2: lint + build ==="
|
||||
$(MAKE) -j$(PARALLEL_JOBS) --output-sync=target MAKE_TIMED=1 \
|
||||
lint \
|
||||
lint/typos \
|
||||
build/coder-slim_$(GOOS)_$(GOARCH)$(GOOS_BIN_EXT)
|
||||
$(check-unstaged)
|
||||
echo "$(BOLD)$(GREEN)=== pre-commit passed in $$(( $$(date +%s) - $$start ))s ===$(RESET)"
|
||||
.PHONY: pre-commit
|
||||
|
||||
pre-push:
|
||||
start=$$(date +%s)
|
||||
echo "=== Phase 1/2: gen + fmt + postgres ==="
|
||||
$(MAKE) -j$(PARALLEL_JOBS) --output-sync=target MAKE_TIMED=1 gen fmt test-postgres-docker
|
||||
$(check-unstaged)
|
||||
echo "=== Phase 2/2: lint + build + test ==="
|
||||
$(MAKE) -j$(PARALLEL_JOBS) --output-sync=target MAKE_TIMED=1 \
|
||||
lint \
|
||||
lint/typos \
|
||||
build/coder-slim_$(GOOS)_$(GOARCH)$(GOOS_BIN_EXT) \
|
||||
site/out/index.html \
|
||||
test \
|
||||
test-js \
|
||||
test-e2e \
|
||||
test-race \
|
||||
sqlc-vet \
|
||||
offlinedocs/check
|
||||
$(check-unstaged)
|
||||
echo "$(BOLD)$(GREEN)=== pre-push passed in $$(( $$(date +%s) - $$start ))s ===$(RESET)"
|
||||
.PHONY: pre-push
|
||||
|
||||
offlinedocs/check: offlinedocs/node_modules/.installed
|
||||
cd offlinedocs/
|
||||
pnpm format:check
|
||||
pnpm lint
|
||||
pnpm export
|
||||
.PHONY: offlinedocs/check
|
||||
|
||||
# All files generated by the database should be added here, and this can be used
|
||||
# as a target for jobs that need to run after the database is generated.
|
||||
DB_GEN_FILES := \
|
||||
@@ -990,7 +822,7 @@ $(TAILNETTEST_MOCKS): tailnet/coordinator.go tailnet/service.go
|
||||
touch "$@"
|
||||
|
||||
tailnet/proto/tailnet.pb.go: tailnet/proto/tailnet.proto
|
||||
./scripts/atomic_protoc.sh \
|
||||
protoc \
|
||||
--go_out=. \
|
||||
--go_opt=paths=source_relative \
|
||||
--go-drpc_out=. \
|
||||
@@ -998,7 +830,7 @@ tailnet/proto/tailnet.pb.go: tailnet/proto/tailnet.proto
|
||||
./tailnet/proto/tailnet.proto
|
||||
|
||||
agent/proto/agent.pb.go: agent/proto/agent.proto
|
||||
./scripts/atomic_protoc.sh \
|
||||
protoc \
|
||||
--go_out=. \
|
||||
--go_opt=paths=source_relative \
|
||||
--go-drpc_out=. \
|
||||
@@ -1006,7 +838,7 @@ agent/proto/agent.pb.go: agent/proto/agent.proto
|
||||
./agent/proto/agent.proto
|
||||
|
||||
agent/agentsocket/proto/agentsocket.pb.go: agent/agentsocket/proto/agentsocket.proto agent/proto/agent.proto
|
||||
./scripts/atomic_protoc.sh \
|
||||
protoc \
|
||||
--go_out=. \
|
||||
--go_opt=paths=source_relative \
|
||||
--go-drpc_out=. \
|
||||
@@ -1014,7 +846,7 @@ agent/agentsocket/proto/agentsocket.pb.go: agent/agentsocket/proto/agentsocket.p
|
||||
./agent/agentsocket/proto/agentsocket.proto
|
||||
|
||||
provisionersdk/proto/provisioner.pb.go: provisionersdk/proto/provisioner.proto
|
||||
./scripts/atomic_protoc.sh \
|
||||
protoc \
|
||||
--go_out=. \
|
||||
--go_opt=paths=source_relative \
|
||||
--go-drpc_out=. \
|
||||
@@ -1022,7 +854,7 @@ provisionersdk/proto/provisioner.pb.go: provisionersdk/proto/provisioner.proto
|
||||
./provisionersdk/proto/provisioner.proto
|
||||
|
||||
provisionerd/proto/provisionerd.pb.go: provisionerd/proto/provisionerd.proto
|
||||
./scripts/atomic_protoc.sh \
|
||||
protoc \
|
||||
--go_out=. \
|
||||
--go_opt=paths=source_relative \
|
||||
--go-drpc_out=. \
|
||||
@@ -1030,110 +862,132 @@ provisionerd/proto/provisionerd.pb.go: provisionerd/proto/provisionerd.proto
|
||||
./provisionerd/proto/provisionerd.proto
|
||||
|
||||
vpn/vpn.pb.go: vpn/vpn.proto
|
||||
./scripts/atomic_protoc.sh \
|
||||
protoc \
|
||||
--go_out=. \
|
||||
--go_opt=paths=source_relative \
|
||||
./vpn/vpn.proto
|
||||
|
||||
agent/boundarylogproxy/codec/boundary.pb.go: agent/boundarylogproxy/codec/boundary.proto agent/proto/agent.proto
|
||||
./scripts/atomic_protoc.sh \
|
||||
protoc \
|
||||
--go_out=. \
|
||||
--go_opt=paths=source_relative \
|
||||
./agent/boundarylogproxy/codec/boundary.proto
|
||||
|
||||
enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged.proto
|
||||
./scripts/atomic_protoc.sh \
|
||||
protoc \
|
||||
--go_out=. \
|
||||
--go_opt=paths=source_relative \
|
||||
--go-drpc_out=. \
|
||||
--go-drpc_opt=paths=source_relative \
|
||||
./enterprise/aibridged/proto/aibridged.proto
|
||||
|
||||
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go') | _gen
|
||||
$(call atomic_write,go run -C ./scripts/apitypings main.go,./scripts/biome_format.sh)
|
||||
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go')
|
||||
# Generate to a temp file, format it, then atomically move to
|
||||
# the target so that an interrupt never leaves a partial or
|
||||
# unformatted file in the working tree.
|
||||
tmpfile=$$(mktemp -d)/$(notdir $@) && \
|
||||
go run -C ./scripts/apitypings main.go > "$$tmpfile" && \
|
||||
./scripts/biome_format.sh "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@"
|
||||
|
||||
site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/proto/provisionerd.pb.go provisionersdk/proto/provisioner.pb.go
|
||||
(cd site/ && pnpm run gen:provisioner)
|
||||
touch "$@"
|
||||
|
||||
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*) | _gen
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && \
|
||||
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*)
|
||||
tmpfile=$$(mktemp -d)/$(notdir $@) && \
|
||||
go run ./scripts/gensite/ -icons "$$tmpfile" && \
|
||||
./scripts/biome_format.sh "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
mv "$$tmpfile" "$@"
|
||||
|
||||
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates) | _gen
|
||||
$(call atomic_write,go run ./scripts/examplegen/main.go)
|
||||
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates)
|
||||
go run ./scripts/examplegen/main.go > "$@.tmp" && mv "$@.tmp" "$@"
|
||||
|
||||
coderd/rbac/object_gen.go: scripts/typegen/rbacobject.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go | _gen
|
||||
$(call atomic_write,go run ./scripts/typegen/main.go rbac object)
|
||||
coderd/rbac/object_gen.go: scripts/typegen/rbacobject.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go
|
||||
tempdir=$(shell mktemp -d /tmp/typegen_rbac_object.XXXXXX)
|
||||
go run ./scripts/typegen/main.go rbac object > "$$tempdir/object_gen.go"
|
||||
mv -v "$$tempdir/object_gen.go" coderd/rbac/object_gen.go
|
||||
rmdir -v "$$tempdir"
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go because `go run` compiles
|
||||
# coderd/rbac which includes it.
|
||||
coderd/rbac/scopes_constants_gen.go: scripts/typegen/scopenames.gotmpl scripts/typegen/main.go coderd/rbac/policy/policy.go \
|
||||
coderd/rbac/object_gen.go | _gen
|
||||
# Write to a temp file first to avoid truncating the package
|
||||
# during build since the generator imports the rbac package.
|
||||
$(call atomic_write,go run ./scripts/typegen/main.go rbac scopenames)
|
||||
coderd/rbac/object_gen.go
|
||||
# Generate typed low-level ScopeName constants from RBACPermissions
|
||||
# Write to a temp file first to avoid truncating the package during build
|
||||
# since the generator imports the rbac package.
|
||||
tempfile=$(shell mktemp /tmp/scopes_constants_gen.XXXXXX)
|
||||
go run ./scripts/typegen/main.go rbac scopenames > "$$tempfile"
|
||||
mv -v "$$tempfile" coderd/rbac/scopes_constants_gen.go
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
|
||||
# `go run` compiles coderd/rbac which includes both.
|
||||
codersdk/rbacresources_gen.go: scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go \
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
|
||||
# Write to a temp file to avoid truncating the target, which
|
||||
# would break the codersdk package and any parallel build targets.
|
||||
$(call atomic_write,go run scripts/typegen/main.go rbac codersdk)
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go
|
||||
# Do no overwrite codersdk/rbacresources_gen.go directly, as it would make the file empty, breaking
|
||||
# the `codersdk` package and any parallel build targets.
|
||||
go run scripts/typegen/main.go rbac codersdk > /tmp/rbacresources_gen.go
|
||||
mv /tmp/rbacresources_gen.go codersdk/rbacresources_gen.go
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
|
||||
# `go run` compiles coderd/rbac which includes both.
|
||||
codersdk/apikey_scopes_gen.go: scripts/apikeyscopesgen/main.go coderd/rbac/scopes_catalog.go coderd/rbac/scopes.go \
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go
|
||||
# Generate SDK constants for external API key scopes.
|
||||
$(call atomic_write,go run ./scripts/apikeyscopesgen)
|
||||
go run ./scripts/apikeyscopesgen > /tmp/apikey_scopes_gen.go
|
||||
mv /tmp/apikey_scopes_gen.go codersdk/apikey_scopes_gen.go
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
|
||||
# `go run` compiles coderd/rbac which includes both.
|
||||
site/src/api/rbacresourcesGenerated.ts: site/node_modules/.installed scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go \
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
|
||||
$(call atomic_write,go run scripts/typegen/main.go rbac typescript,./scripts/biome_format.sh)
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go
|
||||
tmpfile=$$(mktemp -d)/$(notdir $@) && \
|
||||
go run scripts/typegen/main.go rbac typescript > "$$tmpfile" && \
|
||||
./scripts/biome_format.sh "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@"
|
||||
|
||||
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go | _gen
|
||||
$(call atomic_write,go run scripts/typegen/main.go countries,./scripts/biome_format.sh)
|
||||
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go
|
||||
tmpfile=$$(mktemp -d)/$(notdir $@) && \
|
||||
go run scripts/typegen/main.go countries > "$$tmpfile" && \
|
||||
./scripts/biome_format.sh "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@"
|
||||
|
||||
site/src/api/chatModelOptionsGenerated.json: scripts/modeloptionsgen/main.go codersdk/chats.go | _gen
|
||||
$(call atomic_write,go run ./scripts/modeloptionsgen/main.go | tail -n +2,./scripts/biome_format.sh)
|
||||
site/src/api/chatModelOptionsGenerated.json: scripts/modeloptionsgen/main.go codersdk/chats.go
|
||||
tmpfile=$$(mktemp -d)/$(notdir $@) && \
|
||||
go run ./scripts/modeloptionsgen/main.go | tail -n +2 > "$$tmpfile" && \
|
||||
./scripts/biome_format.sh "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@"
|
||||
|
||||
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES) | _gen
|
||||
$(call atomic_write,go run ./scripts/metricsdocgen/scanner)
|
||||
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES)
|
||||
go run ./scripts/metricsdocgen/scanner > $@.tmp && mv $@.tmp $@
|
||||
|
||||
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics scripts/metricsdocgen/generated_metrics | _gen
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && cp "$@" "$$tmpfile" && \
|
||||
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics scripts/metricsdocgen/generated_metrics
|
||||
tmpfile=$$(mktemp -d)/$(notdir $@) && cp "$@" "$$tmpfile" && \
|
||||
go run scripts/metricsdocgen/main.go --prometheus-doc-file="$$tmpfile" && \
|
||||
pnpm exec markdownlint-cli2 --fix "$$tmpfile" && \
|
||||
pnpm exec markdown-table-formatter "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
mv "$$tmpfile" "$@"
|
||||
|
||||
docs/reference/cli/index.md: node_modules/.installed scripts/clidocgen/main.go examples/examples.gen.json $(GO_SRC_FILES) | _gen
|
||||
tmpdir=$$(mktemp -d -p _gen) && \
|
||||
tmpdir=$$(realpath "$$tmpdir") && \
|
||||
docs/reference/cli/index.md: node_modules/.installed scripts/clidocgen/main.go examples/examples.gen.json $(GO_SRC_FILES)
|
||||
tmpdir=$$(mktemp -d) && \
|
||||
mkdir -p "$$tmpdir/docs/reference/cli" && \
|
||||
cp docs/manifest.json "$$tmpdir/docs/manifest.json" && \
|
||||
CI=true DOCS_DIR="$$tmpdir/docs" go run ./scripts/clidocgen && \
|
||||
pnpm exec markdownlint-cli2 --fix "$$tmpdir/docs/reference/cli/*.md" && \
|
||||
pnpm exec markdown-table-formatter "$$tmpdir/docs/reference/cli/*.md" && \
|
||||
for f in "$$tmpdir/docs/reference/cli/"*.md; do mv "$$f" "docs/reference/cli/$$(basename "$$f")"; done && \
|
||||
cp "$$tmpdir/docs/reference/cli/"*.md docs/reference/cli/ && \
|
||||
rm -rf "$$tmpdir"
|
||||
|
||||
docs/admin/security/audit-logs.md: node_modules/.installed coderd/database/querier.go scripts/auditdocgen/main.go enterprise/audit/table.go coderd/rbac/object_gen.go | _gen
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && cp "$@" "$$tmpfile" && \
|
||||
docs/admin/security/audit-logs.md: node_modules/.installed coderd/database/querier.go scripts/auditdocgen/main.go enterprise/audit/table.go coderd/rbac/object_gen.go
|
||||
tmpfile=$$(mktemp -d)/$(notdir $@) && cp "$@" "$$tmpfile" && \
|
||||
go run scripts/auditdocgen/main.go --audit-doc-file="$$tmpfile" && \
|
||||
pnpm exec markdownlint-cli2 --fix "$$tmpfile" && \
|
||||
pnpm exec markdown-table-formatter "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
mv "$$tmpfile" "$@"
|
||||
|
||||
coderd/apidoc/.gen: \
|
||||
node_modules/.installed \
|
||||
@@ -1148,27 +1002,25 @@ coderd/apidoc/.gen: \
|
||||
scripts/apidocgen/generate.sh \
|
||||
scripts/apidocgen/swaginit/main.go \
|
||||
$(wildcard scripts/apidocgen/postprocess/*) \
|
||||
$(wildcard scripts/apidocgen/markdown-template/*) | _gen
|
||||
tmpdir=$$(mktemp -d -p _gen) && swagtmp=$$(mktemp -d -p _gen) && \
|
||||
tmpdir=$$(realpath "$$tmpdir") && swagtmp=$$(realpath "$$swagtmp") && \
|
||||
$(wildcard scripts/apidocgen/markdown-template/*)
|
||||
tmpdir=$$(mktemp -d) && swagtmp=$$(mktemp -d) && \
|
||||
mkdir -p "$$tmpdir/reference/api" && \
|
||||
cp docs/manifest.json "$$tmpdir/manifest.json" && \
|
||||
SWAG_OUTPUT_DIR="$$swagtmp" APIDOCGEN_DOCS_DIR="$$tmpdir" ./scripts/apidocgen/generate.sh && \
|
||||
pnpm exec markdownlint-cli2 --fix "$$tmpdir/reference/api/*.md" && \
|
||||
pnpm exec markdown-table-formatter "$$tmpdir/reference/api/*.md" && \
|
||||
./scripts/biome_format.sh "$$swagtmp/swagger.json" && \
|
||||
for f in "$$tmpdir/reference/api/"*.md; do mv "$$f" "docs/reference/api/$$(basename "$$f")"; done && \
|
||||
mv "$$tmpdir/manifest.json" _gen/manifest-staging.json && \
|
||||
mv "$$swagtmp/docs.go" coderd/apidoc/docs.go && \
|
||||
mv "$$swagtmp/swagger.json" coderd/apidoc/swagger.json && \
|
||||
cp "$$tmpdir/reference/api/"*.md docs/reference/api/ && \
|
||||
cp "$$tmpdir/manifest.json" docs/manifest.json && \
|
||||
cp "$$swagtmp/docs.go" coderd/apidoc/docs.go && \
|
||||
cp "$$swagtmp/swagger.json" coderd/apidoc/swagger.json && \
|
||||
rm -rf "$$tmpdir" "$$swagtmp"
|
||||
touch "$@"
|
||||
|
||||
docs/manifest.json: site/node_modules/.installed coderd/apidoc/.gen docs/reference/cli/index.md | _gen
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && \
|
||||
cp _gen/manifest-staging.json "$$tmpfile" && \
|
||||
docs/manifest.json: site/node_modules/.installed coderd/apidoc/.gen docs/reference/cli/index.md
|
||||
tmpfile=$$(mktemp -d)/$(notdir $@) && cp "$@" "$$tmpfile" && \
|
||||
./scripts/biome_format.sh "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
mv "$$tmpfile" "$@"
|
||||
|
||||
coderd/apidoc/swagger.json: site/node_modules/.installed coderd/apidoc/.gen
|
||||
touch "$@"
|
||||
@@ -1255,22 +1107,10 @@ else
|
||||
GOTESTSUM_RETRY_FLAGS :=
|
||||
endif
|
||||
|
||||
# Default to 8x8 parallelism to avoid overwhelming our workspaces.
|
||||
# Race detection defaults to 4x4 because the detector adds significant
|
||||
# CPU overhead. Override via TEST_NUM_PARALLEL_PACKAGES /
|
||||
# TEST_NUM_PARALLEL_TESTS.
|
||||
TEST_PARALLEL_PACKAGES := $(or $(TEST_NUM_PARALLEL_PACKAGES),8)
|
||||
TEST_PARALLEL_TESTS := $(or $(TEST_NUM_PARALLEL_TESTS),8)
|
||||
RACE_PARALLEL_PACKAGES := $(or $(TEST_NUM_PARALLEL_PACKAGES),4)
|
||||
RACE_PARALLEL_TESTS := $(or $(TEST_NUM_PARALLEL_TESTS),4)
|
||||
|
||||
# Use testsmallbatch tag to reduce wireguard memory allocation in tests
|
||||
# (from ~18GB to negligible). Recursively expanded so target-specific
|
||||
# overrides of TEST_PARALLEL_* take effect (e.g. test-race lowers
|
||||
# parallelism). CI job timeout is 25m (see test-go-pg in ci.yaml),
|
||||
# keep the Go timeout 5m shorter so tests produce goroutine dumps
|
||||
# instead of the CI runner killing the process with no output.
|
||||
GOTEST_FLAGS = -tags=testsmallbatch -v -timeout 20m -p $(TEST_PARALLEL_PACKAGES) -parallel=$(TEST_PARALLEL_TESTS)
|
||||
# default to 8x8 parallelism to avoid overwhelming our workspaces. Hopefully we can remove these defaults
|
||||
# when we get our test suite's resource utilization under control.
|
||||
# Use testsmallbatch tag to reduce wireguard memory allocation in tests (from ~18GB to negligible).
|
||||
GOTEST_FLAGS := -tags=testsmallbatch -v -p $(or $(TEST_NUM_PARALLEL_PACKAGES),"8") -parallel=$(or $(TEST_NUM_PARALLEL_TESTS),"8")
|
||||
|
||||
# The most common use is to set TEST_COUNT=1 to avoid Go's test cache.
|
||||
ifdef TEST_COUNT
|
||||
@@ -1296,34 +1136,13 @@ endif
|
||||
TEST_PACKAGES ?= ./...
|
||||
|
||||
test:
|
||||
$(GIT_FLAGS) gotestsum --format standard-quiet \
|
||||
$(GOTESTSUM_RETRY_FLAGS) \
|
||||
--packages="$(TEST_PACKAGES)" \
|
||||
-- \
|
||||
$(GOTEST_FLAGS)
|
||||
$(GIT_FLAGS) gotestsum --format standard-quiet $(GOTESTSUM_RETRY_FLAGS) --packages="$(TEST_PACKAGES)" -- $(GOTEST_FLAGS)
|
||||
.PHONY: test
|
||||
|
||||
test-race: TEST_PARALLEL_PACKAGES := $(RACE_PARALLEL_PACKAGES)
|
||||
test-race: TEST_PARALLEL_TESTS := $(RACE_PARALLEL_TESTS)
|
||||
test-race:
|
||||
$(GIT_FLAGS) gotestsum --format standard-quiet \
|
||||
--junitfile="gotests.xml" \
|
||||
$(GOTESTSUM_RETRY_FLAGS) \
|
||||
--packages="$(TEST_PACKAGES)" \
|
||||
-- \
|
||||
-race \
|
||||
$(GOTEST_FLAGS)
|
||||
.PHONY: test-race
|
||||
|
||||
test-cli:
|
||||
$(MAKE) test TEST_PACKAGES="./cli..."
|
||||
.PHONY: test-cli
|
||||
|
||||
test-js: site/node_modules/.installed
|
||||
cd site/
|
||||
pnpm test:ci
|
||||
.PHONY: test-js
|
||||
|
||||
# sqlc-cloud-is-setup will fail if no SQLc auth token is set. Use this as a
|
||||
# dependency for any sqlc-cloud related targets.
|
||||
sqlc-cloud-is-setup:
|
||||
@@ -1335,22 +1154,37 @@ sqlc-cloud-is-setup:
|
||||
|
||||
sqlc-push: sqlc-cloud-is-setup test-postgres-docker
|
||||
echo "--- sqlc push"
|
||||
SQLC_DATABASE_URL="postgresql://postgres:postgres@localhost:5432/$$(go run scripts/migrate-ci/main.go)" \
|
||||
SQLC_DATABASE_URL="postgresql://postgres:postgres@localhost:5432/$(shell go run scripts/migrate-ci/main.go)" \
|
||||
sqlc push -f coderd/database/sqlc.yaml && echo "Passed sqlc push"
|
||||
.PHONY: sqlc-push
|
||||
|
||||
sqlc-verify: sqlc-cloud-is-setup test-postgres-docker
|
||||
echo "--- sqlc verify"
|
||||
SQLC_DATABASE_URL="postgresql://postgres:postgres@localhost:5432/$$(go run scripts/migrate-ci/main.go)" \
|
||||
SQLC_DATABASE_URL="postgresql://postgres:postgres@localhost:5432/$(shell go run scripts/migrate-ci/main.go)" \
|
||||
sqlc verify -f coderd/database/sqlc.yaml && echo "Passed sqlc verify"
|
||||
.PHONY: sqlc-verify
|
||||
|
||||
sqlc-vet: test-postgres-docker
|
||||
echo "--- sqlc vet"
|
||||
SQLC_DATABASE_URL="postgresql://postgres:postgres@localhost:5432/$$(go run scripts/migrate-ci/main.go)" \
|
||||
SQLC_DATABASE_URL="postgresql://postgres:postgres@localhost:5432/$(shell go run scripts/migrate-ci/main.go)" \
|
||||
sqlc vet -f coderd/database/sqlc.yaml && echo "Passed sqlc vet"
|
||||
.PHONY: sqlc-vet
|
||||
|
||||
# When updating -timeout for this test, keep in sync with
|
||||
# test-go-postgres (.github/workflows/coder.yaml).
|
||||
# Do add coverage flags so that test caching works.
|
||||
test-postgres: test-postgres-docker
|
||||
# The postgres test is prone to failure, so we limit parallelism for
|
||||
# more consistent execution.
|
||||
$(GIT_FLAGS) gotestsum \
|
||||
--junitfile="gotests.xml" \
|
||||
--jsonfile="gotests.json" \
|
||||
$(GOTESTSUM_RETRY_FLAGS) \
|
||||
--packages="./..." -- \
|
||||
-tags=testsmallbatch \
|
||||
-timeout=20m \
|
||||
-count=1
|
||||
.PHONY: test-postgres
|
||||
|
||||
test-migrations: test-postgres-docker
|
||||
echo "--- test migrations"
|
||||
@@ -1366,24 +1200,13 @@ test-migrations: test-postgres-docker
|
||||
|
||||
# NOTE: we set --memory to the same size as a GitHub runner.
|
||||
test-postgres-docker:
|
||||
# If our container is already running, nothing to do.
|
||||
if docker ps --filter "name=test-postgres-docker-${POSTGRES_VERSION}" --format '{{.Names}}' | grep -q .; then \
|
||||
echo "test-postgres-docker-${POSTGRES_VERSION} is already running."; \
|
||||
exit 0; \
|
||||
fi
|
||||
# If something else is on 5432, warn but don't fail.
|
||||
if pg_isready -h 127.0.0.1 -q 2>/dev/null; then \
|
||||
echo "WARNING: PostgreSQL is already running on 127.0.0.1:5432 (not our container)."; \
|
||||
echo "Tests will use this instance. To use the Makefile's container, stop it first."; \
|
||||
exit 0; \
|
||||
fi
|
||||
docker rm -f test-postgres-docker-${POSTGRES_VERSION} || true
|
||||
|
||||
# Try pulling up to three times to avoid CI flakes.
|
||||
docker pull ${POSTGRES_IMAGE} || {
|
||||
retries=2
|
||||
for try in $$(seq 1 $${retries}); do
|
||||
echo "Failed to pull image, retrying ($${try}/$${retries})..."
|
||||
for try in $(seq 1 ${retries}); do
|
||||
echo "Failed to pull image, retrying (${try}/${retries})..."
|
||||
sleep 1
|
||||
if docker pull ${POSTGRES_IMAGE}; then
|
||||
break
|
||||
@@ -1424,11 +1247,16 @@ test-postgres-docker:
|
||||
-c log_statement=all
|
||||
while ! pg_isready -h 127.0.0.1
|
||||
do
|
||||
echo "$$(date) - waiting for database to start"
|
||||
echo "$(date) - waiting for database to start"
|
||||
sleep 0.5
|
||||
done
|
||||
.PHONY: test-postgres-docker
|
||||
|
||||
# Make sure to keep this in sync with test-go-race from .github/workflows/ci.yaml.
|
||||
test-race:
|
||||
$(GIT_FLAGS) gotestsum --junitfile="gotests.xml" -- -tags=testsmallbatch -race -count=1 -parallel 4 -p 4 ./...
|
||||
.PHONY: test-race
|
||||
|
||||
test-tailnet-integration:
|
||||
env \
|
||||
CODER_TAILNET_TESTS=true \
|
||||
@@ -1457,7 +1285,6 @@ site/e2e/bin/coder: go.mod go.sum $(GO_SRC_FILES)
|
||||
|
||||
test-e2e: site/e2e/bin/coder site/node_modules/.installed site/out/index.html
|
||||
cd site/
|
||||
pnpm playwright:install
|
||||
ifdef CI
|
||||
DEBUG=pw:api pnpm playwright:test --forbid-only --workers 1
|
||||
else
|
||||
|
||||
+2
-10
@@ -41,7 +41,6 @@ import (
|
||||
"github.com/coder/coder/v2/agent/agentcontainers"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentfiles"
|
||||
"github.com/coder/coder/v2/agent/agentgit"
|
||||
"github.com/coder/coder/v2/agent/agentproc"
|
||||
"github.com/coder/coder/v2/agent/agentscripts"
|
||||
"github.com/coder/coder/v2/agent/agentsocket"
|
||||
@@ -103,7 +102,6 @@ type Options struct {
|
||||
Execer agentexec.Execer
|
||||
Devcontainers bool
|
||||
DevcontainerAPIOptions []agentcontainers.Option // Enable Devcontainers for these to be effective.
|
||||
GitAPIOptions []agentgit.Option
|
||||
Clock quartz.Clock
|
||||
SocketServerEnabled bool
|
||||
SocketPath string // Path for the agent socket server socket
|
||||
@@ -219,7 +217,6 @@ func New(options Options) Agent {
|
||||
|
||||
devcontainers: options.Devcontainers,
|
||||
containerAPIOptions: options.DevcontainerAPIOptions,
|
||||
gitAPIOptions: options.GitAPIOptions,
|
||||
socketPath: options.SocketPath,
|
||||
socketServerEnabled: options.SocketServerEnabled,
|
||||
boundaryLogProxySocketPath: options.BoundaryLogProxySocketPath,
|
||||
@@ -305,10 +302,8 @@ type agent struct {
|
||||
devcontainers bool
|
||||
containerAPIOptions []agentcontainers.Option
|
||||
containerAPI *agentcontainers.API
|
||||
gitAPIOptions []agentgit.Option
|
||||
|
||||
filesAPI *agentfiles.API
|
||||
gitAPI *agentgit.API
|
||||
processAPI *agentproc.API
|
||||
|
||||
socketServerEnabled bool
|
||||
@@ -381,11 +376,8 @@ func (a *agent) init() {
|
||||
|
||||
a.containerAPI = agentcontainers.NewAPI(a.logger.Named("containers"), containerAPIOpts...)
|
||||
|
||||
pathStore := agentgit.NewPathStore()
|
||||
a.filesAPI = agentfiles.NewAPI(a.logger.Named("files"), a.filesystem, pathStore)
|
||||
a.processAPI = agentproc.NewAPI(a.logger.Named("processes"), a.execer, a.updateCommandEnv, pathStore)
|
||||
gitOpts := append([]agentgit.Option{agentgit.WithClock(a.clock)}, a.gitAPIOptions...)
|
||||
a.gitAPI = agentgit.NewAPI(a.logger.Named("git"), pathStore, gitOpts...)
|
||||
a.filesAPI = agentfiles.NewAPI(a.logger.Named("files"), a.filesystem)
|
||||
a.processAPI = agentproc.NewAPI(a.logger.Named("processes"), a.execer, a.updateCommandEnv)
|
||||
|
||||
a.reconnectingPTYServer = reconnectingpty.NewServer(
|
||||
a.logger.Named("reconnecting-pty"),
|
||||
|
||||
@@ -7,21 +7,18 @@ import (
|
||||
"github.com/spf13/afero"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentgit"
|
||||
)
|
||||
|
||||
// API exposes file-related operations performed through the agent.
|
||||
type API struct {
|
||||
logger slog.Logger
|
||||
filesystem afero.Fs
|
||||
pathStore *agentgit.PathStore
|
||||
}
|
||||
|
||||
func NewAPI(logger slog.Logger, filesystem afero.Fs, pathStore *agentgit.PathStore) *API {
|
||||
func NewAPI(logger slog.Logger, filesystem afero.Fs) *API {
|
||||
api := &API{
|
||||
logger: logger,
|
||||
filesystem: filesystem,
|
||||
pathStore: pathStore,
|
||||
}
|
||||
return api
|
||||
}
|
||||
|
||||
@@ -13,12 +13,10 @@ import (
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/spf13/afero"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentgit"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
@@ -303,13 +301,6 @@ func (api *API) HandleWriteFile(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Track edited path for git watch.
|
||||
if api.pathStore != nil {
|
||||
if chatID, ancestorIDs, ok := agentgit.ExtractChatContext(r); ok {
|
||||
api.pathStore.AddPaths(append([]uuid.UUID{chatID}, ancestorIDs...), []string{path})
|
||||
}
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{
|
||||
Message: fmt.Sprintf("Successfully wrote to %q", path),
|
||||
})
|
||||
@@ -389,17 +380,6 @@ func (api *API) HandleEditFiles(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Track edited paths for git watch.
|
||||
if api.pathStore != nil {
|
||||
if chatID, ancestorIDs, ok := agentgit.ExtractChatContext(r); ok {
|
||||
filePaths := make([]string, 0, len(req.Files))
|
||||
for _, f := range req.Files {
|
||||
filePaths = append(filePaths, f.Path)
|
||||
}
|
||||
api.pathStore.AddPaths(append([]uuid.UUID{chatID}, ancestorIDs...), filePaths)
|
||||
}
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{
|
||||
Message: "Successfully edited file(s)",
|
||||
})
|
||||
|
||||
@@ -11,12 +11,9 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
@@ -24,7 +21,6 @@ import (
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agentfiles"
|
||||
"github.com/coder/coder/v2/agent/agentgit"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
@@ -120,7 +116,7 @@ func TestReadFile(t *testing.T) {
|
||||
}
|
||||
return nil
|
||||
})
|
||||
api := agentfiles.NewAPI(logger, fs, nil)
|
||||
api := agentfiles.NewAPI(logger, fs)
|
||||
|
||||
dirPath := filepath.Join(tmpdir, "a-directory")
|
||||
err := fs.MkdirAll(dirPath, 0o755)
|
||||
@@ -300,7 +296,7 @@ func TestWriteFile(t *testing.T) {
|
||||
}
|
||||
return nil
|
||||
})
|
||||
api := agentfiles.NewAPI(logger, fs, nil)
|
||||
api := agentfiles.NewAPI(logger, fs)
|
||||
|
||||
dirPath := filepath.Join(tmpdir, "directory")
|
||||
err := fs.MkdirAll(dirPath, 0o755)
|
||||
@@ -418,7 +414,7 @@ func TestEditFiles(t *testing.T) {
|
||||
}
|
||||
return nil
|
||||
})
|
||||
api := agentfiles.NewAPI(logger, fs, nil)
|
||||
api := agentfiles.NewAPI(logger, fs)
|
||||
|
||||
dirPath := filepath.Join(tmpdir, "directory")
|
||||
err := fs.MkdirAll(dirPath, 0o755)
|
||||
@@ -842,169 +838,6 @@ func TestEditFiles(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleWriteFile_ChatHeaders_UpdatesPathStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pathStore := agentgit.NewPathStore()
|
||||
logger := slogtest.Make(t, nil)
|
||||
fs := afero.NewMemMapFs()
|
||||
api := agentfiles.NewAPI(logger, fs, pathStore)
|
||||
|
||||
testPath := filepath.Join(os.TempDir(), "test.txt")
|
||||
|
||||
chatID := uuid.New()
|
||||
ancestorID := uuid.New()
|
||||
ancestorJSON, _ := json.Marshal([]string{ancestorID.String()})
|
||||
|
||||
body := strings.NewReader("hello world")
|
||||
req := httptest.NewRequest(http.MethodPost, "/write-file?path="+testPath, body)
|
||||
req.Header.Set(workspacesdk.CoderChatIDHeader, chatID.String())
|
||||
req.Header.Set(workspacesdk.CoderAncestorChatIDsHeader, string(ancestorJSON))
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
r := chi.NewRouter()
|
||||
r.Post("/write-file", api.HandleWriteFile)
|
||||
r.ServeHTTP(rr, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Verify PathStore was updated for both chat and ancestor.
|
||||
paths := pathStore.GetPaths(chatID)
|
||||
require.Equal(t, []string{testPath}, paths)
|
||||
|
||||
ancestorPaths := pathStore.GetPaths(ancestorID)
|
||||
require.Equal(t, []string{testPath}, ancestorPaths)
|
||||
}
|
||||
|
||||
func TestHandleWriteFile_NoChatHeaders_NoPathStoreUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pathStore := agentgit.NewPathStore()
|
||||
logger := slogtest.Make(t, nil)
|
||||
fs := afero.NewMemMapFs()
|
||||
api := agentfiles.NewAPI(logger, fs, pathStore)
|
||||
|
||||
testPath := filepath.Join(os.TempDir(), "test.txt")
|
||||
|
||||
body := strings.NewReader("hello world")
|
||||
req := httptest.NewRequest(http.MethodPost, "/write-file?path="+testPath, body)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
r := chi.NewRouter()
|
||||
r.Post("/write-file", api.HandleWriteFile)
|
||||
r.ServeHTTP(rr, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// PathStore should be globally empty since no chat headers were set.
|
||||
require.Equal(t, 0, pathStore.Len())
|
||||
}
|
||||
|
||||
func TestHandleWriteFile_Failure_NoPathStoreUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pathStore := agentgit.NewPathStore()
|
||||
logger := slogtest.Make(t, nil)
|
||||
fs := afero.NewMemMapFs()
|
||||
api := agentfiles.NewAPI(logger, fs, pathStore)
|
||||
|
||||
chatID := uuid.New()
|
||||
|
||||
// Write to a relative path (should fail with 400).
|
||||
body := strings.NewReader("hello world")
|
||||
req := httptest.NewRequest(http.MethodPost, "/write-file?path=relative/path.txt", body)
|
||||
req.Header.Set(workspacesdk.CoderChatIDHeader, chatID.String())
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
r := chi.NewRouter()
|
||||
r.Post("/write-file", api.HandleWriteFile)
|
||||
r.ServeHTTP(rr, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
|
||||
// PathStore should NOT be updated on failure.
|
||||
paths := pathStore.GetPaths(chatID)
|
||||
require.Empty(t, paths)
|
||||
}
|
||||
|
||||
func TestHandleEditFiles_ChatHeaders_UpdatesPathStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pathStore := agentgit.NewPathStore()
|
||||
logger := slogtest.Make(t, nil)
|
||||
fs := afero.NewMemMapFs()
|
||||
api := agentfiles.NewAPI(logger, fs, pathStore)
|
||||
|
||||
testPath := filepath.Join(os.TempDir(), "test.txt")
|
||||
|
||||
// Create the file first.
|
||||
require.NoError(t, afero.WriteFile(fs, testPath, []byte("hello"), 0o644))
|
||||
|
||||
chatID := uuid.New()
|
||||
editReq := workspacesdk.FileEditRequest{
|
||||
Files: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: testPath,
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{Search: "hello", Replace: "world"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(editReq)
|
||||
req := httptest.NewRequest(http.MethodPost, "/edit-files", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set(workspacesdk.CoderChatIDHeader, chatID.String())
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
r := chi.NewRouter()
|
||||
r.Post("/edit-files", api.HandleEditFiles)
|
||||
r.ServeHTTP(rr, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
paths := pathStore.GetPaths(chatID)
|
||||
require.Equal(t, []string{testPath}, paths)
|
||||
}
|
||||
|
||||
func TestHandleEditFiles_Failure_NoPathStoreUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pathStore := agentgit.NewPathStore()
|
||||
logger := slogtest.Make(t, nil)
|
||||
fs := afero.NewMemMapFs()
|
||||
api := agentfiles.NewAPI(logger, fs, pathStore)
|
||||
|
||||
chatID := uuid.New()
|
||||
|
||||
// Edit a non-existent file (should fail with 404).
|
||||
editReq := workspacesdk.FileEditRequest{
|
||||
Files: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: "/nonexistent/file.txt",
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{Search: "hello", Replace: "world"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(editReq)
|
||||
req := httptest.NewRequest(http.MethodPost, "/edit-files", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set(workspacesdk.CoderChatIDHeader, chatID.String())
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
r := chi.NewRouter()
|
||||
r.Post("/edit-files", api.HandleEditFiles)
|
||||
r.ServeHTTP(rr, req)
|
||||
|
||||
require.NotEqual(t, http.StatusOK, rr.Code)
|
||||
|
||||
// PathStore should NOT be updated on failure.
|
||||
paths := pathStore.GetPaths(chatID)
|
||||
require.Empty(t, paths)
|
||||
}
|
||||
|
||||
func TestReadFileLines(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1018,7 +851,7 @@ func TestReadFileLines(t *testing.T) {
|
||||
}
|
||||
return nil
|
||||
})
|
||||
api := agentfiles.NewAPI(logger, fs, nil)
|
||||
api := agentfiles.NewAPI(logger, fs)
|
||||
|
||||
dirPath := filepath.Join(tmpdir, "a-directory-lines")
|
||||
err := fs.MkdirAll(dirPath, 0o755)
|
||||
|
||||
@@ -1,441 +0,0 @@
|
||||
// Package agentgit provides a WebSocket-based service for watching git
|
||||
// repository changes on the agent. It is mounted at /api/v0/git/watch
|
||||
// and allows clients to subscribe to file paths, triggering scans of
|
||||
// the corresponding git repositories.
|
||||
package agentgit
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/dustin/go-humanize"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// Option configures the git watch service.
|
||||
type Option func(*Handler)
|
||||
|
||||
// WithClock sets a controllable clock for testing. Defaults to
|
||||
// quartz.NewReal().
|
||||
func WithClock(c quartz.Clock) Option {
|
||||
return func(h *Handler) {
|
||||
h.clock = c
|
||||
}
|
||||
}
|
||||
|
||||
// WithGitBinary overrides the git binary path (for testing).
|
||||
func WithGitBinary(path string) Option {
|
||||
return func(h *Handler) {
|
||||
h.gitBin = path
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
// scanCooldown is the minimum interval between successive scans.
|
||||
scanCooldown = 1 * time.Second
|
||||
// fallbackPollInterval is the safety-net poll period used when no
|
||||
// filesystem events arrive.
|
||||
fallbackPollInterval = 30 * time.Second
|
||||
// maxTotalDiffSize is the maximum size of the combined
|
||||
// unified diff for an entire repository sent over the wire.
|
||||
// This must stay under the WebSocket message size limit.
|
||||
maxTotalDiffSize = 3 * 1024 * 1024 // 3 MiB
|
||||
)
|
||||
|
||||
// Handler manages per-connection git watch state.
|
||||
type Handler struct {
|
||||
logger slog.Logger
|
||||
clock quartz.Clock
|
||||
gitBin string // path to git binary; empty means "git" (from PATH)
|
||||
|
||||
mu sync.Mutex
|
||||
repoRoots map[string]struct{} // watched repo roots
|
||||
lastSnapshots map[string]repoSnapshot // last emitted snapshot per repo
|
||||
lastScanAt time.Time // when the last scan completed
|
||||
scanTrigger chan struct{} // buffered(1), poked by triggers
|
||||
}
|
||||
|
||||
// repoSnapshot captures the last emitted state for delta comparison.
|
||||
type repoSnapshot struct {
|
||||
branch string
|
||||
remoteOrigin string
|
||||
unifiedDiff string
|
||||
}
|
||||
|
||||
// NewHandler creates a new git watch handler.
|
||||
func NewHandler(logger slog.Logger, opts ...Option) *Handler {
|
||||
h := &Handler{
|
||||
logger: logger,
|
||||
clock: quartz.NewReal(),
|
||||
gitBin: "git",
|
||||
repoRoots: make(map[string]struct{}),
|
||||
lastSnapshots: make(map[string]repoSnapshot),
|
||||
scanTrigger: make(chan struct{}, 1),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(h)
|
||||
}
|
||||
|
||||
// Check if git is available.
|
||||
if _, err := exec.LookPath(h.gitBin); err != nil {
|
||||
h.logger.Warn(context.Background(), "git binary not found, git scanning disabled")
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// gitAvailable returns true if the configured git binary can be found
|
||||
// in PATH.
|
||||
func (h *Handler) gitAvailable() bool {
|
||||
_, err := exec.LookPath(h.gitBin)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// Subscribe processes a subscribe message, resolving paths to git repo
|
||||
// roots and adding new repos to the watch set. Returns true if any new
|
||||
// repo roots were added.
|
||||
func (h *Handler) Subscribe(paths []string) bool {
|
||||
if !h.gitAvailable() {
|
||||
return false
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
added := false
|
||||
for _, p := range paths {
|
||||
if !filepath.IsAbs(p) {
|
||||
continue
|
||||
}
|
||||
p = filepath.Clean(p)
|
||||
|
||||
root, err := findRepoRoot(h.gitBin, p)
|
||||
if err != nil {
|
||||
// Not a git path — silently ignore.
|
||||
continue
|
||||
}
|
||||
if _, ok := h.repoRoots[root]; ok {
|
||||
continue
|
||||
}
|
||||
h.repoRoots[root] = struct{}{}
|
||||
added = true
|
||||
}
|
||||
return added
|
||||
}
|
||||
|
||||
// RequestScan pokes the scan trigger so the run loop performs a scan.
|
||||
func (h *Handler) RequestScan() {
|
||||
select {
|
||||
case h.scanTrigger <- struct{}{}:
|
||||
default:
|
||||
// Already pending.
|
||||
}
|
||||
}
|
||||
|
||||
// Scan performs a scan of all subscribed repos and computes deltas
|
||||
// against the previously emitted snapshots.
|
||||
func (h *Handler) Scan(ctx context.Context) *codersdk.WorkspaceAgentGitServerMessage {
|
||||
if !h.gitAvailable() {
|
||||
return nil
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
roots := make([]string, 0, len(h.repoRoots))
|
||||
for r := range h.repoRoots {
|
||||
roots = append(roots, r)
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
if len(roots) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
now := h.clock.Now().UTC()
|
||||
var repos []codersdk.WorkspaceAgentRepoChanges
|
||||
|
||||
// Perform all I/O outside the lock to avoid blocking
|
||||
// AddPaths/GetPaths/Subscribe callers during disk-heavy scans.
|
||||
type scanResult struct {
|
||||
root string
|
||||
changes codersdk.WorkspaceAgentRepoChanges
|
||||
err error
|
||||
}
|
||||
results := make([]scanResult, 0, len(roots))
|
||||
for _, root := range roots {
|
||||
changes, err := getRepoChanges(ctx, h.logger, h.gitBin, root)
|
||||
results = append(results, scanResult{root: root, changes: changes, err: err})
|
||||
}
|
||||
|
||||
// Re-acquire the lock only to commit snapshot updates.
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
for _, res := range results {
|
||||
if res.err != nil {
|
||||
if isRepoDeleted(h.gitBin, res.root) {
|
||||
// Repo root or .git directory was removed.
|
||||
// Emit a removal entry, then evict from watch set.
|
||||
removal := codersdk.WorkspaceAgentRepoChanges{
|
||||
RepoRoot: res.root,
|
||||
Removed: true,
|
||||
}
|
||||
delete(h.repoRoots, res.root)
|
||||
delete(h.lastSnapshots, res.root)
|
||||
repos = append(repos, removal)
|
||||
} else {
|
||||
// Transient error — log and skip without
|
||||
// removing the repo from the watch set.
|
||||
h.logger.Warn(ctx, "scan repo failed",
|
||||
slog.F("root", res.root),
|
||||
slog.Error(res.err),
|
||||
)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
prev, hasPrev := h.lastSnapshots[res.root]
|
||||
if hasPrev &&
|
||||
prev.branch == res.changes.Branch &&
|
||||
prev.remoteOrigin == res.changes.RemoteOrigin &&
|
||||
prev.unifiedDiff == res.changes.UnifiedDiff {
|
||||
// No change in this repo since last emit.
|
||||
continue
|
||||
}
|
||||
|
||||
// Update snapshot.
|
||||
h.lastSnapshots[res.root] = repoSnapshot{
|
||||
branch: res.changes.Branch,
|
||||
remoteOrigin: res.changes.RemoteOrigin,
|
||||
unifiedDiff: res.changes.UnifiedDiff,
|
||||
}
|
||||
|
||||
repos = append(repos, res.changes)
|
||||
}
|
||||
|
||||
h.lastScanAt = now
|
||||
|
||||
if len(repos) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &codersdk.WorkspaceAgentGitServerMessage{
|
||||
Type: codersdk.WorkspaceAgentGitServerMessageTypeChanges,
|
||||
ScannedAt: &now,
|
||||
Repositories: repos,
|
||||
}
|
||||
}
|
||||
|
||||
// RunLoop runs the main event loop that listens for refresh requests
|
||||
// and fallback poll ticks. It calls scanFn whenever a scan should
|
||||
// happen (rate-limited to scanCooldown). It blocks until ctx is
|
||||
// canceled.
|
||||
func (h *Handler) RunLoop(ctx context.Context, scanFn func()) {
|
||||
fallbackTicker := h.clock.NewTicker(fallbackPollInterval)
|
||||
defer fallbackTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
||||
case <-h.scanTrigger:
|
||||
h.rateLimitedScan(ctx, scanFn)
|
||||
|
||||
case <-fallbackTicker.C:
|
||||
h.rateLimitedScan(ctx, scanFn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) rateLimitedScan(ctx context.Context, scanFn func()) {
|
||||
h.mu.Lock()
|
||||
elapsed := h.clock.Since(h.lastScanAt)
|
||||
if elapsed < scanCooldown {
|
||||
h.mu.Unlock()
|
||||
|
||||
// Wait for cooldown then scan.
|
||||
remaining := scanCooldown - elapsed
|
||||
timer := h.clock.NewTimer(remaining)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
scanFn()
|
||||
return
|
||||
}
|
||||
h.mu.Unlock()
|
||||
scanFn()
|
||||
}
|
||||
|
||||
// isRepoDeleted returns true when the repo root directory or its .git
|
||||
// entry no longer represents a valid git repository. This
|
||||
// distinguishes a genuine repo deletion from a transient scan error
|
||||
// (e.g. lock contention).
|
||||
//
|
||||
// It handles three deletion cases:
|
||||
// 1. The repo root directory itself was removed.
|
||||
// 2. The .git entry (directory or file) was removed.
|
||||
// 3. The .git entry is a file (worktree/submodule) whose target
|
||||
// gitdir was removed. In this case .git exists on disk but
|
||||
// `git rev-parse --git-dir` fails because the referenced
|
||||
// directory is gone.
|
||||
func isRepoDeleted(gitBin string, repoRoot string) bool {
|
||||
if _, err := os.Stat(repoRoot); os.IsNotExist(err) {
|
||||
return true
|
||||
}
|
||||
gitPath := filepath.Join(repoRoot, ".git")
|
||||
fi, err := os.Stat(gitPath)
|
||||
if os.IsNotExist(err) {
|
||||
return true
|
||||
}
|
||||
// If .git is a regular file (worktree or submodule), the actual
|
||||
// git object store lives elsewhere. Validate that the target is
|
||||
// still reachable by running git rev-parse.
|
||||
if err == nil && !fi.IsDir() {
|
||||
cmd := exec.CommandContext(context.Background(), gitBin, "-C", repoRoot, "rev-parse", "--git-dir")
|
||||
if err := cmd.Run(); err != nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// findRepoRoot uses `git rev-parse --show-toplevel` to find the
|
||||
// repository root for the given path.
|
||||
func findRepoRoot(gitBin string, p string) (string, error) {
|
||||
// If p is a file, start from its parent directory.
|
||||
dir := p
|
||||
if info, err := os.Stat(dir); err != nil || !info.IsDir() {
|
||||
dir = filepath.Dir(dir)
|
||||
}
|
||||
cmd := exec.CommandContext(context.Background(), gitBin, "rev-parse", "--show-toplevel")
|
||||
cmd.Dir = dir
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("no git repo found for %s", p)
|
||||
}
|
||||
root := filepath.FromSlash(strings.TrimSpace(string(out)))
|
||||
// Resolve symlinks and short (8.3) names on Windows so the
|
||||
// returned root matches paths produced by Go's filepath APIs.
|
||||
if resolved, evalErr := filepath.EvalSymlinks(root); evalErr == nil {
|
||||
root = resolved
|
||||
}
|
||||
return root, nil
|
||||
}
|
||||
|
||||
// getRepoChanges reads the current state of a git repository using
|
||||
// the git CLI. It returns branch, remote origin, and a unified diff.
|
||||
func getRepoChanges(ctx context.Context, logger slog.Logger, gitBin string, repoRoot string) (codersdk.WorkspaceAgentRepoChanges, error) {
|
||||
result := codersdk.WorkspaceAgentRepoChanges{
|
||||
RepoRoot: repoRoot,
|
||||
}
|
||||
|
||||
// Verify this is still a valid git repository before doing
|
||||
// anything else. This catches deleted repos early.
|
||||
verifyCmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "rev-parse", "--git-dir")
|
||||
if err := verifyCmd.Run(); err != nil {
|
||||
return result, xerrors.Errorf("not a git repository: %w", err)
|
||||
}
|
||||
|
||||
// Read branch name.
|
||||
branchCmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "symbolic-ref", "--short", "HEAD")
|
||||
if out, err := branchCmd.Output(); err == nil {
|
||||
result.Branch = strings.TrimSpace(string(out))
|
||||
} else {
|
||||
logger.Debug(ctx, "failed to read HEAD", slog.F("root", repoRoot), slog.Error(err))
|
||||
}
|
||||
|
||||
// Read remote origin URL.
|
||||
remoteCmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "config", "--get", "remote.origin.url")
|
||||
if out, err := remoteCmd.Output(); err == nil {
|
||||
result.RemoteOrigin = strings.TrimSpace(string(out))
|
||||
}
|
||||
|
||||
// Compute unified diff.
|
||||
// `git diff HEAD` shows both staged and unstaged changes vs HEAD.
|
||||
// For repos with no commits yet, fall back to showing untracked
|
||||
// files only.
|
||||
diff, err := computeGitDiff(ctx, logger, gitBin, repoRoot)
|
||||
if err != nil {
|
||||
return result, xerrors.Errorf("compute diff: %w", err)
|
||||
}
|
||||
|
||||
result.UnifiedDiff = diff
|
||||
if len(result.UnifiedDiff) > maxTotalDiffSize {
|
||||
result.UnifiedDiff = "Total diff too large to show. Size: " + humanize.IBytes(uint64(len(result.UnifiedDiff))) + ". Showing branch and remote only."
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// computeGitDiff produces a unified diff string for the repository by
|
||||
// combining `git diff HEAD` (staged + unstaged changes) with diffs
|
||||
// for untracked files.
|
||||
func computeGitDiff(ctx context.Context, logger slog.Logger, gitBin string, repoRoot string) (string, error) {
|
||||
var diffParts []string
|
||||
|
||||
// Check if the repo has any commits.
|
||||
hasCommits := true
|
||||
checkCmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "rev-parse", "HEAD")
|
||||
if err := checkCmd.Run(); err != nil {
|
||||
hasCommits = false
|
||||
}
|
||||
|
||||
if hasCommits {
|
||||
// `git diff HEAD` captures both staged and unstaged changes
|
||||
// relative to HEAD in a single unified diff.
|
||||
cmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "diff", "HEAD")
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("git diff HEAD: %w", err)
|
||||
}
|
||||
if len(out) > 0 {
|
||||
diffParts = append(diffParts, string(out))
|
||||
}
|
||||
}
|
||||
|
||||
// Show untracked files as diffs too.
|
||||
// `git ls-files --others --exclude-standard` lists untracked,
|
||||
// non-ignored files.
|
||||
lsCmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "ls-files", "--others", "--exclude-standard")
|
||||
lsOut, err := lsCmd.Output()
|
||||
if err != nil {
|
||||
logger.Debug(ctx, "failed to list untracked files", slog.F("root", repoRoot), slog.Error(err))
|
||||
return strings.Join(diffParts, ""), nil
|
||||
}
|
||||
|
||||
untrackedFiles := strings.Split(strings.TrimSpace(string(lsOut)), "\n")
|
||||
for _, f := range untrackedFiles {
|
||||
f = strings.TrimSpace(f)
|
||||
if f == "" {
|
||||
continue
|
||||
}
|
||||
// Use `git diff --no-index /dev/null <file>` to generate
|
||||
// a unified diff for untracked files.
|
||||
var stdout bytes.Buffer
|
||||
untrackedCmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "diff", "--no-index", "--", "/dev/null", f)
|
||||
untrackedCmd.Stdout = &stdout
|
||||
// git diff --no-index exits with 1 when files differ,
|
||||
// which is expected. We ignore the error and check for
|
||||
// output instead.
|
||||
_ = untrackedCmd.Run()
|
||||
if stdout.Len() > 0 {
|
||||
diffParts = append(diffParts, stdout.String())
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(diffParts, ""), nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,147 +0,0 @@
|
||||
package agentgit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/wsjson"
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
// API exposes the git watch HTTP routes for the agent.
|
||||
type API struct {
|
||||
logger slog.Logger
|
||||
opts []Option
|
||||
pathStore *PathStore
|
||||
}
|
||||
|
||||
// NewAPI creates a new git watch API.
|
||||
func NewAPI(logger slog.Logger, pathStore *PathStore, opts ...Option) *API {
|
||||
return &API{
|
||||
logger: logger,
|
||||
pathStore: pathStore,
|
||||
opts: opts,
|
||||
}
|
||||
}
|
||||
|
||||
// Routes returns the chi router for mounting at /api/v0/git.
|
||||
func (a *API) Routes() http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Get("/watch", a.handleWatch)
|
||||
return r
|
||||
}
|
||||
|
||||
func (a *API) handleWatch(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
|
||||
CompressionMode: websocket.CompressionNoContextTakeover,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to accept WebSocket.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 4 MiB read limit — subscribe messages with many paths can exceed the
|
||||
// default 32 KB limit. Matches the SDK/proxy side.
|
||||
conn.SetReadLimit(1 << 22)
|
||||
|
||||
stream := wsjson.NewStream[
|
||||
codersdk.WorkspaceAgentGitClientMessage,
|
||||
codersdk.WorkspaceAgentGitServerMessage,
|
||||
](conn, websocket.MessageText, websocket.MessageText, a.logger)
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go httpapi.HeartbeatClose(ctx, a.logger, cancel, conn)
|
||||
|
||||
handler := NewHandler(a.logger, a.opts...)
|
||||
|
||||
// scanAndSend performs a scan and sends results if there are
|
||||
// changes.
|
||||
scanAndSend := func() {
|
||||
msg := handler.Scan(ctx)
|
||||
if msg != nil {
|
||||
if err := stream.Send(*msg); err != nil {
|
||||
a.logger.Debug(ctx, "failed to send changes", slog.Error(err))
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If a chat_id query parameter is provided and the PathStore is
|
||||
// available, subscribe to path updates for this chat.
|
||||
chatIDStr := r.URL.Query().Get("chat_id")
|
||||
if chatIDStr != "" && a.pathStore != nil {
|
||||
chatID, parseErr := uuid.Parse(chatIDStr)
|
||||
if parseErr == nil {
|
||||
// Subscribe to future path updates BEFORE reading
|
||||
// existing paths. This ordering guarantees no
|
||||
// notification from AddPaths is lost: any call that
|
||||
// lands before Subscribe is picked up by GetPaths
|
||||
// below, and any call after Subscribe delivers a
|
||||
// notification on the channel.
|
||||
notifyCh, unsubscribe := a.pathStore.Subscribe(chatID)
|
||||
defer unsubscribe()
|
||||
|
||||
// Load any paths that are already tracked for this chat.
|
||||
existingPaths := a.pathStore.GetPaths(chatID)
|
||||
if len(existingPaths) > 0 {
|
||||
handler.Subscribe(existingPaths)
|
||||
handler.RequestScan()
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-notifyCh:
|
||||
paths := a.pathStore.GetPaths(chatID)
|
||||
handler.Subscribe(paths)
|
||||
handler.RequestScan()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// Start the main run loop in a goroutine.
|
||||
go handler.RunLoop(ctx, scanAndSend)
|
||||
|
||||
// Read client messages.
|
||||
updates := stream.Chan()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = stream.Close(websocket.StatusGoingAway)
|
||||
return
|
||||
case msg, ok := <-updates:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
switch msg.Type {
|
||||
case codersdk.WorkspaceAgentGitClientMessageTypeRefresh:
|
||||
handler.RequestScan()
|
||||
default:
|
||||
if err := stream.Send(codersdk.WorkspaceAgentGitServerMessage{
|
||||
Type: codersdk.WorkspaceAgentGitServerMessageTypeError,
|
||||
Message: "unknown message type",
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
package agentgit
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
// ExtractChatContext reads chat identity headers from the request.
|
||||
// Returns zero values if headers are absent (non-chat request).
|
||||
func ExtractChatContext(r *http.Request) (chatID uuid.UUID, ancestorIDs []uuid.UUID, ok bool) {
|
||||
raw := r.Header.Get(workspacesdk.CoderChatIDHeader)
|
||||
if raw == "" {
|
||||
return uuid.Nil, nil, false
|
||||
}
|
||||
chatID, err := uuid.Parse(raw)
|
||||
if err != nil {
|
||||
return uuid.Nil, nil, false
|
||||
}
|
||||
rawAncestors := r.Header.Get(workspacesdk.CoderAncestorChatIDsHeader)
|
||||
if rawAncestors != "" {
|
||||
var ids []string
|
||||
if err := json.Unmarshal([]byte(rawAncestors), &ids); err == nil {
|
||||
for _, s := range ids {
|
||||
if id, err := uuid.Parse(s); err == nil {
|
||||
ancestorIDs = append(ancestorIDs, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return chatID, ancestorIDs, true
|
||||
}
|
||||
@@ -1,148 +0,0 @@
|
||||
package agentgit_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentgit"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
func TestExtractChatContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
validID := uuid.MustParse("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee")
|
||||
ancestor1 := uuid.MustParse("11111111-2222-3333-4444-555555555555")
|
||||
ancestor2 := uuid.MustParse("66666666-7777-8888-9999-aaaaaaaaaaaa")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
chatID string // empty means header not set
|
||||
setChatID bool // whether to set the chat ID header at all
|
||||
ancestors string // empty means header not set
|
||||
setAncestors bool // whether to set the ancestor header at all
|
||||
wantChatID uuid.UUID
|
||||
wantAncestorIDs []uuid.UUID
|
||||
wantOK bool
|
||||
}{
|
||||
{
|
||||
name: "NoHeadersPresent",
|
||||
setChatID: false,
|
||||
setAncestors: false,
|
||||
wantChatID: uuid.Nil,
|
||||
wantAncestorIDs: nil,
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "ValidChatID_NoAncestors",
|
||||
chatID: validID.String(),
|
||||
setChatID: true,
|
||||
setAncestors: false,
|
||||
wantChatID: validID,
|
||||
wantAncestorIDs: nil,
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "ValidChatID_ValidAncestors",
|
||||
chatID: validID.String(),
|
||||
setChatID: true,
|
||||
ancestors: mustMarshalJSON(t, []string{
|
||||
ancestor1.String(),
|
||||
ancestor2.String(),
|
||||
}),
|
||||
setAncestors: true,
|
||||
wantChatID: validID,
|
||||
wantAncestorIDs: []uuid.UUID{ancestor1, ancestor2},
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "MalformedChatID",
|
||||
chatID: "not-a-uuid",
|
||||
setChatID: true,
|
||||
setAncestors: false,
|
||||
wantChatID: uuid.Nil,
|
||||
wantAncestorIDs: nil,
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "ValidChatID_MalformedAncestorJSON",
|
||||
chatID: validID.String(),
|
||||
setChatID: true,
|
||||
ancestors: `{this is not json}`,
|
||||
setAncestors: true,
|
||||
wantChatID: validID,
|
||||
wantAncestorIDs: nil,
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
// Only valid UUIDs in the array are returned; invalid
|
||||
// entries are silently skipped.
|
||||
name: "ValidChatID_PartialValidAncestorUUIDs",
|
||||
chatID: validID.String(),
|
||||
setChatID: true,
|
||||
ancestors: mustMarshalJSON(t, []string{
|
||||
ancestor1.String(),
|
||||
"bad-uuid",
|
||||
ancestor2.String(),
|
||||
}),
|
||||
setAncestors: true,
|
||||
wantChatID: validID,
|
||||
wantAncestorIDs: []uuid.UUID{ancestor1, ancestor2},
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
// Header is explicitly set to an empty string, which
|
||||
// Header.Get returns as "".
|
||||
name: "EmptyChatIDHeader",
|
||||
chatID: "",
|
||||
setChatID: true,
|
||||
setAncestors: false,
|
||||
wantChatID: uuid.Nil,
|
||||
wantAncestorIDs: nil,
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "ValidChatID_EmptyAncestorHeader",
|
||||
chatID: validID.String(),
|
||||
setChatID: true,
|
||||
ancestors: "",
|
||||
setAncestors: true,
|
||||
wantChatID: validID,
|
||||
wantAncestorIDs: nil,
|
||||
wantOK: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
if tt.setChatID {
|
||||
r.Header.Set(workspacesdk.CoderChatIDHeader, tt.chatID)
|
||||
}
|
||||
if tt.setAncestors {
|
||||
r.Header.Set(workspacesdk.CoderAncestorChatIDsHeader, tt.ancestors)
|
||||
}
|
||||
|
||||
chatID, ancestorIDs, ok := agentgit.ExtractChatContext(r)
|
||||
|
||||
require.Equal(t, tt.wantOK, ok, "ok mismatch")
|
||||
require.Equal(t, tt.wantChatID, chatID, "chatID mismatch")
|
||||
require.Equal(t, tt.wantAncestorIDs, ancestorIDs, "ancestorIDs mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mustMarshalJSON marshals v to a JSON string, failing the test on error.
|
||||
func mustMarshalJSON(t *testing.T, v any) string {
|
||||
t.Helper()
|
||||
b, err := json.Marshal(v)
|
||||
require.NoError(t, err)
|
||||
return string(b)
|
||||
}
|
||||
@@ -1,136 +0,0 @@
|
||||
package agentgit
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// PathStore tracks which file paths each chat has touched.
|
||||
// It is safe for concurrent use.
|
||||
type PathStore struct {
|
||||
mu sync.RWMutex
|
||||
chatPaths map[uuid.UUID]map[string]struct{}
|
||||
subscribers map[uuid.UUID][]chan<- struct{}
|
||||
}
|
||||
|
||||
// NewPathStore creates a new PathStore.
|
||||
func NewPathStore() *PathStore {
|
||||
return &PathStore{
|
||||
chatPaths: make(map[uuid.UUID]map[string]struct{}),
|
||||
subscribers: make(map[uuid.UUID][]chan<- struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// AddPaths adds paths to every chat in chatIDs and notifies
|
||||
// their subscribers. Zero-value UUIDs are silently skipped.
|
||||
func (ps *PathStore) AddPaths(chatIDs []uuid.UUID, paths []string) {
|
||||
affected := make([]uuid.UUID, 0, len(chatIDs))
|
||||
for _, id := range chatIDs {
|
||||
if id != uuid.Nil {
|
||||
affected = append(affected, id)
|
||||
}
|
||||
}
|
||||
if len(affected) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
ps.mu.Lock()
|
||||
for _, id := range affected {
|
||||
m, ok := ps.chatPaths[id]
|
||||
if !ok {
|
||||
m = make(map[string]struct{})
|
||||
ps.chatPaths[id] = m
|
||||
}
|
||||
for _, p := range paths {
|
||||
m[p] = struct{}{}
|
||||
}
|
||||
}
|
||||
ps.mu.Unlock()
|
||||
|
||||
ps.notifySubscribers(affected)
|
||||
}
|
||||
|
||||
// Notify sends a signal to all subscribers of the given chat IDs
|
||||
// without adding any paths. Zero-value UUIDs are silently skipped.
|
||||
func (ps *PathStore) Notify(chatIDs []uuid.UUID) {
|
||||
affected := make([]uuid.UUID, 0, len(chatIDs))
|
||||
for _, id := range chatIDs {
|
||||
if id != uuid.Nil {
|
||||
affected = append(affected, id)
|
||||
}
|
||||
}
|
||||
if len(affected) == 0 {
|
||||
return
|
||||
}
|
||||
ps.notifySubscribers(affected)
|
||||
}
|
||||
|
||||
// notifySubscribers sends a non-blocking signal to all subscriber
|
||||
// channels for the given chat IDs.
|
||||
func (ps *PathStore) notifySubscribers(chatIDs []uuid.UUID) {
|
||||
ps.mu.RLock()
|
||||
toNotify := make([]chan<- struct{}, 0)
|
||||
for _, id := range chatIDs {
|
||||
toNotify = append(toNotify, ps.subscribers[id]...)
|
||||
}
|
||||
ps.mu.RUnlock()
|
||||
|
||||
for _, ch := range toNotify {
|
||||
select {
|
||||
case ch <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetPaths returns all paths tracked for a chat, deduplicated
|
||||
// and sorted lexicographically.
|
||||
func (ps *PathStore) GetPaths(chatID uuid.UUID) []string {
|
||||
ps.mu.RLock()
|
||||
defer ps.mu.RUnlock()
|
||||
|
||||
m := ps.chatPaths[chatID]
|
||||
if len(m) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(m))
|
||||
for p := range m {
|
||||
out = append(out, p)
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// Len returns the number of chat IDs that have tracked paths.
|
||||
func (ps *PathStore) Len() int {
|
||||
ps.mu.RLock()
|
||||
defer ps.mu.RUnlock()
|
||||
return len(ps.chatPaths)
|
||||
}
|
||||
|
||||
// Subscribe returns a channel that receives a signal whenever
|
||||
// paths change for chatID, along with an unsubscribe function
|
||||
// that removes the channel.
|
||||
func (ps *PathStore) Subscribe(chatID uuid.UUID) (<-chan struct{}, func()) {
|
||||
ch := make(chan struct{}, 1)
|
||||
|
||||
ps.mu.Lock()
|
||||
ps.subscribers[chatID] = append(ps.subscribers[chatID], ch)
|
||||
ps.mu.Unlock()
|
||||
|
||||
unsub := func() {
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
subs := ps.subscribers[chatID]
|
||||
for i, s := range subs {
|
||||
if s == ch {
|
||||
ps.subscribers[chatID] = append(subs[:i], subs[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ch, unsub
|
||||
}
|
||||
@@ -1,268 +0,0 @@
|
||||
package agentgit_test
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentgit"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestPathStore_AddPaths_StoresForChatAndAncestors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := agentgit.NewPathStore()
|
||||
chatID := uuid.New()
|
||||
ancestor1 := uuid.New()
|
||||
ancestor2 := uuid.New()
|
||||
|
||||
ps.AddPaths([]uuid.UUID{chatID, ancestor1, ancestor2}, []string{"/a", "/b"})
|
||||
|
||||
// All three IDs should see the paths.
|
||||
require.Equal(t, []string{"/a", "/b"}, ps.GetPaths(chatID))
|
||||
require.Equal(t, []string{"/a", "/b"}, ps.GetPaths(ancestor1))
|
||||
require.Equal(t, []string{"/a", "/b"}, ps.GetPaths(ancestor2))
|
||||
|
||||
// An unrelated chat should see nothing.
|
||||
require.Nil(t, ps.GetPaths(uuid.New()))
|
||||
}
|
||||
|
||||
func TestPathStore_AddPaths_SkipsNilUUIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := agentgit.NewPathStore()
|
||||
|
||||
// A nil chatID should be a no-op.
|
||||
ps.AddPaths([]uuid.UUID{uuid.Nil}, []string{"/x"})
|
||||
require.Nil(t, ps.GetPaths(uuid.Nil))
|
||||
|
||||
// A nil ancestor should be silently skipped.
|
||||
chatID := uuid.New()
|
||||
ps.AddPaths([]uuid.UUID{chatID, uuid.Nil}, []string{"/y"})
|
||||
require.Equal(t, []string{"/y"}, ps.GetPaths(chatID))
|
||||
require.Nil(t, ps.GetPaths(uuid.Nil))
|
||||
}
|
||||
|
||||
func TestPathStore_GetPaths_DeduplicatedSorted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := agentgit.NewPathStore()
|
||||
chatID := uuid.New()
|
||||
|
||||
ps.AddPaths([]uuid.UUID{chatID}, []string{"/z", "/a", "/m", "/a", "/z"})
|
||||
ps.AddPaths([]uuid.UUID{chatID}, []string{"/a", "/b"})
|
||||
|
||||
got := ps.GetPaths(chatID)
|
||||
require.Equal(t, []string{"/a", "/b", "/m", "/z"}, got)
|
||||
}
|
||||
|
||||
func TestPathStore_Subscribe_ReceivesNotification(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := agentgit.NewPathStore()
|
||||
chatID := uuid.New()
|
||||
|
||||
ch, unsub := ps.Subscribe(chatID)
|
||||
defer unsub()
|
||||
|
||||
ps.AddPaths([]uuid.UUID{chatID}, []string{"/file"})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
select {
|
||||
case <-ch:
|
||||
// Success.
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for notification")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPathStore_Subscribe_MultipleSubscribers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := agentgit.NewPathStore()
|
||||
chatID := uuid.New()
|
||||
|
||||
ch1, unsub1 := ps.Subscribe(chatID)
|
||||
defer unsub1()
|
||||
ch2, unsub2 := ps.Subscribe(chatID)
|
||||
defer unsub2()
|
||||
|
||||
ps.AddPaths([]uuid.UUID{chatID}, []string{"/file"})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
for i, ch := range []<-chan struct{}{ch1, ch2} {
|
||||
select {
|
||||
case <-ch:
|
||||
// OK
|
||||
case <-ctx.Done():
|
||||
t.Fatalf("subscriber %d did not receive notification", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPathStore_Unsubscribe_StopsNotifications(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := agentgit.NewPathStore()
|
||||
chatID := uuid.New()
|
||||
|
||||
ch, unsub := ps.Subscribe(chatID)
|
||||
unsub()
|
||||
|
||||
ps.AddPaths([]uuid.UUID{chatID}, []string{"/file"})
|
||||
|
||||
// AddPaths sends synchronously via a non-blocking send to the
|
||||
// buffered channel, so if a notification were going to arrive
|
||||
// it would already be in the channel by now.
|
||||
select {
|
||||
case <-ch:
|
||||
t.Fatal("received notification after unsubscribe")
|
||||
default:
|
||||
// Expected: no notification.
|
||||
}
|
||||
}
|
||||
|
||||
func TestPathStore_Subscribe_AncestorNotification(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := agentgit.NewPathStore()
|
||||
chatID := uuid.New()
|
||||
ancestor := uuid.New()
|
||||
|
||||
// Subscribe to the ancestor, then add paths via the child.
|
||||
ch, unsub := ps.Subscribe(ancestor)
|
||||
defer unsub()
|
||||
|
||||
ps.AddPaths([]uuid.UUID{chatID, ancestor}, []string{"/file"})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
select {
|
||||
case <-ch:
|
||||
// Success.
|
||||
case <-ctx.Done():
|
||||
t.Fatal("ancestor subscriber did not receive notification")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPathStore_Notify_NotifiesWithoutAddingPaths(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := agentgit.NewPathStore()
|
||||
chatID := uuid.New()
|
||||
|
||||
ch, unsub := ps.Subscribe(chatID)
|
||||
defer unsub()
|
||||
|
||||
ps.Notify([]uuid.UUID{chatID})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
select {
|
||||
case <-ch:
|
||||
// Success.
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for notification")
|
||||
}
|
||||
|
||||
require.Nil(t, ps.GetPaths(chatID))
|
||||
}
|
||||
|
||||
func TestPathStore_Notify_SkipsNilUUIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := agentgit.NewPathStore()
|
||||
chatID := uuid.New()
|
||||
|
||||
ch, unsub := ps.Subscribe(chatID)
|
||||
defer unsub()
|
||||
|
||||
ps.Notify([]uuid.UUID{uuid.Nil})
|
||||
|
||||
// Notify sends synchronously via a non-blocking send to the
|
||||
// buffered channel, so if a notification were going to arrive
|
||||
// it would already be in the channel by now.
|
||||
select {
|
||||
case <-ch:
|
||||
t.Fatal("received notification for nil UUID")
|
||||
default:
|
||||
// Expected: no notification.
|
||||
}
|
||||
|
||||
require.Nil(t, ps.GetPaths(chatID))
|
||||
}
|
||||
|
||||
func TestPathStore_Notify_AncestorNotification(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := agentgit.NewPathStore()
|
||||
chatID := uuid.New()
|
||||
ancestorID := uuid.New()
|
||||
|
||||
// Subscribe to the ancestor, then notify via the child.
|
||||
ch, unsub := ps.Subscribe(ancestorID)
|
||||
defer unsub()
|
||||
|
||||
ps.Notify([]uuid.UUID{chatID, ancestorID})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
select {
|
||||
case <-ch:
|
||||
// Success.
|
||||
case <-ctx.Done():
|
||||
t.Fatal("ancestor subscriber did not receive notification")
|
||||
}
|
||||
|
||||
require.Nil(t, ps.GetPaths(ancestorID))
|
||||
}
|
||||
|
||||
func TestPathStore_ConcurrentSafety(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := agentgit.NewPathStore()
|
||||
const goroutines = 20
|
||||
const iterations = 50
|
||||
|
||||
chatIDs := make([]uuid.UUID, goroutines)
|
||||
for i := range chatIDs {
|
||||
chatIDs[i] = uuid.New()
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines * 2) // writers + readers
|
||||
|
||||
// Writers.
|
||||
for i := range goroutines {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
for j := range iterations {
|
||||
ancestors := []uuid.UUID{chatIDs[(idx+1)%goroutines]}
|
||||
path := []string{
|
||||
"/file-" + chatIDs[idx].String() + "-" + time.Now().Format(time.RFC3339Nano),
|
||||
"/iter-" + string(rune('0'+j%10)),
|
||||
}
|
||||
ps.AddPaths(append([]uuid.UUID{chatIDs[idx]}, ancestors...), path)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Readers.
|
||||
for i := range goroutines {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
for range iterations {
|
||||
_ = ps.GetPaths(chatIDs[idx])
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify every chat has at least the paths it wrote.
|
||||
for _, id := range chatIDs {
|
||||
paths := ps.GetPaths(id)
|
||||
require.NotEmpty(t, paths, "chat %s should have paths", id)
|
||||
}
|
||||
}
|
||||
+5
-26
@@ -7,11 +7,9 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentgit"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
@@ -19,17 +17,15 @@ import (
|
||||
|
||||
// API exposes process-related operations through the agent.
|
||||
type API struct {
|
||||
logger slog.Logger
|
||||
manager *manager
|
||||
pathStore *agentgit.PathStore
|
||||
logger slog.Logger
|
||||
manager *manager
|
||||
}
|
||||
|
||||
// NewAPI creates a new process API handler.
|
||||
func NewAPI(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error), pathStore *agentgit.PathStore) *API {
|
||||
func NewAPI(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error)) *API {
|
||||
return &API{
|
||||
logger: logger,
|
||||
manager: newManager(logger, execer, updateEnv),
|
||||
pathStore: pathStore,
|
||||
logger: logger,
|
||||
manager: newManager(logger, execer, updateEnv),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,23 +74,6 @@ func (api *API) handleStartProcess(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Notify git watchers after the process finishes so that
|
||||
// file changes made by the command are visible in the scan.
|
||||
// If a workdir is provided, track it as a path as well.
|
||||
if api.pathStore != nil {
|
||||
if chatID, ancestorIDs, ok := agentgit.ExtractChatContext(r); ok {
|
||||
allIDs := append([]uuid.UUID{chatID}, ancestorIDs...)
|
||||
go func() {
|
||||
<-proc.done
|
||||
if req.WorkDir != "" {
|
||||
api.pathStore.AddPaths(allIDs, []string{req.WorkDir})
|
||||
} else {
|
||||
api.pathStore.Notify(allIDs)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.StartProcessResponse{
|
||||
ID: proc.id,
|
||||
Started: true,
|
||||
|
||||
@@ -12,14 +12,12 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentgit"
|
||||
"github.com/coder/coder/v2/agent/agentproc"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
@@ -101,7 +99,7 @@ func newTestAPIWithUpdateEnv(t *testing.T, updateEnv func([]string) ([]string, e
|
||||
logger := slogtest.Make(t, &slogtest.Options{
|
||||
IgnoreErrors: true,
|
||||
}).Leveled(slog.LevelDebug)
|
||||
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, updateEnv, nil)
|
||||
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, updateEnv)
|
||||
t.Cleanup(func() {
|
||||
_ = api.Close()
|
||||
})
|
||||
@@ -572,46 +570,6 @@ func TestSignalProcess(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleStartProcess_ChatHeaders_EmptyWorkDir_StillNotifies(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pathStore := agentgit.NewPathStore()
|
||||
chatID := uuid.New()
|
||||
ch, unsub := pathStore.Subscribe(chatID)
|
||||
defer unsub()
|
||||
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, func(current []string) ([]string, error) {
|
||||
return current, nil
|
||||
}, pathStore)
|
||||
defer api.Close()
|
||||
|
||||
routes := api.Routes()
|
||||
|
||||
body, err := json.Marshal(workspacesdk.StartProcessRequest{
|
||||
Command: "echo hello",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/start", bytes.NewReader(body))
|
||||
req.Header.Set(workspacesdk.CoderChatIDHeader, chatID.String())
|
||||
rw := httptest.NewRecorder()
|
||||
routes.ServeHTTP(rw, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rw.Code)
|
||||
|
||||
// The subscriber should be notified even though no paths
|
||||
// were added.
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("timed out waiting for path store notification")
|
||||
}
|
||||
|
||||
// No paths should have been stored for this chat.
|
||||
require.Nil(t, pathStore.GetPaths(chatID))
|
||||
}
|
||||
|
||||
func TestProcessLifecycle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -110,11 +110,6 @@ type Config struct {
|
||||
// X11DisplayOffset is the offset to add to the X11 display number.
|
||||
// Default is 10.
|
||||
X11DisplayOffset *int
|
||||
// X11MaxPort overrides the highest port used for X11 forwarding
|
||||
// listeners. Defaults to X11MaxPort (6200). Useful in tests
|
||||
// to shrink the port range and reduce the number of sessions
|
||||
// required.
|
||||
X11MaxPort *int
|
||||
// BlockFileTransfer restricts use of file transfer applications.
|
||||
BlockFileTransfer bool
|
||||
// ReportConnection.
|
||||
@@ -163,10 +158,6 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
offset := X11DefaultDisplayOffset
|
||||
config.X11DisplayOffset = &offset
|
||||
}
|
||||
if config.X11MaxPort == nil {
|
||||
maxPort := X11MaxPort
|
||||
config.X11MaxPort = &maxPort
|
||||
}
|
||||
if config.UpdateEnv == nil {
|
||||
config.UpdateEnv = func(current []string) ([]string, error) { return current, nil }
|
||||
}
|
||||
@@ -210,7 +201,6 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
x11HandlerErrors: metrics.x11HandlerErrors,
|
||||
fs: fs,
|
||||
displayOffset: *config.X11DisplayOffset,
|
||||
maxPort: *config.X11MaxPort,
|
||||
sessions: make(map[*x11Session]struct{}),
|
||||
connections: make(map[net.Conn]struct{}),
|
||||
network: func() X11Network {
|
||||
|
||||
@@ -57,7 +57,6 @@ type x11Forwarder struct {
|
||||
x11HandlerErrors *prometheus.CounterVec
|
||||
fs afero.Fs
|
||||
displayOffset int
|
||||
maxPort int
|
||||
|
||||
// network creates X11 listener sockets. Defaults to osNet{}.
|
||||
network X11Network
|
||||
@@ -315,7 +314,7 @@ func (x *x11Forwarder) evictLeastRecentlyUsedSession() {
|
||||
// the next available port starting from X11StartPort and displayOffset.
|
||||
func (x *x11Forwarder) createX11Listener(ctx context.Context) (ln net.Listener, display int, err error) {
|
||||
// Look for an open port to listen on.
|
||||
for port := X11StartPort + x.displayOffset; port <= x.maxPort; port++ {
|
||||
for port := X11StartPort + x.displayOffset; port <= X11MaxPort; port++ {
|
||||
if ctx.Err() != nil {
|
||||
return nil, -1, ctx.Err()
|
||||
}
|
||||
|
||||
@@ -142,13 +142,8 @@ func TestServer_X11_EvictionLRU(t *testing.T) {
|
||||
// Use in-process networking for X11 forwarding.
|
||||
inproc := testutil.NewInProcNet()
|
||||
|
||||
// Limit port range so we only need a handful of sessions to fill it
|
||||
// (the default 190 ports may easily timeout or conflict with other
|
||||
// ports on the system).
|
||||
maxPort := agentssh.X11StartPort + agentssh.X11DefaultDisplayOffset + 5
|
||||
cfg := &agentssh.Config{
|
||||
X11Net: inproc,
|
||||
X11MaxPort: &maxPort,
|
||||
X11Net: inproc,
|
||||
}
|
||||
|
||||
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, cfg)
|
||||
@@ -177,7 +172,7 @@ func TestServer_X11_EvictionLRU(t *testing.T) {
|
||||
// configured port range.
|
||||
|
||||
startPort := agentssh.X11StartPort + agentssh.X11DefaultDisplayOffset
|
||||
maxSessions := maxPort - startPort + 1 - 1 // -1 for the blocked port
|
||||
maxSessions := agentssh.X11MaxPort - startPort + 1 - 1 // -1 for the blocked port
|
||||
require.Greater(t, maxSessions, 0, "expected a positive maxSessions value")
|
||||
|
||||
// shellSession holds references to the session and its standard streams so
|
||||
|
||||
@@ -28,7 +28,6 @@ func (a *agent) apiHandler() http.Handler {
|
||||
})
|
||||
|
||||
r.Mount("/api/v0", a.filesAPI.Routes())
|
||||
r.Mount("/api/v0/git", a.gitAPI.Routes())
|
||||
r.Mount("/api/v0/processes", a.processAPI.Routes())
|
||||
|
||||
if a.devcontainers {
|
||||
|
||||
@@ -42,20 +42,9 @@ func WithLogger(logger slog.Logger) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// WithDone sets a channel that, when closed, stops the reaper
|
||||
// goroutine. Callers that invoke ForkReap more than once in the
|
||||
// same process (e.g. tests) should use this to prevent goroutine
|
||||
// accumulation.
|
||||
func WithDone(ch chan struct{}) Option {
|
||||
return func(o *options) {
|
||||
o.Done = ch
|
||||
}
|
||||
}
|
||||
|
||||
type options struct {
|
||||
ExecArgs []string
|
||||
PIDs reap.PidCh
|
||||
CatchSignals []os.Signal
|
||||
Logger slog.Logger
|
||||
Done chan struct{}
|
||||
}
|
||||
|
||||
@@ -18,15 +18,6 @@ import (
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// withDone returns an option that stops the reaper goroutine when t
|
||||
// completes, preventing goroutine accumulation across subtests.
|
||||
func withDone(t *testing.T) reaper.Option {
|
||||
t.Helper()
|
||||
done := make(chan struct{})
|
||||
t.Cleanup(func() { close(done) })
|
||||
return reaper.WithDone(done)
|
||||
}
|
||||
|
||||
// TestReap checks that's the reaper is successfully reaping
|
||||
// exited processes and passing the PIDs through the shared
|
||||
// channel.
|
||||
@@ -45,7 +36,6 @@ func TestReap(t *testing.T) {
|
||||
reaper.WithPIDCallback(pids),
|
||||
// Provide some argument that immediately exits.
|
||||
reaper.WithExecArgs("/bin/sh", "-c", "exit 0"),
|
||||
withDone(t),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, exitCode)
|
||||
@@ -99,7 +89,6 @@ func TestForkReapExitCodes(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
exitCode, err := reaper.ForkReap(
|
||||
reaper.WithExecArgs("/bin/sh", "-c", tt.command),
|
||||
withDone(t),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expectedCode, exitCode, "exit code mismatch for %q", tt.command)
|
||||
@@ -129,7 +118,6 @@ func TestReapInterrupt(t *testing.T) {
|
||||
exitCode, err := reaper.ForkReap(
|
||||
reaper.WithPIDCallback(pids),
|
||||
reaper.WithCatchSignals(os.Interrupt),
|
||||
withDone(t),
|
||||
// Signal propagation does not extend to children of children, so
|
||||
// we create a little bash script to ensure sleep is interrupted.
|
||||
reaper.WithExecArgs("/bin/sh", "-c", fmt.Sprintf("pid=0; trap 'kill -USR2 %d; kill -TERM $pid' INT; sleep 10 &\npid=$!; kill -USR1 %d; wait", os.Getpid(), os.Getpid())),
|
||||
@@ -141,17 +129,6 @@ func TestReapInterrupt(t *testing.T) {
|
||||
}()
|
||||
|
||||
require.Equal(t, <-usrSig, syscall.SIGUSR1)
|
||||
|
||||
// Prevent SIGINT from terminating the test process. Under the
|
||||
// race detector, the catchSignals goroutine in ForkReap may not
|
||||
// have called signal.Notify yet, so the default Go handler
|
||||
// could kill us. Registering our own Notify disables the
|
||||
// default behavior. Both this channel and the one inside
|
||||
// catchSignals receive independent copies of the signal.
|
||||
intC := make(chan os.Signal, 1)
|
||||
signal.Notify(intC, os.Interrupt)
|
||||
defer signal.Stop(intC)
|
||||
|
||||
err := syscall.Kill(os.Getpid(), syscall.SIGINT)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, <-usrSig, syscall.SIGUSR2)
|
||||
|
||||
+16
-21
@@ -6,7 +6,6 @@ import (
|
||||
"context"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/hashicorp/go-reap"
|
||||
@@ -20,7 +19,20 @@ func IsInitProcess() bool {
|
||||
return os.Getpid() == 1
|
||||
}
|
||||
|
||||
func catchSignals(logger slog.Logger, pid int, sc <-chan os.Signal) {
|
||||
func catchSignals(logger slog.Logger, pid int, sigs []os.Signal) {
|
||||
if len(sigs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
sc := make(chan os.Signal, 1)
|
||||
signal.Notify(sc, sigs...)
|
||||
defer signal.Stop(sc)
|
||||
|
||||
logger.Info(context.Background(), "reaper catching signals",
|
||||
slog.F("signals", sigs),
|
||||
slog.F("child_pid", pid),
|
||||
)
|
||||
|
||||
for {
|
||||
s := <-sc
|
||||
sig, ok := s.(syscall.Signal)
|
||||
@@ -52,17 +64,10 @@ func ForkReap(opt ...Option) (int, error) {
|
||||
o(opts)
|
||||
}
|
||||
|
||||
// Use the reapLock to prevent the reaper's Wait4(-1) from
|
||||
// stealing the direct child's exit status. The reaper takes
|
||||
// a write lock; we hold a read lock during our own Wait4.
|
||||
var reapLock sync.RWMutex
|
||||
reapLock.RLock()
|
||||
|
||||
go reap.ReapChildren(opts.PIDs, nil, opts.Done, &reapLock)
|
||||
go reap.ReapChildren(opts.PIDs, nil, nil, nil)
|
||||
|
||||
pwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
reapLock.RUnlock()
|
||||
return 1, xerrors.Errorf("get wd: %w", err)
|
||||
}
|
||||
|
||||
@@ -82,26 +87,16 @@ func ForkReap(opt ...Option) (int, error) {
|
||||
//#nosec G204
|
||||
pid, err := syscall.ForkExec(opts.ExecArgs[0], opts.ExecArgs, pattrs)
|
||||
if err != nil {
|
||||
reapLock.RUnlock()
|
||||
return 1, xerrors.Errorf("fork exec: %w", err)
|
||||
}
|
||||
|
||||
// Register the signal handler before spawning the goroutine
|
||||
// so it is active by the time the child process starts. This
|
||||
// avoids a race where a signal arrives before the goroutine
|
||||
// has called signal.Notify.
|
||||
if len(opts.CatchSignals) > 0 {
|
||||
sc := make(chan os.Signal, 1)
|
||||
signal.Notify(sc, opts.CatchSignals...)
|
||||
go catchSignals(opts.Logger, pid, sc)
|
||||
}
|
||||
go catchSignals(opts.Logger, pid, opts.CatchSignals)
|
||||
|
||||
var wstatus syscall.WaitStatus
|
||||
_, err = syscall.Wait4(pid, &wstatus, 0, nil)
|
||||
for xerrors.Is(err, syscall.EINTR) {
|
||||
_, err = syscall.Wait4(pid, &wstatus, 0, nil)
|
||||
}
|
||||
reapLock.RUnlock()
|
||||
|
||||
// Convert wait status to exit code using standard Unix conventions:
|
||||
// - Normal exit: use the exit code
|
||||
|
||||
@@ -123,10 +123,6 @@ func Select(inv *serpent.Invocation, opts SelectOptions) (string, error) {
|
||||
initialModel.height = defaultSelectModelHeight
|
||||
}
|
||||
|
||||
if idx := slices.Index(opts.Options, opts.Default); idx >= 0 {
|
||||
initialModel.cursor = idx
|
||||
}
|
||||
|
||||
initialModel.search.Prompt = ""
|
||||
initialModel.search.Focus()
|
||||
|
||||
|
||||
+3
-3
@@ -109,13 +109,13 @@ func (RootCmd) promptExample() *serpent.Command {
|
||||
Options: []string{
|
||||
"Blue", "Green", "Yellow", "Red", "Something else",
|
||||
},
|
||||
Default: "Green",
|
||||
Default: "",
|
||||
Message: "Select your favorite color:",
|
||||
Size: 5,
|
||||
HideSearch: !useSearch,
|
||||
})
|
||||
if value == "Something else" {
|
||||
_, _ = fmt.Fprint(inv.Stdout, "I would have picked green.\n")
|
||||
_, _ = fmt.Fprint(inv.Stdout, "I would have picked blue.\n")
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "%s is a nice color.\n", value)
|
||||
}
|
||||
@@ -128,7 +128,7 @@ func (RootCmd) promptExample() *serpent.Command {
|
||||
Options: []string{
|
||||
"Car", "Bike", "Plane", "Boat", "Train",
|
||||
},
|
||||
Default: "Bike",
|
||||
Default: "Car",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
+12
-12
@@ -41,11 +41,11 @@ func Test_TaskLogs_Golden(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
|
||||
inv, root := clitest.New(t, "task", "logs", setup.task.Name, "--output", "json")
|
||||
inv, root := clitest.New(t, "task", "logs", task.Name, "--output", "json")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
@@ -64,11 +64,11 @@ func Test_TaskLogs_Golden(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
|
||||
inv, root := clitest.New(t, "task", "logs", setup.task.ID.String(), "--output", "json")
|
||||
inv, root := clitest.New(t, "task", "logs", task.ID.String(), "--output", "json")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
@@ -87,11 +87,11 @@ func Test_TaskLogs_Golden(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
|
||||
inv, root := clitest.New(t, "task", "logs", setup.task.ID.String())
|
||||
inv, root := clitest.New(t, "task", "logs", task.ID.String())
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
@@ -141,10 +141,10 @@ func Test_TaskLogs_Golden(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsErr(assert.AnError))
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsErr(assert.AnError))
|
||||
|
||||
inv, root := clitest.New(t, "task", "logs", setup.task.ID.String())
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
inv, root := clitest.New(t, "task", "logs", task.ID.String())
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
|
||||
+26
-22
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
@@ -20,12 +21,12 @@ func TestExpTaskPause(t *testing.T) {
|
||||
|
||||
// Given: A running task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
// When: We attempt to pause the task
|
||||
inv, root := clitest.New(t, "task", "pause", setup.task.Name, "--yes")
|
||||
inv, root := clitest.New(t, "task", "pause", task.Name, "--yes")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
// Then: Expect the task to be paused
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -33,7 +34,7 @@ func TestExpTaskPause(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, output.Stdout(), "has been paused")
|
||||
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
|
||||
})
|
||||
@@ -45,13 +46,13 @@ func TestExpTaskPause(t *testing.T) {
|
||||
|
||||
// Given: A different user's running task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
adminClient, _, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
// When: We attempt to pause their task
|
||||
identifier := fmt.Sprintf("%s/%s", setup.task.OwnerName, setup.task.Name)
|
||||
identifier := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
|
||||
inv, root := clitest.New(t, "task", "pause", identifier, "--yes")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, setup.ownerClient, root)
|
||||
clitest.SetupConfig(t, adminClient, root)
|
||||
|
||||
// Then: We expect the task to be paused
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -59,7 +60,7 @@ func TestExpTaskPause(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, output.Stdout(), "has been paused")
|
||||
|
||||
updated, err := setup.ownerClient.TaskByIdentifier(ctx, identifier)
|
||||
updated, err := adminClient.TaskByIdentifier(ctx, identifier)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
|
||||
})
|
||||
@@ -69,11 +70,11 @@ func TestExpTaskPause(t *testing.T) {
|
||||
|
||||
// Given: A running task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
// When: We attempt to pause the task
|
||||
inv, root := clitest.New(t, "task", "pause", setup.task.Name)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
inv, root := clitest.New(t, "task", "pause", task.Name)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
// And: We confirm we want to pause the task
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -87,7 +88,7 @@ func TestExpTaskPause(t *testing.T) {
|
||||
pty.ExpectMatchContext(ctx, "has been paused")
|
||||
require.NoError(t, w.Wait())
|
||||
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
|
||||
})
|
||||
@@ -97,11 +98,11 @@ func TestExpTaskPause(t *testing.T) {
|
||||
|
||||
// Given: A running task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
// When: We attempt to pause the task
|
||||
inv, root := clitest.New(t, "task", "pause", setup.task.Name)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
inv, root := clitest.New(t, "task", "pause", task.Name)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
// But: We say no at the confirmation screen
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -113,7 +114,7 @@ func TestExpTaskPause(t *testing.T) {
|
||||
require.Error(t, w.Wait())
|
||||
|
||||
// Then: We expect the task to not be paused
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, codersdk.TaskStatusPaused, updated.Status)
|
||||
})
|
||||
@@ -123,18 +124,21 @@ func TestExpTaskPause(t *testing.T) {
|
||||
|
||||
// Given: A running task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
// And: We paused the running task
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
resp, err := userClient.PauseTask(ctx, task.OwnerName, task.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.WorkspaceBuild)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, resp.WorkspaceBuild.ID)
|
||||
|
||||
// When: We attempt to pause the task again
|
||||
inv, root := clitest.New(t, "task", "pause", setup.task.Name, "--yes")
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
inv, root := clitest.New(t, "task", "pause", task.Name, "--yes")
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
// Then: We expect to get an error that the task is already paused
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
err = inv.WithContext(ctx).Run()
|
||||
require.ErrorContains(t, err, "is already paused")
|
||||
})
|
||||
}
|
||||
|
||||
+43
-31
@@ -1,6 +1,7 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
@@ -16,18 +17,29 @@ import (
|
||||
func TestExpTaskResume(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// pauseTask is a helper that pauses a task and waits for the stop
|
||||
// build to complete.
|
||||
pauseTask := func(ctx context.Context, t *testing.T, client *codersdk.Client, task codersdk.Task) {
|
||||
t.Helper()
|
||||
|
||||
pauseResp, err := client.PauseTask(ctx, task.OwnerName, task.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pauseResp.WorkspaceBuild)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, pauseResp.WorkspaceBuild.ID)
|
||||
}
|
||||
|
||||
t.Run("WithYesFlag", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: A paused task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, userClient, task)
|
||||
|
||||
// When: We attempt to resume the task
|
||||
inv, root := clitest.New(t, "task", "resume", setup.task.Name, "--yes")
|
||||
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
// Then: We expect the task to be resumed
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -35,7 +47,7 @@ func TestExpTaskResume(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, output.Stdout(), "has been resumed")
|
||||
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
|
||||
})
|
||||
@@ -47,14 +59,14 @@ func TestExpTaskResume(t *testing.T) {
|
||||
|
||||
// Given: A different user's paused task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
adminClient, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, userClient, task)
|
||||
|
||||
// When: We attempt to resume their task
|
||||
identifier := fmt.Sprintf("%s/%s", setup.task.OwnerName, setup.task.Name)
|
||||
identifier := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
|
||||
inv, root := clitest.New(t, "task", "resume", identifier, "--yes")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, setup.ownerClient, root)
|
||||
clitest.SetupConfig(t, adminClient, root)
|
||||
|
||||
// Then: We expect the task to be resumed
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -62,7 +74,7 @@ func TestExpTaskResume(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, output.Stdout(), "has been resumed")
|
||||
|
||||
updated, err := setup.ownerClient.TaskByIdentifier(ctx, identifier)
|
||||
updated, err := adminClient.TaskByIdentifier(ctx, identifier)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
|
||||
})
|
||||
@@ -72,13 +84,13 @@ func TestExpTaskResume(t *testing.T) {
|
||||
|
||||
// Given: A paused task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, userClient, task)
|
||||
|
||||
// When: We attempt to resume the task (and specify no wait)
|
||||
inv, root := clitest.New(t, "task", "resume", setup.task.Name, "--yes", "--no-wait")
|
||||
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes", "--no-wait")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
// Then: We expect the task to be resumed in the background
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -87,11 +99,11 @@ func TestExpTaskResume(t *testing.T) {
|
||||
require.Contains(t, output.Stdout(), "in the background")
|
||||
|
||||
// And: The task to eventually be resumed
|
||||
require.True(t, setup.task.WorkspaceID.Valid, "task should have a workspace ID")
|
||||
ws := coderdtest.MustWorkspace(t, setup.userClient, setup.task.WorkspaceID.UUID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, setup.userClient, ws.LatestBuild.ID)
|
||||
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
|
||||
ws := coderdtest.MustWorkspace(t, userClient, task.WorkspaceID.UUID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, ws.LatestBuild.ID)
|
||||
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
|
||||
})
|
||||
@@ -101,12 +113,12 @@ func TestExpTaskResume(t *testing.T) {
|
||||
|
||||
// Given: A paused task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, userClient, task)
|
||||
|
||||
// When: We attempt to resume the task
|
||||
inv, root := clitest.New(t, "task", "resume", setup.task.Name)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
inv, root := clitest.New(t, "task", "resume", task.Name)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
// And: We confirm we want to resume the task
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -120,7 +132,7 @@ func TestExpTaskResume(t *testing.T) {
|
||||
pty.ExpectMatchContext(ctx, "has been resumed")
|
||||
require.NoError(t, w.Wait())
|
||||
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
|
||||
})
|
||||
@@ -130,12 +142,12 @@ func TestExpTaskResume(t *testing.T) {
|
||||
|
||||
// Given: A paused task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, userClient, task)
|
||||
|
||||
// When: We attempt to resume the task
|
||||
inv, root := clitest.New(t, "task", "resume", setup.task.Name)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
inv, root := clitest.New(t, "task", "resume", task.Name)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
// But: Say no at the confirmation screen
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -147,7 +159,7 @@ func TestExpTaskResume(t *testing.T) {
|
||||
require.Error(t, w.Wait())
|
||||
|
||||
// Then: We expect the task to still be paused
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
|
||||
})
|
||||
@@ -157,11 +169,11 @@ func TestExpTaskResume(t *testing.T) {
|
||||
|
||||
// Given: A running task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
// When: We attempt to resume the task that is not paused
|
||||
inv, root := clitest.New(t, "task", "resume", setup.task.Name, "--yes")
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes")
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
// Then: We expect to get an error that the task is not paused
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
+9
-154
@@ -1,15 +1,10 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
@@ -20,15 +15,13 @@ func (r *RootCmd) taskSend() *serpent.Command {
|
||||
cmd := &serpent.Command{
|
||||
Use: "send <task> [<input> | --stdin]",
|
||||
Short: "Send input to a task",
|
||||
Long: `Send input to a task. If the task is paused, it will be automatically resumed before input is sent. If the task is initializing, it will wait for the task to become ready.
|
||||
` +
|
||||
FormatExamples(Example{
|
||||
Description: "Send direct input to a task",
|
||||
Command: `coder task send task1 "Please also add unit tests"`,
|
||||
}, Example{
|
||||
Description: "Send input from stdin to a task",
|
||||
Command: `echo "Please also add unit tests" | coder task send task1 --stdin`,
|
||||
}),
|
||||
Long: FormatExamples(Example{
|
||||
Description: "Send direct input to a task.",
|
||||
Command: "coder task send task1 \"Please also add unit tests\"",
|
||||
}, Example{
|
||||
Description: "Send input from stdin to a task.",
|
||||
Command: "echo \"Please also add unit tests\" | coder task send task1 --stdin",
|
||||
}),
|
||||
Middleware: serpent.RequireRangeArgs(1, 2),
|
||||
Options: serpent.OptionSet{
|
||||
{
|
||||
@@ -71,48 +64,8 @@ func (r *RootCmd) taskSend() *serpent.Command {
|
||||
return xerrors.Errorf("resolve task: %w", err)
|
||||
}
|
||||
|
||||
display := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
|
||||
|
||||
// Before attempting to send, check the task status and
|
||||
// handle non-active states.
|
||||
var workspaceBuildID uuid.UUID
|
||||
|
||||
switch task.Status {
|
||||
case codersdk.TaskStatusActive:
|
||||
// Already active, no build to watch.
|
||||
|
||||
case codersdk.TaskStatusPaused:
|
||||
resp, err := client.ResumeTask(ctx, task.OwnerName, task.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resume task %q: %w", display, err)
|
||||
} else if resp.WorkspaceBuild == nil {
|
||||
return xerrors.Errorf("resume task %q", display)
|
||||
}
|
||||
|
||||
workspaceBuildID = resp.WorkspaceBuild.ID
|
||||
|
||||
case codersdk.TaskStatusInitializing:
|
||||
if !task.WorkspaceID.Valid {
|
||||
return xerrors.Errorf("send input to task %q: task has no backing workspace", display)
|
||||
}
|
||||
|
||||
workspace, err := client.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get workspace for task %q: %w", display, err)
|
||||
}
|
||||
|
||||
workspaceBuildID = workspace.LatestBuild.ID
|
||||
|
||||
default:
|
||||
return xerrors.Errorf("task %q has status %s and cannot be sent input", display, task.Status)
|
||||
}
|
||||
|
||||
if err := waitForTaskIdle(ctx, inv, client, task, workspaceBuildID); err != nil {
|
||||
return xerrors.Errorf("wait for task %q to be idle: %w", display, err)
|
||||
}
|
||||
|
||||
if err := client.TaskSend(ctx, codersdk.Me, task.ID, codersdk.TaskSendRequest{Input: taskInput}); err != nil {
|
||||
return xerrors.Errorf("send input to task %q: %w", display, err)
|
||||
if err = client.TaskSend(ctx, codersdk.Me, task.ID, codersdk.TaskSendRequest{Input: taskInput}); err != nil {
|
||||
return xerrors.Errorf("send input to task: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -121,101 +74,3 @@ func (r *RootCmd) taskSend() *serpent.Command {
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// waitForTaskIdle optionally watches a workspace build to completion,
|
||||
// then polls until the task becomes active and its app state is idle.
|
||||
// This merges build-watching and idle-polling into a single loop so
|
||||
// that status changes (e.g. paused) are never missed between phases.
|
||||
func waitForTaskIdle(ctx context.Context, inv *serpent.Invocation, client *codersdk.Client, task codersdk.Task, workspaceBuildID uuid.UUID) error {
|
||||
if workspaceBuildID != uuid.Nil {
|
||||
if err := cliui.WorkspaceBuild(ctx, inv.Stdout, client, workspaceBuildID); err != nil {
|
||||
return xerrors.Errorf("watch workspace build: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
cliui.Infof(inv.Stdout, "Waiting for task to become idle...")
|
||||
|
||||
// NOTE(DanielleMaywood):
|
||||
// It has been observed that the `TaskStatusError` state has
|
||||
// appeared during a typical healthy startup [^0]. To combat
|
||||
// this, we allow a 5 minute grace period where we allow
|
||||
// `TaskStatusError` to surface without immediately failing.
|
||||
//
|
||||
// TODO(DanielleMaywood):
|
||||
// Remove this grace period once the upstream agentapi health
|
||||
// check no longer reports transient error states during normal
|
||||
// startup.
|
||||
//
|
||||
// [0]: https://github.com/coder/coder/pull/22203#discussion_r2858002569
|
||||
const errorGracePeriod = 5 * time.Minute
|
||||
gracePeriodDeadline := time.Now().Add(errorGracePeriod)
|
||||
|
||||
// NOTE(DanielleMaywood):
|
||||
// On resume the MCP may not report an initial app status,
|
||||
// leaving CurrentState nil indefinitely. To avoid hanging
|
||||
// forever we treat Active with nil CurrentState as idle
|
||||
// after a grace period, giving the MCP time to report
|
||||
// during normal startup.
|
||||
const nilStateGracePeriod = 30 * time.Second
|
||||
var nilStateDeadline time.Time
|
||||
|
||||
// TODO(DanielleMaywood):
|
||||
// When we have a streaming Task API, this should be converted
|
||||
// away from polling.
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
task, err := client.TaskByID(ctx, task.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get task by id: %w", err)
|
||||
}
|
||||
|
||||
switch task.Status {
|
||||
case codersdk.TaskStatusInitializing,
|
||||
codersdk.TaskStatusPending:
|
||||
// Not yet active, keep polling.
|
||||
continue
|
||||
case codersdk.TaskStatusActive:
|
||||
// Task is active; check app state.
|
||||
if task.CurrentState == nil {
|
||||
// The MCP may not have reported state yet.
|
||||
// Start a grace period on first observation
|
||||
// and treat as idle once it expires.
|
||||
if nilStateDeadline.IsZero() {
|
||||
nilStateDeadline = time.Now().Add(nilStateGracePeriod)
|
||||
}
|
||||
if time.Now().After(nilStateDeadline) {
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Reset nil-state deadline since we got a real
|
||||
// state report.
|
||||
nilStateDeadline = time.Time{}
|
||||
switch task.CurrentState.State {
|
||||
case codersdk.TaskStateIdle,
|
||||
codersdk.TaskStateComplete,
|
||||
codersdk.TaskStateFailed:
|
||||
return nil
|
||||
default:
|
||||
// Still working, keep polling.
|
||||
continue
|
||||
}
|
||||
case codersdk.TaskStatusError:
|
||||
if time.Now().After(gracePeriodDeadline) {
|
||||
return xerrors.Errorf("task entered %s state while waiting for it to become idle", task.Status)
|
||||
}
|
||||
case codersdk.TaskStatusPaused:
|
||||
return xerrors.Errorf("task was paused while waiting for it to become idle")
|
||||
case codersdk.TaskStatusUnknown:
|
||||
return xerrors.Errorf("task entered %s state while waiting for it to become idle", task.Status)
|
||||
default:
|
||||
return xerrors.Errorf("task entered unexpected state (%s) while waiting for it to become idle", task.Status)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+13
-224
@@ -12,14 +12,9 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
agentapisdk "github.com/coder/agentapi-sdk-go"
|
||||
"github.com/coder/coder/v2/agent"
|
||||
"github.com/coder/coder/v2/agent/agenttest"
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
@@ -30,12 +25,12 @@ func Test_TaskSend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.Name, "carry on with the task")
|
||||
inv, root := clitest.New(t, "task", "send", task.Name, "carry on with the task")
|
||||
inv.Stdout = &stdout
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
@@ -46,12 +41,12 @@ func Test_TaskSend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.ID.String(), "carry on with the task")
|
||||
inv, root := clitest.New(t, "task", "send", task.ID.String(), "carry on with the task")
|
||||
inv.Stdout = &stdout
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
@@ -62,13 +57,13 @@ func Test_TaskSend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.Name, "--stdin")
|
||||
inv, root := clitest.New(t, "task", "send", task.Name, "--stdin")
|
||||
inv.Stdout = &stdout
|
||||
inv.Stdin = strings.NewReader("carry on with the task")
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
@@ -115,223 +110,17 @@ func Test_TaskSend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendErr(assert.AnError))
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendErr(t, assert.AnError))
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.Name, "some task input")
|
||||
inv, root := clitest.New(t, "task", "send", task.Name, "some task input")
|
||||
inv.Stdout = &stdout
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.ErrorContains(t, err, assert.AnError.Error())
|
||||
})
|
||||
|
||||
t.Run("WaitsForInitializingTask", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "some task input", "some task response"))
|
||||
|
||||
// Close the first agent, pause, then resume the task so the
|
||||
// workspace is started but no agent is connected.
|
||||
// This puts the task in "initializing" state.
|
||||
require.NoError(t, setup.agent.Close())
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
resumeTask(setupCtx, t, setup.userClient, setup.task)
|
||||
|
||||
// When: We attempt to send input to the initializing task.
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.Name, "some task input")
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
inv = inv.WithContext(ctx)
|
||||
|
||||
// Use a pty so we can wait for the command to produce build
|
||||
// output, confirming it has entered the initializing code
|
||||
// path before we connect the agent.
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
w := clitest.StartWithWaiter(t, inv)
|
||||
|
||||
// Wait for the command to observe the initializing state and
|
||||
// start watching the workspace build. This ensures the command
|
||||
// has entered the waiting code path.
|
||||
pty.ExpectMatchContext(ctx, "Queued")
|
||||
|
||||
// Connect a new agent so the task can transition to active.
|
||||
agentClient := agentsdk.New(setup.userClient.URL, agentsdk.WithFixedToken(setup.agentToken))
|
||||
setup.agent = agenttest.New(t, setup.userClient.URL, setup.agentToken, func(o *agent.Options) {
|
||||
o.Client = agentClient
|
||||
})
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, setup.userClient, setup.task.WorkspaceID.UUID).
|
||||
WaitFor(coderdtest.AgentsReady)
|
||||
|
||||
// Report the task app as idle so waitForTaskIdle can proceed.
|
||||
require.NoError(t, agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
|
||||
AppSlug: "task-sidebar",
|
||||
State: codersdk.WorkspaceAppStatusStateIdle,
|
||||
Message: "ready",
|
||||
}))
|
||||
|
||||
// Then: The command should complete successfully.
|
||||
require.NoError(t, w.Wait())
|
||||
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusActive, updated.Status)
|
||||
})
|
||||
|
||||
t.Run("ResumesPausedTask", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "some task input", "some task response"))
|
||||
|
||||
// Close the first agent before pausing so it does not conflict
|
||||
// with the agent we reconnect after the workspace is resumed.
|
||||
require.NoError(t, setup.agent.Close())
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
|
||||
// When: We attempt to send input to the paused task.
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.Name, "some task input")
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
inv = inv.WithContext(ctx)
|
||||
|
||||
// Use a pty so we can wait for the command to produce build
|
||||
// output, confirming it has entered the paused code path and
|
||||
// triggered a resume before we connect the agent.
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
w := clitest.StartWithWaiter(t, inv)
|
||||
|
||||
// Wait for the command to observe the paused state, trigger
|
||||
// a resume, and start watching the workspace build.
|
||||
pty.ExpectMatchContext(ctx, "Queued")
|
||||
|
||||
// Connect a new agent so the task can transition to active.
|
||||
agentClient := agentsdk.New(setup.userClient.URL, agentsdk.WithFixedToken(setup.agentToken))
|
||||
setup.agent = agenttest.New(t, setup.userClient.URL, setup.agentToken, func(o *agent.Options) {
|
||||
o.Client = agentClient
|
||||
})
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, setup.userClient, setup.task.WorkspaceID.UUID).
|
||||
WaitFor(coderdtest.AgentsReady)
|
||||
|
||||
// Report the task app as idle so waitForTaskIdle can proceed.
|
||||
require.NoError(t, agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
|
||||
AppSlug: "task-sidebar",
|
||||
State: codersdk.WorkspaceAppStatusStateIdle,
|
||||
Message: "ready",
|
||||
}))
|
||||
|
||||
// Then: The command should complete successfully.
|
||||
require.NoError(t, w.Wait())
|
||||
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusActive, updated.Status)
|
||||
})
|
||||
|
||||
t.Run("PausedDuringWaitForReady", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: An initializing task (workspace running, no agent
|
||||
// connected).
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
require.NoError(t, setup.agent.Close())
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
resumeTask(setupCtx, t, setup.userClient, setup.task)
|
||||
|
||||
// When: We attempt to send input to the initializing task.
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.Name, "some task input")
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
inv = inv.WithContext(ctx)
|
||||
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
w := clitest.StartWithWaiter(t, inv)
|
||||
|
||||
// Wait for the command to enter the build-watching phase
|
||||
// of waitForTaskReady.
|
||||
pty.ExpectMatchContext(ctx, "Queued")
|
||||
|
||||
// Pause the task while waitForTaskReady is polling. Since
|
||||
// no agent is connected, the task stays initializing until
|
||||
// we pause it, at which point the status becomes paused.
|
||||
pauseTask(ctx, t, setup.userClient, setup.task)
|
||||
|
||||
// Then: The command should fail because the task was paused.
|
||||
err := w.Wait()
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "was paused while waiting for it to become idle")
|
||||
})
|
||||
|
||||
t.Run("WaitsForWorkingAppState", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: An active task whose app is in "working" state.
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "some task input", "some task response"))
|
||||
|
||||
// Move the app into "working" state before running the command.
|
||||
agentClient := agentsdk.New(setup.userClient.URL, agentsdk.WithFixedToken(setup.agentToken))
|
||||
require.NoError(t, agentClient.PatchAppStatus(setupCtx, agentsdk.PatchAppStatus{
|
||||
AppSlug: "task-sidebar",
|
||||
State: codersdk.WorkspaceAppStatusStateWorking,
|
||||
Message: "busy",
|
||||
}))
|
||||
|
||||
// When: We send input while the app is working.
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.Name, "some task input")
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
inv = inv.WithContext(ctx)
|
||||
w := clitest.StartWithWaiter(t, inv)
|
||||
|
||||
// Transition the app back to idle so waitForTaskIdle proceeds.
|
||||
require.NoError(t, agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
|
||||
AppSlug: "task-sidebar",
|
||||
State: codersdk.WorkspaceAppStatusStateIdle,
|
||||
Message: "ready",
|
||||
}))
|
||||
|
||||
// Then: The command should complete successfully.
|
||||
require.NoError(t, w.Wait())
|
||||
})
|
||||
|
||||
t.Run("SendToNonIdleAppState", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, appState := range []codersdk.WorkspaceAppStatusState{
|
||||
codersdk.WorkspaceAppStatusStateComplete,
|
||||
codersdk.WorkspaceAppStatusStateFailure,
|
||||
} {
|
||||
t.Run(string(appState), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "some input", "some response"))
|
||||
|
||||
agentClient := agentsdk.New(setup.userClient.URL, agentsdk.WithFixedToken(setup.agentToken))
|
||||
require.NoError(t, agentClient.PatchAppStatus(setupCtx, agentsdk.PatchAppStatus{
|
||||
AppSlug: "task-sidebar",
|
||||
State: appState,
|
||||
Message: "done",
|
||||
}))
|
||||
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.Name, "some input")
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func fakeAgentAPITaskSendOK(t *testing.T, expectMessage, returnMessage string) map[string]http.HandlerFunc {
|
||||
@@ -362,7 +151,7 @@ func fakeAgentAPITaskSendOK(t *testing.T, expectMessage, returnMessage string) m
|
||||
}
|
||||
}
|
||||
|
||||
func fakeAgentAPITaskSendErr(returnErr error) map[string]http.HandlerFunc {
|
||||
func fakeAgentAPITaskSendErr(t *testing.T, returnErr error) map[string]http.HandlerFunc {
|
||||
return map[string]http.HandlerFunc{
|
||||
"/status": func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
+5
-56
@@ -88,13 +88,6 @@ func Test_Tasks(t *testing.T) {
|
||||
o.Client = agentClient
|
||||
})
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, userClient, tasks[0].WorkspaceID.UUID).WithContext(ctx).WaitFor(coderdtest.AgentsReady)
|
||||
// Report the task app as idle so that waitForTaskIdle
|
||||
// can proceed during the "send task message" step.
|
||||
require.NoError(t, agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
|
||||
AppSlug: "task-sidebar",
|
||||
State: codersdk.WorkspaceAppStatusStateIdle,
|
||||
Message: "ready",
|
||||
}))
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -279,19 +272,10 @@ func fakeAgentAPIEcho(ctx context.Context, t testing.TB, initMsg agentapisdk.Mes
|
||||
// setupCLITaskTest creates a test workspace with an AI task template and agent,
|
||||
// with a fake agent API configured with the provided set of handlers.
|
||||
// Returns the user client and workspace.
|
||||
// setupCLITaskTestResult holds the return values from setupCLITaskTest.
|
||||
type setupCLITaskTestResult struct {
|
||||
ownerClient *codersdk.Client
|
||||
userClient *codersdk.Client
|
||||
task codersdk.Task
|
||||
agentToken string
|
||||
agent agent.Agent
|
||||
}
|
||||
|
||||
func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[string]http.HandlerFunc) setupCLITaskTestResult {
|
||||
func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[string]http.HandlerFunc) (ownerClient *codersdk.Client, memberClient *codersdk.Client, task codersdk.Task) {
|
||||
t.Helper()
|
||||
|
||||
ownerClient := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
ownerClient = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
owner := coderdtest.CreateFirstUser(t, ownerClient)
|
||||
userClient, _ := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID)
|
||||
|
||||
@@ -308,56 +292,21 @@ func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[st
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for the task's underlying workspace to be built.
|
||||
// Wait for the task's underlying workspace to be built
|
||||
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
|
||||
workspace, err := userClient.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
require.NoError(t, err)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, workspace.LatestBuild.ID)
|
||||
|
||||
agentClient := agentsdk.New(userClient.URL, agentsdk.WithFixedToken(authToken))
|
||||
agt := agenttest.New(t, userClient.URL, authToken, func(o *agent.Options) {
|
||||
_ = agenttest.New(t, userClient.URL, authToken, func(o *agent.Options) {
|
||||
o.Client = agentClient
|
||||
})
|
||||
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, userClient, workspace.ID).
|
||||
WaitFor(coderdtest.AgentsReady)
|
||||
|
||||
// Report the task app as idle so that waitForTaskIdle can proceed.
|
||||
err = agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
|
||||
AppSlug: "task-sidebar",
|
||||
State: codersdk.WorkspaceAppStatusStateIdle,
|
||||
Message: "ready",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return setupCLITaskTestResult{
|
||||
ownerClient: ownerClient,
|
||||
userClient: userClient,
|
||||
task: task,
|
||||
agentToken: authToken,
|
||||
agent: agt,
|
||||
}
|
||||
}
|
||||
|
||||
// pauseTask pauses the task and waits for the stop build to complete.
|
||||
func pauseTask(ctx context.Context, t *testing.T, client *codersdk.Client, task codersdk.Task) {
|
||||
t.Helper()
|
||||
|
||||
pauseResp, err := client.PauseTask(ctx, task.OwnerName, task.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pauseResp.WorkspaceBuild)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, pauseResp.WorkspaceBuild.ID)
|
||||
}
|
||||
|
||||
// resumeTask resumes the task waits for the start build to complete. The task
|
||||
// will be in "initializing" state after this returns because no agent is connected.
|
||||
func resumeTask(ctx context.Context, t *testing.T, client *codersdk.Client, task codersdk.Task) {
|
||||
t.Helper()
|
||||
|
||||
resumeResp, err := client.ResumeTask(ctx, task.OwnerName, task.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resumeResp.WorkspaceBuild)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, resumeResp.WorkspaceBuild.ID)
|
||||
return ownerClient, userClient, task
|
||||
}
|
||||
|
||||
// setupCLITaskTestWithSnapshot creates a task in the specified status with a log snapshot.
|
||||
|
||||
+2
-5
@@ -5,14 +5,11 @@ USAGE:
|
||||
|
||||
Send input to a task
|
||||
|
||||
Send input to a task. If the task is paused, it will be automatically resumed
|
||||
before input is sent. If the task is initializing, it will wait for the task
|
||||
to become ready.
|
||||
- Send direct input to a task:
|
||||
- Send direct input to a task.:
|
||||
|
||||
$ coder task send task1 "Please also add unit tests"
|
||||
|
||||
- Send input from stdin to a task:
|
||||
- Send input from stdin to a task.:
|
||||
|
||||
$ echo "Please also add unit tests" | coder task send task1 --stdin
|
||||
|
||||
|
||||
+2
-13
@@ -16,7 +16,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
@@ -185,22 +184,12 @@ func TestTokens(t *testing.T) {
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf = new(bytes.Buffer)
|
||||
inv.Stdout = buf
|
||||
|
||||
// Precondition: validate token is not expired before expiring
|
||||
var expiredAtBefore time.Time
|
||||
token, err := client.APIKeyByName(ctx, secondUser.ID.String(), "token-two")
|
||||
require.NoError(t, err)
|
||||
now := dbtime.Now()
|
||||
require.True(t, token.ExpiresAt.After(now), "token should not be expired yet (expiresAt=%s, now=%s)", token.ExpiresAt.UTC(), now)
|
||||
expiredAtBefore = token.ExpiresAt
|
||||
|
||||
err = inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
// Validate that token was expired
|
||||
if token, err := client.APIKeyByName(ctx, secondUser.ID.String(), "token-two"); assert.NoError(t, err) {
|
||||
now := dbtime.Now()
|
||||
require.NotEqual(t, token.ExpiresAt, expiredAtBefore, "token expiresAt is the same as before expiring, but should have been updated")
|
||||
require.False(t, token.ExpiresAt.After(now), "token expiresAt should not be in the future after expiring, but was %s (now=%s)", token.ExpiresAt.UTC(), now)
|
||||
now := time.Now()
|
||||
require.False(t, token.ExpiresAt.After(now), "token expiresAt should not be in the future, but was %s (now=%s)", token.ExpiresAt, now)
|
||||
}
|
||||
|
||||
// Delete by ID (explicit delete flag)
|
||||
|
||||
@@ -387,9 +387,9 @@ func (b *Batcher) flush(ctx context.Context, reason string) {
|
||||
b.Metrics.BatchSize.Observe(float64(count))
|
||||
b.Metrics.MetadataTotal.Add(float64(count))
|
||||
b.Metrics.BatchesTotal.WithLabelValues(reason).Inc()
|
||||
elapsed = b.clock.Since(start)
|
||||
b.Metrics.FlushDuration.WithLabelValues(reason).Observe(elapsed.Seconds())
|
||||
b.Metrics.FlushDuration.WithLabelValues(reason).Observe(time.Since(start).Seconds())
|
||||
|
||||
elapsed = time.Since(start)
|
||||
b.log.Debug(ctx, "flush complete",
|
||||
slog.F("count", count),
|
||||
slog.F("elapsed", elapsed),
|
||||
|
||||
+1
-13
@@ -315,18 +315,6 @@ func taskFromDBTaskAndWorkspace(dbTask database.Task, ws codersdk.Workspace) cod
|
||||
}
|
||||
}
|
||||
|
||||
// appStatusStateToTaskState converts a WorkspaceAppStatusState to a
|
||||
// TaskState. The two enums mostly share values but "failure" in the
|
||||
// app status maps to "failed" in the public task API.
|
||||
func appStatusStateToTaskState(s codersdk.WorkspaceAppStatusState) codersdk.TaskState {
|
||||
switch s {
|
||||
case codersdk.WorkspaceAppStatusStateFailure:
|
||||
return codersdk.TaskStateFailed
|
||||
default:
|
||||
return codersdk.TaskState(s)
|
||||
}
|
||||
}
|
||||
|
||||
// deriveTaskCurrentState determines the current state of a task based on the
|
||||
// workspace's latest app status and initialization phase.
|
||||
// Returns nil if no valid state can be determined.
|
||||
@@ -346,7 +334,7 @@ func deriveTaskCurrentState(
|
||||
if ws.LatestBuild.Transition != codersdk.WorkspaceTransitionStart || ws.LatestAppStatus.CreatedAt.After(ws.LatestBuild.CreatedAt) {
|
||||
currentState = &codersdk.TaskStateEntry{
|
||||
Timestamp: ws.LatestAppStatus.CreatedAt,
|
||||
State: appStatusStateToTaskState(ws.LatestAppStatus.State),
|
||||
State: codersdk.TaskState(ws.LatestAppStatus.State),
|
||||
Message: ws.LatestAppStatus.Message,
|
||||
URI: ws.LatestAppStatus.URI,
|
||||
}
|
||||
|
||||
Generated
-160
@@ -481,128 +481,6 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/files": {
|
||||
"post": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"consumes": [
|
||||
"application/octet-stream"
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "Upload a chat file",
|
||||
"operationId": "upload-chat-file",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Content-Type must be an image type (image/png, image/jpeg, image/gif, image/webp)",
|
||||
"name": "Content-Type",
|
||||
"in": "header",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Organization ID",
|
||||
"name": "organization",
|
||||
"in": "query",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"201": {
|
||||
"description": "Created",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UploadChatFileResponse"
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"description": "Bad Request",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"description": "Unauthorized",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
},
|
||||
"413": {
|
||||
"description": "Request Entity Too Large",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
},
|
||||
"500": {
|
||||
"description": "Internal Server Error",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/files/{file}": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "Get a chat file",
|
||||
"operationId": "get-chat-file",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "File ID",
|
||||
"name": "file",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK"
|
||||
},
|
||||
"400": {
|
||||
"description": "Bad Request",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"description": "Unauthorized",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
},
|
||||
"404": {
|
||||
"description": "Not Found",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
},
|
||||
"500": {
|
||||
"description": "Internal Server Error",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/{chat}/archive": {
|
||||
"post": {
|
||||
"tags": [
|
||||
@@ -617,35 +495,6 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/{chat}/git/watch": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "Watch git changes for a chat.",
|
||||
"operationId": "watch-chat-git",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Chat ID",
|
||||
"name": "chat",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"101": {
|
||||
"description": "Switching Protocols"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/{chat}/unarchive": {
|
||||
"post": {
|
||||
"tags": [
|
||||
@@ -20456,15 +20305,6 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UploadChatFileResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UploadResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
Generated
-150
@@ -410,120 +410,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/files": {
|
||||
"post": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"consumes": ["application/octet-stream"],
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Chats"],
|
||||
"summary": "Upload a chat file",
|
||||
"operationId": "upload-chat-file",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Content-Type must be an image type (image/png, image/jpeg, image/gif, image/webp)",
|
||||
"name": "Content-Type",
|
||||
"in": "header",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Organization ID",
|
||||
"name": "organization",
|
||||
"in": "query",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"201": {
|
||||
"description": "Created",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UploadChatFileResponse"
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"description": "Bad Request",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"description": "Unauthorized",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
},
|
||||
"413": {
|
||||
"description": "Request Entity Too Large",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
},
|
||||
"500": {
|
||||
"description": "Internal Server Error",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/files/{file}": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": ["Chats"],
|
||||
"summary": "Get a chat file",
|
||||
"operationId": "get-chat-file",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "File ID",
|
||||
"name": "file",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK"
|
||||
},
|
||||
"400": {
|
||||
"description": "Bad Request",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"description": "Unauthorized",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
},
|
||||
"404": {
|
||||
"description": "Not Found",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
},
|
||||
"500": {
|
||||
"description": "Internal Server Error",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/{chat}/archive": {
|
||||
"post": {
|
||||
"tags": ["Chats"],
|
||||
@@ -536,33 +422,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/{chat}/git/watch": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": ["Chats"],
|
||||
"summary": "Watch git changes for a chat.",
|
||||
"operationId": "watch-chat-git",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Chat ID",
|
||||
"name": "chat",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"101": {
|
||||
"description": "Switching Protocols"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/{chat}/unarchive": {
|
||||
"post": {
|
||||
"tags": ["Chats"],
|
||||
@@ -18764,15 +18623,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UploadChatFileResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UploadResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
+11
-11
@@ -48,8 +48,8 @@ func TestTokenCRUD(t *testing.T) {
|
||||
require.EqualValues(t, len(keys), 1)
|
||||
require.Contains(t, res.Key, keys[0].ID)
|
||||
// expires_at should default to 30 days
|
||||
require.Greater(t, keys[0].ExpiresAt, dbtime.Now().Add(time.Hour*24*6))
|
||||
require.Less(t, keys[0].ExpiresAt, dbtime.Now().Add(time.Hour*24*8))
|
||||
require.Greater(t, keys[0].ExpiresAt, time.Now().Add(time.Hour*24*6))
|
||||
require.Less(t, keys[0].ExpiresAt, time.Now().Add(time.Hour*24*8))
|
||||
require.Equal(t, codersdk.APIKeyScopeAll, keys[0].Scope)
|
||||
require.Len(t, keys[0].AllowList, 1)
|
||||
require.Equal(t, "*:*", keys[0].AllowList[0].String())
|
||||
@@ -194,8 +194,8 @@ func TestUserSetTokenDuration(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
keys, err := client.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{})
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, keys[0].ExpiresAt, dbtime.Now().Add(time.Hour*6*24))
|
||||
require.Less(t, keys[0].ExpiresAt, dbtime.Now().Add(time.Hour*8*24))
|
||||
require.Greater(t, keys[0].ExpiresAt, time.Now().Add(time.Hour*6*24))
|
||||
require.Less(t, keys[0].ExpiresAt, time.Now().Add(time.Hour*8*24))
|
||||
}
|
||||
|
||||
func TestDefaultTokenDuration(t *testing.T) {
|
||||
@@ -210,8 +210,8 @@ func TestDefaultTokenDuration(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
keys, err := client.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{})
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, keys[0].ExpiresAt, dbtime.Now().Add(time.Hour*24*6))
|
||||
require.Less(t, keys[0].ExpiresAt, dbtime.Now().Add(time.Hour*24*8))
|
||||
require.Greater(t, keys[0].ExpiresAt, time.Now().Add(time.Hour*24*6))
|
||||
require.Less(t, keys[0].ExpiresAt, time.Now().Add(time.Hour*24*8))
|
||||
}
|
||||
|
||||
func TestTokenUserSetMaxLifetime(t *testing.T) {
|
||||
@@ -518,7 +518,7 @@ func TestExpireAPIKey(t *testing.T) {
|
||||
// Verify the token is not expired.
|
||||
key, err := adminClient.APIKeyByID(ctx, codersdk.Me, keyID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, key.ExpiresAt.After(dbtime.Now()))
|
||||
require.True(t, key.ExpiresAt.After(time.Now()))
|
||||
|
||||
auditor.ResetLogs()
|
||||
|
||||
@@ -529,7 +529,7 @@ func TestExpireAPIKey(t *testing.T) {
|
||||
// Verify the token is expired.
|
||||
key, err = adminClient.APIKeyByID(ctx, codersdk.Me, keyID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, key.ExpiresAt.Before(dbtime.Now()))
|
||||
require.True(t, key.ExpiresAt.Before(time.Now()))
|
||||
|
||||
// Verify audit log.
|
||||
als := auditor.AuditLogs()
|
||||
@@ -556,7 +556,7 @@ func TestExpireAPIKey(t *testing.T) {
|
||||
// Verify the token is expired.
|
||||
key, err := memberClient.APIKeyByID(ctx, codersdk.Me, keyID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, key.ExpiresAt.Before(dbtime.Now()))
|
||||
require.True(t, key.ExpiresAt.Before(time.Now()))
|
||||
})
|
||||
|
||||
t.Run("MemberCannotExpireOtherUsersToken", func(t *testing.T) {
|
||||
@@ -607,7 +607,7 @@ func TestExpireAPIKey(t *testing.T) {
|
||||
// Invariant: make sure it's actually expired
|
||||
key, err := adminClient.APIKeyByID(ctx, codersdk.Me, keyID)
|
||||
require.NoError(t, err)
|
||||
require.LessOrEqual(t, key.ExpiresAt, dbtime.Now(), "key should be expired")
|
||||
require.LessOrEqual(t, key.ExpiresAt, time.Now(), "key should be expired")
|
||||
|
||||
// Expire it again - should succeed (idempotent).
|
||||
err = adminClient.ExpireAPIKey(ctx, codersdk.Me, keyID)
|
||||
@@ -636,7 +636,7 @@ func TestExpireAPIKey(t *testing.T) {
|
||||
// Verify it's expired.
|
||||
key, err := adminClient.APIKeyByID(ctx, codersdk.Me, keyID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, key.ExpiresAt.Before(dbtime.Now()))
|
||||
require.True(t, key.ExpiresAt.Before(time.Now()))
|
||||
|
||||
// Delete the expired token - should succeed.
|
||||
err = adminClient.DeleteAPIKey(ctx, codersdk.Me, keyID)
|
||||
|
||||
@@ -599,11 +599,8 @@ func TestExecutorAutostopAIAgentActivity(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// When: the autobuild executor ticks after the bumped deadline.
|
||||
// Use time.Now() to account for elapsed time since the test's
|
||||
// "now" variable, because the activity bump uses the database
|
||||
// NOW() which advances with wall clock time.
|
||||
go func() {
|
||||
tickTime := time.Now().Add(time.Hour).Add(time.Minute)
|
||||
tickTime := now.Add(time.Hour).Add(time.Minute)
|
||||
coderdtest.UpdateProvisionerLastSeenAt(t, db, p.ID, tickTime)
|
||||
tickCh <- tickTime
|
||||
close(tickCh)
|
||||
|
||||
+253
-532
File diff suppressed because it is too large
Load Diff
+21
-106
@@ -73,11 +73,10 @@ func TestInterruptChatBroadcastsStatusAcrossInstances(t *testing.T) {
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case event := <-events:
|
||||
if event.Type == codersdk.ChatStreamEventTypeStatus && event.Status != nil {
|
||||
return event.Status.Status == codersdk.ChatStatusWaiting
|
||||
if event.Type != codersdk.ChatStreamEventTypeStatus || event.Status == nil {
|
||||
return false
|
||||
}
|
||||
t.Logf("skipping unexpected event: type=%s", event.Type)
|
||||
return false
|
||||
return event.Status.Status == codersdk.ChatStatusWaiting
|
||||
default:
|
||||
return false
|
||||
}
|
||||
@@ -367,7 +366,7 @@ func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
|
||||
require.Len(t, messages, 1)
|
||||
}
|
||||
|
||||
func TestSendMessageInterruptBehaviorQueuesAndInterruptsWhenBusy(t *testing.T) {
|
||||
func TestSendMessageInterruptBehaviorSendsImmediatelyWhenBusy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
@@ -399,31 +398,26 @@ func TestSendMessageInterruptBehaviorQueuesAndInterruptsWhenBusy(t *testing.T) {
|
||||
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// The message should be queued, not inserted directly.
|
||||
require.True(t, result.Queued)
|
||||
require.NotNil(t, result.QueuedMessage)
|
||||
|
||||
// The chat should transition to waiting (interrupt signal),
|
||||
// not pending.
|
||||
require.Equal(t, database.ChatStatusWaiting, result.Chat.Status)
|
||||
require.False(t, result.Queued)
|
||||
require.Equal(t, database.ChatStatusPending, result.Chat.Status)
|
||||
require.False(t, result.Chat.WorkerID.Valid)
|
||||
|
||||
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ChatStatusWaiting, fromDB.Status)
|
||||
require.Equal(t, database.ChatStatusPending, fromDB.Status)
|
||||
require.False(t, fromDB.WorkerID.Valid)
|
||||
|
||||
// The message should be in the queue, not in chat_messages.
|
||||
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queued, 1)
|
||||
require.Len(t, queued, 0)
|
||||
|
||||
// Only the initial user message should be in chat_messages.
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, messages, 1)
|
||||
require.Len(t, messages, 2)
|
||||
require.Equal(t, messages[len(messages)-1].ID, result.Message.ID)
|
||||
}
|
||||
|
||||
func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) {
|
||||
@@ -871,15 +865,15 @@ func TestSubscribeNoPubsubNoDuplicateMessageParts(t *testing.T) {
|
||||
// events — the snapshot already contained everything. Before
|
||||
// the fix, localSnapshot was replayed into the channel,
|
||||
// causing duplicates.
|
||||
require.Never(t, func() bool {
|
||||
select {
|
||||
case <-events:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
select {
|
||||
case event, ok := <-events:
|
||||
if ok {
|
||||
t.Fatalf("unexpected event from channel (would be a duplicate): type=%s", event.Type)
|
||||
}
|
||||
}, 200*time.Millisecond, testutil.IntervalFast,
|
||||
"expected no duplicate events after snapshot")
|
||||
// Channel closed without events is fine.
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
// No events — correct behavior.
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscribeAfterMessageID(t *testing.T) {
|
||||
@@ -1468,17 +1462,10 @@ func TestInterruptChatDoesNotSendWebPushNotification(t *testing.T) {
|
||||
// mockWebpushDispatcher implements webpush.Dispatcher and records Dispatch calls.
|
||||
type mockWebpushDispatcher struct {
|
||||
dispatchCount atomic.Int32
|
||||
mu sync.Mutex
|
||||
lastMessage codersdk.WebpushMessage
|
||||
lastUserID uuid.UUID
|
||||
}
|
||||
|
||||
func (m *mockWebpushDispatcher) Dispatch(_ context.Context, userID uuid.UUID, msg codersdk.WebpushMessage) error {
|
||||
func (m *mockWebpushDispatcher) Dispatch(_ context.Context, _ uuid.UUID, _ codersdk.WebpushMessage) error {
|
||||
m.dispatchCount.Add(1)
|
||||
m.mu.Lock()
|
||||
m.lastMessage = msg
|
||||
m.lastUserID = userID
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1490,78 +1477,6 @@ func (*mockWebpushDispatcher) PublicKey() string {
|
||||
return "test-vapid-public-key"
|
||||
}
|
||||
|
||||
func TestSuccessfulChatSendsWebPushWithNavigationData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Set up a mock OpenAI that returns a simple successful response.
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("title")
|
||||
}
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("done")...,
|
||||
)
|
||||
})
|
||||
|
||||
// Mock webpush dispatcher that captures the dispatched message.
|
||||
mockPush := &mockWebpushDispatcher{}
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := chatd.New(chatd.Config{
|
||||
Logger: logger,
|
||||
Database: db,
|
||||
ReplicaID: uuid.New(),
|
||||
Pubsub: ps,
|
||||
PendingChatAcquireInterval: 10 * time.Millisecond,
|
||||
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
||||
WebpushDispatcher: mockPush,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, server.Close())
|
||||
})
|
||||
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
||||
|
||||
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "push-nav-test",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for the chat to complete and return to waiting status.
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
||||
if dbErr != nil {
|
||||
return false
|
||||
}
|
||||
return fromDB.Status == database.ChatStatusWaiting && !fromDB.WorkerID.Valid && mockPush.dispatchCount.Load() == 1
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
// Verify a web push notification was dispatched exactly once.
|
||||
require.Equal(t, int32(1), mockPush.dispatchCount.Load(),
|
||||
"expected exactly one web push dispatch for a completed chat")
|
||||
|
||||
// Verify the notification was sent to the correct user.
|
||||
mockPush.mu.Lock()
|
||||
capturedMsg := mockPush.lastMessage
|
||||
capturedUserID := mockPush.lastUserID
|
||||
mockPush.mu.Unlock()
|
||||
|
||||
require.Equal(t, user.ID, capturedUserID,
|
||||
"web push should be dispatched to the chat owner")
|
||||
|
||||
// Verify the Data field contains the correct navigation URL.
|
||||
expectedURL := fmt.Sprintf("/agents/%s", chat.ID)
|
||||
require.Equal(t, expectedURL, capturedMsg.Data["url"],
|
||||
"web push Data should contain the chat navigation URL")
|
||||
}
|
||||
|
||||
func TestCloseDuringShutdownContextCanceledShouldRetryOnNewReplica(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -73,9 +73,7 @@ type RunOptions struct {
|
||||
// OnRetry is called before each retry attempt when the LLM
|
||||
// stream fails with a retryable error. It provides the attempt
|
||||
// number, error, and backoff delay so callers can publish status
|
||||
// events to connected clients. Callers should also clear any
|
||||
// buffered stream state from the failed attempt in this callback
|
||||
// to avoid sending duplicated content.
|
||||
// events to connected clients.
|
||||
OnRetry chatretry.OnRetryFn
|
||||
|
||||
OnInterruptedPersistError func(error)
|
||||
@@ -211,10 +209,6 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
var lastUsage fantasy.Usage
|
||||
var lastProviderMetadata fantasy.ProviderMetadata
|
||||
|
||||
totalSteps := 0
|
||||
// When totalSteps reaches MaxSteps the inner loop exits immediately
|
||||
// (its condition is false), stoppedByModel stays false, and the
|
||||
// post-loop guard breaks the outer compaction loop.
|
||||
for compactionAttempt := 0; ; compactionAttempt++ {
|
||||
alreadyCompacted := false
|
||||
// stoppedByModel is true when the inner step loop
|
||||
@@ -228,8 +222,7 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
// agent never had a chance to use the compacted context.
|
||||
compactedOnFinalStep := false
|
||||
|
||||
for step := 0; totalSteps < opts.MaxSteps; step++ {
|
||||
totalSteps++
|
||||
for step := 0; step < opts.MaxSteps; step++ {
|
||||
// Copy messages so that provider-specific caching
|
||||
// mutations don't leak back to the caller's slice.
|
||||
// copy copies Message structs by value, so field
|
||||
@@ -328,12 +321,6 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
lastUsage = result.usage
|
||||
lastProviderMetadata = result.providerMetadata
|
||||
|
||||
// Append the step's response messages so that both
|
||||
// inline and post-loop compaction see the full
|
||||
// conversation including the latest assistant reply.
|
||||
stepMessages := result.toResponseMessages()
|
||||
messages = append(messages, stepMessages...)
|
||||
|
||||
// Inline compaction.
|
||||
if opts.Compaction != nil && opts.ReloadMessages != nil {
|
||||
did, compactErr := tryCompact(
|
||||
@@ -367,11 +354,17 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
// The agent is continuing with tool calls, so any
|
||||
// prior compaction has already been consumed.
|
||||
compactedOnFinalStep = false
|
||||
|
||||
// Build messages from the step for the next iteration.
|
||||
// toResponseMessages produces assistant-role content
|
||||
// (text, reasoning, tool calls) and tool-result content.
|
||||
stepMessages := result.toResponseMessages()
|
||||
messages = append(messages, stepMessages...)
|
||||
}
|
||||
|
||||
// Post-run compaction safety net: if we never compacted
|
||||
// during the loop, try once at the end.
|
||||
if !alreadyCompacted && opts.Compaction != nil && opts.ReloadMessages != nil {
|
||||
if !alreadyCompacted && opts.Compaction != nil {
|
||||
did, err := tryCompact(
|
||||
ctx,
|
||||
opts.Model,
|
||||
@@ -390,6 +383,7 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
compactedOnFinalStep = true
|
||||
}
|
||||
}
|
||||
|
||||
// Re-enter the step loop when compaction fired on the
|
||||
// model's final step. This lets the agent continue
|
||||
// working with fresh summarized context instead of
|
||||
@@ -520,6 +514,7 @@ func processStepStream(
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
case fantasy.StreamPartTypeToolInputStart:
|
||||
activeToolCalls[part.ID] = &fantasy.ToolCallContent{
|
||||
ToolCallID: part.ID,
|
||||
|
||||
@@ -123,8 +123,7 @@ func tryCompact(
|
||||
config.SystemSummaryPrefix + "\n\n" + summary,
|
||||
)
|
||||
|
||||
persistCtx := context.WithoutCancel(ctx)
|
||||
err = config.Persist(persistCtx, CompactionResult{
|
||||
err = config.Persist(ctx, CompactionResult{
|
||||
SystemSummary: systemSummary,
|
||||
SummaryReport: summary,
|
||||
ThresholdPercent: config.ThresholdPercent,
|
||||
|
||||
@@ -76,20 +76,9 @@ func TestRun_Compaction(t *testing.T) {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) {
|
||||
return []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleUser, "hello"),
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// Compaction fires twice: once inline when the threshold is
|
||||
// reached on step 0 (the only step, since MaxSteps=1), and
|
||||
// once from the post-run safety net during the re-entry
|
||||
// iteration (where totalSteps already equals MaxSteps so the
|
||||
// inner loop doesn't execute, but lastUsage still exceeds
|
||||
// the threshold).
|
||||
require.Equal(t, 2, persistCompactionCalls)
|
||||
require.Equal(t, 1, persistCompactionCalls)
|
||||
require.Contains(t, persistedCompaction.SystemSummary, summaryText)
|
||||
require.Equal(t, summaryText, persistedCompaction.SummaryReport)
|
||||
require.Equal(t, int64(80), persistedCompaction.ContextTokens)
|
||||
@@ -162,25 +151,13 @@ func TestRun_Compaction(t *testing.T) {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) {
|
||||
return []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleUser, "hello"),
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// Compaction fires twice (see PersistsWhenThresholdReached
|
||||
// for the full explanation). Each cycle follows the order:
|
||||
// publish_tool_call → generate → persist → publish_tool_result.
|
||||
require.Equal(t, []string{
|
||||
"publish_tool_call",
|
||||
"generate",
|
||||
"persist",
|
||||
"publish_tool_result",
|
||||
"publish_tool_call",
|
||||
"generate",
|
||||
"persist",
|
||||
"publish_tool_result",
|
||||
}, callOrder)
|
||||
})
|
||||
|
||||
@@ -480,11 +457,6 @@ func TestRun_Compaction(t *testing.T) {
|
||||
compactionErr = err
|
||||
},
|
||||
},
|
||||
ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) {
|
||||
return []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleUser, "hello"),
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Error(t, compactionErr)
|
||||
@@ -600,117 +572,4 @@ func TestRun_Compaction(t *testing.T) {
|
||||
// Two stream calls: one before compaction, one after re-entry.
|
||||
require.Equal(t, 2, streamCallCount)
|
||||
})
|
||||
|
||||
t.Run("PostRunCompactionReEntryIncludesUserSummary", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// After compaction the summary is stored as a user-role
|
||||
// message. When the loop re-enters, the reloaded prompt
|
||||
// must contain this user message so the LLM provider
|
||||
// receives a valid prompt (providers like Anthropic
|
||||
// require at least one non-system message).
|
||||
|
||||
var mu sync.Mutex
|
||||
var streamCallCount int
|
||||
var reEntryPrompt []fantasy.Message
|
||||
persistCompactionCalls := 0
|
||||
|
||||
const summaryText = "post-run compacted summary"
|
||||
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
mu.Lock()
|
||||
step := streamCallCount
|
||||
streamCallCount++
|
||||
mu.Unlock()
|
||||
|
||||
switch step {
|
||||
case 0:
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "initial response"},
|
||||
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
|
||||
{
|
||||
Type: fantasy.StreamPartTypeFinish,
|
||||
FinishReason: fantasy.FinishReasonStop,
|
||||
Usage: fantasy.Usage{
|
||||
InputTokens: 80,
|
||||
TotalTokens: 85,
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
default:
|
||||
mu.Lock()
|
||||
reEntryPrompt = append([]fantasy.Message(nil), call.Prompt...)
|
||||
mu.Unlock()
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextStart, ID: "text-2"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-2", Delta: "continued"},
|
||||
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-2"},
|
||||
{
|
||||
Type: fantasy.StreamPartTypeFinish,
|
||||
FinishReason: fantasy.FinishReasonStop,
|
||||
Usage: fantasy.Usage{
|
||||
InputTokens: 20,
|
||||
TotalTokens: 25,
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
},
|
||||
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
return &fantasy.Response{
|
||||
Content: []fantasy.Content{
|
||||
fantasy.TextContent{Text: summaryText},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
// Simulate real post-compaction DB state: the summary is
|
||||
// a user-role message (the only non-system content).
|
||||
compactedMessages := []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleSystem, "system prompt"),
|
||||
textMessage(fantasy.MessageRoleUser, "Summary of earlier chat context:\n\ncompacted summary"),
|
||||
}
|
||||
|
||||
err := Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
Messages: []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleUser, "hello"),
|
||||
},
|
||||
MaxSteps: 5,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
ContextLimitFallback: 100,
|
||||
Compaction: &CompactionOptions{
|
||||
ThresholdPercent: 70,
|
||||
SummaryPrompt: "summarize now",
|
||||
Persist: func(_ context.Context, _ CompactionResult) error {
|
||||
persistCompactionCalls++
|
||||
return nil
|
||||
},
|
||||
},
|
||||
ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) {
|
||||
return compactedMessages, nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.GreaterOrEqual(t, persistCompactionCalls, 1)
|
||||
// Re-entry happened: stream was called at least twice.
|
||||
require.Equal(t, 2, streamCallCount)
|
||||
// The re-entry prompt must contain the user summary.
|
||||
require.NotEmpty(t, reEntryPrompt)
|
||||
hasUser := false
|
||||
for _, msg := range reEntryPrompt {
|
||||
if msg.Role == fantasy.MessageRoleUser {
|
||||
hasUser = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasUser, "re-entry prompt must contain a user message (the compaction summary)")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
package chatprompt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyopenai "charm.land/fantasy/providers/openai"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
@@ -18,156 +16,12 @@ import (
|
||||
|
||||
var toolCallIDSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_-]`)
|
||||
|
||||
// FileData holds resolved file content for LLM prompt building.
|
||||
type FileData struct {
|
||||
Data []byte
|
||||
MediaType string
|
||||
}
|
||||
|
||||
// FileResolver fetches file content by ID for LLM prompt building.
|
||||
type FileResolver func(ctx context.Context, ids []uuid.UUID) (map[uuid.UUID]FileData, error)
|
||||
|
||||
// ExtractFileID parses the file_id from a serialized file content
|
||||
// block envelope. Returns uuid.Nil and an error when the block is
|
||||
// not a file-type block or has no file_id.
|
||||
func ExtractFileID(raw json.RawMessage) (uuid.UUID, error) {
|
||||
var envelope struct {
|
||||
Type string `json:"type"`
|
||||
Data struct {
|
||||
FileID string `json:"file_id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &envelope); err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("unmarshal content block: %w", err)
|
||||
}
|
||||
if !strings.EqualFold(envelope.Type, string(fantasy.ContentTypeFile)) {
|
||||
return uuid.Nil, xerrors.Errorf("not a file content block: %s", envelope.Type)
|
||||
}
|
||||
if envelope.Data.FileID == "" {
|
||||
return uuid.Nil, xerrors.New("no file_id")
|
||||
}
|
||||
return uuid.Parse(envelope.Data.FileID)
|
||||
}
|
||||
|
||||
// extractFileIDs scans raw message content for file_id references.
|
||||
// Returns a map of block index to file ID. Returns nil for
|
||||
// non-array content or content with no file references.
|
||||
func extractFileIDs(raw pqtype.NullRawMessage) map[int]uuid.UUID {
|
||||
if !raw.Valid || len(raw.RawMessage) == 0 {
|
||||
return nil
|
||||
}
|
||||
var rawBlocks []json.RawMessage
|
||||
if err := json.Unmarshal(raw.RawMessage, &rawBlocks); err != nil {
|
||||
return nil
|
||||
}
|
||||
var result map[int]uuid.UUID
|
||||
for i, block := range rawBlocks {
|
||||
fid, err := ExtractFileID(block)
|
||||
if err == nil {
|
||||
if result == nil {
|
||||
result = make(map[int]uuid.UUID)
|
||||
}
|
||||
result[i] = fid
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// patchFileContent fills in empty Data on FileContent blocks from
|
||||
// resolved file data. Blocks that already have inline data (backward
|
||||
// compat) or have no resolved data are left unchanged.
|
||||
func patchFileContent(
|
||||
content []fantasy.Content,
|
||||
fileIDs map[int]uuid.UUID,
|
||||
resolved map[uuid.UUID]FileData,
|
||||
) {
|
||||
for blockIdx, fid := range fileIDs {
|
||||
if blockIdx >= len(content) {
|
||||
continue
|
||||
}
|
||||
switch fc := content[blockIdx].(type) {
|
||||
case fantasy.FileContent:
|
||||
if len(fc.Data) > 0 {
|
||||
continue
|
||||
}
|
||||
if data, found := resolved[fid]; found {
|
||||
fc.Data = data.Data
|
||||
content[blockIdx] = fc
|
||||
}
|
||||
case *fantasy.FileContent:
|
||||
if len(fc.Data) > 0 {
|
||||
continue
|
||||
}
|
||||
if data, found := resolved[fid]; found {
|
||||
fc.Data = data.Data
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertMessages converts persisted chat messages into LLM prompt
|
||||
// messages without resolving file references from storage. Inline
|
||||
// file data is preserved when present (backward compat).
|
||||
func ConvertMessages(
|
||||
messages []database.ChatMessage,
|
||||
) ([]fantasy.Message, error) {
|
||||
return ConvertMessagesWithFiles(context.Background(), messages, nil)
|
||||
}
|
||||
|
||||
// ConvertMessagesWithFiles converts persisted chat messages into LLM
|
||||
// prompt messages, resolving file references via the provided
|
||||
// resolver. When resolver is nil, file blocks without inline data
|
||||
// are passed through as-is (same behavior as ConvertMessages).
|
||||
func ConvertMessagesWithFiles(
|
||||
ctx context.Context,
|
||||
messages []database.ChatMessage,
|
||||
resolver FileResolver,
|
||||
) ([]fantasy.Message, error) {
|
||||
// Phase 1: Pre-scan user messages for file_id references.
|
||||
var allFileIDs []uuid.UUID
|
||||
seenFileIDs := make(map[uuid.UUID]struct{})
|
||||
fileIDsByMsg := make(map[int]map[int]uuid.UUID)
|
||||
|
||||
if resolver != nil {
|
||||
for i, msg := range messages {
|
||||
visibility := msg.Visibility
|
||||
if visibility == "" {
|
||||
visibility = database.ChatMessageVisibilityBoth
|
||||
}
|
||||
if visibility != database.ChatMessageVisibilityModel &&
|
||||
visibility != database.ChatMessageVisibilityBoth {
|
||||
continue
|
||||
}
|
||||
if msg.Role != string(fantasy.MessageRoleUser) {
|
||||
continue
|
||||
}
|
||||
fids := extractFileIDs(msg.Content)
|
||||
if len(fids) > 0 {
|
||||
fileIDsByMsg[i] = fids
|
||||
for _, fid := range fids {
|
||||
if _, seen := seenFileIDs[fid]; !seen {
|
||||
seenFileIDs[fid] = struct{}{}
|
||||
allFileIDs = append(allFileIDs, fid)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: Batch resolve file data.
|
||||
var resolved map[uuid.UUID]FileData
|
||||
if len(allFileIDs) > 0 {
|
||||
var err error
|
||||
resolved, err = resolver(ctx, allFileIDs)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("resolve chat files: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 3: Convert messages, patching file content as needed.
|
||||
prompt := make([]fantasy.Message, 0, len(messages))
|
||||
toolNameByCallID := make(map[string]string)
|
||||
for i, message := range messages {
|
||||
for _, message := range messages {
|
||||
visibility := message.Visibility
|
||||
if visibility == "" {
|
||||
visibility = database.ChatMessageVisibilityBoth
|
||||
@@ -197,9 +51,6 @@ func ConvertMessagesWithFiles(
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if fids, ok := fileIDsByMsg[i]; ok {
|
||||
patchFileContent(content, fids, resolved)
|
||||
}
|
||||
prompt = append(prompt, fantasy.Message{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: ToMessageParts(content),
|
||||
@@ -549,10 +400,7 @@ func ExtractToolCalls(parts []fantasy.MessagePart) []fantasy.ToolCallContent {
|
||||
}
|
||||
|
||||
// MarshalContent encodes message content blocks for persistence.
|
||||
// fileIDs optionally maps block indices to chat_files IDs, which
|
||||
// are injected into the JSON envelope for file-type blocks so
|
||||
// the reference survives round-trips through storage.
|
||||
func MarshalContent(blocks []fantasy.Content, fileIDs map[int]uuid.UUID) (pqtype.NullRawMessage, error) {
|
||||
func MarshalContent(blocks []fantasy.Content) (pqtype.NullRawMessage, error) {
|
||||
if len(blocks) == 0 {
|
||||
return pqtype.NullRawMessage{}, nil
|
||||
}
|
||||
@@ -567,16 +415,6 @@ func MarshalContent(blocks []fantasy.Content, fileIDs map[int]uuid.UUID) (pqtype
|
||||
err,
|
||||
)
|
||||
}
|
||||
if fid, ok := fileIDs[i]; ok {
|
||||
encoded, err = injectFileID(encoded, fid)
|
||||
if err != nil {
|
||||
return pqtype.NullRawMessage{}, xerrors.Errorf(
|
||||
"inject file_id into content block %d: %w",
|
||||
i,
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
encodedBlocks = append(encodedBlocks, encoded)
|
||||
}
|
||||
|
||||
@@ -587,27 +425,6 @@ func MarshalContent(blocks []fantasy.Content, fileIDs map[int]uuid.UUID) (pqtype
|
||||
return pqtype.NullRawMessage{RawMessage: data, Valid: true}, nil
|
||||
}
|
||||
|
||||
// injectFileID adds a file_id field into the data sub-object of a
|
||||
// serialized content block envelope. This follows the same pattern
|
||||
// as the reasoning title injection in marshalContentBlock.
|
||||
func injectFileID(encoded json.RawMessage, fileID uuid.UUID) (json.RawMessage, error) {
|
||||
var envelope struct {
|
||||
Type string `json:"type"`
|
||||
Data struct {
|
||||
MediaType string `json:"media_type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
FileID string `json:"file_id,omitempty"`
|
||||
ProviderMetadata *json.RawMessage `json:"provider_metadata,omitempty"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(encoded, &envelope); err != nil {
|
||||
return encoded, err
|
||||
}
|
||||
envelope.Data.FileID = fileID.String()
|
||||
envelope.Data.Data = nil // Strip inline data; resolved at LLM dispatch time.
|
||||
return json.Marshal(envelope)
|
||||
}
|
||||
|
||||
// MarshalToolResult encodes a single tool result for persistence as
|
||||
// an opaque JSON blob. The stored shape is
|
||||
// [{"tool_call_id":…,"tool_name":…,"result":…,"is_error":…}].
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
package chatprompt_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
@@ -55,7 +52,7 @@ func TestConvertMessages_NormalizesAssistantToolCallInput(t *testing.T) {
|
||||
ToolName: "execute",
|
||||
Input: tc.input,
|
||||
},
|
||||
}, nil)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
toolContent, err := chatprompt.MarshalToolResult(
|
||||
@@ -92,139 +89,3 @@ func TestConvertMessages_NormalizesAssistantToolCallInput(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertMessagesWithFiles_ResolvesFileData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fileID := uuid.New()
|
||||
fileData := []byte("fake-image-bytes")
|
||||
|
||||
// Build a user message with file_id but no inline data, as
|
||||
// would be stored after injectFileID strips the data.
|
||||
rawContent := mustJSON(t, []json.RawMessage{
|
||||
mustJSON(t, map[string]any{
|
||||
"type": "file",
|
||||
"data": map[string]any{
|
||||
"media_type": "image/png",
|
||||
"file_id": fileID.String(),
|
||||
},
|
||||
}),
|
||||
})
|
||||
|
||||
resolver := func(_ context.Context, ids []uuid.UUID) (map[uuid.UUID]chatprompt.FileData, error) {
|
||||
result := make(map[uuid.UUID]chatprompt.FileData)
|
||||
for _, id := range ids {
|
||||
if id == fileID {
|
||||
result[id] = chatprompt.FileData{
|
||||
Data: fileData,
|
||||
MediaType: "image/png",
|
||||
}
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
prompt, err := chatprompt.ConvertMessagesWithFiles(
|
||||
context.Background(),
|
||||
[]database.ChatMessage{
|
||||
{
|
||||
Role: string(fantasy.MessageRoleUser),
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
Content: pqtype.NullRawMessage{RawMessage: rawContent, Valid: true},
|
||||
},
|
||||
},
|
||||
resolver,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, prompt, 1)
|
||||
require.Equal(t, fantasy.MessageRoleUser, prompt[0].Role)
|
||||
require.Len(t, prompt[0].Content, 1)
|
||||
|
||||
filePart, ok := fantasy.AsMessagePart[fantasy.FilePart](prompt[0].Content[0])
|
||||
require.True(t, ok, "expected FilePart")
|
||||
require.Equal(t, fileData, filePart.Data)
|
||||
require.Equal(t, "image/png", filePart.MediaType)
|
||||
}
|
||||
|
||||
func TestConvertMessagesWithFiles_BackwardCompat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// A message with inline data and a file_id should use the
|
||||
// inline data even when the resolver returns nothing.
|
||||
fileID := uuid.New()
|
||||
inlineData := []byte("inline-image-data")
|
||||
|
||||
rawContent := mustJSON(t, []json.RawMessage{
|
||||
mustJSON(t, map[string]any{
|
||||
"type": "file",
|
||||
"data": map[string]any{
|
||||
"media_type": "image/png",
|
||||
"data": inlineData,
|
||||
"file_id": fileID.String(),
|
||||
},
|
||||
}),
|
||||
})
|
||||
|
||||
prompt, err := chatprompt.ConvertMessagesWithFiles(
|
||||
context.Background(),
|
||||
[]database.ChatMessage{
|
||||
{
|
||||
Role: string(fantasy.MessageRoleUser),
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
Content: pqtype.NullRawMessage{RawMessage: rawContent, Valid: true},
|
||||
},
|
||||
},
|
||||
nil, // No resolver.
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, prompt, 1)
|
||||
require.Len(t, prompt[0].Content, 1)
|
||||
|
||||
filePart, ok := fantasy.AsMessagePart[fantasy.FilePart](prompt[0].Content[0])
|
||||
require.True(t, ok, "expected FilePart")
|
||||
require.Equal(t, inlineData, filePart.Data)
|
||||
}
|
||||
|
||||
func TestInjectFileID_StripsInlineData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fileID := uuid.New()
|
||||
imageData := []byte("raw-image-bytes")
|
||||
|
||||
// Marshal a file content block with inline data, then inject
|
||||
// a file_id. The result should have file_id but no data.
|
||||
content, err := chatprompt.MarshalContent([]fantasy.Content{
|
||||
fantasy.FileContent{
|
||||
MediaType: "image/png",
|
||||
Data: imageData,
|
||||
},
|
||||
}, map[int]uuid.UUID{0: fileID})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse the stored content to verify shape.
|
||||
var blocks []json.RawMessage
|
||||
require.NoError(t, json.Unmarshal(content.RawMessage, &blocks))
|
||||
require.Len(t, blocks, 1)
|
||||
|
||||
var envelope struct {
|
||||
Type string `json:"type"`
|
||||
Data struct {
|
||||
MediaType string `json:"media_type"`
|
||||
Data *json.RawMessage `json:"data,omitempty"`
|
||||
FileID string `json:"file_id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(blocks[0], &envelope))
|
||||
require.Equal(t, "file", envelope.Type)
|
||||
require.Equal(t, "image/png", envelope.Data.MediaType)
|
||||
require.Equal(t, fileID.String(), envelope.Data.FileID)
|
||||
// Data should be nil (omitted) since injectFileID strips it.
|
||||
require.Nil(t, envelope.Data.Data, "inline data should be stripped")
|
||||
}
|
||||
|
||||
func mustJSON(t *testing.T, v any) json.RawMessage {
|
||||
t.Helper()
|
||||
data, err := json.Marshal(v)
|
||||
require.NoError(t, err)
|
||||
return data
|
||||
}
|
||||
|
||||
@@ -8,8 +8,6 @@ import (
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -20,12 +18,6 @@ const (
|
||||
// MaxDelay is the upper bound for the exponential backoff
|
||||
// duration. Matches the cap used in coder/mux.
|
||||
MaxDelay = 60 * time.Second
|
||||
|
||||
// MaxAttempts is the upper bound on retry attempts before
|
||||
// giving up. With a 60s max backoff this allows roughly
|
||||
// 25 minutes of retries, which is reasonable for transient
|
||||
// LLM provider issues.
|
||||
MaxAttempts = 25
|
||||
)
|
||||
|
||||
// nonRetryablePatterns are substrings that indicate a permanent error
|
||||
@@ -139,8 +131,9 @@ type RetryFn func(ctx context.Context) error
|
||||
type OnRetryFn func(attempt int, err error, delay time.Duration)
|
||||
|
||||
// Retry calls fn repeatedly until it succeeds, returns a
|
||||
// non-retryable error, ctx is canceled, or MaxAttempts is reached.
|
||||
// Retries use exponential backoff capped at MaxDelay.
|
||||
// non-retryable error, or ctx is canceled. There is no max attempt
|
||||
// limit — retries continue indefinitely with exponential backoff
|
||||
// (capped at 60s), matching the behavior of coder/mux.
|
||||
//
|
||||
// The onRetry callback (if non-nil) is called before each retry
|
||||
// attempt, giving the caller a chance to reset state, log, or
|
||||
@@ -163,15 +156,10 @@ func Retry(ctx context.Context, fn RetryFn, onRetry OnRetryFn) error {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
attempt++
|
||||
if attempt >= MaxAttempts {
|
||||
return xerrors.Errorf("max retry attempts (%d) exceeded: %w", MaxAttempts, err)
|
||||
}
|
||||
|
||||
delay := Delay(attempt - 1)
|
||||
delay := Delay(attempt)
|
||||
|
||||
if onRetry != nil {
|
||||
onRetry(attempt, err, delay)
|
||||
onRetry(attempt+1, err, delay)
|
||||
}
|
||||
|
||||
timer := time.NewTimer(delay)
|
||||
@@ -181,5 +169,7 @@ func Retry(ctx context.Context, fn RetryFn, onRetry OnRetryFn) error {
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
attempt++
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/openai/openai-go/v3/responses"
|
||||
)
|
||||
|
||||
// OpenAIHandler handles OpenAI API requests and returns a response.
|
||||
@@ -307,17 +306,6 @@ func writeChatCompletionsStreaming(w http.ResponseWriter, r *http.Request, chunk
|
||||
}
|
||||
}
|
||||
|
||||
// writeSSEEvent marshals v as JSON and writes it as an SSE data
|
||||
// frame. Returns any write error.
|
||||
func writeSSEEvent(w http.ResponseWriter, v interface{}) error {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
return err
|
||||
}
|
||||
|
||||
func writeResponsesAPIStreaming(w http.ResponseWriter, r *http.Request, chunks <-chan OpenAIChunk) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
@@ -341,23 +329,7 @@ func writeResponsesAPIStreaming(w http.ResponseWriter, r *http.Request, chunks <
|
||||
return
|
||||
case chunk, ok = <-chunks:
|
||||
if !ok {
|
||||
// Emit Responses API lifecycle events so
|
||||
// the fantasy client closes open text
|
||||
// blocks and persists the step content.
|
||||
for outputIndex, itemID := range itemIDs {
|
||||
_ = writeSSEEvent(w, responses.ResponseTextDoneEvent{
|
||||
ItemID: itemID,
|
||||
OutputIndex: int64(outputIndex),
|
||||
})
|
||||
_ = writeSSEEvent(w, responses.ResponseOutputItemDoneEvent{
|
||||
OutputIndex: int64(outputIndex),
|
||||
Item: responses.ResponseOutputItemUnion{
|
||||
ID: itemID,
|
||||
Type: "message",
|
||||
},
|
||||
})
|
||||
}
|
||||
_ = writeSSEEvent(w, responses.ResponseCompletedEvent{})
|
||||
_, _ = fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
return
|
||||
}
|
||||
@@ -372,19 +344,6 @@ func writeResponsesAPIStreaming(w http.ResponseWriter, r *http.Request, chunks <
|
||||
if !found {
|
||||
itemID = fmt.Sprintf("msg_%s", uuid.New().String()[:8])
|
||||
itemIDs[outputIndex] = itemID
|
||||
|
||||
// Emit response.output_item.added so the
|
||||
// fantasy client triggers TextStart.
|
||||
if err := writeSSEEvent(w, responses.ResponseOutputItemAddedEvent{
|
||||
OutputIndex: int64(outputIndex),
|
||||
Item: responses.ResponseOutputItemUnion{
|
||||
ID: itemID,
|
||||
Type: "message",
|
||||
},
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
chunkData := map[string]interface{}{
|
||||
|
||||
+229
-502
@@ -1,22 +1,17 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
@@ -40,8 +35,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/wsjson"
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -251,7 +244,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
contentBlocks, contentFileIDs, titleSource, inputError := createChatInputFromRequest(ctx, api.Database, req)
|
||||
contentBlocks, titleSource, inputError := createChatInputFromRequest(req)
|
||||
if inputError != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, *inputError)
|
||||
return
|
||||
@@ -286,7 +279,6 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
ModelConfigID: modelConfigID,
|
||||
SystemPrompt: defaultChatSystemPrompt(),
|
||||
InitialUserContent: contentBlocks,
|
||||
ContentFileIDs: contentFileIDs,
|
||||
})
|
||||
if err != nil {
|
||||
if database.IsForeignKeyViolation(
|
||||
@@ -421,162 +413,6 @@ func (api *API) getChat(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
// @Summary Watch git changes for a chat.
|
||||
// @ID watch-chat-git
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Chats
|
||||
// @Param chat path string true "Chat ID" format(uuid)
|
||||
// @Success 101
|
||||
// @Router /chats/{chat}/git/watch [get]
|
||||
//
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // HTTP handler writes to ResponseWriter.
|
||||
func (api *API) watchChatGit(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
chat = httpmw.ChatParam(r)
|
||||
logger = api.Logger.Named("chat_git_watcher").With(slog.F("chat_id", chat.ID))
|
||||
)
|
||||
|
||||
if !chat.WorkspaceID.Valid {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Chat has no workspace to watch.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
agents, err := api.Database.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, chat.WorkspaceID.UUID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching workspace agents.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if len(agents) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Chat workspace has no agents.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
apiAgent, err := db2sdk.WorkspaceAgent(
|
||||
api.DERPMap(),
|
||||
*api.TailnetCoordinator.Load(),
|
||||
agents[0],
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
api.AgentInactiveDisconnectTimeout,
|
||||
api.DeploymentValues.AgentFallbackTroubleshootingURL.String(),
|
||||
)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error reading workspace agent.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if apiAgent.Status != codersdk.WorkspaceAgentConnected {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Agent state is %q, it must be in the %q state.", apiAgent.Status, codersdk.WorkspaceAgentConnected),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer dialCancel()
|
||||
|
||||
agentConn, release, err := api.agentProvider.AgentConn(dialCtx, agents[0].ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error dialing workspace agent.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
defer release()
|
||||
|
||||
agentStream, err := agentConn.WatchGit(ctx, logger, chat.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error watching agent's git state.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
defer agentStream.Close(websocket.StatusGoingAway)
|
||||
|
||||
clientConn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
|
||||
CompressionMode: websocket.CompressionNoContextTakeover,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error(ctx, "failed to accept websocket", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
clientStream := wsjson.NewStream[
|
||||
codersdk.WorkspaceAgentGitClientMessage,
|
||||
codersdk.WorkspaceAgentGitServerMessage,
|
||||
](clientConn, websocket.MessageText, websocket.MessageText, logger)
|
||||
|
||||
ctx, cancel := context.WithCancel(r.Context())
|
||||
defer cancel()
|
||||
|
||||
go httpapi.HeartbeatClose(ctx, logger, cancel, clientConn)
|
||||
|
||||
// Proxy agent → client.
|
||||
agentCh := agentStream.Chan()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-api.ctx.Done():
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case msg, ok := <-agentCh:
|
||||
if !ok {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
if err := clientStream.Send(msg); err != nil {
|
||||
logger.Debug(ctx, "failed to forward agent message to client", slog.Error(err))
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Proxy client → agent.
|
||||
clientCh := clientStream.Chan()
|
||||
proxyLoop:
|
||||
for {
|
||||
select {
|
||||
case <-api.ctx.Done():
|
||||
break proxyLoop
|
||||
case <-ctx.Done():
|
||||
break proxyLoop
|
||||
case msg, ok := <-clientCh:
|
||||
if !ok {
|
||||
break proxyLoop
|
||||
}
|
||||
if err := agentStream.Send(msg); err != nil {
|
||||
logger.Debug(ctx, "failed to forward client message to agent", slog.Error(err))
|
||||
break proxyLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cancel()
|
||||
wg.Wait()
|
||||
_ = clientStream.Close(websocket.StatusGoingAway)
|
||||
}
|
||||
|
||||
// @Summary Archive a chat
|
||||
// @ID archive-chat
|
||||
// @Tags Chats
|
||||
@@ -593,15 +429,7 @@ func (api *API) archiveChat(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
var err error
|
||||
// Use chatDaemon when available so it can notify
|
||||
// active subscribers. Fall back to direct DB for the
|
||||
// simple archive flag — no streaming state is involved.
|
||||
if api.chatDaemon != nil {
|
||||
err = api.chatDaemon.ArchiveChat(ctx, chat.ID)
|
||||
} else {
|
||||
err = api.Database.ArchiveChatByID(ctx, chat.ID)
|
||||
}
|
||||
err := api.Database.ArchiveChatByID(ctx, chat.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to archive chat.",
|
||||
@@ -629,15 +457,7 @@ func (api *API) unarchiveChat(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
var err error
|
||||
// Use chatDaemon when available so it can notify
|
||||
// active subscribers. Fall back to direct DB for the
|
||||
// simple unarchive flag — no streaming state is involved.
|
||||
if api.chatDaemon != nil {
|
||||
err = api.chatDaemon.UnarchiveChat(ctx, chat.ID)
|
||||
} else {
|
||||
err = api.Database.UnarchiveChatByID(ctx, chat.ID)
|
||||
}
|
||||
err := api.Database.UnarchiveChatByID(ctx, chat.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to unarchive chat.",
|
||||
@@ -668,7 +488,7 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
contentBlocks, contentFileIDs, _, inputError := createChatInputFromParts(ctx, api.Database, req.Content, "content")
|
||||
contentBlocks, _, inputError := createChatInputFromParts(req.Content, "content")
|
||||
if inputError != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: inputError.Message,
|
||||
@@ -680,11 +500,10 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
|
||||
sendResult, sendErr := api.chatDaemon.SendMessage(
|
||||
ctx,
|
||||
chatd.SendMessageOptions{
|
||||
ChatID: chatID,
|
||||
Content: contentBlocks,
|
||||
ContentFileIDs: contentFileIDs,
|
||||
ModelConfigID: req.ModelConfigID,
|
||||
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
||||
ChatID: chatID,
|
||||
Content: contentBlocks,
|
||||
ModelConfigID: req.ModelConfigID,
|
||||
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
||||
},
|
||||
)
|
||||
if sendErr != nil {
|
||||
@@ -743,7 +562,7 @@ func (api *API) patchChatMessage(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
contentBlocks, contentFileIDs, _, inputError := createChatInputFromParts(ctx, api.Database, req.Content, "content")
|
||||
contentBlocks, _, inputError := createChatInputFromParts(req.Content, "content")
|
||||
if inputError != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: inputError.Message,
|
||||
@@ -756,7 +575,6 @@ func (api *API) patchChatMessage(rw http.ResponseWriter, r *http.Request) {
|
||||
ChatID: chat.ID,
|
||||
EditedMessageID: messageID,
|
||||
Content: contentBlocks,
|
||||
ContentFileIDs: contentFileIDs,
|
||||
})
|
||||
if editErr != nil {
|
||||
switch {
|
||||
@@ -871,6 +689,18 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to open chat stream.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
<-senderClosed
|
||||
}()
|
||||
|
||||
var afterMessageID int64
|
||||
if v := r.URL.Query().Get("after_id"); v != "" {
|
||||
var err error
|
||||
@@ -884,31 +714,14 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to open chat stream.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
snapshot, events, cancel, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
|
||||
if !ok {
|
||||
_ = sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeError,
|
||||
Data: codersdk.Response{
|
||||
Message: "Chat streaming is not available.",
|
||||
Detail: "Chat stream state is not configured.",
|
||||
},
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Chat streaming is not available.",
|
||||
Detail: "Chat stream state is not configured.",
|
||||
})
|
||||
// Ensure the WebSocket is closed so senderClosed
|
||||
// completes and the handler can return.
|
||||
<-senderClosed
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
<-senderClosed
|
||||
}()
|
||||
defer cancel()
|
||||
|
||||
sendChatStreamBatch := func(batch []codersdk.ChatStreamEvent) error {
|
||||
@@ -1001,13 +814,9 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
|
||||
if updateErr != nil {
|
||||
api.Logger.Error(ctx, "failed to mark chat as waiting",
|
||||
slog.F("chat_id", chatID), slog.Error(updateErr))
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to interrupt chat.",
|
||||
Detail: updateErr.Error(),
|
||||
})
|
||||
return
|
||||
} else {
|
||||
chat = updatedChat
|
||||
}
|
||||
chat = updatedChat
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, convertChat(chat, nil))
|
||||
@@ -1052,6 +861,198 @@ func (api *API) getChatDiffContents(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, diff)
|
||||
}
|
||||
|
||||
// @Summary Get file content from a chat's linked GitHub repository
|
||||
// @ID get-chat-diff-file-content
|
||||
// @Security CoderSessionToken
|
||||
// @Produce application/octet-stream
|
||||
// @Tags Chats
|
||||
// @Param chat path string true "Chat ID" format(uuid)
|
||||
// @Param path query string true "Repo-relative file path"
|
||||
// @Param ref query string true "Git ref (SHA or branch name)"
|
||||
// @Success 200
|
||||
// @Router /chats/{chat}/diff/file-content [get]
|
||||
//
|
||||
// getChatDiffFileContent proxies a single file's raw content from
|
||||
// the chat's linked GitHub repository. The frontend uses this to
|
||||
// render image diffs for binary files that cannot be shown as text.
|
||||
// It resolves the repository owner/name from the chat's cached diff
|
||||
// reference and streams the raw bytes back from the GitHub Contents
|
||||
// API with a 10 MiB body limit.
|
||||
//
|
||||
//nolint:revive // HTTP handler writes to ResponseWriter.
|
||||
func (api *API) getChatDiffFileContent(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
|
||||
filePath := r.URL.Query().Get("path")
|
||||
if filePath == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Missing required query parameter: path.",
|
||||
})
|
||||
return
|
||||
}
|
||||
// Reject absolute paths and path traversal attempts. The path
|
||||
// must be a clean relative path within the repository tree.
|
||||
if strings.HasPrefix(filePath, "/") {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Query parameter 'path' must be a relative file path.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if strings.Contains(filePath, "..") {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Query parameter 'path' must not contain path traversal segments.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ref := r.URL.Query().Get("ref")
|
||||
if ref == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Missing required query parameter: ref.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
owner, repo, token, err := api.resolveGitHubRepoForChat(ctx, chat)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to resolve repository for chat.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if owner == "" || repo == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "Chat does not have a GitHub repository reference.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
contentType, body, err := api.fetchGitHubFileContent(ctx, owner, repo, filePath, ref, token)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadGateway, codersdk.Response{
|
||||
Message: "Failed to fetch file content from GitHub.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
rw.Header().Set("Content-Type", contentType)
|
||||
rw.Header().Set("Cache-Control", "private, max-age=300")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
_, _ = rw.Write(body)
|
||||
}
|
||||
|
||||
// resolveGitHubRepoForChat resolves the GitHub owner, repo, and
|
||||
// access token for a chat by looking up its cached diff reference.
|
||||
// Returns empty owner/repo when the chat is not linked to a GitHub
|
||||
// repository.
|
||||
func (api *API) resolveGitHubRepoForChat(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
) (owner, repo, token string, err error) {
|
||||
status, found, err := api.getCachedChatDiffStatus(ctx, chat.ID)
|
||||
if err != nil {
|
||||
return "", "", "", xerrors.Errorf("get cached diff status: %w", err)
|
||||
}
|
||||
|
||||
reference, err := api.resolveChatDiffReference(ctx, chat, found, status)
|
||||
if err != nil {
|
||||
return "", "", "", xerrors.Errorf("resolve diff reference: %w", err)
|
||||
}
|
||||
|
||||
if reference.RepositoryRef == nil ||
|
||||
!strings.EqualFold(reference.RepositoryRef.Provider, string(codersdk.EnhancedExternalAuthProviderGitHub)) {
|
||||
return "", "", "", nil
|
||||
}
|
||||
|
||||
token = api.resolveChatGitHubAccessToken(ctx, chat.OwnerID)
|
||||
|
||||
owner = reference.RepositoryRef.Owner
|
||||
repo = reference.RepositoryRef.Repo
|
||||
if owner == "" || repo == "" {
|
||||
if reference.PullRequestURL != "" {
|
||||
prRef, ok := parseGitHubPullRequestURL(reference.PullRequestURL)
|
||||
if ok {
|
||||
owner = prRef.Owner
|
||||
repo = prRef.Repo
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return owner, repo, token, nil
|
||||
}
|
||||
|
||||
// fetchGitHubFileContent fetches raw file content from the GitHub
|
||||
// Contents API. Each path segment is individually URL-escaped to
|
||||
// prevent path traversal. The response body is limited to 10 MiB.
|
||||
func (api *API) fetchGitHubFileContent(
|
||||
ctx context.Context,
|
||||
owner, repo, filePath, ref, token string,
|
||||
) (contentType string, body []byte, err error) {
|
||||
// Escape each path segment individually so directory separators
|
||||
// are preserved but special characters cannot break out.
|
||||
segments := strings.Split(filePath, "/")
|
||||
for i, seg := range segments {
|
||||
segments[i] = url.PathEscape(seg)
|
||||
}
|
||||
safePath := strings.Join(segments, "/")
|
||||
|
||||
requestURL := fmt.Sprintf(
|
||||
"%s/repos/%s/%s/contents/%s?ref=%s",
|
||||
githubAPIBaseURL, url.PathEscape(owner), url.PathEscape(repo), safePath, url.QueryEscape(ref),
|
||||
)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
|
||||
if err != nil {
|
||||
return "", nil, xerrors.Errorf("create github file content request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/vnd.github.raw+json")
|
||||
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
|
||||
req.Header.Set("User-Agent", "coder-chat-diff")
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
httpClient := api.HTTPClient
|
||||
if httpClient == nil {
|
||||
httpClient = http.DefaultClient
|
||||
}
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", nil, xerrors.Errorf("execute github file content request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
errBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192))
|
||||
if readErr != nil {
|
||||
return "", nil, xerrors.Errorf(
|
||||
"github file content request failed with status %d",
|
||||
resp.StatusCode,
|
||||
)
|
||||
}
|
||||
return "", nil, xerrors.Errorf(
|
||||
"github file content request failed with status %d: %s",
|
||||
resp.StatusCode,
|
||||
strings.TrimSpace(string(errBody)),
|
||||
)
|
||||
}
|
||||
|
||||
contentType = resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
|
||||
const maxFileSize = 10 << 20 // 10 MiB
|
||||
body, err = io.ReadAll(io.LimitReader(resp.Body, maxFileSize))
|
||||
if err != nil {
|
||||
return "", nil, xerrors.Errorf("read github file content response: %w", err)
|
||||
}
|
||||
|
||||
return contentType, body, nil
|
||||
}
|
||||
|
||||
// chatCreateWorkspace provides workspace creation for the chat
|
||||
// processor. RBAC authorization uses context-based checks via
|
||||
// dbauthz.As rather than fake *http.Request objects.
|
||||
@@ -2228,317 +2229,45 @@ func normalizeChatCompressionThreshold(
|
||||
return threshold, nil
|
||||
}
|
||||
|
||||
const (
|
||||
// maxChatFileSize is the maximum size of a chat file upload (10 MB).
|
||||
maxChatFileSize = 10 << 20
|
||||
// maxChatFileName is the maximum length of an uploaded file name.
|
||||
maxChatFileName = 255
|
||||
)
|
||||
|
||||
// allowedChatFileMIMETypes lists the content types accepted for chat
|
||||
// file uploads. SVG is explicitly excluded because it can contain scripts.
|
||||
var allowedChatFileMIMETypes = map[string]bool{
|
||||
"image/png": true,
|
||||
"image/jpeg": true,
|
||||
"image/gif": true,
|
||||
"image/webp": true,
|
||||
"image/svg+xml": false, // SVG can contain scripts.
|
||||
}
|
||||
|
||||
var (
|
||||
webpMagicRIFF = []byte("RIFF")
|
||||
webpMagicWEBP = []byte("WEBP")
|
||||
)
|
||||
|
||||
// detectChatFileType detects the MIME type of the given data.
|
||||
// It extends http.DetectContentType with support for WebP, which
|
||||
// Go's standard sniffer does not recognize.
|
||||
func detectChatFileType(data []byte) string {
|
||||
if len(data) >= 12 &&
|
||||
bytes.Equal(data[0:4], webpMagicRIFF) &&
|
||||
bytes.Equal(data[8:12], webpMagicWEBP) {
|
||||
return "image/webp"
|
||||
}
|
||||
return http.DetectContentType(data)
|
||||
}
|
||||
|
||||
func defaultChatSystemPrompt() string {
|
||||
return chatd.DefaultSystemPrompt
|
||||
}
|
||||
|
||||
// @Summary Upload a chat file
|
||||
// @ID upload-chat-file
|
||||
// @Security CoderSessionToken
|
||||
// @Accept application/octet-stream
|
||||
// @Produce json
|
||||
// @Tags Chats
|
||||
// @Param Content-Type header string true "Content-Type must be an image type (image/png, image/jpeg, image/gif, image/webp)"
|
||||
// @Param organization query string true "Organization ID" format(uuid)
|
||||
// @Success 201 {object} codersdk.UploadChatFileResponse
|
||||
// @Failure 400 {object} codersdk.Response
|
||||
// @Failure 401 {object} codersdk.Response
|
||||
// @Failure 413 {object} codersdk.Response
|
||||
// @Failure 500 {object} codersdk.Response
|
||||
// @Router /chats/files [post]
|
||||
func (api *API) postChatFile(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
|
||||
if !api.Authorize(r, policy.ActionCreate, rbac.ResourceChat.WithOwner(apiKey.UserID.String())) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
|
||||
orgIDStr := r.URL.Query().Get("organization")
|
||||
if orgIDStr == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Missing organization query parameter.",
|
||||
})
|
||||
return
|
||||
}
|
||||
orgID, err := uuid.Parse(orgIDStr)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid organization ID.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
// Strip parameters (e.g. "image/png; charset=utf-8" → "image/png")
|
||||
// so the allowlist check matches the base media type.
|
||||
if mediaType, _, err := mime.ParseMediaType(contentType); err == nil {
|
||||
contentType = mediaType
|
||||
}
|
||||
|
||||
if allowed, ok := allowedChatFileMIMETypes[contentType]; !ok || !allowed {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Unsupported file type.",
|
||||
Detail: "Allowed types: image/png, image/jpeg, image/gif, image/webp.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
r.Body = http.MaxBytesReader(rw, r.Body, maxChatFileSize)
|
||||
br := bufio.NewReader(r.Body)
|
||||
|
||||
// Peek at the leading bytes to sniff the real content type
|
||||
// before reading the entire body.
|
||||
peek, peekErr := br.Peek(512)
|
||||
if peekErr != nil && !errors.Is(peekErr, io.EOF) && !errors.Is(peekErr, bufio.ErrBufferFull) {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to read file from request.",
|
||||
Detail: peekErr.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the actual content matches a safe image type so that
|
||||
// a client cannot spoof Content-Type to serve active content.
|
||||
detected := detectChatFileType(peek)
|
||||
if allowed, ok := allowedChatFileMIMETypes[detected]; !ok || !allowed {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Unsupported file type.",
|
||||
Detail: "Allowed types: image/png, image/jpeg, image/gif, image/webp.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Read the full body now that we know the type is valid.
|
||||
data, err := io.ReadAll(br)
|
||||
if err != nil {
|
||||
var maxBytesErr *http.MaxBytesError
|
||||
if errors.As(err, &maxBytesErr) {
|
||||
httpapi.Write(ctx, rw, http.StatusRequestEntityTooLarge, codersdk.Response{
|
||||
Message: "File too large.",
|
||||
Detail: fmt.Sprintf("Maximum file size is %d bytes.", maxChatFileSize),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to read file from request.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Extract filename from Content-Disposition header if provided.
|
||||
var filename string
|
||||
if cd := r.Header.Get("Content-Disposition"); cd != "" {
|
||||
if _, params, err := mime.ParseMediaType(cd); err == nil {
|
||||
filename = params["filename"]
|
||||
if len(filename) > maxChatFileName {
|
||||
// Truncate at rune boundary to avoid splitting
|
||||
// multi-byte UTF-8 characters.
|
||||
var truncated []byte
|
||||
for _, r := range filename {
|
||||
encoded := []byte(string(r))
|
||||
if len(truncated)+len(encoded) > maxChatFileName {
|
||||
break
|
||||
}
|
||||
truncated = append(truncated, encoded...)
|
||||
}
|
||||
filename = string(truncated)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
chatFile, err := api.Database.InsertChatFile(ctx, database.InsertChatFileParams{
|
||||
OwnerID: apiKey.UserID,
|
||||
OrganizationID: orgID,
|
||||
Name: filename,
|
||||
Mimetype: detected,
|
||||
Data: data,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to save chat file.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusCreated, codersdk.UploadChatFileResponse{
|
||||
ID: chatFile.ID,
|
||||
})
|
||||
}
|
||||
|
||||
// @Summary Get a chat file
|
||||
// @ID get-chat-file
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Chats
|
||||
// @Param file path string true "File ID" format(uuid)
|
||||
// @Success 200
|
||||
// @Failure 400 {object} codersdk.Response
|
||||
// @Failure 401 {object} codersdk.Response
|
||||
// @Failure 404 {object} codersdk.Response
|
||||
// @Failure 500 {object} codersdk.Response
|
||||
// @Router /chats/files/{file} [get]
|
||||
func (api *API) chatFileByID(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
fileIDStr := chi.URLParam(r, "file")
|
||||
fileID, err := uuid.Parse(fileIDStr)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid file ID.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
chatFile, err := api.Database.GetChatFileByID(ctx, fileID)
|
||||
if err != nil {
|
||||
if httpapi.Is404Error(err) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get chat file.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
rw.Header().Set("Content-Type", chatFile.Mimetype)
|
||||
if chatFile.Name != "" {
|
||||
rw.Header().Set("Content-Disposition", mime.FormatMediaType("inline", map[string]string{"filename": chatFile.Name}))
|
||||
} else {
|
||||
rw.Header().Set("Content-Disposition", "inline")
|
||||
}
|
||||
rw.Header().Set("Cache-Control", "private, max-age=31536000, immutable")
|
||||
rw.Header().Set("Content-Length", strconv.Itoa(len(chatFile.Data)))
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
_, _ = rw.Write(chatFile.Data)
|
||||
}
|
||||
|
||||
func createChatInputFromRequest(ctx context.Context, db database.Store, req codersdk.CreateChatRequest) (
|
||||
func createChatInputFromRequest(req codersdk.CreateChatRequest) (
|
||||
[]fantasy.Content,
|
||||
map[int]uuid.UUID,
|
||||
string,
|
||||
*codersdk.Response,
|
||||
) {
|
||||
return createChatInputFromParts(ctx, db, req.Content, "content")
|
||||
return createChatInputFromParts(req.Content, "content")
|
||||
}
|
||||
|
||||
func createChatInputFromParts(
|
||||
ctx context.Context,
|
||||
db database.Store,
|
||||
parts []codersdk.ChatInputPart,
|
||||
fieldName string,
|
||||
) ([]fantasy.Content, map[int]uuid.UUID, string, *codersdk.Response) {
|
||||
) ([]fantasy.Content, string, *codersdk.Response) {
|
||||
if len(parts) == 0 {
|
||||
return nil, nil, "", &codersdk.Response{
|
||||
return nil, "", &codersdk.Response{
|
||||
Message: "Content is required.",
|
||||
Detail: "Content cannot be empty.",
|
||||
}
|
||||
}
|
||||
|
||||
content := make([]fantasy.Content, 0, len(parts))
|
||||
fileIDs := make(map[int]uuid.UUID)
|
||||
textParts := make([]string, 0, len(parts))
|
||||
for i, part := range parts {
|
||||
switch strings.ToLower(strings.TrimSpace(string(part.Type))) {
|
||||
case string(codersdk.ChatInputPartTypeText):
|
||||
text := strings.TrimSpace(part.Text)
|
||||
if text == "" {
|
||||
return nil, nil, "", &codersdk.Response{
|
||||
return nil, "", &codersdk.Response{
|
||||
Message: "Invalid input part.",
|
||||
Detail: fmt.Sprintf("%s[%d].text cannot be empty.", fieldName, i),
|
||||
}
|
||||
}
|
||||
content = append(content, fantasy.TextContent{Text: text})
|
||||
textParts = append(textParts, text)
|
||||
case string(codersdk.ChatInputPartTypeFile):
|
||||
if part.FileID == uuid.Nil {
|
||||
return nil, nil, "", &codersdk.Response{
|
||||
Message: "Invalid input part.",
|
||||
Detail: fmt.Sprintf("%s[%d].file_id is required for file parts.", fieldName, i),
|
||||
}
|
||||
}
|
||||
// Validate that the file exists and get its media type.
|
||||
// File data is not loaded here; it's resolved at LLM
|
||||
// dispatch time via chatFileResolver.
|
||||
chatFile, err := db.GetChatFileByID(ctx, part.FileID)
|
||||
if err != nil {
|
||||
if httpapi.Is404Error(err) {
|
||||
return nil, nil, "", &codersdk.Response{
|
||||
Message: "Invalid input part.",
|
||||
Detail: fmt.Sprintf("%s[%d].file_id references a file that does not exist.", fieldName, i),
|
||||
}
|
||||
}
|
||||
return nil, nil, "", &codersdk.Response{
|
||||
Message: "Internal error.",
|
||||
Detail: fmt.Sprintf("Failed to retrieve file for %s[%d].", fieldName, i),
|
||||
}
|
||||
}
|
||||
content = append(content, fantasy.FileContent{
|
||||
MediaType: chatFile.Mimetype,
|
||||
})
|
||||
fileIDs[len(content)-1] = part.FileID
|
||||
case string(codersdk.ChatInputPartTypeFileReference):
|
||||
if part.FileName == "" {
|
||||
return nil, nil, "", &codersdk.Response{
|
||||
Message: "Invalid input part.",
|
||||
Detail: fmt.Sprintf("%s[%d].file_name cannot be empty for file-reference.", fieldName, i),
|
||||
}
|
||||
}
|
||||
lineRange := fmt.Sprintf("%d", part.StartLine)
|
||||
if part.StartLine != part.EndLine {
|
||||
lineRange = fmt.Sprintf("%d-%d", part.StartLine, part.EndLine)
|
||||
}
|
||||
var sb strings.Builder
|
||||
_, _ = fmt.Fprintf(&sb, "[file-reference] %s:%s", part.FileName, lineRange)
|
||||
if strings.TrimSpace(part.Content) != "" {
|
||||
_, _ = fmt.Fprintf(&sb, "\n```%s\n%s\n```", part.FileName, strings.TrimSpace(part.Content))
|
||||
}
|
||||
text := sb.String()
|
||||
content = append(content, fantasy.TextContent{Text: text})
|
||||
textParts = append(textParts, text)
|
||||
default:
|
||||
return nil, nil, "", &codersdk.Response{
|
||||
return nil, "", &codersdk.Response{
|
||||
Message: "Invalid input part.",
|
||||
Detail: fmt.Sprintf(
|
||||
"%s[%d].type %q is not supported.",
|
||||
@@ -2550,16 +2279,14 @@ func createChatInputFromParts(
|
||||
}
|
||||
}
|
||||
|
||||
// Allow file-only messages. The titleSource may be empty
|
||||
// when only file parts are provided, callers handle this.
|
||||
if len(content) == 0 {
|
||||
return nil, nil, "", &codersdk.Response{
|
||||
titleSource := strings.TrimSpace(strings.Join(textParts, " "))
|
||||
if titleSource == "" {
|
||||
return nil, "", &codersdk.Response{
|
||||
Message: "Content is required.",
|
||||
Detail: fmt.Sprintf("%s must include at least one text or file part.", fieldName),
|
||||
Detail: "Content must include at least one text part.",
|
||||
}
|
||||
}
|
||||
titleSource := strings.TrimSpace(strings.Join(textParts, " "))
|
||||
return content, fileIDs, titleSource, nil
|
||||
return content, titleSource, nil
|
||||
}
|
||||
|
||||
func chatTitleFromMessage(message string) string {
|
||||
|
||||
+268
-883
File diff suppressed because it is too large
Load Diff
+8
-30
@@ -99,7 +99,6 @@ import (
|
||||
"github.com/coder/coder/v2/provisionersdk"
|
||||
"github.com/coder/coder/v2/site"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/derpmetrics"
|
||||
"github.com/coder/quartz"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
@@ -662,7 +661,6 @@ func New(options *Options) *API {
|
||||
api.SiteHandler, err = site.New(&site.Options{
|
||||
CacheDir: siteCacheDir,
|
||||
Database: options.Database,
|
||||
Authorizer: options.Authorizer,
|
||||
SiteFS: site.FS(),
|
||||
OAuth2Configs: oauthConfigs,
|
||||
DocsURL: options.DeploymentValues.DocsURL.String(),
|
||||
@@ -901,18 +899,17 @@ func New(options *Options) *API {
|
||||
apiRateLimiter := httpmw.RateLimit(options.APIRateLimit, time.Minute)
|
||||
|
||||
// Register DERP on expvar HTTP handler, which we serve below in the router, c.f. expvar.Handler()
|
||||
// These are the metrics the DERP server exposes.
|
||||
// TODO: export via prometheus
|
||||
expDERPOnce.Do(func() {
|
||||
// We need to do this via a global Once because expvar registry is global and panics if we
|
||||
// register multiple times. In production there is only one Coderd and one DERP server per
|
||||
// process, but in testing, we create multiple of both, so the Once protects us from
|
||||
// panicking.
|
||||
if options.DERPServer != nil && expvar.Get("derp") == nil {
|
||||
if options.DERPServer != nil {
|
||||
expvar.Publish("derp", api.DERPServer.ExpVar())
|
||||
}
|
||||
})
|
||||
if options.PrometheusRegistry != nil && options.DERPServer != nil {
|
||||
options.PrometheusRegistry.MustRegister(derpmetrics.NewDERPExpvarCollector(options.DERPServer))
|
||||
}
|
||||
cors := httpmw.Cors(options.DeploymentValues.Dangerous.AllowAllCors.Value())
|
||||
prometheusMW := httpmw.Prometheus(options.PrometheusRegistry)
|
||||
|
||||
@@ -927,16 +924,6 @@ func New(options *Options) *API {
|
||||
loggermw.Logger(api.Logger),
|
||||
singleSlashMW,
|
||||
rolestore.CustomRoleMW,
|
||||
// Validate API key on every request (if present) and store
|
||||
// the result in context. The rate limiter reads this to key
|
||||
// by user ID, and downstream ExtractAPIKeyMW reuses it to
|
||||
// avoid redundant DB lookups. Never rejects requests.
|
||||
httpmw.PrecheckAPIKey(httpmw.ValidateAPIKeyConfig{
|
||||
DB: options.Database,
|
||||
OAuth2Configs: oauthConfigs,
|
||||
DisableSessionExpiryRefresh: options.DeploymentValues.Sessions.DisableExpiryRefresh.Value(),
|
||||
Logger: options.Logger,
|
||||
}),
|
||||
httpmw.HTTPRoute, // NB: prometheusMW depends on this middleware.
|
||||
prometheusMW,
|
||||
// Build-Version is helpful for debugging.
|
||||
@@ -1085,6 +1072,8 @@ func New(options *Options) *API {
|
||||
|
||||
r.NotFound(func(rw http.ResponseWriter, _ *http.Request) { httpapi.RouteNotFound(rw) })
|
||||
r.Use(
|
||||
// Specific routes can specify different limits, but every rate
|
||||
// limit must be configurable by the admin.
|
||||
apiRateLimiter,
|
||||
httpmw.ReportCLITelemetry(api.Logger, options.Telemetry),
|
||||
)
|
||||
@@ -1122,11 +1111,6 @@ func New(options *Options) *API {
|
||||
r.Post("/", api.postChats)
|
||||
r.Get("/models", api.listChatModels)
|
||||
r.Get("/watch", api.watchChats)
|
||||
r.Route("/files", func(r chi.Router) {
|
||||
r.Use(httpmw.RateLimit(options.FilesRateLimit, time.Minute))
|
||||
r.Post("/", api.postChatFile)
|
||||
r.Get("/{file}", api.chatFileByID)
|
||||
})
|
||||
r.Route("/providers", func(r chi.Router) {
|
||||
r.Get("/", api.listChatProviders)
|
||||
r.Post("/", api.createChatProvider)
|
||||
@@ -1146,7 +1130,6 @@ func New(options *Options) *API {
|
||||
r.Route("/{chat}", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractChatParam(options.Database))
|
||||
r.Get("/", api.getChat)
|
||||
r.Get("/git/watch", api.watchChatGit)
|
||||
r.Post("/archive", api.archiveChat)
|
||||
r.Post("/unarchive", api.unarchiveChat)
|
||||
r.Post("/messages", api.postChatMessages)
|
||||
@@ -1155,6 +1138,7 @@ func New(options *Options) *API {
|
||||
r.Post("/interrupt", api.interruptChat)
|
||||
r.Get("/diff-status", api.getChatDiffStatus)
|
||||
r.Get("/diff", api.getChatDiffContents)
|
||||
r.Get("/diff/file-content", api.getChatDiffFileContent)
|
||||
r.Route("/queue/{queuedMessage}", func(r chi.Router) {
|
||||
r.Delete("/", api.deleteChatQueuedMessage)
|
||||
r.Post("/promote", api.promoteChatQueuedMessage)
|
||||
@@ -1177,6 +1161,8 @@ func New(options *Options) *API {
|
||||
|
||||
r.NotFound(func(rw http.ResponseWriter, _ *http.Request) { httpapi.RouteNotFound(rw) })
|
||||
r.Use(
|
||||
// Specific routes can specify different limits, but every rate
|
||||
// limit must be configurable by the admin.
|
||||
apiRateLimiter,
|
||||
httpmw.ReportCLITelemetry(api.Logger, options.Telemetry),
|
||||
)
|
||||
@@ -1854,14 +1840,6 @@ func New(options *Options) *API {
|
||||
"parsing additional CSP headers", slog.Error(cspParseErrors))
|
||||
}
|
||||
|
||||
// Add blob: to img-src for chat file attachment previews when
|
||||
// the agents experiment is enabled.
|
||||
if api.Experiments.Enabled(codersdk.ExperimentAgents) {
|
||||
additionalCSPHeaders[httpmw.CSPDirectiveImgSrc] = append(
|
||||
additionalCSPHeaders[httpmw.CSPDirectiveImgSrc], "blob:",
|
||||
)
|
||||
}
|
||||
|
||||
// Add CSP headers to all static assets and pages. CSP headers only affect
|
||||
// browsers, so these don't make sense on api routes.
|
||||
cspMW := httpmw.CSPHeaders(
|
||||
|
||||
@@ -390,117 +390,3 @@ func TestCSRFExempt(t *testing.T) {
|
||||
require.NotContains(t, string(data), "CSRF")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDERPMetrics(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, _, api := coderdtest.NewWithAPI(t, nil)
|
||||
|
||||
require.NotNil(t, api.Options.DERPServer, "DERP server should be configured")
|
||||
require.NotNil(t, api.Options.PrometheusRegistry, "Prometheus registry should be configured")
|
||||
|
||||
// The registry is created internally by coderd. Gather from it
|
||||
// to verify DERP metrics were registered during startup.
|
||||
metrics, err := api.Options.PrometheusRegistry.Gather()
|
||||
require.NoError(t, err)
|
||||
|
||||
names := make(map[string]struct{})
|
||||
for _, m := range metrics {
|
||||
names[m.GetName()] = struct{}{}
|
||||
}
|
||||
|
||||
assert.Contains(t, names, "coder_derp_server_connections",
|
||||
"expected coder_derp_server_connections to be registered")
|
||||
assert.Contains(t, names, "coder_derp_server_bytes_received_total",
|
||||
"expected coder_derp_server_bytes_received_total to be registered")
|
||||
assert.Contains(t, names, "coder_derp_server_packets_dropped_reason_total",
|
||||
"expected coder_derp_server_packets_dropped_reason_total to be registered")
|
||||
}
|
||||
|
||||
// TestRateLimitByUser verifies that rate limiting keys by user ID when
|
||||
// an authenticated session is present, rather than falling back to IP.
|
||||
// This is a regression test for https://github.com/coder/coder/issues/20857
|
||||
func TestRateLimitByUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const rateLimit = 5
|
||||
|
||||
ownerClient := coderdtest.New(t, &coderdtest.Options{
|
||||
APIRateLimit: rateLimit,
|
||||
})
|
||||
firstUser := coderdtest.CreateFirstUser(t, ownerClient)
|
||||
|
||||
t.Run("HitsLimit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Make rateLimit requests — they should all succeed.
|
||||
for i := 0; i < rateLimit; i++ {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
|
||||
ownerClient.URL.String()+"/api/v2/buildinfo", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(codersdk.SessionTokenHeader, ownerClient.SessionToken())
|
||||
|
||||
resp, err := ownerClient.HTTPClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode,
|
||||
"request %d should succeed", i+1)
|
||||
}
|
||||
|
||||
// The next request should be rate-limited.
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
|
||||
ownerClient.URL.String()+"/api/v2/buildinfo", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(codersdk.SessionTokenHeader, ownerClient.SessionToken())
|
||||
|
||||
resp, err := ownerClient.HTTPClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
resp.Body.Close()
|
||||
require.Equal(t, http.StatusTooManyRequests, resp.StatusCode,
|
||||
"request should be rate limited")
|
||||
})
|
||||
|
||||
t.Run("BypassOwner", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Owner with bypass header should not be rate-limited.
|
||||
for i := 0; i < rateLimit+5; i++ {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
|
||||
ownerClient.URL.String()+"/api/v2/buildinfo", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(codersdk.SessionTokenHeader, ownerClient.SessionToken())
|
||||
req.Header.Set(codersdk.BypassRatelimitHeader, "true")
|
||||
|
||||
resp, err := ownerClient.HTTPClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode,
|
||||
"owner bypass request %d should succeed", i+1)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MemberCannotBypass", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
memberClient, _ := coderdtest.CreateAnotherUser(t, ownerClient, firstUser.OrganizationID)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// A member requesting the bypass header should be rejected
|
||||
// with 428 Precondition Required — only owners may bypass.
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
|
||||
memberClient.URL.String()+"/api/v2/buildinfo", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(codersdk.SessionTokenHeader, memberClient.SessionToken())
|
||||
req.Header.Set(codersdk.BypassRatelimitHeader, "true")
|
||||
|
||||
resp, err := memberClient.HTTPClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
resp.Body.Close()
|
||||
require.Equal(t, http.StatusPreconditionRequired, resp.StatusCode,
|
||||
"member should not be able to bypass rate limit")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -17,9 +17,9 @@ const (
|
||||
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
|
||||
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
|
||||
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
|
||||
CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces
|
||||
CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces
|
||||
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
|
||||
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
|
||||
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
|
||||
CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces
|
||||
CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces
|
||||
)
|
||||
|
||||
@@ -1156,7 +1156,9 @@ func chatMessageParts(role string, raw pqtype.NullRawMessage) ([]codersdk.ChatMe
|
||||
}
|
||||
|
||||
var rawBlocks []json.RawMessage
|
||||
_ = json.Unmarshal(raw.RawMessage, &rawBlocks)
|
||||
if role == string(fantasy.MessageRoleAssistant) {
|
||||
_ = json.Unmarshal(raw.RawMessage, &rawBlocks)
|
||||
}
|
||||
|
||||
parts := make([]codersdk.ChatMessagePart, 0, len(content))
|
||||
for i, block := range content {
|
||||
@@ -1164,20 +1166,10 @@ func chatMessageParts(role string, raw pqtype.NullRawMessage) ([]codersdk.ChatMe
|
||||
if part.Type == "" {
|
||||
continue
|
||||
}
|
||||
if i < len(rawBlocks) {
|
||||
switch part.Type {
|
||||
case codersdk.ChatMessagePartTypeReasoning:
|
||||
if part.Type == codersdk.ChatMessagePartTypeReasoning {
|
||||
part.Title = ""
|
||||
if i < len(rawBlocks) {
|
||||
part.Title = reasoningStoredTitle(rawBlocks[i])
|
||||
case codersdk.ChatMessagePartTypeFile:
|
||||
if fid, err := chatprompt.ExtractFileID(rawBlocks[i]); err == nil {
|
||||
part.FileID = uuid.NullUUID{UUID: fid, Valid: true}
|
||||
}
|
||||
// When a file_id is present, omit inline data
|
||||
// from the response. Clients fetch content via
|
||||
// the GET /chats/files/{id} endpoint instead.
|
||||
if part.FileID.Valid {
|
||||
part.Data = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
parts = append(parts, part)
|
||||
|
||||
@@ -2457,30 +2457,6 @@ func (q *querier) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIDs []uu
|
||||
return q.db.GetChatDiffStatusesByChatIDs(ctx, chatIDs)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatFileByID(ctx context.Context, id uuid.UUID) (database.ChatFile, error) {
|
||||
file, err := q.db.GetChatFileByID(ctx, id)
|
||||
if err != nil {
|
||||
return database.ChatFile{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, file); err != nil {
|
||||
return database.ChatFile{}, err
|
||||
}
|
||||
return file, nil
|
||||
}
|
||||
|
||||
func (q *querier) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) {
|
||||
files, err := q.db.GetChatFilesByIDs(ctx, ids)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, f := range files {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, f); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) {
|
||||
// ChatMessages are authorized through their parent Chat.
|
||||
// We need to fetch the message first to get its chat_id.
|
||||
@@ -3433,7 +3409,12 @@ func (q *querier) GetTaskSnapshot(ctx context.Context, taskID uuid.UUID) (databa
|
||||
return database.TaskSnapshot{}, err
|
||||
}
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, task.RBACObject()); err != nil {
|
||||
obj := rbac.ResourceTask.
|
||||
WithID(task.ID).
|
||||
WithOwner(task.OwnerID.String()).
|
||||
InOrg(task.OrganizationID)
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, obj); err != nil {
|
||||
return database.TaskSnapshot{}, err
|
||||
}
|
||||
|
||||
@@ -4515,11 +4496,6 @@ func (q *querier) InsertChat(ctx context.Context, arg database.InsertChatParams)
|
||||
return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()), q.db.InsertChat)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) {
|
||||
// Authorize create on chat resource scoped to the owner and org.
|
||||
return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID), q.db.InsertChatFile)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
|
||||
// Authorize create on the parent chat (using update permission).
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
@@ -6659,7 +6635,12 @@ func (q *querier) UpsertTaskSnapshot(ctx context.Context, arg database.UpsertTas
|
||||
return err
|
||||
}
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, task.RBACObject()); err != nil {
|
||||
obj := rbac.ResourceTask.
|
||||
WithID(task.ID).
|
||||
WithOwner(task.OwnerID.String()).
|
||||
InOrg(task.OrganizationID)
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, obj); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -463,16 +463,6 @@ func (s *MethodTestSuite) TestChats() {
|
||||
Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).
|
||||
Returns([]database.ChatDiffStatus{diffStatusA, diffStatusB})
|
||||
}))
|
||||
s.Run("GetChatFileByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
file := testutil.Fake(s.T(), faker, database.ChatFile{})
|
||||
dbm.EXPECT().GetChatFileByID(gomock.Any(), file.ID).Return(file, nil).AnyTimes()
|
||||
check.Args(file.ID).Asserts(rbac.ResourceChat.WithOwner(file.OwnerID.String()).InOrg(file.OrganizationID).WithID(file.ID), policy.ActionRead).Returns(file)
|
||||
}))
|
||||
s.Run("GetChatFilesByIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
file := testutil.Fake(s.T(), faker, database.ChatFile{})
|
||||
dbm.EXPECT().GetChatFilesByIDs(gomock.Any(), []uuid.UUID{file.ID}).Return([]database.ChatFile{file}, nil).AnyTimes()
|
||||
check.Args([]uuid.UUID{file.ID}).Asserts(rbac.ResourceChat.WithOwner(file.OwnerID.String()).InOrg(file.OrganizationID).WithID(file.ID), policy.ActionRead).Returns([]database.ChatFile{file})
|
||||
}))
|
||||
s.Run("GetChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
|
||||
@@ -589,12 +579,6 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().InsertChat(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionCreate).Returns(chat)
|
||||
}))
|
||||
s.Run("InsertChatFile", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := testutil.Fake(s.T(), faker, database.InsertChatFileParams{})
|
||||
file := testutil.Fake(s.T(), faker, database.InsertChatFileRow{OwnerID: arg.OwnerID, OrganizationID: arg.OrganizationID})
|
||||
dbm.EXPECT().InsertChatFile(gomock.Any(), arg).Return(file, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID), policy.ActionCreate).Returns(file)
|
||||
}))
|
||||
s.Run("InsertChatMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := testutil.Fake(s.T(), faker, database.InsertChatMessageParams{ChatID: chat.ID})
|
||||
|
||||
@@ -1595,7 +1595,6 @@ func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertA
|
||||
Client: seed.Client,
|
||||
ThreadParentInterceptionID: seed.ThreadParentInterceptionID,
|
||||
ThreadRootInterceptionID: seed.ThreadRootInterceptionID,
|
||||
ClientSessionID: seed.ClientSessionID,
|
||||
})
|
||||
if endedAt != nil {
|
||||
interception, err = db.UpdateAIBridgeInterceptionEnded(genCtx, database.UpdateAIBridgeInterceptionEndedParams{
|
||||
|
||||
@@ -1007,22 +1007,6 @@ func (m queryMetricsStore) GetChatDiffStatusesByChatIDs(ctx context.Context, cha
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatFileByID(ctx context.Context, id uuid.UUID) (database.ChatFile, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatFileByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetChatFileByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatFileByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatFilesByIDs(ctx, ids)
|
||||
m.queryLatencies.WithLabelValues("GetChatFilesByIDs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatFilesByIDs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessageByID(ctx, id)
|
||||
@@ -2959,14 +2943,6 @@ func (m queryMetricsStore) InsertChat(ctx context.Context, arg database.InsertCh
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChatFile(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertChatFile").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatFile").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChatMessage(ctx, arg)
|
||||
|
||||
@@ -1837,36 +1837,6 @@ func (mr *MockStoreMockRecorder) GetChatDiffStatusesByChatIDs(ctx, chatIds any)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDiffStatusesByChatIDs", reflect.TypeOf((*MockStore)(nil).GetChatDiffStatusesByChatIDs), ctx, chatIds)
|
||||
}
|
||||
|
||||
// GetChatFileByID mocks base method.
|
||||
func (m *MockStore) GetChatFileByID(ctx context.Context, id uuid.UUID) (database.ChatFile, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatFileByID", ctx, id)
|
||||
ret0, _ := ret[0].(database.ChatFile)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatFileByID indicates an expected call of GetChatFileByID.
|
||||
func (mr *MockStoreMockRecorder) GetChatFileByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatFileByID", reflect.TypeOf((*MockStore)(nil).GetChatFileByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetChatFilesByIDs mocks base method.
|
||||
func (m *MockStore) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatFilesByIDs", ctx, ids)
|
||||
ret0, _ := ret[0].([]database.ChatFile)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatFilesByIDs indicates an expected call of GetChatFilesByIDs.
|
||||
func (mr *MockStoreMockRecorder) GetChatFilesByIDs(ctx, ids any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatFilesByIDs", reflect.TypeOf((*MockStore)(nil).GetChatFilesByIDs), ctx, ids)
|
||||
}
|
||||
|
||||
// GetChatMessageByID mocks base method.
|
||||
func (m *MockStore) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -5541,21 +5511,6 @@ func (mr *MockStoreMockRecorder) InsertChat(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChat", reflect.TypeOf((*MockStore)(nil).InsertChat), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChatFile mocks base method.
|
||||
func (m *MockStore) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertChatFile", ctx, arg)
|
||||
ret0, _ := ret[0].(database.InsertChatFileRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertChatFile indicates an expected call of InsertChatFile.
|
||||
func (mr *MockStoreMockRecorder) InsertChatFile(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatFile", reflect.TypeOf((*MockStore)(nil).InsertChatFile), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChatMessage mocks base method.
|
||||
func (m *MockStore) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
+27
-58
@@ -1046,8 +1046,7 @@ CREATE TABLE aibridge_interceptions (
|
||||
api_key_id text,
|
||||
client character varying(64) DEFAULT 'Unknown'::character varying,
|
||||
thread_parent_id uuid,
|
||||
thread_root_id uuid,
|
||||
client_session_id character varying(256)
|
||||
thread_root_id uuid
|
||||
);
|
||||
|
||||
COMMENT ON TABLE aibridge_interceptions IS 'Audit log of requests intercepted by AI Bridge';
|
||||
@@ -1058,8 +1057,6 @@ COMMENT ON COLUMN aibridge_interceptions.thread_parent_id IS 'The interception w
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.thread_root_id IS 'The root interception of the thread that this interception belongs to.';
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.client_session_id IS 'The session ID supplied by the client (optional and not universally supported).';
|
||||
|
||||
CREATE TABLE aibridge_token_usages (
|
||||
id uuid NOT NULL,
|
||||
interception_id uuid NOT NULL,
|
||||
@@ -1190,16 +1187,6 @@ CREATE TABLE chat_diff_statuses (
|
||||
git_remote_origin text DEFAULT ''::text NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE chat_files (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
owner_id uuid NOT NULL,
|
||||
organization_id uuid NOT NULL,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
name text DEFAULT ''::text NOT NULL,
|
||||
mimetype text NOT NULL,
|
||||
data bytea NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE chat_messages (
|
||||
id bigint NOT NULL,
|
||||
chat_id uuid NOT NULL,
|
||||
@@ -2107,31 +2094,6 @@ CREATE TABLE workspace_builds (
|
||||
CONSTRAINT workspace_builds_deadline_below_max_deadline CHECK ((((deadline <> '0001-01-01 00:00:00+00'::timestamp with time zone) AND (deadline <= max_deadline)) OR (max_deadline = '0001-01-01 00:00:00+00'::timestamp with time zone)))
|
||||
);
|
||||
|
||||
CREATE TABLE workspaces (
|
||||
id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
owner_id uuid NOT NULL,
|
||||
organization_id uuid NOT NULL,
|
||||
template_id uuid NOT NULL,
|
||||
deleted boolean DEFAULT false NOT NULL,
|
||||
name character varying(64) NOT NULL,
|
||||
autostart_schedule text,
|
||||
ttl bigint,
|
||||
last_used_at timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL,
|
||||
dormant_at timestamp with time zone,
|
||||
deleting_at timestamp with time zone,
|
||||
automatic_updates automatic_updates DEFAULT 'never'::automatic_updates NOT NULL,
|
||||
favorite boolean DEFAULT false NOT NULL,
|
||||
next_start_at timestamp with time zone,
|
||||
group_acl jsonb DEFAULT '{}'::jsonb NOT NULL,
|
||||
user_acl jsonb DEFAULT '{}'::jsonb NOT NULL,
|
||||
CONSTRAINT group_acl_is_object CHECK ((jsonb_typeof(group_acl) = 'object'::text)),
|
||||
CONSTRAINT user_acl_is_object CHECK ((jsonb_typeof(user_acl) = 'object'::text))
|
||||
);
|
||||
|
||||
COMMENT ON COLUMN workspaces.favorite IS 'Favorite is true if the workspace owner has favorited the workspace.';
|
||||
|
||||
CREATE VIEW tasks_with_status AS
|
||||
SELECT tasks.id,
|
||||
tasks.organization_id,
|
||||
@@ -2144,8 +2106,6 @@ CREATE VIEW tasks_with_status AS
|
||||
tasks.created_at,
|
||||
tasks.deleted_at,
|
||||
tasks.display_name,
|
||||
COALESCE(workspaces.group_acl, '{}'::jsonb) AS workspace_group_acl,
|
||||
COALESCE(workspaces.user_acl, '{}'::jsonb) AS workspace_user_acl,
|
||||
CASE
|
||||
WHEN (tasks.workspace_id IS NULL) THEN 'pending'::task_status
|
||||
WHEN (build_status.status <> 'active'::task_status) THEN build_status.status
|
||||
@@ -2161,8 +2121,7 @@ CREATE VIEW tasks_with_status AS
|
||||
task_owner.owner_username,
|
||||
task_owner.owner_name,
|
||||
task_owner.owner_avatar_url
|
||||
FROM (((((((((tasks
|
||||
LEFT JOIN workspaces ON ((workspaces.id = tasks.workspace_id)))
|
||||
FROM ((((((((tasks
|
||||
CROSS JOIN LATERAL ( SELECT vu.username AS owner_username,
|
||||
vu.name AS owner_name,
|
||||
vu.avatar_url AS owner_avatar_url
|
||||
@@ -2905,6 +2864,31 @@ CREATE VIEW workspace_build_with_user AS
|
||||
|
||||
COMMENT ON VIEW workspace_build_with_user IS 'Joins in the username + avatar url of the initiated by user.';
|
||||
|
||||
CREATE TABLE workspaces (
|
||||
id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
owner_id uuid NOT NULL,
|
||||
organization_id uuid NOT NULL,
|
||||
template_id uuid NOT NULL,
|
||||
deleted boolean DEFAULT false NOT NULL,
|
||||
name character varying(64) NOT NULL,
|
||||
autostart_schedule text,
|
||||
ttl bigint,
|
||||
last_used_at timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL,
|
||||
dormant_at timestamp with time zone,
|
||||
deleting_at timestamp with time zone,
|
||||
automatic_updates automatic_updates DEFAULT 'never'::automatic_updates NOT NULL,
|
||||
favorite boolean DEFAULT false NOT NULL,
|
||||
next_start_at timestamp with time zone,
|
||||
group_acl jsonb DEFAULT '{}'::jsonb NOT NULL,
|
||||
user_acl jsonb DEFAULT '{}'::jsonb NOT NULL,
|
||||
CONSTRAINT group_acl_is_object CHECK ((jsonb_typeof(group_acl) = 'object'::text)),
|
||||
CONSTRAINT user_acl_is_object CHECK ((jsonb_typeof(user_acl) = 'object'::text))
|
||||
);
|
||||
|
||||
COMMENT ON COLUMN workspaces.favorite IS 'Favorite is true if the workspace owner has favorited the workspace.';
|
||||
|
||||
CREATE VIEW workspace_latest_builds AS
|
||||
SELECT latest_build.id,
|
||||
latest_build.workspace_id,
|
||||
@@ -3150,9 +3134,6 @@ ALTER TABLE ONLY boundary_usage_stats
|
||||
ALTER TABLE ONLY chat_diff_statuses
|
||||
ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
|
||||
|
||||
ALTER TABLE ONLY chat_files
|
||||
ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY chat_messages
|
||||
ADD CONSTRAINT chat_messages_pkey PRIMARY KEY (id);
|
||||
|
||||
@@ -3466,8 +3447,6 @@ CREATE INDEX idx_agent_stats_user_id ON workspace_agent_stats USING btree (user_
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_client ON aibridge_interceptions USING btree (client);
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_client_session_id ON aibridge_interceptions USING btree (client_session_id) WHERE (client_session_id IS NOT NULL);
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_initiator_id ON aibridge_interceptions USING btree (initiator_id);
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_model ON aibridge_interceptions USING btree (model);
|
||||
@@ -3508,10 +3487,6 @@ CREATE INDEX idx_audit_logs_time_desc ON audit_logs USING btree ("time" DESC);
|
||||
|
||||
CREATE INDEX idx_chat_diff_statuses_stale_at ON chat_diff_statuses USING btree (stale_at);
|
||||
|
||||
CREATE INDEX idx_chat_files_org ON chat_files USING btree (organization_id);
|
||||
|
||||
CREATE INDEX idx_chat_files_owner ON chat_files USING btree (owner_id);
|
||||
|
||||
CREATE INDEX idx_chat_messages_chat ON chat_messages USING btree (chat_id);
|
||||
|
||||
CREATE INDEX idx_chat_messages_chat_created ON chat_messages USING btree (chat_id, created_at);
|
||||
@@ -3791,12 +3766,6 @@ ALTER TABLE ONLY api_keys
|
||||
ALTER TABLE ONLY chat_diff_statuses
|
||||
ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_files
|
||||
ADD CONSTRAINT chat_files_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_files
|
||||
ADD CONSTRAINT chat_files_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_messages
|
||||
ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
|
||||
@@ -9,8 +9,6 @@ const (
|
||||
ForeignKeyAibridgeInterceptionsInitiatorID ForeignKeyConstraint = "aibridge_interceptions_initiator_id_fkey" // ALTER TABLE ONLY aibridge_interceptions ADD CONSTRAINT aibridge_interceptions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id);
|
||||
ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatDiffStatusesChatID ForeignKeyConstraint = "chat_diff_statuses_chat_id_fkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatFilesOrganizationID ForeignKeyConstraint = "chat_files_organization_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatFilesOwnerID ForeignKeyConstraint = "chat_files_owner_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatMessagesChatID ForeignKeyConstraint = "chat_messages_chat_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatMessagesModelConfigID ForeignKeyConstraint = "chat_messages_model_config_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_model_config_id_fkey FOREIGN KEY (model_config_id) REFERENCES chat_model_configs(id);
|
||||
ForeignKeyChatModelConfigsCreatedBy ForeignKeyConstraint = "chat_model_configs_created_by_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id);
|
||||
|
||||
+16
-22
@@ -22,12 +22,8 @@ SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}")
|
||||
# The logic below depends on the exact version being correct :(
|
||||
sqlc generate
|
||||
|
||||
# Work directory for formatting before atomic replacement of
|
||||
# generated files, ensuring the source tree is never left in a
|
||||
# partially written state.
|
||||
mkdir -p ../../_gen
|
||||
workdir=$(mktemp -d ../../_gen/.dbgen.XXXXXX)
|
||||
trap 'rm -rf "$workdir"' EXIT
|
||||
tmpfile=$(mktemp "${TMPDIR:-/tmp}/queries.sql.go.XXXXXX")
|
||||
trap 'rm -f "$tmpfile"' EXIT
|
||||
|
||||
first=true
|
||||
files=$(find ./queries/ -type f -name "*.sql.go" | LC_ALL=C sort)
|
||||
@@ -42,34 +38,32 @@ SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}")
|
||||
|
||||
# Copy the header from the first file only, ignoring the source comment.
|
||||
if $first; then
|
||||
head -n 6 <"$fi" | grep -v "source" >"$workdir/queries.sql.go"
|
||||
head -n 6 <"$fi" | grep -v "source" >"$tmpfile"
|
||||
first=false
|
||||
fi
|
||||
|
||||
# Append the file past the imports section into queries.sql.go.
|
||||
tail -n "+$cut" <"$fi" >>"$workdir/queries.sql.go"
|
||||
tail -n "+$cut" <"$fi" >>"$tmpfile"
|
||||
done
|
||||
|
||||
# Move sqlc outputs into workdir for formatting.
|
||||
mv queries/querier.go "$workdir/querier.go"
|
||||
mv queries/models.go "$workdir/models.go"
|
||||
# Atomically replace the target file.
|
||||
mv "$tmpfile" queries.sql.go
|
||||
|
||||
# Move the files we want.
|
||||
mv queries/querier.go .
|
||||
mv queries/models.go .
|
||||
|
||||
# Remove temporary go files.
|
||||
rm -f queries/*.go
|
||||
|
||||
# Fix struct/interface names in the workdir (not the source tree).
|
||||
gofmt -w -r 'Querier -> sqlcQuerier' -- "$workdir"/*.go
|
||||
gofmt -w -r 'Queries -> sqlQuerier' -- "$workdir"/*.go
|
||||
# Fix struct/interface names.
|
||||
gofmt -w -r 'Querier -> sqlcQuerier' -- *.go
|
||||
gofmt -w -r 'Queries -> sqlQuerier' -- *.go
|
||||
|
||||
# Ensure correct imports exist. Modules must all be downloaded so we
|
||||
# get correct suggestions.
|
||||
# Ensure correct imports exist. Modules must all be downloaded so we get correct
|
||||
# suggestions.
|
||||
go mod download
|
||||
go tool golang.org/x/tools/cmd/goimports -w "$workdir/queries.sql.go"
|
||||
|
||||
# Atomically replace all three target files.
|
||||
mv "$workdir/queries.sql.go" queries.sql.go
|
||||
mv "$workdir/querier.go" querier.go
|
||||
mv "$workdir/models.go" models.go
|
||||
go tool golang.org/x/tools/cmd/goimports -w queries.sql.go
|
||||
|
||||
go run ../../scripts/dbgen
|
||||
# This will error if a view is broken. This is in it's own package to avoid
|
||||
|
||||
@@ -1,145 +0,0 @@
|
||||
-- Fix task status logic: pending provisioner job should give pending task status, not initializing.
|
||||
-- A task is pending when the provisioner hasn't picked up the job yet.
|
||||
-- A task is initializing when the provisioner is actively running the job.
|
||||
DROP VIEW IF EXISTS tasks_with_status;
|
||||
|
||||
CREATE VIEW
|
||||
tasks_with_status
|
||||
AS
|
||||
SELECT
|
||||
tasks.*,
|
||||
-- Combine component statuses with precedence: build -> agent -> app.
|
||||
CASE
|
||||
WHEN tasks.workspace_id IS NULL THEN 'pending'::task_status
|
||||
WHEN build_status.status != 'active' THEN build_status.status::task_status
|
||||
WHEN agent_status.status != 'active' THEN agent_status.status::task_status
|
||||
ELSE app_status.status::task_status
|
||||
END AS status,
|
||||
-- Attach debug information for troubleshooting status.
|
||||
jsonb_build_object(
|
||||
'build', jsonb_build_object(
|
||||
'transition', latest_build_raw.transition,
|
||||
'job_status', latest_build_raw.job_status,
|
||||
'computed', build_status.status
|
||||
),
|
||||
'agent', jsonb_build_object(
|
||||
'lifecycle_state', agent_raw.lifecycle_state,
|
||||
'computed', agent_status.status
|
||||
),
|
||||
'app', jsonb_build_object(
|
||||
'health', app_raw.health,
|
||||
'computed', app_status.status
|
||||
)
|
||||
) AS status_debug,
|
||||
task_app.*,
|
||||
agent_raw.lifecycle_state AS workspace_agent_lifecycle_state,
|
||||
app_raw.health AS workspace_app_health,
|
||||
task_owner.*
|
||||
FROM
|
||||
tasks
|
||||
CROSS JOIN LATERAL (
|
||||
SELECT
|
||||
vu.username AS owner_username,
|
||||
vu.name AS owner_name,
|
||||
vu.avatar_url AS owner_avatar_url
|
||||
FROM
|
||||
visible_users vu
|
||||
WHERE
|
||||
vu.id = tasks.owner_id
|
||||
) task_owner
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT
|
||||
task_app.workspace_build_number,
|
||||
task_app.workspace_agent_id,
|
||||
task_app.workspace_app_id
|
||||
FROM
|
||||
task_workspace_apps task_app
|
||||
WHERE
|
||||
task_id = tasks.id
|
||||
ORDER BY
|
||||
task_app.workspace_build_number DESC
|
||||
LIMIT 1
|
||||
) task_app ON TRUE
|
||||
|
||||
-- Join the raw data for computing task status.
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT
|
||||
workspace_build.transition,
|
||||
provisioner_job.job_status,
|
||||
workspace_build.job_id
|
||||
FROM
|
||||
workspace_builds workspace_build
|
||||
JOIN
|
||||
provisioner_jobs provisioner_job
|
||||
ON provisioner_job.id = workspace_build.job_id
|
||||
WHERE
|
||||
workspace_build.workspace_id = tasks.workspace_id
|
||||
AND workspace_build.build_number = task_app.workspace_build_number
|
||||
) latest_build_raw ON TRUE
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT
|
||||
workspace_agent.lifecycle_state
|
||||
FROM
|
||||
workspace_agents workspace_agent
|
||||
WHERE
|
||||
workspace_agent.id = task_app.workspace_agent_id
|
||||
) agent_raw ON TRUE
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT
|
||||
workspace_app.health
|
||||
FROM
|
||||
workspace_apps workspace_app
|
||||
WHERE
|
||||
workspace_app.id = task_app.workspace_app_id
|
||||
) app_raw ON TRUE
|
||||
|
||||
-- Compute the status for each component.
|
||||
CROSS JOIN LATERAL (
|
||||
SELECT
|
||||
CASE
|
||||
WHEN latest_build_raw.job_status IS NULL THEN 'pending'::task_status
|
||||
WHEN latest_build_raw.job_status IN ('failed', 'canceling', 'canceled') THEN 'error'::task_status
|
||||
WHEN
|
||||
latest_build_raw.transition IN ('stop', 'delete')
|
||||
AND latest_build_raw.job_status = 'succeeded' THEN 'paused'::task_status
|
||||
-- Job is pending (not picked up by provisioner yet).
|
||||
WHEN
|
||||
latest_build_raw.transition = 'start'
|
||||
AND latest_build_raw.job_status = 'pending' THEN 'pending'::task_status
|
||||
-- Job is running or done, defer to agent/app status.
|
||||
WHEN
|
||||
latest_build_raw.transition = 'start'
|
||||
AND latest_build_raw.job_status IN ('running', 'succeeded') THEN 'active'::task_status
|
||||
ELSE 'unknown'::task_status
|
||||
END AS status
|
||||
) build_status
|
||||
CROSS JOIN LATERAL (
|
||||
SELECT
|
||||
CASE
|
||||
-- No agent or connecting.
|
||||
WHEN
|
||||
agent_raw.lifecycle_state IS NULL
|
||||
OR agent_raw.lifecycle_state IN ('created', 'starting') THEN 'initializing'::task_status
|
||||
-- Agent is running, defer to app status.
|
||||
-- NOTE(mafredri): The start_error/start_timeout states means connected, but some startup script failed.
|
||||
-- This may or may not affect the task status but this has to be caught by app health check.
|
||||
WHEN agent_raw.lifecycle_state IN ('ready', 'start_timeout', 'start_error') THEN 'active'::task_status
|
||||
-- If the agent is shutting down or turned off, this is an unknown state because we would expect a stop
|
||||
-- build to be running.
|
||||
-- This is essentially equal to: `IN ('shutting_down', 'shutdown_timeout', 'shutdown_error', 'off')`,
|
||||
-- but we cannot use them because the values were added in a migration.
|
||||
WHEN agent_raw.lifecycle_state NOT IN ('created', 'starting', 'ready', 'start_timeout', 'start_error') THEN 'unknown'::task_status
|
||||
ELSE 'unknown'::task_status
|
||||
END AS status
|
||||
) agent_status
|
||||
CROSS JOIN LATERAL (
|
||||
SELECT
|
||||
CASE
|
||||
WHEN app_raw.health = 'initializing' THEN 'initializing'::task_status
|
||||
WHEN app_raw.health = 'unhealthy' THEN 'error'::task_status
|
||||
WHEN app_raw.health IN ('healthy', 'disabled') THEN 'active'::task_status
|
||||
ELSE 'unknown'::task_status
|
||||
END AS status
|
||||
) app_status
|
||||
WHERE
|
||||
tasks.deleted_at IS NULL;
|
||||
@@ -1,151 +0,0 @@
|
||||
-- Fix task status logic: pending provisioner job should give pending task status, not initializing.
|
||||
-- A task is pending when the provisioner hasn't picked up the job yet.
|
||||
-- A task is initializing when the provisioner is actively running the job.
|
||||
DROP VIEW IF EXISTS tasks_with_status;
|
||||
|
||||
CREATE VIEW
|
||||
tasks_with_status
|
||||
AS
|
||||
SELECT
|
||||
tasks.*,
|
||||
coalesce(workspaces.group_acl, '{}'::jsonb) as workspace_group_acl,
|
||||
coalesce(workspaces.user_acl, '{}'::jsonb) as workspace_user_acl,
|
||||
-- Combine component statuses with precedence: build -> agent -> app.
|
||||
CASE
|
||||
WHEN tasks.workspace_id IS NULL THEN 'pending'::task_status
|
||||
WHEN build_status.status != 'active' THEN build_status.status::task_status
|
||||
WHEN agent_status.status != 'active' THEN agent_status.status::task_status
|
||||
ELSE app_status.status::task_status
|
||||
END AS status,
|
||||
-- Attach debug information for troubleshooting status.
|
||||
jsonb_build_object(
|
||||
'build', jsonb_build_object(
|
||||
'transition', latest_build_raw.transition,
|
||||
'job_status', latest_build_raw.job_status,
|
||||
'computed', build_status.status
|
||||
),
|
||||
'agent', jsonb_build_object(
|
||||
'lifecycle_state', agent_raw.lifecycle_state,
|
||||
'computed', agent_status.status
|
||||
),
|
||||
'app', jsonb_build_object(
|
||||
'health', app_raw.health,
|
||||
'computed', app_status.status
|
||||
)
|
||||
) AS status_debug,
|
||||
task_app.*,
|
||||
agent_raw.lifecycle_state AS workspace_agent_lifecycle_state,
|
||||
app_raw.health AS workspace_app_health,
|
||||
task_owner.*
|
||||
FROM
|
||||
tasks
|
||||
|
||||
LEFT JOIN
|
||||
workspaces ON workspaces.id = tasks.workspace_id
|
||||
|
||||
CROSS JOIN LATERAL (
|
||||
SELECT
|
||||
vu.username AS owner_username,
|
||||
vu.name AS owner_name,
|
||||
vu.avatar_url AS owner_avatar_url
|
||||
FROM
|
||||
visible_users vu
|
||||
WHERE
|
||||
vu.id = tasks.owner_id
|
||||
) task_owner
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT
|
||||
task_app.workspace_build_number,
|
||||
task_app.workspace_agent_id,
|
||||
task_app.workspace_app_id
|
||||
FROM
|
||||
task_workspace_apps task_app
|
||||
WHERE
|
||||
task_id = tasks.id
|
||||
ORDER BY
|
||||
task_app.workspace_build_number DESC
|
||||
LIMIT 1
|
||||
) task_app ON TRUE
|
||||
|
||||
-- Join the raw data for computing task status.
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT
|
||||
workspace_build.transition,
|
||||
provisioner_job.job_status,
|
||||
workspace_build.job_id
|
||||
FROM
|
||||
workspace_builds workspace_build
|
||||
JOIN
|
||||
provisioner_jobs provisioner_job
|
||||
ON provisioner_job.id = workspace_build.job_id
|
||||
WHERE
|
||||
workspace_build.workspace_id = tasks.workspace_id
|
||||
AND workspace_build.build_number = task_app.workspace_build_number
|
||||
) latest_build_raw ON TRUE
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT
|
||||
workspace_agent.lifecycle_state
|
||||
FROM
|
||||
workspace_agents workspace_agent
|
||||
WHERE
|
||||
workspace_agent.id = task_app.workspace_agent_id
|
||||
) agent_raw ON TRUE
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT
|
||||
workspace_app.health
|
||||
FROM
|
||||
workspace_apps workspace_app
|
||||
WHERE
|
||||
workspace_app.id = task_app.workspace_app_id
|
||||
) app_raw ON TRUE
|
||||
|
||||
-- Compute the status for each component.
|
||||
CROSS JOIN LATERAL (
|
||||
SELECT
|
||||
CASE
|
||||
WHEN latest_build_raw.job_status IS NULL THEN 'pending'::task_status
|
||||
WHEN latest_build_raw.job_status IN ('failed', 'canceling', 'canceled') THEN 'error'::task_status
|
||||
WHEN
|
||||
latest_build_raw.transition IN ('stop', 'delete')
|
||||
AND latest_build_raw.job_status = 'succeeded' THEN 'paused'::task_status
|
||||
-- Job is pending (not picked up by provisioner yet).
|
||||
WHEN
|
||||
latest_build_raw.transition = 'start'
|
||||
AND latest_build_raw.job_status = 'pending' THEN 'pending'::task_status
|
||||
-- Job is running or done, defer to agent/app status.
|
||||
WHEN
|
||||
latest_build_raw.transition = 'start'
|
||||
AND latest_build_raw.job_status IN ('running', 'succeeded') THEN 'active'::task_status
|
||||
ELSE 'unknown'::task_status
|
||||
END AS status
|
||||
) build_status
|
||||
CROSS JOIN LATERAL (
|
||||
SELECT
|
||||
CASE
|
||||
-- No agent or connecting.
|
||||
WHEN
|
||||
agent_raw.lifecycle_state IS NULL
|
||||
OR agent_raw.lifecycle_state IN ('created', 'starting') THEN 'initializing'::task_status
|
||||
-- Agent is running, defer to app status.
|
||||
-- NOTE(mafredri): The start_error/start_timeout states means connected, but some startup script failed.
|
||||
-- This may or may not affect the task status but this has to be caught by app health check.
|
||||
WHEN agent_raw.lifecycle_state IN ('ready', 'start_timeout', 'start_error') THEN 'active'::task_status
|
||||
-- If the agent is shutting down or turned off, this is an unknown state because we would expect a stop
|
||||
-- build to be running.
|
||||
-- This is essentially equal to: `IN ('shutting_down', 'shutdown_timeout', 'shutdown_error', 'off')`,
|
||||
-- but we cannot use them because the values were added in a migration.
|
||||
WHEN agent_raw.lifecycle_state NOT IN ('created', 'starting', 'ready', 'start_timeout', 'start_error') THEN 'unknown'::task_status
|
||||
ELSE 'unknown'::task_status
|
||||
END AS status
|
||||
) agent_status
|
||||
CROSS JOIN LATERAL (
|
||||
SELECT
|
||||
CASE
|
||||
WHEN app_raw.health = 'initializing' THEN 'initializing'::task_status
|
||||
WHEN app_raw.health = 'unhealthy' THEN 'error'::task_status
|
||||
WHEN app_raw.health IN ('healthy', 'disabled') THEN 'active'::task_status
|
||||
ELSE 'unknown'::task_status
|
||||
END AS status
|
||||
) app_status
|
||||
WHERE
|
||||
tasks.deleted_at IS NULL;
|
||||
@@ -1,4 +0,0 @@
|
||||
DROP INDEX IF EXISTS idx_aibridge_interceptions_client_session_id;
|
||||
|
||||
ALTER TABLE aibridge_interceptions
|
||||
DROP COLUMN client_session_id;
|
||||
@@ -1,7 +0,0 @@
|
||||
ALTER TABLE aibridge_interceptions
|
||||
ADD COLUMN client_session_id VARCHAR(256) NULL;
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.client_session_id IS 'The session ID supplied by the client (optional and not universally supported).';
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_client_session_id ON aibridge_interceptions (client_session_id)
|
||||
WHERE client_session_id IS NOT NULL;
|
||||
@@ -1,2 +0,0 @@
|
||||
DROP INDEX IF EXISTS idx_chat_files_org;
|
||||
DROP TABLE IF EXISTS chat_files;
|
||||
@@ -1,12 +0,0 @@
|
||||
CREATE TABLE chat_files (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
owner_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
organization_id UUID NOT NULL REFERENCES organizations(id) ON DELETE CASCADE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
name TEXT NOT NULL DEFAULT '',
|
||||
mimetype TEXT NOT NULL,
|
||||
data BYTEA NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX idx_chat_files_owner ON chat_files(owner_id);
|
||||
CREATE INDEX idx_chat_files_org ON chat_files(organization_id);
|
||||
@@ -1,13 +0,0 @@
|
||||
INSERT INTO chat_files (id, owner_id, organization_id, created_at, name, mimetype, data)
|
||||
SELECT
|
||||
'00000000-0000-0000-0000-000000000099',
|
||||
u.id,
|
||||
om.organization_id,
|
||||
'2024-01-01 00:00:00+00',
|
||||
'test.png',
|
||||
'image/png',
|
||||
E'\\x89504E47'
|
||||
FROM users u
|
||||
JOIN organization_members om ON om.user_id = u.id
|
||||
ORDER BY u.created_at, u.id
|
||||
LIMIT 1;
|
||||
@@ -155,33 +155,20 @@ func (t Task) TaskTable() TaskTable {
|
||||
}
|
||||
|
||||
func (t Task) RBACObject() rbac.Object {
|
||||
obj := rbac.ResourceTask.
|
||||
return t.TaskTable().RBACObject()
|
||||
}
|
||||
|
||||
func (t TaskTable) RBACObject() rbac.Object {
|
||||
return rbac.ResourceTask.
|
||||
WithID(t.ID).
|
||||
WithOwner(t.OwnerID.String()).
|
||||
InOrg(t.OrganizationID)
|
||||
|
||||
if rbac.WorkspaceACLDisabled() {
|
||||
return obj
|
||||
}
|
||||
|
||||
if t.WorkspaceGroupACL != nil {
|
||||
obj = obj.WithGroupACL(t.WorkspaceGroupACL.RBACACL())
|
||||
}
|
||||
if t.WorkspaceUserACL != nil {
|
||||
obj = obj.WithACLUserList(t.WorkspaceUserACL.RBACACL())
|
||||
}
|
||||
|
||||
return obj
|
||||
}
|
||||
|
||||
func (c Chat) RBACObject() rbac.Object {
|
||||
return rbac.ResourceChat.WithID(c.ID).WithOwner(c.OwnerID.String())
|
||||
}
|
||||
|
||||
func (c ChatFile) RBACObject() rbac.Object {
|
||||
return rbac.ResourceChat.WithID(c.ID).WithOwner(c.OwnerID.String()).InOrg(c.OrganizationID)
|
||||
}
|
||||
|
||||
func (s APIKeyScope) ToRBAC() rbac.ScopeName {
|
||||
switch s {
|
||||
case ApiKeyScopeCoderAll:
|
||||
|
||||
@@ -815,7 +815,6 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, ar
|
||||
&i.AIBridgeInterception.Client,
|
||||
&i.AIBridgeInterception.ThreadParentID,
|
||||
&i.AIBridgeInterception.ThreadRootID,
|
||||
&i.AIBridgeInterception.ClientSessionID,
|
||||
&i.VisibleUser.ID,
|
||||
&i.VisibleUser.Username,
|
||||
&i.VisibleUser.Name,
|
||||
|
||||
@@ -3793,8 +3793,6 @@ type AIBridgeInterception struct {
|
||||
ThreadParentID uuid.NullUUID `db:"thread_parent_id" json:"thread_parent_id"`
|
||||
// The root interception of the thread that this interception belongs to.
|
||||
ThreadRootID uuid.NullUUID `db:"thread_root_id" json:"thread_root_id"`
|
||||
// The session ID supplied by the client (optional and not universally supported).
|
||||
ClientSessionID sql.NullString `db:"client_session_id" json:"client_session_id"`
|
||||
}
|
||||
|
||||
// Audit log of tokens used by intercepted requests in AI Bridge
|
||||
@@ -3926,16 +3924,6 @@ type ChatDiffStatus struct {
|
||||
GitRemoteOrigin string `db:"git_remote_origin" json:"git_remote_origin"`
|
||||
}
|
||||
|
||||
type ChatFile struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
|
||||
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Mimetype string `db:"mimetype" json:"mimetype"`
|
||||
Data []byte `db:"data" json:"data"`
|
||||
}
|
||||
|
||||
type ChatMessage struct {
|
||||
ID int64 `db:"id" json:"id"`
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
@@ -4506,8 +4494,6 @@ type Task struct {
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
DeletedAt sql.NullTime `db:"deleted_at" json:"deleted_at"`
|
||||
DisplayName string `db:"display_name" json:"display_name"`
|
||||
WorkspaceGroupACL WorkspaceACL `db:"workspace_group_acl" json:"workspace_group_acl"`
|
||||
WorkspaceUserACL WorkspaceACL `db:"workspace_user_acl" json:"workspace_user_acl"`
|
||||
Status TaskStatus `db:"status" json:"status"`
|
||||
StatusDebug json.RawMessage `db:"status_debug" json:"status_debug"`
|
||||
WorkspaceBuildNumber sql.NullInt32 `db:"workspace_build_number" json:"workspace_build_number"`
|
||||
|
||||
@@ -218,8 +218,6 @@ type sqlcQuerier interface {
|
||||
GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Chat, error)
|
||||
GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error)
|
||||
GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error)
|
||||
GetChatFileByID(ctx context.Context, id uuid.UUID) (ChatFile, error)
|
||||
GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]ChatFile, error)
|
||||
GetChatMessageByID(ctx context.Context, id int64) (ChatMessage, error)
|
||||
GetChatMessagesByChatID(ctx context.Context, arg GetChatMessagesByChatIDParams) ([]ChatMessage, error)
|
||||
GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error)
|
||||
@@ -603,7 +601,6 @@ type sqlcQuerier interface {
|
||||
InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (Group, error)
|
||||
InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error)
|
||||
InsertChat(ctx context.Context, arg InsertChatParams) (Chat, error)
|
||||
InsertChatFile(ctx context.Context, arg InsertChatFileParams) (InsertChatFileRow, error)
|
||||
InsertChatMessage(ctx context.Context, arg InsertChatMessageParams) (ChatMessage, error)
|
||||
InsertChatModelConfig(ctx context.Context, arg InsertChatModelConfigParams) (ChatModelConfig, error)
|
||||
InsertChatProvider(ctx context.Context, arg InsertChatProviderParams) (ChatProvider, error)
|
||||
|
||||
@@ -3867,37 +3867,6 @@ func TestGetProvisionerJobsByIDsWithQueuePosition(t *testing.T) {
|
||||
queueSizes: nil, // TODO(yevhenii): should it be empty array instead?
|
||||
queuePositions: nil,
|
||||
},
|
||||
// Many daemons with identical tags should produce same results as one.
|
||||
{
|
||||
name: "duplicate-daemons-same-tags",
|
||||
jobTags: []database.StringMap{
|
||||
{"a": "1"},
|
||||
{"a": "1", "b": "2"},
|
||||
},
|
||||
daemonTags: []database.StringMap{
|
||||
{"a": "1", "b": "2"},
|
||||
{"a": "1", "b": "2"},
|
||||
{"a": "1", "b": "2"},
|
||||
},
|
||||
queueSizes: []int64{2, 2},
|
||||
queuePositions: []int64{1, 2},
|
||||
},
|
||||
// Jobs that don't match any queried job's daemon should still
|
||||
// have correct queue positions.
|
||||
{
|
||||
name: "irrelevant-daemons-filtered",
|
||||
jobTags: []database.StringMap{
|
||||
{"a": "1"},
|
||||
{"x": "9"},
|
||||
},
|
||||
daemonTags: []database.StringMap{
|
||||
{"a": "1"},
|
||||
{"x": "9"},
|
||||
},
|
||||
queueSizes: []int64{1},
|
||||
queuePositions: []int64{1},
|
||||
skipJobIDs: map[int]struct{}{1: {}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -4223,51 +4192,6 @@ func TestGetProvisionerJobsByIDsWithQueuePosition_OrderValidation(t *testing.T)
|
||||
assert.EqualValues(t, []int64{1, 2, 3, 4, 5, 6}, queuePositions, "expected queue positions to be set correctly")
|
||||
}
|
||||
|
||||
func TestGetProvisionerJobsByIDsWithQueuePosition_DuplicateDaemons(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
now := dbtime.Now()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
// Create 3 pending jobs with the same tags.
|
||||
jobs := make([]database.ProvisionerJob, 3)
|
||||
for i := range jobs {
|
||||
jobs[i] = dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
||||
CreatedAt: now.Add(-time.Duration(3-i) * time.Minute),
|
||||
Tags: database.StringMap{"scope": "organization", "owner": ""},
|
||||
})
|
||||
}
|
||||
|
||||
// Create 50 daemons with identical tags (simulates scale).
|
||||
for i := range 50 {
|
||||
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
||||
Name: fmt.Sprintf("daemon_%d", i),
|
||||
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
||||
Tags: database.StringMap{"scope": "organization", "owner": ""},
|
||||
})
|
||||
}
|
||||
|
||||
jobIDs := make([]uuid.UUID, len(jobs))
|
||||
for i, j := range jobs {
|
||||
jobIDs[i] = j.ID
|
||||
}
|
||||
|
||||
results, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx,
|
||||
database.GetProvisionerJobsByIDsWithQueuePositionParams{
|
||||
IDs: jobIDs,
|
||||
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 3)
|
||||
|
||||
// All daemons have identical tags, so queue should be same as
|
||||
// if there were just one daemon.
|
||||
for i, r := range results {
|
||||
assert.Equal(t, int64(3), r.QueueSize, "job %d queue size", i)
|
||||
assert.Equal(t, int64(i+1), r.QueuePosition, "job %d queue position", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroupRemovalTrigger(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -8489,7 +8413,7 @@ func TestGetAuthenticatedWorkspaceAgentAndBuildByAuthToken_ShutdownScripts(t *te
|
||||
})
|
||||
|
||||
// Agent should still authenticate during stop build execution.
|
||||
row, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent.AuthToken)
|
||||
row, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agent.AuthToken)
|
||||
require.NoError(t, err, "agent should authenticate during stop build execution")
|
||||
require.Equal(t, agent.ID, row.WorkspaceAgent.ID)
|
||||
require.Equal(t, startBuild.ID, row.WorkspaceBuild.ID, "should return start build, not stop build")
|
||||
@@ -8547,7 +8471,7 @@ func TestGetAuthenticatedWorkspaceAgentAndBuildByAuthToken_ShutdownScripts(t *te
|
||||
})
|
||||
|
||||
// Agent should NOT authenticate after stop job completes.
|
||||
_, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent.AuthToken)
|
||||
_, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agent.AuthToken)
|
||||
require.ErrorIs(t, err, sql.ErrNoRows, "agent should not authenticate after stop job completes")
|
||||
})
|
||||
|
||||
@@ -8601,7 +8525,7 @@ func TestGetAuthenticatedWorkspaceAgentAndBuildByAuthToken_ShutdownScripts(t *te
|
||||
})
|
||||
|
||||
// Agent should NOT authenticate (start build failed).
|
||||
_, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent.AuthToken)
|
||||
_, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agent.AuthToken)
|
||||
require.ErrorIs(t, err, sql.ErrNoRows, "agent from failed start build should not authenticate")
|
||||
})
|
||||
|
||||
@@ -8656,7 +8580,7 @@ func TestGetAuthenticatedWorkspaceAgentAndBuildByAuthToken_ShutdownScripts(t *te
|
||||
})
|
||||
|
||||
// Agent should authenticate during pending stop build.
|
||||
row, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent.AuthToken)
|
||||
row, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agent.AuthToken)
|
||||
require.NoError(t, err, "agent should authenticate during pending stop build")
|
||||
require.Equal(t, agent.ID, row.WorkspaceAgent.ID)
|
||||
require.Equal(t, startBuild.ID, row.WorkspaceBuild.ID, "should return start build")
|
||||
@@ -8753,13 +8677,13 @@ func TestGetAuthenticatedWorkspaceAgentAndBuildByAuthToken_ShutdownScripts(t *te
|
||||
})
|
||||
|
||||
// Agent from build 3 should authenticate.
|
||||
row, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent2.AuthToken)
|
||||
row, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agent2.AuthToken)
|
||||
require.NoError(t, err, "agent from most recent start should authenticate during stop")
|
||||
require.Equal(t, agent2.ID, row.WorkspaceAgent.ID)
|
||||
require.Equal(t, startBuild2.ID, row.WorkspaceBuild.ID)
|
||||
|
||||
// Agent from build 1 should NOT authenticate.
|
||||
_, err = db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent1.AuthToken)
|
||||
_, err = db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agent1.AuthToken)
|
||||
require.ErrorIs(t, err, sql.ErrNoRows, "agent from old cycle should not authenticate")
|
||||
})
|
||||
|
||||
@@ -8813,7 +8737,7 @@ func TestGetAuthenticatedWorkspaceAgentAndBuildByAuthToken_ShutdownScripts(t *te
|
||||
})
|
||||
|
||||
// Agent from build 1 should NOT authenticate (latest is not STOP).
|
||||
_, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent1.AuthToken)
|
||||
_, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agent1.AuthToken)
|
||||
require.ErrorIs(t, err, sql.ErrNoRows, "agent should not authenticate when latest build is not STOP")
|
||||
})
|
||||
}
|
||||
@@ -8917,202 +8841,3 @@ func TestInsertWorkspaceAgentDevcontainers(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetChatMessagesForPromptByChatID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// This test exercises a complex CTE query for prompt
|
||||
// reconstruction after compaction. It requires Postgres.
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Helper: create a chat model config (required FK for chats).
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
|
||||
// A chat_providers row is required as a FK for model configs.
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "test-model",
|
||||
DisplayName: "Test Model",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 80,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
newChat := func(t *testing.T) database.Chat {
|
||||
t.Helper()
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "test-chat-" + uuid.NewString(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return chat
|
||||
}
|
||||
|
||||
insertMsg := func(
|
||||
t *testing.T,
|
||||
chatID uuid.UUID,
|
||||
role string,
|
||||
vis database.ChatMessageVisibility,
|
||||
compressed bool,
|
||||
content string,
|
||||
) database.ChatMessage {
|
||||
t.Helper()
|
||||
msg, err := db.InsertChatMessage(ctx, database.InsertChatMessageParams{
|
||||
ChatID: chatID,
|
||||
Role: role,
|
||||
Visibility: vis,
|
||||
Compressed: sql.NullBool{Bool: compressed, Valid: true},
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`"` + content + `"`),
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return msg
|
||||
}
|
||||
|
||||
msgIDs := func(msgs []database.ChatMessage) []int64 {
|
||||
ids := make([]int64, len(msgs))
|
||||
for i, m := range msgs {
|
||||
ids[i] = m.ID
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
t.Run("NoCompaction", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
chat := newChat(t)
|
||||
|
||||
sys := insertMsg(t, chat.ID, "system", database.ChatMessageVisibilityModel, false, "system prompt")
|
||||
usr := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityBoth, false, "hello")
|
||||
ast := insertMsg(t, chat.ID, "assistant", database.ChatMessageVisibilityBoth, false, "hi there")
|
||||
|
||||
got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{sys.ID, usr.ID, ast.ID}, msgIDs(got))
|
||||
})
|
||||
|
||||
t.Run("UserOnlyVisibilityExcluded", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
chat := newChat(t)
|
||||
|
||||
// Messages with visibility=user should NOT appear in the
|
||||
// prompt (they are only for the UI).
|
||||
insertMsg(t, chat.ID, "system", database.ChatMessageVisibilityModel, false, "system prompt")
|
||||
insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityUser, false, "user-only msg")
|
||||
usr := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityBoth, false, "hello")
|
||||
|
||||
got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
for _, m := range got {
|
||||
require.NotEqual(t, database.ChatMessageVisibilityUser, m.Visibility,
|
||||
"visibility=user messages should not appear in the prompt")
|
||||
}
|
||||
require.Contains(t, msgIDs(got), usr.ID)
|
||||
})
|
||||
|
||||
t.Run("AfterCompaction", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
chat := newChat(t)
|
||||
|
||||
// Pre-compaction conversation.
|
||||
sys := insertMsg(t, chat.ID, "system", database.ChatMessageVisibilityModel, false, "system prompt")
|
||||
preUser := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityBoth, false, "old question")
|
||||
preAsst := insertMsg(t, chat.ID, "assistant", database.ChatMessageVisibilityBoth, false, "old answer")
|
||||
|
||||
// Compaction messages:
|
||||
// 1. Summary (role=user, visibility=model, compressed=true).
|
||||
summary := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityModel, true, "compaction summary")
|
||||
// 2. Compressed assistant tool-call (visibility=user).
|
||||
insertMsg(t, chat.ID, "assistant", database.ChatMessageVisibilityUser, true, "tool call")
|
||||
// 3. Compressed tool result (visibility=both).
|
||||
insertMsg(t, chat.ID, "tool", database.ChatMessageVisibilityBoth, true, "tool result")
|
||||
|
||||
// Post-compaction messages.
|
||||
postUser := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityBoth, false, "new question")
|
||||
postAsst := insertMsg(t, chat.ID, "assistant", database.ChatMessageVisibilityBoth, false, "new answer")
|
||||
|
||||
got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
gotIDs := msgIDs(got)
|
||||
|
||||
// Must include: system prompt, summary, post-compaction.
|
||||
require.Contains(t, gotIDs, sys.ID, "system prompt must be included")
|
||||
require.Contains(t, gotIDs, summary.ID, "compaction summary must be included")
|
||||
require.Contains(t, gotIDs, postUser.ID, "post-compaction user msg must be included")
|
||||
require.Contains(t, gotIDs, postAsst.ID, "post-compaction assistant msg must be included")
|
||||
|
||||
// Must exclude: pre-compaction non-system messages.
|
||||
require.NotContains(t, gotIDs, preUser.ID, "pre-compaction user msg must be excluded")
|
||||
require.NotContains(t, gotIDs, preAsst.ID, "pre-compaction assistant msg must be excluded")
|
||||
|
||||
// Verify ordering.
|
||||
require.Equal(t, []int64{sys.ID, summary.ID, postUser.ID, postAsst.ID}, gotIDs)
|
||||
})
|
||||
|
||||
t.Run("AfterCompactionSummaryIsUserRole", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
chat := newChat(t)
|
||||
|
||||
// After compaction the summary must appear as role=user so
|
||||
// that LLM APIs (e.g. Anthropic) see at least one
|
||||
// non-system message in the prompt.
|
||||
insertMsg(t, chat.ID, "system", database.ChatMessageVisibilityModel, false, "system prompt")
|
||||
summary := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityModel, true, "summary text")
|
||||
newUsr := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityBoth, false, "new question")
|
||||
|
||||
got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
hasNonSystem := false
|
||||
for _, m := range got {
|
||||
if m.Role != "system" {
|
||||
hasNonSystem = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasNonSystem,
|
||||
"prompt must contain at least one non-system message after compaction")
|
||||
require.Contains(t, msgIDs(got), summary.ID)
|
||||
require.Contains(t, msgIDs(got), newUsr.ID)
|
||||
})
|
||||
|
||||
t.Run("CompressedToolResultNotPickedAsSummary", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
chat := newChat(t)
|
||||
|
||||
// The CTE uses visibility='model' (exact match). If it
|
||||
// used IN ('model','both'), the compressed tool result
|
||||
// (visibility=both) would be picked as the "summary"
|
||||
// instead of the actual summary.
|
||||
insertMsg(t, chat.ID, "system", database.ChatMessageVisibilityModel, false, "system prompt")
|
||||
summary := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityModel, true, "real summary")
|
||||
compressedTool := insertMsg(t, chat.ID, "tool", database.ChatMessageVisibilityBoth, true, "tool result")
|
||||
postUser := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityBoth, false, "follow-up")
|
||||
|
||||
got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
gotIDs := msgIDs(got)
|
||||
require.Contains(t, gotIDs, summary.ID, "real summary must be included")
|
||||
require.NotContains(t, gotIDs, compressedTool.ID,
|
||||
"compressed tool result must not be included")
|
||||
require.Contains(t, gotIDs, postUser.ID)
|
||||
})
|
||||
}
|
||||
|
||||
+21
-143
@@ -378,7 +378,7 @@ func (q *sqlQuerier) DeleteOldAIBridgeRecords(ctx context.Context, beforeTime ti
|
||||
|
||||
const getAIBridgeInterceptionByID = `-- name: GetAIBridgeInterceptionByID :one
|
||||
SELECT
|
||||
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id
|
||||
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
WHERE
|
||||
@@ -400,7 +400,6 @@ func (q *sqlQuerier) GetAIBridgeInterceptionByID(ctx context.Context, id uuid.UU
|
||||
&i.Client,
|
||||
&i.ThreadParentID,
|
||||
&i.ThreadRootID,
|
||||
&i.ClientSessionID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -435,7 +434,7 @@ func (q *sqlQuerier) GetAIBridgeInterceptionLineageByToolCallID(ctx context.Cont
|
||||
|
||||
const getAIBridgeInterceptions = `-- name: GetAIBridgeInterceptions :many
|
||||
SELECT
|
||||
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id
|
||||
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
`
|
||||
@@ -461,7 +460,6 @@ func (q *sqlQuerier) GetAIBridgeInterceptions(ctx context.Context) ([]AIBridgeIn
|
||||
&i.Client,
|
||||
&i.ThreadParentID,
|
||||
&i.ThreadRootID,
|
||||
&i.ClientSessionID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -608,11 +606,11 @@ func (q *sqlQuerier) GetAIBridgeUserPromptsByInterceptionID(ctx context.Context,
|
||||
|
||||
const insertAIBridgeInterception = `-- name: InsertAIBridgeInterception :one
|
||||
INSERT INTO aibridge_interceptions (
|
||||
id, api_key_id, initiator_id, provider, model, metadata, started_at, client, client_session_id, thread_parent_id, thread_root_id
|
||||
id, api_key_id, initiator_id, provider, model, metadata, started_at, client, thread_parent_id, thread_root_id
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, COALESCE($6::jsonb, '{}'::jsonb), $7, $8, $9, $10::uuid, $11::uuid
|
||||
$1, $2, $3, $4, $5, COALESCE($6::jsonb, '{}'::jsonb), $7, $8, $9::uuid, $10::uuid
|
||||
)
|
||||
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id
|
||||
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id
|
||||
`
|
||||
|
||||
type InsertAIBridgeInterceptionParams struct {
|
||||
@@ -624,7 +622,6 @@ type InsertAIBridgeInterceptionParams struct {
|
||||
Metadata json.RawMessage `db:"metadata" json:"metadata"`
|
||||
StartedAt time.Time `db:"started_at" json:"started_at"`
|
||||
Client sql.NullString `db:"client" json:"client"`
|
||||
ClientSessionID sql.NullString `db:"client_session_id" json:"client_session_id"`
|
||||
ThreadParentInterceptionID uuid.NullUUID `db:"thread_parent_interception_id" json:"thread_parent_interception_id"`
|
||||
ThreadRootInterceptionID uuid.NullUUID `db:"thread_root_interception_id" json:"thread_root_interception_id"`
|
||||
}
|
||||
@@ -639,7 +636,6 @@ func (q *sqlQuerier) InsertAIBridgeInterception(ctx context.Context, arg InsertA
|
||||
arg.Metadata,
|
||||
arg.StartedAt,
|
||||
arg.Client,
|
||||
arg.ClientSessionID,
|
||||
arg.ThreadParentInterceptionID,
|
||||
arg.ThreadRootInterceptionID,
|
||||
)
|
||||
@@ -656,7 +652,6 @@ func (q *sqlQuerier) InsertAIBridgeInterception(ctx context.Context, arg InsertA
|
||||
&i.Client,
|
||||
&i.ThreadParentID,
|
||||
&i.ThreadRootID,
|
||||
&i.ClientSessionID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -798,7 +793,7 @@ func (q *sqlQuerier) InsertAIBridgeUserPrompt(ctx context.Context, arg InsertAIB
|
||||
|
||||
const listAIBridgeInterceptions = `-- name: ListAIBridgeInterceptions :many
|
||||
SELECT
|
||||
aibridge_interceptions.id, aibridge_interceptions.initiator_id, aibridge_interceptions.provider, aibridge_interceptions.model, aibridge_interceptions.started_at, aibridge_interceptions.metadata, aibridge_interceptions.ended_at, aibridge_interceptions.api_key_id, aibridge_interceptions.client, aibridge_interceptions.thread_parent_id, aibridge_interceptions.thread_root_id, aibridge_interceptions.client_session_id,
|
||||
aibridge_interceptions.id, aibridge_interceptions.initiator_id, aibridge_interceptions.provider, aibridge_interceptions.model, aibridge_interceptions.started_at, aibridge_interceptions.metadata, aibridge_interceptions.ended_at, aibridge_interceptions.api_key_id, aibridge_interceptions.client, aibridge_interceptions.thread_parent_id, aibridge_interceptions.thread_root_id,
|
||||
visible_users.id, visible_users.username, visible_users.name, visible_users.avatar_url
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
@@ -909,7 +904,6 @@ func (q *sqlQuerier) ListAIBridgeInterceptions(ctx context.Context, arg ListAIBr
|
||||
&i.AIBridgeInterception.Client,
|
||||
&i.AIBridgeInterception.ThreadParentID,
|
||||
&i.AIBridgeInterception.ThreadRootID,
|
||||
&i.AIBridgeInterception.ClientSessionID,
|
||||
&i.VisibleUser.ID,
|
||||
&i.VisibleUser.Username,
|
||||
&i.VisibleUser.Name,
|
||||
@@ -1170,7 +1164,7 @@ UPDATE aibridge_interceptions
|
||||
WHERE
|
||||
id = $2::uuid
|
||||
AND ended_at IS NULL
|
||||
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id
|
||||
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id
|
||||
`
|
||||
|
||||
type UpdateAIBridgeInterceptionEndedParams struct {
|
||||
@@ -1193,7 +1187,6 @@ func (q *sqlQuerier) UpdateAIBridgeInterceptionEnded(ctx context.Context, arg Up
|
||||
&i.Client,
|
||||
&i.ThreadParentID,
|
||||
&i.ThreadRootID,
|
||||
&i.ClientSessionID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -2214,103 +2207,6 @@ func (q *sqlQuerier) UpsertBoundaryUsageStats(ctx context.Context, arg UpsertBou
|
||||
return new_period, err
|
||||
}
|
||||
|
||||
const getChatFileByID = `-- name: GetChatFileByID :one
|
||||
SELECT id, owner_id, organization_id, created_at, name, mimetype, data FROM chat_files WHERE id = $1::uuid
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetChatFileByID(ctx context.Context, id uuid.UUID) (ChatFile, error) {
|
||||
row := q.db.QueryRowContext(ctx, getChatFileByID, id)
|
||||
var i ChatFile
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.OwnerID,
|
||||
&i.OrganizationID,
|
||||
&i.CreatedAt,
|
||||
&i.Name,
|
||||
&i.Mimetype,
|
||||
&i.Data,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getChatFilesByIDs = `-- name: GetChatFilesByIDs :many
|
||||
SELECT id, owner_id, organization_id, created_at, name, mimetype, data FROM chat_files WHERE id = ANY($1::uuid[])
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]ChatFile, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getChatFilesByIDs, pq.Array(ids))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []ChatFile
|
||||
for rows.Next() {
|
||||
var i ChatFile
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.OwnerID,
|
||||
&i.OrganizationID,
|
||||
&i.CreatedAt,
|
||||
&i.Name,
|
||||
&i.Mimetype,
|
||||
&i.Data,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const insertChatFile = `-- name: InsertChatFile :one
|
||||
INSERT INTO chat_files (owner_id, organization_id, name, mimetype, data)
|
||||
VALUES ($1::uuid, $2::uuid, $3::text, $4::text, $5::bytea)
|
||||
RETURNING id, owner_id, organization_id, created_at, name, mimetype
|
||||
`
|
||||
|
||||
type InsertChatFileParams struct {
|
||||
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
|
||||
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Mimetype string `db:"mimetype" json:"mimetype"`
|
||||
Data []byte `db:"data" json:"data"`
|
||||
}
|
||||
|
||||
type InsertChatFileRow struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
|
||||
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Mimetype string `db:"mimetype" json:"mimetype"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertChatFile(ctx context.Context, arg InsertChatFileParams) (InsertChatFileRow, error) {
|
||||
row := q.db.QueryRowContext(ctx, insertChatFile,
|
||||
arg.OwnerID,
|
||||
arg.OrganizationID,
|
||||
arg.Name,
|
||||
arg.Mimetype,
|
||||
arg.Data,
|
||||
)
|
||||
var i InsertChatFileRow
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.OwnerID,
|
||||
&i.OrganizationID,
|
||||
&i.CreatedAt,
|
||||
&i.Name,
|
||||
&i.Mimetype,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const deleteChatModelConfigByID = `-- name: DeleteChatModelConfigByID :exec
|
||||
UPDATE
|
||||
chat_model_configs
|
||||
@@ -3321,8 +3217,9 @@ WITH latest_compressed_summary AS (
|
||||
chat_messages
|
||||
WHERE
|
||||
chat_id = $1::uuid
|
||||
AND role = 'system'
|
||||
AND visibility IN ('model', 'both')
|
||||
AND compressed = TRUE
|
||||
AND visibility = 'model'
|
||||
ORDER BY
|
||||
created_at DESC,
|
||||
id DESC
|
||||
@@ -12654,7 +12551,7 @@ const getProvisionerJobsByIDsWithQueuePosition = `-- name: GetProvisionerJobsByI
|
||||
WITH filtered_provisioner_jobs AS (
|
||||
-- Step 1: Filter provisioner_jobs
|
||||
SELECT
|
||||
id, created_at, tags
|
||||
id, created_at
|
||||
FROM
|
||||
provisioner_jobs
|
||||
WHERE
|
||||
@@ -12669,32 +12566,21 @@ pending_jobs AS (
|
||||
WHERE
|
||||
job_status = 'pending'
|
||||
),
|
||||
unique_daemon_tags AS (
|
||||
SELECT DISTINCT tags FROM provisioner_daemons pd
|
||||
WHERE pd.last_seen_at IS NOT NULL
|
||||
AND pd.last_seen_at >= (NOW() - ($2::bigint || ' ms')::interval)
|
||||
),
|
||||
relevant_daemon_tags AS (
|
||||
SELECT udt.tags
|
||||
FROM unique_daemon_tags udt
|
||||
WHERE EXISTS (
|
||||
SELECT 1 FROM filtered_provisioner_jobs fpj
|
||||
WHERE provisioner_tagset_contains(udt.tags, fpj.tags)
|
||||
)
|
||||
online_provisioner_daemons AS (
|
||||
SELECT id, tags FROM provisioner_daemons pd
|
||||
WHERE pd.last_seen_at IS NOT NULL AND pd.last_seen_at >= (NOW() - ($2::bigint || ' ms')::interval)
|
||||
),
|
||||
ranked_jobs AS (
|
||||
-- Step 3: Rank only pending jobs based on provisioner availability
|
||||
SELECT
|
||||
pj.id,
|
||||
pj.created_at,
|
||||
ROW_NUMBER() OVER (PARTITION BY rdt.tags ORDER BY pj.initiator_id = 'c42fdf75-3097-471c-8c33-fb52454d81c0'::uuid ASC, pj.created_at ASC) AS queue_position,
|
||||
COUNT(*) OVER (PARTITION BY rdt.tags) AS queue_size
|
||||
ROW_NUMBER() OVER (PARTITION BY opd.id ORDER BY pj.initiator_id = 'c42fdf75-3097-471c-8c33-fb52454d81c0'::uuid ASC, pj.created_at ASC) AS queue_position,
|
||||
COUNT(*) OVER (PARTITION BY opd.id) AS queue_size
|
||||
FROM
|
||||
pending_jobs pj
|
||||
INNER JOIN
|
||||
relevant_daemon_tags rdt
|
||||
ON
|
||||
provisioner_tagset_contains(rdt.tags, pj.tags)
|
||||
INNER JOIN online_provisioner_daemons opd
|
||||
ON provisioner_tagset_contains(opd.tags, pj.tags) -- Join only on the small pending set
|
||||
),
|
||||
final_jobs AS (
|
||||
-- Step 4: Compute best queue position and max queue size per job
|
||||
@@ -15310,7 +15196,7 @@ func (q *sqlQuerier) DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid
|
||||
}
|
||||
|
||||
const getTaskByID = `-- name: GetTaskByID :one
|
||||
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, display_name, workspace_group_acl, workspace_user_acl, status, status_debug, workspace_build_number, workspace_agent_id, workspace_app_id, workspace_agent_lifecycle_state, workspace_app_health, owner_username, owner_name, owner_avatar_url FROM tasks_with_status WHERE id = $1::uuid
|
||||
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, display_name, status, status_debug, workspace_build_number, workspace_agent_id, workspace_app_id, workspace_agent_lifecycle_state, workspace_app_health, owner_username, owner_name, owner_avatar_url FROM tasks_with_status WHERE id = $1::uuid
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetTaskByID(ctx context.Context, id uuid.UUID) (Task, error) {
|
||||
@@ -15328,8 +15214,6 @@ func (q *sqlQuerier) GetTaskByID(ctx context.Context, id uuid.UUID) (Task, error
|
||||
&i.CreatedAt,
|
||||
&i.DeletedAt,
|
||||
&i.DisplayName,
|
||||
&i.WorkspaceGroupACL,
|
||||
&i.WorkspaceUserACL,
|
||||
&i.Status,
|
||||
&i.StatusDebug,
|
||||
&i.WorkspaceBuildNumber,
|
||||
@@ -15345,7 +15229,7 @@ func (q *sqlQuerier) GetTaskByID(ctx context.Context, id uuid.UUID) (Task, error
|
||||
}
|
||||
|
||||
const getTaskByOwnerIDAndName = `-- name: GetTaskByOwnerIDAndName :one
|
||||
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, display_name, workspace_group_acl, workspace_user_acl, status, status_debug, workspace_build_number, workspace_agent_id, workspace_app_id, workspace_agent_lifecycle_state, workspace_app_health, owner_username, owner_name, owner_avatar_url FROM tasks_with_status
|
||||
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, display_name, status, status_debug, workspace_build_number, workspace_agent_id, workspace_app_id, workspace_agent_lifecycle_state, workspace_app_health, owner_username, owner_name, owner_avatar_url FROM tasks_with_status
|
||||
WHERE
|
||||
owner_id = $1::uuid
|
||||
AND deleted_at IS NULL
|
||||
@@ -15372,8 +15256,6 @@ func (q *sqlQuerier) GetTaskByOwnerIDAndName(ctx context.Context, arg GetTaskByO
|
||||
&i.CreatedAt,
|
||||
&i.DeletedAt,
|
||||
&i.DisplayName,
|
||||
&i.WorkspaceGroupACL,
|
||||
&i.WorkspaceUserACL,
|
||||
&i.Status,
|
||||
&i.StatusDebug,
|
||||
&i.WorkspaceBuildNumber,
|
||||
@@ -15389,7 +15271,7 @@ func (q *sqlQuerier) GetTaskByOwnerIDAndName(ctx context.Context, arg GetTaskByO
|
||||
}
|
||||
|
||||
const getTaskByWorkspaceID = `-- name: GetTaskByWorkspaceID :one
|
||||
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, display_name, workspace_group_acl, workspace_user_acl, status, status_debug, workspace_build_number, workspace_agent_id, workspace_app_id, workspace_agent_lifecycle_state, workspace_app_health, owner_username, owner_name, owner_avatar_url FROM tasks_with_status WHERE workspace_id = $1::uuid
|
||||
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, display_name, status, status_debug, workspace_build_number, workspace_agent_id, workspace_app_id, workspace_agent_lifecycle_state, workspace_app_health, owner_username, owner_name, owner_avatar_url FROM tasks_with_status WHERE workspace_id = $1::uuid
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetTaskByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (Task, error) {
|
||||
@@ -15407,8 +15289,6 @@ func (q *sqlQuerier) GetTaskByWorkspaceID(ctx context.Context, workspaceID uuid.
|
||||
&i.CreatedAt,
|
||||
&i.DeletedAt,
|
||||
&i.DisplayName,
|
||||
&i.WorkspaceGroupACL,
|
||||
&i.WorkspaceUserACL,
|
||||
&i.Status,
|
||||
&i.StatusDebug,
|
||||
&i.WorkspaceBuildNumber,
|
||||
@@ -15688,7 +15568,7 @@ func (q *sqlQuerier) InsertTask(ctx context.Context, arg InsertTaskParams) (Task
|
||||
}
|
||||
|
||||
const listTasks = `-- name: ListTasks :many
|
||||
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, display_name, workspace_group_acl, workspace_user_acl, status, status_debug, workspace_build_number, workspace_agent_id, workspace_app_id, workspace_agent_lifecycle_state, workspace_app_health, owner_username, owner_name, owner_avatar_url FROM tasks_with_status tws
|
||||
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, display_name, status, status_debug, workspace_build_number, workspace_agent_id, workspace_app_id, workspace_agent_lifecycle_state, workspace_app_health, owner_username, owner_name, owner_avatar_url FROM tasks_with_status tws
|
||||
WHERE tws.deleted_at IS NULL
|
||||
AND CASE WHEN $1::UUID != '00000000-0000-0000-0000-000000000000' THEN tws.owner_id = $1::UUID ELSE TRUE END
|
||||
AND CASE WHEN $2::UUID != '00000000-0000-0000-0000-000000000000' THEN tws.organization_id = $2::UUID ELSE TRUE END
|
||||
@@ -15723,8 +15603,6 @@ func (q *sqlQuerier) ListTasks(ctx context.Context, arg ListTasksParams) ([]Task
|
||||
&i.CreatedAt,
|
||||
&i.DeletedAt,
|
||||
&i.DisplayName,
|
||||
&i.WorkspaceGroupACL,
|
||||
&i.WorkspaceUserACL,
|
||||
&i.Status,
|
||||
&i.StatusDebug,
|
||||
&i.WorkspaceBuildNumber,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
-- name: InsertAIBridgeInterception :one
|
||||
INSERT INTO aibridge_interceptions (
|
||||
id, api_key_id, initiator_id, provider, model, metadata, started_at, client, client_session_id, thread_parent_id, thread_root_id
|
||||
id, api_key_id, initiator_id, provider, model, metadata, started_at, client, thread_parent_id, thread_root_id
|
||||
) VALUES (
|
||||
@id, @api_key_id, @initiator_id, @provider, @model, COALESCE(@metadata::jsonb, '{}'::jsonb), @started_at, @client, sqlc.narg('client_session_id'), sqlc.narg('thread_parent_interception_id')::uuid, sqlc.narg('thread_root_interception_id')::uuid
|
||||
@id, @api_key_id, @initiator_id, @provider, @model, COALESCE(@metadata::jsonb, '{}'::jsonb), @started_at, @client, sqlc.narg('thread_parent_interception_id')::uuid, sqlc.narg('thread_root_interception_id')::uuid
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
-- name: InsertChatFile :one
|
||||
INSERT INTO chat_files (owner_id, organization_id, name, mimetype, data)
|
||||
VALUES (@owner_id::uuid, @organization_id::uuid, @name::text, @mimetype::text, @data::bytea)
|
||||
RETURNING id, owner_id, organization_id, created_at, name, mimetype;
|
||||
|
||||
-- name: GetChatFileByID :one
|
||||
SELECT * FROM chat_files WHERE id = @id::uuid;
|
||||
|
||||
-- name: GetChatFilesByIDs :many
|
||||
SELECT * FROM chat_files WHERE id = ANY(@ids::uuid[]);
|
||||
@@ -54,8 +54,9 @@ WITH latest_compressed_summary AS (
|
||||
chat_messages
|
||||
WHERE
|
||||
chat_id = @chat_id::uuid
|
||||
AND role = 'system'
|
||||
AND visibility IN ('model', 'both')
|
||||
AND compressed = TRUE
|
||||
AND visibility = 'model'
|
||||
ORDER BY
|
||||
created_at DESC,
|
||||
id DESC
|
||||
|
||||
@@ -79,7 +79,7 @@ WHERE
|
||||
WITH filtered_provisioner_jobs AS (
|
||||
-- Step 1: Filter provisioner_jobs
|
||||
SELECT
|
||||
id, created_at, tags
|
||||
id, created_at
|
||||
FROM
|
||||
provisioner_jobs
|
||||
WHERE
|
||||
@@ -94,32 +94,21 @@ pending_jobs AS (
|
||||
WHERE
|
||||
job_status = 'pending'
|
||||
),
|
||||
unique_daemon_tags AS (
|
||||
SELECT DISTINCT tags FROM provisioner_daemons pd
|
||||
WHERE pd.last_seen_at IS NOT NULL
|
||||
AND pd.last_seen_at >= (NOW() - (@stale_interval_ms::bigint || ' ms')::interval)
|
||||
),
|
||||
relevant_daemon_tags AS (
|
||||
SELECT udt.tags
|
||||
FROM unique_daemon_tags udt
|
||||
WHERE EXISTS (
|
||||
SELECT 1 FROM filtered_provisioner_jobs fpj
|
||||
WHERE provisioner_tagset_contains(udt.tags, fpj.tags)
|
||||
)
|
||||
online_provisioner_daemons AS (
|
||||
SELECT id, tags FROM provisioner_daemons pd
|
||||
WHERE pd.last_seen_at IS NOT NULL AND pd.last_seen_at >= (NOW() - (@stale_interval_ms::bigint || ' ms')::interval)
|
||||
),
|
||||
ranked_jobs AS (
|
||||
-- Step 3: Rank only pending jobs based on provisioner availability
|
||||
SELECT
|
||||
pj.id,
|
||||
pj.created_at,
|
||||
ROW_NUMBER() OVER (PARTITION BY rdt.tags ORDER BY pj.initiator_id = 'c42fdf75-3097-471c-8c33-fb52454d81c0'::uuid ASC, pj.created_at ASC) AS queue_position,
|
||||
COUNT(*) OVER (PARTITION BY rdt.tags) AS queue_size
|
||||
ROW_NUMBER() OVER (PARTITION BY opd.id ORDER BY pj.initiator_id = 'c42fdf75-3097-471c-8c33-fb52454d81c0'::uuid ASC, pj.created_at ASC) AS queue_position,
|
||||
COUNT(*) OVER (PARTITION BY opd.id) AS queue_size
|
||||
FROM
|
||||
pending_jobs pj
|
||||
INNER JOIN
|
||||
relevant_daemon_tags rdt
|
||||
ON
|
||||
provisioner_tagset_contains(rdt.tags, pj.tags)
|
||||
INNER JOIN online_provisioner_daemons opd
|
||||
ON provisioner_tagset_contains(opd.tags, pj.tags) -- Join only on the small pending set
|
||||
),
|
||||
final_jobs AS (
|
||||
-- Step 4: Compute best queue position and max queue size per job
|
||||
|
||||
@@ -82,12 +82,6 @@ sql:
|
||||
- column: "template_usage_stats.app_usage_mins"
|
||||
go_type:
|
||||
type: "StringMapOfInt"
|
||||
- column: "tasks_with_status.workspace_user_acl"
|
||||
go_type:
|
||||
type: "WorkspaceACL"
|
||||
- column: "tasks_with_status.workspace_group_acl"
|
||||
go_type:
|
||||
type: "WorkspaceACL"
|
||||
- column: "workspaces.user_acl"
|
||||
go_type:
|
||||
type: "WorkspaceACL"
|
||||
@@ -192,8 +186,6 @@ sql:
|
||||
jwt: JWT
|
||||
user_acl: UserACL
|
||||
group_acl: GroupACL
|
||||
workspace_user_acl: WorkspaceUserACL
|
||||
workspace_group_acl: WorkspaceGroupACL
|
||||
user_acl_display_info: UserACLDisplayInfo
|
||||
group_acl_display_info: GroupACLDisplayInfo
|
||||
troubleshooting_url: TroubleshootingURL
|
||||
|
||||
@@ -15,7 +15,6 @@ const (
|
||||
UniqueAuditLogsPkey UniqueConstraint = "audit_logs_pkey" // ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id);
|
||||
UniqueBoundaryUsageStatsPkey UniqueConstraint = "boundary_usage_stats_pkey" // ALTER TABLE ONLY boundary_usage_stats ADD CONSTRAINT boundary_usage_stats_pkey PRIMARY KEY (replica_id);
|
||||
UniqueChatDiffStatusesPkey UniqueConstraint = "chat_diff_statuses_pkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
|
||||
UniqueChatFilesPkey UniqueConstraint = "chat_files_pkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id);
|
||||
UniqueChatMessagesPkey UniqueConstraint = "chat_messages_pkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_pkey PRIMARY KEY (id);
|
||||
UniqueChatModelConfigsPkey UniqueConstraint = "chat_model_configs_pkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_pkey PRIMARY KEY (id);
|
||||
UniqueChatProvidersPkey UniqueConstraint = "chat_providers_pkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_pkey PRIMARY KEY (id);
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
@@ -336,7 +337,7 @@ func TestRefreshToken(t *testing.T) {
|
||||
require.Equal(t, 1, validateCalls, "token is validated")
|
||||
require.Equal(t, 1, refreshCalls, "token is refreshed")
|
||||
require.NotEqualf(t, link.OAuthAccessToken, updated.OAuthAccessToken, "token is updated")
|
||||
dbLink, err := db.GetExternalAuthLink(context.Background(), database.GetExternalAuthLinkParams{
|
||||
dbLink, err := db.GetExternalAuthLink(dbauthz.AsSystemRestricted(context.Background()), database.GetExternalAuthLinkParams{
|
||||
ProviderID: link.ProviderID,
|
||||
UserID: link.UserID,
|
||||
})
|
||||
|
||||
+194
-391
@@ -30,57 +30,7 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
type (
|
||||
apiKeyContextKey struct{}
|
||||
apiKeyPrecheckedContextKey struct{}
|
||||
)
|
||||
|
||||
// ValidateAPIKeyConfig holds the settings needed for API key
|
||||
// validation at the top of the request lifecycle. Unlike
|
||||
// ExtractAPIKeyConfig it omits route-specific fields
|
||||
// (RedirectToLogin, Optional, ActivateDormantUser, etc.).
|
||||
type ValidateAPIKeyConfig struct {
|
||||
DB database.Store
|
||||
OAuth2Configs *OAuth2Configs
|
||||
DisableSessionExpiryRefresh bool
|
||||
// SessionTokenFunc overrides how the API token is extracted
|
||||
// from the request. Nil uses the default (cookie/header).
|
||||
SessionTokenFunc func(*http.Request) string
|
||||
Logger slog.Logger
|
||||
}
|
||||
|
||||
// ValidateAPIKeyResult is the outcome of successful validation.
|
||||
type ValidateAPIKeyResult struct {
|
||||
Key database.APIKey
|
||||
Subject rbac.Subject
|
||||
UserStatus database.UserStatus
|
||||
}
|
||||
|
||||
// ValidateAPIKeyError represents a validation failure with enough
|
||||
// context for downstream middlewares to decide how to respond.
|
||||
type ValidateAPIKeyError struct {
|
||||
Code int
|
||||
Response codersdk.Response
|
||||
// Hard is true for server errors and active failures (5xx,
|
||||
// OAuth refresh failures) that must be surfaced even on
|
||||
// optional-auth routes. Soft errors (missing/expired token)
|
||||
// may be swallowed on optional routes.
|
||||
Hard bool
|
||||
}
|
||||
|
||||
func (e *ValidateAPIKeyError) Error() string {
|
||||
return e.Response.Message
|
||||
}
|
||||
|
||||
// APIKeyPrechecked stores the result of top-level API key
|
||||
// validation performed by PrecheckAPIKey. It distinguishes
|
||||
// two states:
|
||||
// - Validation failed (including no token): Result == nil && Err != nil
|
||||
// - Validation passed: Result != nil && Err == nil
|
||||
type APIKeyPrechecked struct {
|
||||
Result *ValidateAPIKeyResult
|
||||
Err *ValidateAPIKeyError
|
||||
}
|
||||
type apiKeyContextKey struct{}
|
||||
|
||||
// APIKeyOptional may return an API key from the ExtractAPIKey handler.
|
||||
func APIKeyOptional(r *http.Request) (database.APIKey, bool) {
|
||||
@@ -199,298 +149,6 @@ func ExtractAPIKeyMW(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// PrecheckAPIKey extracts and fully validates the API key on every
|
||||
// request (if present) and stores the result in context. It never
|
||||
// writes error responses and always calls next.
|
||||
//
|
||||
// The rate limiter reads the stored result to key by user ID and
|
||||
// check the Owner bypass header. Downstream ExtractAPIKeyMW reads
|
||||
// it to avoid redundant DB lookups and validation.
|
||||
func PrecheckAPIKey(cfg ValidateAPIKeyConfig) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Already prechecked (shouldn't happen, but guard).
|
||||
if _, ok := ctx.Value(apiKeyPrecheckedContextKey{}).(APIKeyPrechecked); ok {
|
||||
next.ServeHTTP(rw, r)
|
||||
return
|
||||
}
|
||||
|
||||
result, valErr := ValidateAPIKey(ctx, cfg, r)
|
||||
|
||||
prechecked := APIKeyPrechecked{
|
||||
Result: result,
|
||||
Err: valErr,
|
||||
}
|
||||
ctx = context.WithValue(ctx, apiKeyPrecheckedContextKey{}, prechecked)
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateAPIKey extracts and validates the API key from the
|
||||
// request. It performs all security-critical checks:
|
||||
// - Token extraction and parsing
|
||||
// - Database lookup + secret hash validation
|
||||
// - Expiry check
|
||||
// - OIDC/OAuth token refresh (if applicable)
|
||||
// - API key LastUsed / ExpiresAt DB updates
|
||||
// - User role lookup (UserRBACSubject)
|
||||
//
|
||||
// It does NOT:
|
||||
// - Write HTTP error responses
|
||||
// - Activate dormant users (route-specific)
|
||||
// - Redirect to login (route-specific)
|
||||
// - Check OAuth2 audience (route-specific, depends on AccessURL)
|
||||
// - Set PostAuth headers (route-specific)
|
||||
// - Check user active status (route-specific, depends on dormant activation)
|
||||
//
|
||||
// Returns (result, nil) on success or (nil, error) on failure.
|
||||
func ValidateAPIKey(ctx context.Context, cfg ValidateAPIKeyConfig, r *http.Request) (*ValidateAPIKeyResult, *ValidateAPIKeyError) {
|
||||
key, resp, ok := APIKeyFromRequest(ctx, cfg.DB, cfg.SessionTokenFunc, r)
|
||||
if !ok {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: resp,
|
||||
}
|
||||
}
|
||||
|
||||
// Log the API key ID for all requests that have a valid key
|
||||
// format and secret, regardless of whether subsequent validation
|
||||
// (expiry, user status, etc.) succeeds.
|
||||
if rl := loggermw.RequestLoggerFromContext(ctx); rl != nil {
|
||||
rl.WithFields(slog.F("api_key_id", key.ID))
|
||||
}
|
||||
|
||||
now := dbtime.Now()
|
||||
if key.ExpiresAt.Before(now) {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Refresh OIDC/GitHub tokens if applicable.
|
||||
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC {
|
||||
//nolint:gocritic // System needs to fetch UserLink to check if it's valid.
|
||||
link, err := cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: key.UserID,
|
||||
LoginType: key.LoginType,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: "You must re-authenticate with the login provider.",
|
||||
},
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusInternalServerError,
|
||||
Response: codersdk.Response{
|
||||
Message: "A database error occurred",
|
||||
Detail: fmt.Sprintf("get user link by user ID and login type: %s", err.Error()),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
// Check if the OAuth token is expired.
|
||||
if !link.OAuthExpiry.IsZero() && link.OAuthExpiry.Before(now) {
|
||||
if cfg.OAuth2Configs.IsZero() {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusInternalServerError,
|
||||
Response: codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Unable to refresh OAuth token for login type %q. "+
|
||||
"No OAuth2Configs provided. Contact an administrator to configure this login type.", key.LoginType),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
|
||||
var friendlyName string
|
||||
var oauthConfig promoauth.OAuth2Config
|
||||
switch key.LoginType {
|
||||
case database.LoginTypeGithub:
|
||||
oauthConfig = cfg.OAuth2Configs.Github
|
||||
friendlyName = "GitHub"
|
||||
case database.LoginTypeOIDC:
|
||||
oauthConfig = cfg.OAuth2Configs.OIDC
|
||||
friendlyName = "OpenID Connect"
|
||||
default:
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusInternalServerError,
|
||||
Response: codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Unexpected authentication type %q.", key.LoginType),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
|
||||
if oauthConfig == nil {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusInternalServerError,
|
||||
Response: codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Unable to refresh OAuth token for login type %q. "+
|
||||
"OAuth2Config not provided. Contact an administrator to configure this login type.", key.LoginType),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Soft error: session expired naturally with no
|
||||
// refresh token. Optional-auth routes treat this as
|
||||
// unauthenticated.
|
||||
if link.OAuthRefreshToken == "" {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: fmt.Sprintf("%s session expired at %q. Try signing in again.", friendlyName, link.OAuthExpiry.String()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// We have a refresh token, so let's try it.
|
||||
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
|
||||
AccessToken: link.OAuthAccessToken,
|
||||
RefreshToken: link.OAuthRefreshToken,
|
||||
Expiry: link.OAuthExpiry,
|
||||
}).Token()
|
||||
// Hard error: we actively tried to refresh and the
|
||||
// provider rejected it — surface even on optional-auth
|
||||
// routes.
|
||||
if err != nil {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: codersdk.Response{
|
||||
Message: fmt.Sprintf(
|
||||
"Could not refresh expired %s token. Try re-authenticating to resolve this issue.",
|
||||
friendlyName),
|
||||
Detail: err.Error(),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
link.OAuthAccessToken = token.AccessToken
|
||||
link.OAuthRefreshToken = token.RefreshToken
|
||||
link.OAuthExpiry = token.Expiry
|
||||
//nolint:gocritic // system needs to update user link
|
||||
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
|
||||
UserID: link.UserID,
|
||||
LoginType: link.LoginType,
|
||||
OAuthAccessToken: link.OAuthAccessToken,
|
||||
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
OAuthRefreshToken: link.OAuthRefreshToken,
|
||||
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
OAuthExpiry: link.OAuthExpiry,
|
||||
// Refresh should keep the same debug context because we use
|
||||
// the original claims for the group/role sync.
|
||||
Claims: link.Claims,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusInternalServerError,
|
||||
Response: codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update LastUsed and session expiry.
|
||||
changed := false
|
||||
if now.Sub(key.LastUsed) > time.Hour {
|
||||
key.LastUsed = now
|
||||
remoteIP := net.ParseIP(r.RemoteAddr)
|
||||
if remoteIP == nil {
|
||||
remoteIP = net.IPv4(0, 0, 0, 0)
|
||||
}
|
||||
bitlen := len(remoteIP) * 8
|
||||
key.IPAddress = pqtype.Inet{
|
||||
IPNet: net.IPNet{
|
||||
IP: remoteIP,
|
||||
Mask: net.CIDRMask(bitlen, bitlen),
|
||||
},
|
||||
Valid: true,
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
if !cfg.DisableSessionExpiryRefresh {
|
||||
apiKeyLifetime := time.Duration(key.LifetimeSeconds) * time.Second
|
||||
if key.ExpiresAt.Sub(now) <= apiKeyLifetime-time.Hour {
|
||||
key.ExpiresAt = now.Add(apiKeyLifetime)
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if changed {
|
||||
//nolint:gocritic // System needs to update API Key LastUsed
|
||||
err := cfg.DB.UpdateAPIKeyByID(dbauthz.AsSystemRestricted(ctx), database.UpdateAPIKeyByIDParams{
|
||||
ID: key.ID,
|
||||
LastUsed: key.LastUsed,
|
||||
ExpiresAt: key.ExpiresAt,
|
||||
IPAddress: key.IPAddress,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusInternalServerError,
|
||||
Response: codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("API key couldn't update: %s.", err.Error()),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:gocritic // system needs to update user last seen at
|
||||
_, err = cfg.DB.UpdateUserLastSeenAt(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLastSeenAtParams{
|
||||
ID: key.UserID,
|
||||
LastSeenAt: dbtime.Now(),
|
||||
UpdatedAt: dbtime.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusInternalServerError,
|
||||
Response: codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("update user last_seen_at: %s", err.Error()),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch user roles.
|
||||
actor, userStatus, err := UserRBACSubject(ctx, cfg.DB, key.UserID, key.ScopeSet())
|
||||
if err != nil {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Internal error fetching user's roles. %s", err.Error()),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
|
||||
return &ValidateAPIKeyResult{
|
||||
Key: *key,
|
||||
Subject: actor,
|
||||
UserStatus: userStatus,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func APIKeyFromRequest(ctx context.Context, db database.Store, sessionTokenFunc func(r *http.Request) string, r *http.Request) (*database.APIKey, codersdk.Response, bool) {
|
||||
tokenFunc := APITokenFromRequest
|
||||
if sessionTokenFunc != nil {
|
||||
@@ -582,60 +240,29 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
// --- Consume prechecked result if available ---
|
||||
// Skip prechecked data when cfg has a custom SessionTokenFunc,
|
||||
// because the precheck used the default token extraction and may
|
||||
// have validated a different token (e.g. workspace app token
|
||||
// issuance in workspaceapps/db.go).
|
||||
var key *database.APIKey
|
||||
var actor rbac.Subject
|
||||
var userStatus database.UserStatus
|
||||
var skipValidation bool
|
||||
|
||||
if cfg.SessionTokenFunc == nil {
|
||||
if pc, ok := ctx.Value(apiKeyPrecheckedContextKey{}).(APIKeyPrechecked); ok {
|
||||
if pc.Err != nil {
|
||||
// Validation failed at the top level (includes
|
||||
// "no token provided").
|
||||
if pc.Err.Hard {
|
||||
return write(pc.Err.Code, pc.Err.Response)
|
||||
}
|
||||
return optionalWrite(pc.Err.Code, pc.Err.Response)
|
||||
}
|
||||
// Valid — use prechecked data, skip to route-specific logic.
|
||||
key = &pc.Result.Key
|
||||
actor = pc.Result.Subject
|
||||
userStatus = pc.Result.UserStatus
|
||||
skipValidation = true
|
||||
}
|
||||
key, resp, ok := APIKeyFromRequest(ctx, cfg.DB, cfg.SessionTokenFunc, r)
|
||||
if !ok {
|
||||
return optionalWrite(http.StatusUnauthorized, resp)
|
||||
}
|
||||
|
||||
if !skipValidation {
|
||||
// Full validation path (no prechecked result or custom token func).
|
||||
result, valErr := ValidateAPIKey(ctx, ValidateAPIKeyConfig{
|
||||
DB: cfg.DB,
|
||||
OAuth2Configs: cfg.OAuth2Configs,
|
||||
DisableSessionExpiryRefresh: cfg.DisableSessionExpiryRefresh,
|
||||
SessionTokenFunc: cfg.SessionTokenFunc,
|
||||
Logger: cfg.Logger,
|
||||
}, r)
|
||||
if valErr != nil {
|
||||
if valErr.Hard {
|
||||
return write(valErr.Code, valErr.Response)
|
||||
}
|
||||
return optionalWrite(valErr.Code, valErr.Response)
|
||||
}
|
||||
key = &result.Key
|
||||
actor = result.Subject
|
||||
userStatus = result.UserStatus
|
||||
// Log the API key ID for all requests that have a valid key format and secret,
|
||||
// regardless of whether subsequent validation (expiry, user status, etc.) succeeds.
|
||||
if rl := loggermw.RequestLoggerFromContext(ctx); rl != nil {
|
||||
rl.WithFields(slog.F("api_key_id", key.ID))
|
||||
}
|
||||
|
||||
// --- Route-specific logic (always runs) ---
|
||||
now := dbtime.Now()
|
||||
if key.ExpiresAt.Before(now) {
|
||||
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
|
||||
})
|
||||
}
|
||||
|
||||
// Validate OAuth2 provider app token audience (RFC 8707) if applicable.
|
||||
// Validate OAuth2 provider app token audience (RFC 8707) if applicable
|
||||
if key.LoginType == database.LoginTypeOAuth2ProviderApp {
|
||||
if err := validateOAuth2ProviderAppTokenAudience(ctx, cfg.DB, *key, cfg.AccessURL, r); err != nil {
|
||||
// Log the detailed error for debugging but don't expose it to the client.
|
||||
// Log the detailed error for debugging but don't expose it to the client
|
||||
cfg.Logger.Debug(ctx, "oauth2 token audience validation failed", slog.Error(err))
|
||||
return optionalWrite(http.StatusForbidden, codersdk.Response{
|
||||
Message: "Token audience validation failed",
|
||||
@@ -643,7 +270,183 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
|
||||
}
|
||||
}
|
||||
|
||||
// Dormant activation (config-dependent).
|
||||
// We only check OIDC stuff if we have a valid APIKey. An expired key means we don't trust the requestor
|
||||
// really is the user whose key they have, and so we shouldn't be doing anything on their behalf including possibly
|
||||
// refreshing the OIDC token.
|
||||
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC {
|
||||
var err error
|
||||
//nolint:gocritic // System needs to fetch UserLink to check if it's valid.
|
||||
link, err := cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: key.UserID,
|
||||
LoginType: key.LoginType,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: "You must re-authenticate with the login provider.",
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "A database error occurred",
|
||||
Detail: fmt.Sprintf("get user link by user ID and login type: %s", err.Error()),
|
||||
})
|
||||
}
|
||||
// Check if the OAuth token is expired
|
||||
if !link.OAuthExpiry.IsZero() && link.OAuthExpiry.Before(now) {
|
||||
if cfg.OAuth2Configs.IsZero() {
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Unable to refresh OAuth token for login type %q. "+
|
||||
"No OAuth2Configs provided. Contact an administrator to configure this login type.", key.LoginType),
|
||||
})
|
||||
}
|
||||
|
||||
var friendlyName string
|
||||
var oauthConfig promoauth.OAuth2Config
|
||||
switch key.LoginType {
|
||||
case database.LoginTypeGithub:
|
||||
oauthConfig = cfg.OAuth2Configs.Github
|
||||
friendlyName = "GitHub"
|
||||
case database.LoginTypeOIDC:
|
||||
oauthConfig = cfg.OAuth2Configs.OIDC
|
||||
friendlyName = "OpenID Connect"
|
||||
default:
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Unexpected authentication type %q.", key.LoginType),
|
||||
})
|
||||
}
|
||||
|
||||
// It's possible for cfg.OAuth2Configs to be non-nil, but still
|
||||
// missing this type. For example, if a user logged in with GitHub,
|
||||
// but the administrator later removed GitHub and replaced it with
|
||||
// OIDC.
|
||||
if oauthConfig == nil {
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Unable to refresh OAuth token for login type %q. "+
|
||||
"OAuth2Config not provided. Contact an administrator to configure this login type.", key.LoginType),
|
||||
})
|
||||
}
|
||||
|
||||
if link.OAuthRefreshToken == "" {
|
||||
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: fmt.Sprintf("%s session expired at %q. Try signing in again.", friendlyName, link.OAuthExpiry.String()),
|
||||
})
|
||||
}
|
||||
// We have a refresh token, so let's try it
|
||||
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
|
||||
AccessToken: link.OAuthAccessToken,
|
||||
RefreshToken: link.OAuthRefreshToken,
|
||||
Expiry: link.OAuthExpiry,
|
||||
}).Token()
|
||||
if err != nil {
|
||||
return write(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: fmt.Sprintf(
|
||||
"Could not refresh expired %s token. Try re-authenticating to resolve this issue.",
|
||||
friendlyName),
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
link.OAuthAccessToken = token.AccessToken
|
||||
link.OAuthRefreshToken = token.RefreshToken
|
||||
link.OAuthExpiry = token.Expiry
|
||||
//nolint:gocritic // system needs to update user link
|
||||
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
|
||||
UserID: link.UserID,
|
||||
LoginType: link.LoginType,
|
||||
OAuthAccessToken: link.OAuthAccessToken,
|
||||
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
OAuthRefreshToken: link.OAuthRefreshToken,
|
||||
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
OAuthExpiry: link.OAuthExpiry,
|
||||
// Refresh should keep the same debug context because we use
|
||||
// the original claims for the group/role sync.
|
||||
Claims: link.Claims,
|
||||
})
|
||||
if err != nil {
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tracks if the API key has properties updated
|
||||
changed := false
|
||||
|
||||
// Only update LastUsed once an hour to prevent database spam.
|
||||
if now.Sub(key.LastUsed) > time.Hour {
|
||||
key.LastUsed = now
|
||||
remoteIP := net.ParseIP(r.RemoteAddr)
|
||||
if remoteIP == nil {
|
||||
remoteIP = net.IPv4(0, 0, 0, 0)
|
||||
}
|
||||
bitlen := len(remoteIP) * 8
|
||||
key.IPAddress = pqtype.Inet{
|
||||
IPNet: net.IPNet{
|
||||
IP: remoteIP,
|
||||
Mask: net.CIDRMask(bitlen, bitlen),
|
||||
},
|
||||
Valid: true,
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
// Only update the ExpiresAt once an hour to prevent database spam.
|
||||
// We extend the ExpiresAt to reduce re-authentication.
|
||||
if !cfg.DisableSessionExpiryRefresh {
|
||||
apiKeyLifetime := time.Duration(key.LifetimeSeconds) * time.Second
|
||||
if key.ExpiresAt.Sub(now) <= apiKeyLifetime-time.Hour {
|
||||
key.ExpiresAt = now.Add(apiKeyLifetime)
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if changed {
|
||||
//nolint:gocritic // System needs to update API Key LastUsed
|
||||
err := cfg.DB.UpdateAPIKeyByID(dbauthz.AsSystemRestricted(ctx), database.UpdateAPIKeyByIDParams{
|
||||
ID: key.ID,
|
||||
LastUsed: key.LastUsed,
|
||||
ExpiresAt: key.ExpiresAt,
|
||||
IPAddress: key.IPAddress,
|
||||
})
|
||||
if err != nil {
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("API key couldn't update: %s.", err.Error()),
|
||||
})
|
||||
}
|
||||
|
||||
// We only want to update this occasionally to reduce DB write
|
||||
// load. We update alongside the UserLink and APIKey since it's
|
||||
// easier on the DB to colocate writes.
|
||||
//nolint:gocritic // system needs to update user last seen at
|
||||
_, err = cfg.DB.UpdateUserLastSeenAt(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLastSeenAtParams{
|
||||
ID: key.UserID,
|
||||
LastSeenAt: dbtime.Now(),
|
||||
UpdatedAt: dbtime.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("update user last_seen_at: %s", err.Error()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// If the key is valid, we also fetch the user roles and status.
|
||||
// The roles are used for RBAC authorize checks, and the status
|
||||
// is to block 'suspended' users from accessing the platform.
|
||||
actor, userStatus, err := UserRBACSubject(ctx, cfg.DB, key.UserID, key.ScopeSet())
|
||||
if err != nil {
|
||||
return write(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Internal error fetching user's roles. %s", err.Error()),
|
||||
})
|
||||
}
|
||||
|
||||
if userStatus == database.UserStatusDormant && cfg.ActivateDormantUser != nil {
|
||||
id, _ := uuid.Parse(actor.ID)
|
||||
user, err := cfg.ActivateDormantUser(ctx, database.User{
|
||||
|
||||
+15
-36
@@ -32,56 +32,35 @@ func RateLimit(count int, window time.Duration) func(http.Handler) http.Handler
|
||||
count,
|
||||
window,
|
||||
httprate.WithKeyFuncs(func(r *http.Request) (string, error) {
|
||||
// Identify the caller. We check two sources:
|
||||
//
|
||||
// 1. apiKeyPrecheckedContextKey — set by PrecheckAPIKey
|
||||
// at the root of the router. Only fully validated
|
||||
// keys are used.
|
||||
// 2. apiKeyContextKey — set by ExtractAPIKeyMW if it
|
||||
// has already run (e.g. unit tests, workspace-app
|
||||
// routes that don't go through PrecheckAPIKey).
|
||||
//
|
||||
// If neither is present, fall back to IP.
|
||||
var userID string
|
||||
var subject *rbac.Subject
|
||||
|
||||
if pc, ok := r.Context().Value(apiKeyPrecheckedContextKey{}).(APIKeyPrechecked); ok && pc.Result != nil {
|
||||
userID = pc.Result.Key.UserID.String()
|
||||
subject = &pc.Result.Subject
|
||||
} else if ak, ok := r.Context().Value(apiKeyContextKey{}).(database.APIKey); ok {
|
||||
userID = ak.UserID.String()
|
||||
if auth, ok := UserAuthorizationOptional(r.Context()); ok {
|
||||
subject = &auth
|
||||
}
|
||||
} else {
|
||||
// Prioritize by user, but fallback to IP.
|
||||
apiKey, ok := r.Context().Value(apiKeyContextKey{}).(database.APIKey)
|
||||
if !ok {
|
||||
return httprate.KeyByIP(r)
|
||||
}
|
||||
|
||||
if ok, _ := strconv.ParseBool(r.Header.Get(codersdk.BypassRatelimitHeader)); !ok {
|
||||
// No bypass attempt, just rate limit by user.
|
||||
return userID, nil
|
||||
// No bypass attempt, just ratelimit.
|
||||
return apiKey.UserID.String(), nil
|
||||
}
|
||||
|
||||
// Allow Owner to bypass rate limiting for load tests
|
||||
// and automation. We avoid using rbac.Authorizer since
|
||||
// rego is CPU-intensive and undermines the
|
||||
// DoS-prevention goal of the rate limiter.
|
||||
if subject == nil {
|
||||
// Can't verify roles — rate limit normally.
|
||||
return userID, nil
|
||||
}
|
||||
for _, role := range subject.SafeRoleNames() {
|
||||
// and automation.
|
||||
auth := UserAuthorization(r.Context())
|
||||
|
||||
// We avoid using rbac.Authorizer since rego is CPU-intensive
|
||||
// and undermines the DoS-prevention goal of the rate limiter.
|
||||
for _, role := range auth.SafeRoleNames() {
|
||||
if role == rbac.RoleOwner() {
|
||||
// HACK: use a random key each time to
|
||||
// de facto disable rate limiting. The
|
||||
// httprate package has no support for
|
||||
// selectively changing the limit for
|
||||
// particular keys.
|
||||
// `httprate` package has no
|
||||
// support for selectively changing the limit
|
||||
// for particular keys.
|
||||
return cryptorand.String(16)
|
||||
}
|
||||
}
|
||||
|
||||
return userID, xerrors.Errorf(
|
||||
return apiKey.UserID.String(), xerrors.Errorf(
|
||||
"%q provided but user is not %v",
|
||||
codersdk.BypassRatelimitHeader, rbac.RoleOwner(),
|
||||
)
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/idpsync"
|
||||
@@ -356,7 +357,7 @@ func TestGroupSyncTable(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
defOrg, err := db.GetDefaultOrganization(ctx)
|
||||
defOrg, err := db.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx))
|
||||
require.NoError(t, err)
|
||||
SetupOrganization(t, s, db, user, defOrg.ID, def)
|
||||
asserts = append(asserts, func(t *testing.T) {
|
||||
@@ -554,6 +555,7 @@ func TestApplyGroupDifference(t *testing.T) {
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
ctx = dbauthz.AsSystemRestricted(ctx)
|
||||
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
_, err := db.InsertAllUsersGroup(ctx, org.ID)
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
@@ -272,7 +273,7 @@ func TestRoleSyncTable(t *testing.T) {
|
||||
}
|
||||
|
||||
// Also assert site wide roles
|
||||
allRoles, err := db.GetAuthorizationUserRoles(ctx, user.ID)
|
||||
allRoles, err := db.GetAuthorizationUserRoles(dbauthz.AsSystemRestricted(ctx), user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
allRoleIDs, err := allRoles.RoleNames()
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/notifications"
|
||||
@@ -29,6 +30,7 @@ func TestBufferedUpdates(t *testing.T) {
|
||||
|
||||
// setup
|
||||
|
||||
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitSuperLong))
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
logger := testutil.Logger(t)
|
||||
|
||||
@@ -55,7 +57,6 @@ func TestBufferedUpdates(t *testing.T) {
|
||||
user := dbgen.User(t, store, database.User{})
|
||||
|
||||
// WHEN: notifications are enqueued which should succeed and fail
|
||||
ctx := testutil.Context(t, testutil.WaitSuperLong)
|
||||
_, err = enq.Enqueue(ctx, user.ID, notifications.TemplateWorkspaceDeleted, map[string]string{"nice": "true", "i": "0"}, "") // Will succeed.
|
||||
require.NoError(t, err)
|
||||
_, err = enq.Enqueue(ctx, user.ID, notifications.TemplateWorkspaceDeleted, map[string]string{"nice": "true", "i": "1"}, "") // Will succeed.
|
||||
@@ -105,6 +106,7 @@ func TestBuildPayload(t *testing.T) {
|
||||
|
||||
// SETUP
|
||||
|
||||
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitSuperLong))
|
||||
store, _ := dbtestutil.NewDB(t)
|
||||
logger := testutil.Logger(t)
|
||||
|
||||
@@ -144,7 +146,6 @@ func TestBuildPayload(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// WHEN: a notification is enqueued
|
||||
ctx := testutil.Context(t, testutil.WaitSuperLong)
|
||||
_, err = enq.Enqueue(ctx, uuid.New(), notifications.TemplateWorkspaceDeleted, map[string]string{
|
||||
"name": "my-workspace",
|
||||
}, "test")
|
||||
@@ -162,6 +163,7 @@ func TestStopBeforeRun(t *testing.T) {
|
||||
|
||||
// SETUP
|
||||
|
||||
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitSuperLong))
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
logger := testutil.Logger(t)
|
||||
|
||||
@@ -170,7 +172,6 @@ func TestStopBeforeRun(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// THEN: validate that the manager can be stopped safely without Run() having been called yet
|
||||
ctx := testutil.Context(t, testutil.WaitSuperLong)
|
||||
require.Eventually(t, func() bool {
|
||||
assert.NoError(t, mgr.Stop(ctx))
|
||||
return true
|
||||
@@ -182,6 +183,7 @@ func TestRunStopRace(t *testing.T) {
|
||||
|
||||
// SETUP
|
||||
|
||||
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitMedium))
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
logger := testutil.Logger(t)
|
||||
|
||||
@@ -192,7 +194,6 @@ func TestRunStopRace(t *testing.T) {
|
||||
// Start Run and Stop after each other (run does "go loop()").
|
||||
// This is to catch a (now fixed) race condition where the manager
|
||||
// would be accessed/stopped while it was being created/starting up.
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
mgr.Run(ctx)
|
||||
err = mgr.Stop(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/notifications"
|
||||
"github.com/coder/coder/v2/coderd/notifications/dispatch"
|
||||
@@ -32,6 +33,7 @@ func TestMetrics(t *testing.T) {
|
||||
|
||||
// SETUP
|
||||
|
||||
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitSuperLong))
|
||||
store, pubsub := dbtestutil.NewDB(t)
|
||||
logger := testutil.Logger(t)
|
||||
|
||||
@@ -55,7 +57,6 @@ func TestMetrics(t *testing.T) {
|
||||
|
||||
mgr, err := notifications.NewManager(cfg, store, pubsub, defaultHelpers(), metrics, logger.Named("manager"))
|
||||
require.NoError(t, err)
|
||||
ctx := testutil.Context(t, testutil.WaitSuperLong)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, mgr.Stop(ctx))
|
||||
})
|
||||
@@ -220,6 +221,7 @@ func TestPendingUpdatesMetric(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// SETUP
|
||||
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitSuperLong))
|
||||
store, pubsub := dbtestutil.NewDB(t)
|
||||
logger := testutil.Logger(t)
|
||||
|
||||
@@ -245,7 +247,6 @@ func TestPendingUpdatesMetric(t *testing.T) {
|
||||
mgr, err := notifications.NewManager(cfg, interceptor, pubsub, defaultHelpers(), metrics, logger.Named("manager"),
|
||||
notifications.WithTestClock(mClock))
|
||||
require.NoError(t, err)
|
||||
ctx := testutil.Context(t, testutil.WaitSuperLong)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, mgr.Stop(ctx))
|
||||
})
|
||||
@@ -313,6 +314,7 @@ func TestInflightDispatchesMetric(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// SETUP
|
||||
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitSuperLong))
|
||||
store, pubsub := dbtestutil.NewDB(t)
|
||||
logger := testutil.Logger(t)
|
||||
|
||||
@@ -331,7 +333,6 @@ func TestInflightDispatchesMetric(t *testing.T) {
|
||||
|
||||
mgr, err := notifications.NewManager(cfg, store, pubsub, defaultHelpers(), metrics, logger.Named("manager"))
|
||||
require.NoError(t, err)
|
||||
ctx := testutil.Context(t, testutil.WaitSuperLong)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, mgr.Stop(ctx))
|
||||
})
|
||||
@@ -385,6 +386,7 @@ func TestInflightDispatchesMetric(t *testing.T) {
|
||||
|
||||
func TestCustomMethodMetricCollection(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitSuperLong))
|
||||
store, pubsub := dbtestutil.NewDB(t)
|
||||
logger := testutil.Logger(t)
|
||||
|
||||
@@ -400,8 +402,6 @@ func TestCustomMethodMetricCollection(t *testing.T) {
|
||||
defaultMethod = database.NotificationMethodSmtp
|
||||
)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitSuperLong)
|
||||
|
||||
// GIVEN: a template whose notification method differs from the default.
|
||||
out, err := store.UpdateNotificationTemplateMethodByID(ctx, database.UpdateNotificationTemplateMethodByIDParams{
|
||||
ID: tmpl,
|
||||
|
||||
@@ -1472,12 +1472,12 @@ func TestNotificationTemplates_Golden(t *testing.T) {
|
||||
// as appearance changes are enterprise features and we do not want to mix those
|
||||
// can't use the api
|
||||
if tc.appName != "" {
|
||||
err = (*db).UpsertApplicationName(ctx, "Custom Application")
|
||||
err = (*db).UpsertApplicationName(dbauthz.AsSystemRestricted(ctx), "Custom Application")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
if tc.logoURL != "" {
|
||||
err = (*db).UpsertLogoURL(ctx, "https://custom.application/logo.png")
|
||||
err = (*db).UpsertLogoURL(dbauthz.AsSystemRestricted(ctx), "https://custom.application/logo.png")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
|
||||
@@ -516,7 +516,7 @@ func TestOAuth2ProviderTokenExchange(t *testing.T) {
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token.AccessToken)
|
||||
require.True(t, dbtime.Now().Before(token.Expiry))
|
||||
require.True(t, time.Now().Before(token.Expiry))
|
||||
|
||||
// Check that the token works.
|
||||
newClient := codersdk.New(userClient.URL)
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/prometheusmetrics/insights"
|
||||
@@ -126,7 +127,7 @@ func TestCollectInsights(t *testing.T) {
|
||||
AppStatBatchSize: workspaceapps.DefaultStatsDBReporterBatchSize,
|
||||
})
|
||||
refTime := time.Now().Add(-3 * time.Minute).Truncate(time.Minute)
|
||||
err = reporter.ReportAppStats(context.Background(), []workspaceapps.StatsReport{
|
||||
err = reporter.ReportAppStats(dbauthz.AsSystemRestricted(context.Background()), []workspaceapps.StatsReport{
|
||||
{
|
||||
UserID: user.ID,
|
||||
WorkspaceID: workspace1.ID,
|
||||
|
||||
@@ -564,7 +564,7 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
|
||||
// The check `s.OIDCConfig != nil` is not as strict, since it can be an interface
|
||||
// pointing to a typed nil.
|
||||
if !reflect.ValueOf(s.OIDCConfig).IsNil() {
|
||||
workspaceOwnerOIDCAccessToken, err = ObtainOIDCAccessToken(ctx, s.Logger, s.Database, s.OIDCConfig, owner.ID)
|
||||
workspaceOwnerOIDCAccessToken, err = obtainOIDCAccessToken(ctx, s.Logger, s.Database, s.OIDCConfig, owner.ID)
|
||||
if err != nil {
|
||||
return nil, failJob(fmt.Sprintf("obtain OIDC access token: %s", err))
|
||||
}
|
||||
@@ -3075,15 +3075,15 @@ func deleteSessionTokenForUserAndWorkspace(ctx context.Context, db database.Stor
|
||||
return nil
|
||||
}
|
||||
|
||||
func shouldRefreshOIDCToken(link database.UserLink) (bool, time.Time) {
|
||||
func shouldRefreshOIDCToken(link database.UserLink) bool {
|
||||
if link.OAuthRefreshToken == "" {
|
||||
// We cannot refresh even if we wanted to
|
||||
return false, link.OAuthExpiry
|
||||
return false
|
||||
}
|
||||
|
||||
if link.OAuthExpiry.IsZero() {
|
||||
// 0 expire means the token never expires, so we shouldn't refresh
|
||||
return false, link.OAuthExpiry
|
||||
return false
|
||||
}
|
||||
|
||||
// This handles an edge case where the token is about to expire. A workspace
|
||||
@@ -3093,19 +3093,15 @@ func shouldRefreshOIDCToken(link database.UserLink) (bool, time.Time) {
|
||||
//
|
||||
// If an OIDC provider issues short-lived tokens less than our defined period,
|
||||
// the token will always be refreshed on every workspace build.
|
||||
//
|
||||
// By setting the expiration backwards, we are effectively shortening the
|
||||
// time a token can be alive for by 10 minutes.
|
||||
// Note: This is how it is done in the oauth2 package's own token refreshing logic.
|
||||
expiresAt := link.OAuthExpiry.Add(-time.Minute * 10)
|
||||
assumeExpiredAt := dbtime.Now().Add(-1 * time.Minute * 10)
|
||||
|
||||
// Return if the token is assumed to be expired.
|
||||
return expiresAt.Before(dbtime.Now()), expiresAt
|
||||
return link.OAuthExpiry.Before(assumeExpiredAt)
|
||||
}
|
||||
|
||||
// ObtainOIDCAccessToken returns a valid OpenID Connect access token
|
||||
// obtainOIDCAccessToken returns a valid OpenID Connect access token
|
||||
// for the user if it's able to obtain one, otherwise it returns an empty string.
|
||||
func ObtainOIDCAccessToken(ctx context.Context, logger slog.Logger, db database.Store, oidcConfig promoauth.OAuth2Config, userID uuid.UUID) (string, error) {
|
||||
func obtainOIDCAccessToken(ctx context.Context, logger slog.Logger, db database.Store, oidcConfig promoauth.OAuth2Config, userID uuid.UUID) (string, error) {
|
||||
link, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: userID,
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
@@ -3117,13 +3113,11 @@ func ObtainOIDCAccessToken(ctx context.Context, logger slog.Logger, db database.
|
||||
return "", xerrors.Errorf("get owner oidc link: %w", err)
|
||||
}
|
||||
|
||||
if shouldRefresh, expiresAt := shouldRefreshOIDCToken(link); shouldRefresh {
|
||||
if shouldRefreshOIDCToken(link) {
|
||||
token, err := oidcConfig.TokenSource(ctx, &oauth2.Token{
|
||||
AccessToken: link.OAuthAccessToken,
|
||||
RefreshToken: link.OAuthRefreshToken,
|
||||
// Use the expiresAt returned by shouldRefreshOIDCToken.
|
||||
// It will force a refresh with an expired time.
|
||||
Expiry: expiresAt,
|
||||
Expiry: link.OAuthExpiry,
|
||||
}).Token()
|
||||
if err != nil {
|
||||
// If OIDC fails to refresh, we return an empty string and don't fail.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user