Compare commits
62 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f661c435c8 | |||
| e7ea649dc2 | |||
| 56960585af | |||
| f07e266904 | |||
| 9bc884d597 | |||
| f46692531f | |||
| 6e9e39a4e0 | |||
| 1a2eea5e76 | |||
| 9e7125f852 | |||
| e6983648aa | |||
| 47846c0ee4 | |||
| ff715c9f4c | |||
| f4ab854b06 | |||
| c6b68b2991 | |||
| 5dfd563e4b | |||
| 4957888270 | |||
| 26adc26a26 | |||
| b33b8e476b | |||
| 95bd099c77 | |||
| 3f939375fa | |||
| a072d542a5 | |||
| a96ec4c397 | |||
| 2eb3ab4cf5 | |||
| 51a627c107 | |||
| 49006685b0 | |||
| 715486465b | |||
| e205a3493d | |||
| 6b14a3eb7f | |||
| 0fea47d97c | |||
| 02b1951aac | |||
| dd34e3d3c2 | |||
| a48e4a43e2 | |||
| 5b7ba93cb2 | |||
| aba3832b15 | |||
| ca873060c6 | |||
| 896c43d5b7 | |||
| 2ad0e74e67 | |||
| 6509fb2574 | |||
| 667d501282 | |||
| 69a4a8825d | |||
| 4b3ed61210 | |||
| c01430e53b | |||
| 7d6fde35bd | |||
| 77ca772552 | |||
| 703629f5e9 | |||
| 4cf8d4414e | |||
| 3608064600 | |||
| 4e50ca6b6e | |||
| 4c83a7021f | |||
| b9c729457b | |||
| 9bd712013f | |||
| 8c52e150f6 | |||
| f404463317 | |||
| 338d30e4c4 | |||
| 4afdfc50a5 | |||
| b199ef1b69 | |||
| eecb7d0b66 | |||
| 2cd871e88f | |||
| b9b3c67c73 | |||
| 09aa7b1887 | |||
| 5712faaa2c | |||
| a104d608a3 |
@@ -189,8 +189,8 @@ func (q *sqlQuerier) UpdateUser(ctx context.Context, arg UpdateUserParams) (User
|
||||
### Common Debug Commands
|
||||
|
||||
```bash
|
||||
# Check database connection
|
||||
make test-postgres
|
||||
# Run tests (starts Postgres automatically if needed)
|
||||
make test
|
||||
|
||||
# Run specific database tests
|
||||
go test ./coderd/database/... -run TestSpecificFunction
|
||||
|
||||
@@ -67,7 +67,6 @@ coderd/
|
||||
| `make test` | Run all Go tests |
|
||||
| `make test RUN=TestFunctionName` | Run specific test |
|
||||
| `go test -v ./path/to/package -run TestFunctionName` | Run test with verbose output |
|
||||
| `make test-postgres` | Run tests with Postgres database |
|
||||
| `make test-race` | Run tests with Go race detector |
|
||||
| `make test-e2e` | Run end-to-end tests |
|
||||
|
||||
|
||||
@@ -109,7 +109,6 @@
|
||||
|
||||
- Run full test suite: `make test`
|
||||
- Run specific test: `make test RUN=TestFunctionName`
|
||||
- Run with Postgres: `make test-postgres`
|
||||
- Run with race detector: `make test-race`
|
||||
- Run end-to-end tests: `make test-e2e`
|
||||
|
||||
|
||||
@@ -70,11 +70,7 @@ runs:
|
||||
set -euo pipefail
|
||||
|
||||
if [[ ${RACE_DETECTION} == true ]]; then
|
||||
gotestsum --junitfile="gotests.xml" --packages="${TEST_PACKAGES}" -- \
|
||||
-tags=testsmallbatch \
|
||||
-race \
|
||||
-parallel "${TEST_NUM_PARALLEL_TESTS}" \
|
||||
-p "${TEST_NUM_PARALLEL_PACKAGES}"
|
||||
make test-race
|
||||
else
|
||||
make test
|
||||
fi
|
||||
|
||||
@@ -366,9 +366,9 @@ jobs:
|
||||
needs: changes
|
||||
if: needs.changes.outputs.go == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
|
||||
# This timeout must be greater than the timeout set by `go test` in
|
||||
# `make test-postgres` to ensure we receive a trace of running
|
||||
# goroutines. Setting this to the timeout +5m should work quite well
|
||||
# even if some of the preceding steps are slow.
|
||||
# `make test` to ensure we receive a trace of running goroutines.
|
||||
# Setting this to the timeout +5m should work quite well even if
|
||||
# some of the preceding steps are slow.
|
||||
timeout-minutes: 25
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -475,11 +475,6 @@ jobs:
|
||||
mkdir -p /tmp/tmpfs
|
||||
sudo mount_tmpfs -o noowners -s 8g /tmp/tmpfs
|
||||
|
||||
# Install google-chrome for scaletests.
|
||||
# As another concern, should we really have this kind of external dependency
|
||||
# requirement on standard CI?
|
||||
brew install google-chrome
|
||||
|
||||
# macOS will output "The default interactive shell is now zsh" intermittently in CI.
|
||||
touch ~/.bash_profile && echo "export BASH_SILENCE_DEPRECATION_WARNING=1" >> ~/.bash_profile
|
||||
|
||||
@@ -574,9 +569,9 @@ jobs:
|
||||
- changes
|
||||
if: needs.changes.outputs.go == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
|
||||
# This timeout must be greater than the timeout set by `go test` in
|
||||
# `make test-postgres` to ensure we receive a trace of running
|
||||
# goroutines. Setting this to the timeout +5m should work quite well
|
||||
# even if some of the preceding steps are slow.
|
||||
# `make test` to ensure we receive a trace of running goroutines.
|
||||
# Setting this to the timeout +5m should work quite well even if
|
||||
# some of the preceding steps are slow.
|
||||
timeout-minutes: 25
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
@@ -986,6 +981,9 @@ jobs:
|
||||
run: |
|
||||
make build/coder_docs_"$(./scripts/version.sh)".tgz
|
||||
|
||||
- name: Check for unstaged files
|
||||
run: ./scripts/check_unstaged.sh
|
||||
|
||||
required:
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
name: Linear Release
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
# This event reads the workflow from the default branch (main), not the
|
||||
# release branch. No cherry-pick needed.
|
||||
# https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#release
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
sync:
|
||||
name: Sync issues to Linear release
|
||||
if: github.event_name == 'push'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Sync issues
|
||||
id: sync
|
||||
uses: linear/linear-release-action@f64cdc603e6eb7a7ef934bc5492ae929f88c8d1a # v0
|
||||
with:
|
||||
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
|
||||
command: sync
|
||||
|
||||
- name: Print release URL
|
||||
if: steps.sync.outputs.release-url
|
||||
run: echo "Synced to $RELEASE_URL"
|
||||
env:
|
||||
RELEASE_URL: ${{ steps.sync.outputs.release-url }}
|
||||
|
||||
complete:
|
||||
name: Complete Linear release
|
||||
if: github.event_name == 'release'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Complete release
|
||||
id: complete
|
||||
uses: linear/linear-release-action@f64cdc603e6eb7a7ef934bc5492ae929f88c8d1a # v0
|
||||
with:
|
||||
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
|
||||
command: complete
|
||||
version: ${{ github.event.release.tag_name }}
|
||||
|
||||
- name: Print release URL
|
||||
if: steps.complete.outputs.release-url
|
||||
run: echo "Completed $RELEASE_URL"
|
||||
env:
|
||||
RELEASE_URL: ${{ steps.complete.outputs.release-url }}
|
||||
@@ -16,9 +16,9 @@ jobs:
|
||||
# when changing runner sizes
|
||||
runs-on: ${{ matrix.os == 'macos-latest' && github.repository_owner == 'coder' && 'depot-macos-latest' || matrix.os == 'windows-2022' && github.repository_owner == 'coder' && 'depot-windows-2022-16' || matrix.os }}
|
||||
# This timeout must be greater than the timeout set by `go test` in
|
||||
# `make test-postgres` to ensure we receive a trace of running
|
||||
# goroutines. Setting this to the timeout +5m should work quite well
|
||||
# even if some of the preceding steps are slow.
|
||||
# `make test` to ensure we receive a trace of running goroutines.
|
||||
# Setting this to the timeout +5m should work quite well even if
|
||||
# some of the preceding steps are slow.
|
||||
timeout-minutes: 25
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
@@ -37,21 +37,20 @@ Only pause to ask for confirmation when:
|
||||
|
||||
## Essential Commands
|
||||
|
||||
| Task | Command | Notes |
|
||||
|-------------------|--------------------------|-------------------------------------|
|
||||
| **Development** | `./scripts/develop.sh` | ⚠️ Don't use manual build |
|
||||
| **Build** | `make build` | Fat binaries (includes server) |
|
||||
| **Build Slim** | `make build-slim` | Slim binaries |
|
||||
| **Test** | `make test` | Full test suite |
|
||||
| **Test Single** | `make test RUN=TestName` | Faster than full suite |
|
||||
| **Test Postgres** | `make test-postgres` | Run tests with Postgres database |
|
||||
| **Test Race** | `make test-race` | Run tests with Go race detector |
|
||||
| **Lint** | `make lint` | Always run after changes |
|
||||
| **Generate** | `make gen` | After database changes |
|
||||
| **Format** | `make fmt` | Auto-format code |
|
||||
| **Clean** | `make clean` | Clean build artifacts |
|
||||
| **Pre-commit** | `make pre-commit` | Fast CI checks (gen/fmt/lint/build) |
|
||||
| **Pre-push** | `make pre-push` | All CI checks including tests |
|
||||
| Task | Command | Notes |
|
||||
|-----------------|--------------------------|-------------------------------------|
|
||||
| **Development** | `./scripts/develop.sh` | ⚠️ Don't use manual build |
|
||||
| **Build** | `make build` | Fat binaries (includes server) |
|
||||
| **Build Slim** | `make build-slim` | Slim binaries |
|
||||
| **Test** | `make test` | Full test suite |
|
||||
| **Test Single** | `make test RUN=TestName` | Faster than full suite |
|
||||
| **Test Race** | `make test-race` | Run tests with Go race detector |
|
||||
| **Lint** | `make lint` | Always run after changes |
|
||||
| **Generate** | `make gen` | After database changes |
|
||||
| **Format** | `make fmt` | Auto-format code |
|
||||
| **Clean** | `make clean` | Clean build artifacts |
|
||||
| **Pre-commit** | `make pre-commit` | Fast CI checks (gen/fmt/lint/build) |
|
||||
| **Pre-push** | `make pre-push` | All CI checks including tests |
|
||||
|
||||
### Documentation Commands
|
||||
|
||||
@@ -105,22 +104,37 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID)
|
||||
|
||||
### Full workflows available in imported WORKFLOWS.md
|
||||
|
||||
### Git Hooks (MANDATORY)
|
||||
### Git Hooks (MANDATORY - DO NOT SKIP)
|
||||
|
||||
Before your first commit, ensure the git hooks are installed.
|
||||
Two hooks run automatically:
|
||||
**You MUST install and use the git hooks. NEVER bypass them with
|
||||
`--no-verify`. Skipping hooks wastes CI cycles and is unacceptable.**
|
||||
|
||||
- **pre-commit**: `make pre-commit` (gen, fmt, lint, typos, build).
|
||||
Fast checks that catch most CI failures.
|
||||
- **pre-push**: `make pre-push` (full CI suite including tests).
|
||||
Runs before pushing to catch everything CI would.
|
||||
|
||||
Wait for them to complete, do not skip or bypass them.
|
||||
The first run will be slow as caches warm up. Consecutive runs are
|
||||
**significantly faster** (often 10x) thanks to Go build cache,
|
||||
generated file timestamps, and warm node_modules. This is NOT a
|
||||
reason to skip them. Wait for hooks to complete before proceeding,
|
||||
no matter how long they take.
|
||||
|
||||
```sh
|
||||
git config core.hooksPath scripts/githooks
|
||||
```
|
||||
|
||||
Two hooks run automatically:
|
||||
|
||||
- **pre-commit**: `make pre-commit` (gen, fmt, lint, typos, build).
|
||||
Fast checks that catch most CI failures. Allow at least 5 minutes.
|
||||
- **pre-push**: `make pre-push` (full CI suite including tests).
|
||||
Runs before pushing to catch everything CI would. Allow at least
|
||||
15 minutes (race tests are slow without cache).
|
||||
|
||||
`git commit` and `git push` will appear to hang while hooks run.
|
||||
This is normal. Do not interrupt, retry, or reduce the timeout.
|
||||
|
||||
NEVER run `git config core.hooksPath` to change or disable hooks.
|
||||
|
||||
If a hook fails, fix the issue and retry. Do not work around the
|
||||
failure by skipping the hook.
|
||||
|
||||
### Git Workflow
|
||||
|
||||
When working on existing PRs, check out the branch first:
|
||||
|
||||
@@ -19,6 +19,16 @@ SHELL := bash
|
||||
.SHELLFLAGS := -ceu
|
||||
.ONESHELL:
|
||||
|
||||
# When MAKE_TIMED=1, replace SHELL with a wrapper that prints
|
||||
# elapsed wall-clock time for each recipe. pre-commit and pre-push
|
||||
# set this on their sub-makes so every parallel job reports its
|
||||
# duration. Ad-hoc usage: make MAKE_TIMED=1 test
|
||||
ifdef MAKE_TIMED
|
||||
SHELL := $(CURDIR)/scripts/lib/timed-shell.sh
|
||||
.SHELLFLAGS = $@ -ceu
|
||||
export MAKE_TIMED
|
||||
endif
|
||||
|
||||
# This doesn't work on directories.
|
||||
# See https://stackoverflow.com/questions/25752543/make-delete-on-error-for-directory-targets
|
||||
.DELETE_ON_ERROR:
|
||||
@@ -103,6 +113,11 @@ VERSION := $(shell ./scripts/version.sh)
|
||||
POSTGRES_VERSION ?= 17
|
||||
POSTGRES_IMAGE ?= us-docker.pkg.dev/coder-v2-images-public/public/postgres:$(POSTGRES_VERSION)
|
||||
|
||||
# Limit parallel Make jobs in pre-commit/pre-push. Defaults to
|
||||
# nproc/4 (min 2) since test and lint targets have internal
|
||||
# parallelism. Override: make pre-push PARALLEL_JOBS=8
|
||||
PARALLEL_JOBS ?= $(shell n=$$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 8); echo $$(( n / 4 > 2 ? n / 4 : 2 )))
|
||||
|
||||
# Use the highest ZSTD compression level in CI.
|
||||
ifdef CI
|
||||
ZSTDFLAGS := -22 --ultra
|
||||
@@ -706,9 +721,11 @@ lint/typos: build/typos-$(TYPOS_VERSION)
|
||||
# pre-push runs the full CI suite including tests. This is the git
|
||||
# pre-push hook default, catching everything CI would before pushing.
|
||||
#
|
||||
# Both use two-phase execution: gen+fmt first (writes files), then
|
||||
# lint+build (reads files). This avoids races where gen's `go run`
|
||||
# creates temporary .go files that lint's find-based checks pick up.
|
||||
# pre-push uses two-phase execution: gen+fmt+test-postgres-docker
|
||||
# first (writes files, starts Docker), then lint+build+test in
|
||||
# parallel. pre-commit uses two phases: gen+fmt first, then
|
||||
# lint+build. This avoids races where gen's `go run` creates
|
||||
# temporary .go files that lint's find-based checks pick up.
|
||||
# Within each phase, targets run in parallel via -j. Both fail if
|
||||
# any tracked files have unstaged changes afterward.
|
||||
#
|
||||
@@ -717,7 +734,7 @@ lint/typos: build/typos-$(TYPOS_VERSION)
|
||||
#
|
||||
# pre-push only (need external services or are slow):
|
||||
# site/out/index.html (pnpm build)
|
||||
# test-postgres (needs Docker)
|
||||
# test-postgres-docker + test (needs Docker)
|
||||
# test-js, test-e2e (needs Playwright)
|
||||
# sqlc-vet (needs Docker)
|
||||
# offlinedocs/check
|
||||
@@ -744,30 +761,38 @@ define check-unstaged
|
||||
endef
|
||||
|
||||
pre-commit:
|
||||
$(MAKE) -j --output-sync=target gen fmt
|
||||
start=$$(date +%s)
|
||||
echo "=== Phase 1/2: gen + fmt ==="
|
||||
$(MAKE) -j$(PARALLEL_JOBS) --output-sync=target MAKE_TIMED=1 gen fmt
|
||||
$(check-unstaged)
|
||||
$(MAKE) -j --output-sync=target \
|
||||
echo "=== Phase 2/2: lint + build ==="
|
||||
$(MAKE) -j$(PARALLEL_JOBS) --output-sync=target MAKE_TIMED=1 \
|
||||
lint \
|
||||
lint/typos \
|
||||
build/coder-slim_$(GOOS)_$(GOARCH)$(GOOS_BIN_EXT)
|
||||
$(check-unstaged)
|
||||
echo "$(BOLD)$(GREEN)=== pre-commit passed in $$(( $$(date +%s) - $$start ))s ===$(RESET)"
|
||||
.PHONY: pre-commit
|
||||
|
||||
pre-push:
|
||||
$(MAKE) -j --output-sync=target gen fmt
|
||||
start=$$(date +%s)
|
||||
echo "=== Phase 1/2: gen + fmt + postgres ==="
|
||||
$(MAKE) -j$(PARALLEL_JOBS) --output-sync=target MAKE_TIMED=1 gen fmt test-postgres-docker
|
||||
$(check-unstaged)
|
||||
$(MAKE) -j --output-sync=target \
|
||||
echo "=== Phase 2/2: lint + build + test ==="
|
||||
$(MAKE) -j$(PARALLEL_JOBS) --output-sync=target MAKE_TIMED=1 \
|
||||
lint \
|
||||
lint/typos \
|
||||
build/coder-slim_$(GOOS)_$(GOARCH)$(GOOS_BIN_EXT) \
|
||||
site/out/index.html \
|
||||
test-postgres \
|
||||
test \
|
||||
test-js \
|
||||
test-e2e \
|
||||
test-race \
|
||||
sqlc-vet \
|
||||
offlinedocs/check
|
||||
$(check-unstaged)
|
||||
echo "$(BOLD)$(GREEN)=== pre-push passed in $$(( $$(date +%s) - $$start ))s ===$(RESET)"
|
||||
.PHONY: pre-push
|
||||
|
||||
offlinedocs/check: offlinedocs/node_modules/.installed
|
||||
@@ -1230,10 +1255,22 @@ else
|
||||
GOTESTSUM_RETRY_FLAGS :=
|
||||
endif
|
||||
|
||||
# default to 8x8 parallelism to avoid overwhelming our workspaces. Hopefully we can remove these defaults
|
||||
# when we get our test suite's resource utilization under control.
|
||||
# Use testsmallbatch tag to reduce wireguard memory allocation in tests (from ~18GB to negligible).
|
||||
GOTEST_FLAGS := -tags=testsmallbatch -v -p $(or $(TEST_NUM_PARALLEL_PACKAGES),"8") -parallel=$(or $(TEST_NUM_PARALLEL_TESTS),"8")
|
||||
# Default to 8x8 parallelism to avoid overwhelming our workspaces.
|
||||
# Race detection defaults to 4x4 because the detector adds significant
|
||||
# CPU overhead. Override via TEST_NUM_PARALLEL_PACKAGES /
|
||||
# TEST_NUM_PARALLEL_TESTS.
|
||||
TEST_PARALLEL_PACKAGES := $(or $(TEST_NUM_PARALLEL_PACKAGES),8)
|
||||
TEST_PARALLEL_TESTS := $(or $(TEST_NUM_PARALLEL_TESTS),8)
|
||||
RACE_PARALLEL_PACKAGES := $(or $(TEST_NUM_PARALLEL_PACKAGES),4)
|
||||
RACE_PARALLEL_TESTS := $(or $(TEST_NUM_PARALLEL_TESTS),4)
|
||||
|
||||
# Use testsmallbatch tag to reduce wireguard memory allocation in tests
|
||||
# (from ~18GB to negligible). Recursively expanded so target-specific
|
||||
# overrides of TEST_PARALLEL_* take effect (e.g. test-race lowers
|
||||
# parallelism). CI job timeout is 25m (see test-go-pg in ci.yaml),
|
||||
# keep the Go timeout 5m shorter so tests produce goroutine dumps
|
||||
# instead of the CI runner killing the process with no output.
|
||||
GOTEST_FLAGS = -tags=testsmallbatch -v -timeout 20m -p $(TEST_PARALLEL_PACKAGES) -parallel=$(TEST_PARALLEL_TESTS)
|
||||
|
||||
# The most common use is to set TEST_COUNT=1 to avoid Go's test cache.
|
||||
ifdef TEST_COUNT
|
||||
@@ -1259,9 +1296,25 @@ endif
|
||||
TEST_PACKAGES ?= ./...
|
||||
|
||||
test:
|
||||
$(GIT_FLAGS) gotestsum --format standard-quiet $(GOTESTSUM_RETRY_FLAGS) --packages="$(TEST_PACKAGES)" -- $(GOTEST_FLAGS)
|
||||
$(GIT_FLAGS) gotestsum --format standard-quiet \
|
||||
$(GOTESTSUM_RETRY_FLAGS) \
|
||||
--packages="$(TEST_PACKAGES)" \
|
||||
-- \
|
||||
$(GOTEST_FLAGS)
|
||||
.PHONY: test
|
||||
|
||||
test-race: TEST_PARALLEL_PACKAGES := $(RACE_PARALLEL_PACKAGES)
|
||||
test-race: TEST_PARALLEL_TESTS := $(RACE_PARALLEL_TESTS)
|
||||
test-race:
|
||||
$(GIT_FLAGS) gotestsum --format standard-quiet \
|
||||
--junitfile="gotests.xml" \
|
||||
$(GOTESTSUM_RETRY_FLAGS) \
|
||||
--packages="$(TEST_PACKAGES)" \
|
||||
-- \
|
||||
-race \
|
||||
$(GOTEST_FLAGS)
|
||||
.PHONY: test-race
|
||||
|
||||
test-cli:
|
||||
$(MAKE) test TEST_PACKAGES="./cli..."
|
||||
.PHONY: test-cli
|
||||
@@ -1282,37 +1335,22 @@ sqlc-cloud-is-setup:
|
||||
|
||||
sqlc-push: sqlc-cloud-is-setup test-postgres-docker
|
||||
echo "--- sqlc push"
|
||||
SQLC_DATABASE_URL="postgresql://postgres:postgres@localhost:5432/$(shell go run scripts/migrate-ci/main.go)" \
|
||||
SQLC_DATABASE_URL="postgresql://postgres:postgres@localhost:5432/$$(go run scripts/migrate-ci/main.go)" \
|
||||
sqlc push -f coderd/database/sqlc.yaml && echo "Passed sqlc push"
|
||||
.PHONY: sqlc-push
|
||||
|
||||
sqlc-verify: sqlc-cloud-is-setup test-postgres-docker
|
||||
echo "--- sqlc verify"
|
||||
SQLC_DATABASE_URL="postgresql://postgres:postgres@localhost:5432/$(shell go run scripts/migrate-ci/main.go)" \
|
||||
SQLC_DATABASE_URL="postgresql://postgres:postgres@localhost:5432/$$(go run scripts/migrate-ci/main.go)" \
|
||||
sqlc verify -f coderd/database/sqlc.yaml && echo "Passed sqlc verify"
|
||||
.PHONY: sqlc-verify
|
||||
|
||||
sqlc-vet: test-postgres-docker
|
||||
echo "--- sqlc vet"
|
||||
SQLC_DATABASE_URL="postgresql://postgres:postgres@localhost:5432/$(shell go run scripts/migrate-ci/main.go)" \
|
||||
SQLC_DATABASE_URL="postgresql://postgres:postgres@localhost:5432/$$(go run scripts/migrate-ci/main.go)" \
|
||||
sqlc vet -f coderd/database/sqlc.yaml && echo "Passed sqlc vet"
|
||||
.PHONY: sqlc-vet
|
||||
|
||||
# When updating -timeout for this test, keep in sync with
|
||||
# test-go-postgres (.github/workflows/coder.yaml).
|
||||
# Do add coverage flags so that test caching works.
|
||||
test-postgres: test-postgres-docker
|
||||
# The postgres test is prone to failure, so we limit parallelism for
|
||||
# more consistent execution.
|
||||
$(GIT_FLAGS) gotestsum \
|
||||
--junitfile="gotests.xml" \
|
||||
--jsonfile="gotests.json" \
|
||||
$(GOTESTSUM_RETRY_FLAGS) \
|
||||
--packages="./..." -- \
|
||||
-tags=testsmallbatch \
|
||||
-timeout=20m \
|
||||
-count=1
|
||||
.PHONY: test-postgres
|
||||
|
||||
test-migrations: test-postgres-docker
|
||||
echo "--- test migrations"
|
||||
@@ -1328,13 +1366,24 @@ test-migrations: test-postgres-docker
|
||||
|
||||
# NOTE: we set --memory to the same size as a GitHub runner.
|
||||
test-postgres-docker:
|
||||
# If our container is already running, nothing to do.
|
||||
if docker ps --filter "name=test-postgres-docker-${POSTGRES_VERSION}" --format '{{.Names}}' | grep -q .; then \
|
||||
echo "test-postgres-docker-${POSTGRES_VERSION} is already running."; \
|
||||
exit 0; \
|
||||
fi
|
||||
# If something else is on 5432, warn but don't fail.
|
||||
if pg_isready -h 127.0.0.1 -q 2>/dev/null; then \
|
||||
echo "WARNING: PostgreSQL is already running on 127.0.0.1:5432 (not our container)."; \
|
||||
echo "Tests will use this instance. To use the Makefile's container, stop it first."; \
|
||||
exit 0; \
|
||||
fi
|
||||
docker rm -f test-postgres-docker-${POSTGRES_VERSION} || true
|
||||
|
||||
# Try pulling up to three times to avoid CI flakes.
|
||||
docker pull ${POSTGRES_IMAGE} || {
|
||||
retries=2
|
||||
for try in $(seq 1 ${retries}); do
|
||||
echo "Failed to pull image, retrying (${try}/${retries})..."
|
||||
for try in $$(seq 1 $${retries}); do
|
||||
echo "Failed to pull image, retrying ($${try}/$${retries})..."
|
||||
sleep 1
|
||||
if docker pull ${POSTGRES_IMAGE}; then
|
||||
break
|
||||
@@ -1375,16 +1424,11 @@ test-postgres-docker:
|
||||
-c log_statement=all
|
||||
while ! pg_isready -h 127.0.0.1
|
||||
do
|
||||
echo "$(date) - waiting for database to start"
|
||||
echo "$$(date) - waiting for database to start"
|
||||
sleep 0.5
|
||||
done
|
||||
.PHONY: test-postgres-docker
|
||||
|
||||
# Make sure to keep this in sync with test-go-race from .github/workflows/ci.yaml.
|
||||
test-race:
|
||||
$(GIT_FLAGS) gotestsum --junitfile="gotests.xml" -- -tags=testsmallbatch -race -count=1 -parallel 4 -p 4 ./...
|
||||
.PHONY: test-race
|
||||
|
||||
test-tailnet-integration:
|
||||
env \
|
||||
CODER_TAILNET_TESTS=true \
|
||||
@@ -1413,6 +1457,7 @@ site/e2e/bin/coder: go.mod go.sum $(GO_SRC_FILES)
|
||||
|
||||
test-e2e: site/e2e/bin/coder site/node_modules/.installed site/out/index.html
|
||||
cd site/
|
||||
pnpm playwright:install
|
||||
ifdef CI
|
||||
DEBUG=pw:api pnpm playwright:test --forbid-only --workers 1
|
||||
else
|
||||
|
||||
+128
-443
@@ -7,22 +7,14 @@ package agentgit
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/dustin/go-humanize"
|
||||
"github.com/go-git/go-git/v5"
|
||||
"github.com/go-git/go-git/v5/plumbing"
|
||||
"github.com/go-git/go-git/v5/plumbing/filemode"
|
||||
fdiff "github.com/go-git/go-git/v5/plumbing/format/diff"
|
||||
"github.com/go-git/go-git/v5/plumbing/object"
|
||||
"github.com/go-git/go-git/v5/utils/diff"
|
||||
dmp "github.com/sergi/go-diff/diffmatchpatch"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
@@ -41,20 +33,19 @@ func WithClock(c quartz.Clock) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// WithGitBinary overrides the git binary path (for testing).
|
||||
func WithGitBinary(path string) Option {
|
||||
return func(h *Handler) {
|
||||
h.gitBin = path
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
// scanCooldown is the minimum interval between successive scans.
|
||||
scanCooldown = 1 * time.Second
|
||||
// fallbackPollInterval is the safety-net poll period used when no
|
||||
// filesystem events arrive.
|
||||
fallbackPollInterval = 30 * time.Second
|
||||
// maxFileReadSize is the maximum file size that will be read
|
||||
// into memory. Files larger than this are tracked by status
|
||||
// only, and their diffs show a placeholder message.
|
||||
maxFileReadSize = 2 * 1024 * 1024 // 2 MiB
|
||||
// maxFileDiffSize is the maximum encoded size of a single
|
||||
// file's diff. If an individual file's diff exceeds this
|
||||
// limit, it is replaced with a placeholder stub.
|
||||
maxFileDiffSize = 256 * 1024 // 256 KiB
|
||||
// maxTotalDiffSize is the maximum size of the combined
|
||||
// unified diff for an entire repository sent over the wire.
|
||||
// This must stay under the WebSocket message size limit.
|
||||
@@ -65,6 +56,7 @@ const (
|
||||
type Handler struct {
|
||||
logger slog.Logger
|
||||
clock quartz.Clock
|
||||
gitBin string // path to git binary; empty means "git" (from PATH)
|
||||
|
||||
mu sync.Mutex
|
||||
repoRoots map[string]struct{} // watched repo roots
|
||||
@@ -85,6 +77,7 @@ func NewHandler(logger slog.Logger, opts ...Option) *Handler {
|
||||
h := &Handler{
|
||||
logger: logger,
|
||||
clock: quartz.NewReal(),
|
||||
gitBin: "git",
|
||||
repoRoots: make(map[string]struct{}),
|
||||
lastSnapshots: make(map[string]repoSnapshot),
|
||||
scanTrigger: make(chan struct{}, 1),
|
||||
@@ -92,13 +85,30 @@ func NewHandler(logger slog.Logger, opts ...Option) *Handler {
|
||||
for _, opt := range opts {
|
||||
opt(h)
|
||||
}
|
||||
|
||||
// Check if git is available.
|
||||
if _, err := exec.LookPath(h.gitBin); err != nil {
|
||||
h.logger.Warn(context.Background(), "git binary not found, git scanning disabled")
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// gitAvailable returns true if the configured git binary can be found
|
||||
// in PATH.
|
||||
func (h *Handler) gitAvailable() bool {
|
||||
_, err := exec.LookPath(h.gitBin)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// Subscribe processes a subscribe message, resolving paths to git repo
|
||||
// roots and adding new repos to the watch set. Returns true if any new
|
||||
// repo roots were added.
|
||||
func (h *Handler) Subscribe(paths []string) bool {
|
||||
if !h.gitAvailable() {
|
||||
return false
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
@@ -109,7 +119,7 @@ func (h *Handler) Subscribe(paths []string) bool {
|
||||
}
|
||||
p = filepath.Clean(p)
|
||||
|
||||
root, err := findRepoRoot(p)
|
||||
root, err := findRepoRoot(h.gitBin, p)
|
||||
if err != nil {
|
||||
// Not a git path — silently ignore.
|
||||
continue
|
||||
@@ -135,6 +145,10 @@ func (h *Handler) RequestScan() {
|
||||
// Scan performs a scan of all subscribed repos and computes deltas
|
||||
// against the previously emitted snapshots.
|
||||
func (h *Handler) Scan(ctx context.Context) *codersdk.WorkspaceAgentGitServerMessage {
|
||||
if !h.gitAvailable() {
|
||||
return nil
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
roots := make([]string, 0, len(h.repoRoots))
|
||||
for r := range h.repoRoots {
|
||||
@@ -158,7 +172,7 @@ func (h *Handler) Scan(ctx context.Context) *codersdk.WorkspaceAgentGitServerMes
|
||||
}
|
||||
results := make([]scanResult, 0, len(roots))
|
||||
for _, root := range roots {
|
||||
changes, err := getRepoChanges(ctx, h.logger, root)
|
||||
changes, err := getRepoChanges(ctx, h.logger, h.gitBin, root)
|
||||
results = append(results, scanResult{root: root, changes: changes, err: err})
|
||||
}
|
||||
|
||||
@@ -168,7 +182,7 @@ func (h *Handler) Scan(ctx context.Context) *codersdk.WorkspaceAgentGitServerMes
|
||||
|
||||
for _, res := range results {
|
||||
if res.err != nil {
|
||||
if isRepoDeleted(res.root) {
|
||||
if isRepoDeleted(h.gitBin, res.root) {
|
||||
// Repo root or .git directory was removed.
|
||||
// Emit a removal entry, then evict from watch set.
|
||||
removal := codersdk.WorkspaceAgentRepoChanges{
|
||||
@@ -276,8 +290,9 @@ func (h *Handler) rateLimitedScan(ctx context.Context, scanFn func()) {
|
||||
// 2. The .git entry (directory or file) was removed.
|
||||
// 3. The .git entry is a file (worktree/submodule) whose target
|
||||
// gitdir was removed. In this case .git exists on disk but
|
||||
// git.PlainOpen fails because the referenced directory is gone.
|
||||
func isRepoDeleted(repoRoot string) bool {
|
||||
// `git rev-parse --git-dir` fails because the referenced
|
||||
// directory is gone.
|
||||
func isRepoDeleted(gitBin string, repoRoot string) bool {
|
||||
if _, err := os.Stat(repoRoot); os.IsNotExist(err) {
|
||||
return true
|
||||
}
|
||||
@@ -288,78 +303,77 @@ func isRepoDeleted(repoRoot string) bool {
|
||||
}
|
||||
// If .git is a regular file (worktree or submodule), the actual
|
||||
// git object store lives elsewhere. Validate that the target is
|
||||
// still reachable by attempting to open the repo.
|
||||
// still reachable by running git rev-parse.
|
||||
if err == nil && !fi.IsDir() {
|
||||
if _, openErr := git.PlainOpen(repoRoot); openErr != nil {
|
||||
cmd := exec.CommandContext(context.Background(), gitBin, "-C", repoRoot, "rev-parse", "--git-dir")
|
||||
if err := cmd.Run(); err != nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// findRepoRoot walks up from the given path to find a .git directory.
|
||||
func findRepoRoot(p string) (string, error) {
|
||||
// If p is a file, start from its directory.
|
||||
// findRepoRoot uses `git rev-parse --show-toplevel` to find the
|
||||
// repository root for the given path.
|
||||
func findRepoRoot(gitBin string, p string) (string, error) {
|
||||
// If p is a file, start from its parent directory.
|
||||
dir := p
|
||||
for {
|
||||
_, err := git.PlainOpen(dir)
|
||||
if err == nil {
|
||||
return dir, nil
|
||||
}
|
||||
parent := filepath.Dir(dir)
|
||||
if parent == dir {
|
||||
return "", xerrors.Errorf("no git repo found for %s", p)
|
||||
}
|
||||
dir = parent
|
||||
if info, err := os.Stat(dir); err != nil || !info.IsDir() {
|
||||
dir = filepath.Dir(dir)
|
||||
}
|
||||
cmd := exec.CommandContext(context.Background(), gitBin, "rev-parse", "--show-toplevel")
|
||||
cmd.Dir = dir
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("no git repo found for %s", p)
|
||||
}
|
||||
root := filepath.FromSlash(strings.TrimSpace(string(out)))
|
||||
// Resolve symlinks and short (8.3) names on Windows so the
|
||||
// returned root matches paths produced by Go's filepath APIs.
|
||||
if resolved, evalErr := filepath.EvalSymlinks(root); evalErr == nil {
|
||||
root = resolved
|
||||
}
|
||||
return root, nil
|
||||
}
|
||||
|
||||
// getRepoChanges reads the current state of a git repository using
|
||||
// go-git. It returns branch, remote origin, and per-file status.
|
||||
func getRepoChanges(ctx context.Context, logger slog.Logger, repoRoot string) (codersdk.WorkspaceAgentRepoChanges, error) {
|
||||
repo, err := git.PlainOpen(repoRoot)
|
||||
if err != nil {
|
||||
return codersdk.WorkspaceAgentRepoChanges{}, xerrors.Errorf("open repo: %w", err)
|
||||
}
|
||||
|
||||
// the git CLI. It returns branch, remote origin, and a unified diff.
|
||||
func getRepoChanges(ctx context.Context, logger slog.Logger, gitBin string, repoRoot string) (codersdk.WorkspaceAgentRepoChanges, error) {
|
||||
result := codersdk.WorkspaceAgentRepoChanges{
|
||||
RepoRoot: repoRoot,
|
||||
}
|
||||
|
||||
// Read branch.
|
||||
headRef, err := repo.Head()
|
||||
if err != nil {
|
||||
// Repo may have no commits yet.
|
||||
// Verify this is still a valid git repository before doing
|
||||
// anything else. This catches deleted repos early.
|
||||
verifyCmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "rev-parse", "--git-dir")
|
||||
if err := verifyCmd.Run(); err != nil {
|
||||
return result, xerrors.Errorf("not a git repository: %w", err)
|
||||
}
|
||||
|
||||
// Read branch name.
|
||||
branchCmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "symbolic-ref", "--short", "HEAD")
|
||||
if out, err := branchCmd.Output(); err == nil {
|
||||
result.Branch = strings.TrimSpace(string(out))
|
||||
} else {
|
||||
logger.Debug(ctx, "failed to read HEAD", slog.F("root", repoRoot), slog.Error(err))
|
||||
} else if headRef.Name().IsBranch() {
|
||||
result.Branch = headRef.Name().Short()
|
||||
}
|
||||
|
||||
// Read remote origin URL.
|
||||
cfg, err := repo.Config()
|
||||
if err == nil {
|
||||
if origin, ok := cfg.Remotes["origin"]; ok && len(origin.URLs) > 0 {
|
||||
result.RemoteOrigin = origin.URLs[0]
|
||||
}
|
||||
remoteCmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "config", "--get", "remote.origin.url")
|
||||
if out, err := remoteCmd.Output(); err == nil {
|
||||
result.RemoteOrigin = strings.TrimSpace(string(out))
|
||||
}
|
||||
|
||||
// Get worktree status.
|
||||
wt, err := repo.Worktree()
|
||||
// Compute unified diff.
|
||||
// `git diff HEAD` shows both staged and unstaged changes vs HEAD.
|
||||
// For repos with no commits yet, fall back to showing untracked
|
||||
// files only.
|
||||
diff, err := computeGitDiff(ctx, logger, gitBin, repoRoot)
|
||||
if err != nil {
|
||||
return result, xerrors.Errorf("get worktree: %w", err)
|
||||
return result, xerrors.Errorf("compute diff: %w", err)
|
||||
}
|
||||
|
||||
status, err := wt.Status()
|
||||
if err != nil {
|
||||
return result, xerrors.Errorf("worktree status: %w", err)
|
||||
}
|
||||
|
||||
worktreeDiff, err := computeWorktreeDiff(repo, repoRoot, status)
|
||||
if err != nil {
|
||||
return result, xerrors.Errorf("compute worktree diff: %w", err)
|
||||
}
|
||||
|
||||
result.UnifiedDiff = worktreeDiff.unifiedDiff
|
||||
result.UnifiedDiff = diff
|
||||
if len(result.UnifiedDiff) > maxTotalDiffSize {
|
||||
result.UnifiedDiff = "Total diff too large to show. Size: " + humanize.IBytes(uint64(len(result.UnifiedDiff))) + ". Showing branch and remote only."
|
||||
}
|
||||
@@ -367,390 +381,61 @@ func getRepoChanges(ctx context.Context, logger slog.Logger, repoRoot string) (c
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type worktreeDiffResult struct {
|
||||
unifiedDiff string
|
||||
additions int
|
||||
deletions int
|
||||
}
|
||||
// computeGitDiff produces a unified diff string for the repository by
|
||||
// combining `git diff HEAD` (staged + unstaged changes) with diffs
|
||||
// for untracked files.
|
||||
func computeGitDiff(ctx context.Context, logger slog.Logger, gitBin string, repoRoot string) (string, error) {
|
||||
var diffParts []string
|
||||
|
||||
type fileSnapshot struct {
|
||||
exists bool
|
||||
content []byte
|
||||
mode filemode.FileMode
|
||||
binary bool
|
||||
tooLarge bool
|
||||
size int64 // actual file size on disk, set even when tooLarge
|
||||
}
|
||||
|
||||
func computeWorktreeDiff(
|
||||
repo *git.Repository,
|
||||
repoRoot string,
|
||||
status git.Status,
|
||||
) (worktreeDiffResult, error) {
|
||||
headTree, err := getHeadTree(repo)
|
||||
if err != nil {
|
||||
return worktreeDiffResult{}, xerrors.Errorf("get head tree: %w", err)
|
||||
// Check if the repo has any commits.
|
||||
hasCommits := true
|
||||
checkCmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "rev-parse", "HEAD")
|
||||
if err := checkCmd.Run(); err != nil {
|
||||
hasCommits = false
|
||||
}
|
||||
|
||||
paths := sortedStatusPaths(status)
|
||||
filePatches := make([]fdiff.FilePatch, 0, len(paths))
|
||||
totalAdditions := 0
|
||||
totalDeletions := 0
|
||||
|
||||
for _, path := range paths {
|
||||
fileStatus := status[path]
|
||||
|
||||
fromPath := path
|
||||
if isRenamed(fileStatus) && fileStatus.Extra != "" {
|
||||
fromPath = fileStatus.Extra
|
||||
}
|
||||
toPath := path
|
||||
|
||||
before, err := readHeadFileSnapshot(headTree, fromPath)
|
||||
if hasCommits {
|
||||
// `git diff HEAD` captures both staged and unstaged changes
|
||||
// relative to HEAD in a single unified diff.
|
||||
cmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "diff", "HEAD")
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return worktreeDiffResult{}, xerrors.Errorf("read head file %q: %w", fromPath, err)
|
||||
return "", xerrors.Errorf("git diff HEAD: %w", err)
|
||||
}
|
||||
|
||||
after, err := readWorktreeFileSnapshot(repoRoot, toPath)
|
||||
if err != nil {
|
||||
return worktreeDiffResult{}, xerrors.Errorf("read worktree file %q: %w", toPath, err)
|
||||
if len(out) > 0 {
|
||||
diffParts = append(diffParts, string(out))
|
||||
}
|
||||
}
|
||||
|
||||
filePatch, additions, deletions := buildFilePatch(fromPath, toPath, before, after)
|
||||
if filePatch == nil {
|
||||
// Show untracked files as diffs too.
|
||||
// `git ls-files --others --exclude-standard` lists untracked,
|
||||
// non-ignored files.
|
||||
lsCmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "ls-files", "--others", "--exclude-standard")
|
||||
lsOut, err := lsCmd.Output()
|
||||
if err != nil {
|
||||
logger.Debug(ctx, "failed to list untracked files", slog.F("root", repoRoot), slog.Error(err))
|
||||
return strings.Join(diffParts, ""), nil
|
||||
}
|
||||
|
||||
untrackedFiles := strings.Split(strings.TrimSpace(string(lsOut)), "\n")
|
||||
for _, f := range untrackedFiles {
|
||||
f = strings.TrimSpace(f)
|
||||
if f == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check whether this single file's diff exceeds the
|
||||
// per-file limit. If so, replace it with a stub.
|
||||
encoded, err := encodeUnifiedDiff([]fdiff.FilePatch{filePatch})
|
||||
if err != nil {
|
||||
return worktreeDiffResult{}, xerrors.Errorf("encode file diff %q: %w", toPath, err)
|
||||
// Use `git diff --no-index /dev/null <file>` to generate
|
||||
// a unified diff for untracked files.
|
||||
var stdout bytes.Buffer
|
||||
untrackedCmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "diff", "--no-index", "--", "/dev/null", f)
|
||||
untrackedCmd.Stdout = &stdout
|
||||
// git diff --no-index exits with 1 when files differ,
|
||||
// which is expected. We ignore the error and check for
|
||||
// output instead.
|
||||
_ = untrackedCmd.Run()
|
||||
if stdout.Len() > 0 {
|
||||
diffParts = append(diffParts, stdout.String())
|
||||
}
|
||||
if len(encoded) > maxFileDiffSize {
|
||||
msg := "File diff too large to show. Diff size: " + humanize.IBytes(uint64(len(encoded)))
|
||||
filePatch = buildStubFilePatch(fromPath, toPath, before, after, msg)
|
||||
additions = 0
|
||||
deletions = 0
|
||||
}
|
||||
|
||||
filePatches = append(filePatches, filePatch)
|
||||
totalAdditions += additions
|
||||
totalDeletions += deletions
|
||||
}
|
||||
|
||||
diffText, err := encodeUnifiedDiff(filePatches)
|
||||
if err != nil {
|
||||
return worktreeDiffResult{}, xerrors.Errorf("encode unified diff: %w", err)
|
||||
}
|
||||
|
||||
return worktreeDiffResult{
|
||||
unifiedDiff: diffText,
|
||||
additions: totalAdditions,
|
||||
deletions: totalDeletions,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getHeadTree(repo *git.Repository) (*object.Tree, error) {
|
||||
headRef, err := repo.Head()
|
||||
if err != nil {
|
||||
if errors.Is(err, plumbing.ErrReferenceNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
commit, err := repo.CommitObject(headRef.Hash())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return commit.Tree()
|
||||
}
|
||||
|
||||
func readHeadFileSnapshot(headTree *object.Tree, path string) (fileSnapshot, error) {
|
||||
if headTree == nil {
|
||||
return fileSnapshot{}, nil
|
||||
}
|
||||
|
||||
file, err := headTree.File(path)
|
||||
if err != nil {
|
||||
if errors.Is(err, object.ErrFileNotFound) {
|
||||
return fileSnapshot{}, nil
|
||||
}
|
||||
return fileSnapshot{}, err
|
||||
}
|
||||
|
||||
if file.Size > maxFileReadSize {
|
||||
return fileSnapshot{
|
||||
exists: true,
|
||||
tooLarge: true,
|
||||
size: file.Size,
|
||||
mode: file.Mode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
content, err := file.Contents()
|
||||
if err != nil {
|
||||
return fileSnapshot{}, err
|
||||
}
|
||||
|
||||
isBinary, err := file.IsBinary()
|
||||
if err != nil {
|
||||
return fileSnapshot{}, err
|
||||
}
|
||||
|
||||
return fileSnapshot{
|
||||
exists: true,
|
||||
content: []byte(content),
|
||||
mode: file.Mode,
|
||||
binary: isBinary,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func readWorktreeFileSnapshot(repoRoot string, path string) (fileSnapshot, error) {
|
||||
absPath := filepath.Join(repoRoot, filepath.FromSlash(path))
|
||||
fileInfo, err := os.Stat(absPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return fileSnapshot{}, nil
|
||||
}
|
||||
return fileSnapshot{}, err
|
||||
}
|
||||
if fileInfo.IsDir() {
|
||||
return fileSnapshot{}, nil
|
||||
}
|
||||
|
||||
if fileInfo.Size() > maxFileReadSize {
|
||||
mode, err := filemode.NewFromOSFileMode(fileInfo.Mode())
|
||||
if err != nil {
|
||||
mode = filemode.Regular
|
||||
}
|
||||
return fileSnapshot{
|
||||
exists: true,
|
||||
tooLarge: true,
|
||||
size: fileInfo.Size(),
|
||||
mode: mode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(absPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return fileSnapshot{}, nil
|
||||
}
|
||||
return fileSnapshot{}, err
|
||||
}
|
||||
|
||||
mode, err := filemode.NewFromOSFileMode(fileInfo.Mode())
|
||||
if err != nil {
|
||||
mode = filemode.Regular
|
||||
}
|
||||
|
||||
return fileSnapshot{
|
||||
exists: true,
|
||||
content: content,
|
||||
mode: mode,
|
||||
binary: isBinaryContent(content),
|
||||
size: fileInfo.Size(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildFilePatch(
|
||||
fromPath string,
|
||||
toPath string,
|
||||
before fileSnapshot,
|
||||
after fileSnapshot,
|
||||
) (fdiff.FilePatch, int, int) {
|
||||
if !before.exists && !after.exists {
|
||||
return nil, 0, 0
|
||||
}
|
||||
|
||||
unchangedContent := bytes.Equal(before.content, after.content)
|
||||
if before.exists &&
|
||||
after.exists &&
|
||||
fromPath == toPath &&
|
||||
before.mode == after.mode &&
|
||||
unchangedContent {
|
||||
return nil, 0, 0
|
||||
}
|
||||
|
||||
// Files that exceed the read size limit get a stub patch
|
||||
// instead of a full diff to avoid OOM.
|
||||
if before.tooLarge || after.tooLarge {
|
||||
sz := max(after.size, 0)
|
||||
//nolint:gosec // sz is guaranteed to fit in uint64
|
||||
msg := "File too large to diff. Current size: " + humanize.IBytes(uint64(sz))
|
||||
return buildStubFilePatch(fromPath, toPath, before, after, msg), 0, 0
|
||||
}
|
||||
|
||||
patch := &workspaceFilePatch{
|
||||
from: snapshotToDiffFile(fromPath, before),
|
||||
to: snapshotToDiffFile(toPath, after),
|
||||
}
|
||||
|
||||
if before.binary || after.binary {
|
||||
patch.binary = true
|
||||
return patch, 0, 0
|
||||
}
|
||||
|
||||
diffs := diff.Do(string(before.content), string(after.content))
|
||||
chunks := make([]fdiff.Chunk, 0, len(diffs))
|
||||
additions := 0
|
||||
deletions := 0
|
||||
|
||||
for _, d := range diffs {
|
||||
var operation fdiff.Operation
|
||||
switch d.Type {
|
||||
case dmp.DiffEqual:
|
||||
operation = fdiff.Equal
|
||||
case dmp.DiffDelete:
|
||||
operation = fdiff.Delete
|
||||
deletions += countChunkLines(d.Text)
|
||||
case dmp.DiffInsert:
|
||||
operation = fdiff.Add
|
||||
additions += countChunkLines(d.Text)
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
chunks = append(chunks, workspaceDiffChunk{
|
||||
content: d.Text,
|
||||
op: operation,
|
||||
})
|
||||
}
|
||||
|
||||
patch.chunks = chunks
|
||||
return patch, additions, deletions
|
||||
}
|
||||
|
||||
func buildStubFilePatch(fromPath, toPath string, before, after fileSnapshot, message string) fdiff.FilePatch {
|
||||
return &workspaceFilePatch{
|
||||
from: snapshotToDiffFile(fromPath, before),
|
||||
to: snapshotToDiffFile(toPath, after),
|
||||
chunks: []fdiff.Chunk{
|
||||
workspaceDiffChunk{
|
||||
content: message + "\n",
|
||||
op: fdiff.Add,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func snapshotToDiffFile(path string, snapshot fileSnapshot) fdiff.File {
|
||||
if !snapshot.exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
return workspaceDiffFile{
|
||||
path: path,
|
||||
mode: snapshot.mode,
|
||||
hash: plumbing.ComputeHash(plumbing.BlobObject, snapshot.content),
|
||||
}
|
||||
}
|
||||
|
||||
func encodeUnifiedDiff(filePatches []fdiff.FilePatch) (string, error) {
|
||||
if len(filePatches) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
patch := workspaceDiffPatch{filePatches: filePatches}
|
||||
var builder strings.Builder
|
||||
encoder := fdiff.NewUnifiedEncoder(&builder, fdiff.DefaultContextLines)
|
||||
if err := encoder.Encode(patch); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return builder.String(), nil
|
||||
}
|
||||
|
||||
func sortedStatusPaths(status git.Status) []string {
|
||||
paths := make([]string, 0, len(status))
|
||||
for path := range status {
|
||||
paths = append(paths, path)
|
||||
}
|
||||
sort.Strings(paths)
|
||||
return paths
|
||||
}
|
||||
|
||||
func isRenamed(fileStatus *git.FileStatus) bool {
|
||||
return fileStatus.Staging == git.Renamed || fileStatus.Worktree == git.Renamed
|
||||
}
|
||||
|
||||
func countChunkLines(content string) int {
|
||||
if content == "" {
|
||||
return 0
|
||||
}
|
||||
|
||||
lines := strings.Count(content, "\n")
|
||||
if !strings.HasSuffix(content, "\n") {
|
||||
lines++
|
||||
}
|
||||
return lines
|
||||
}
|
||||
|
||||
func isBinaryContent(content []byte) bool {
|
||||
return bytes.IndexByte(content, 0) >= 0
|
||||
}
|
||||
|
||||
type workspaceDiffPatch struct {
|
||||
filePatches []fdiff.FilePatch
|
||||
}
|
||||
|
||||
func (p workspaceDiffPatch) FilePatches() []fdiff.FilePatch {
|
||||
return p.filePatches
|
||||
}
|
||||
|
||||
func (workspaceDiffPatch) Message() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
type workspaceFilePatch struct {
|
||||
from fdiff.File
|
||||
to fdiff.File
|
||||
chunks []fdiff.Chunk
|
||||
binary bool
|
||||
}
|
||||
|
||||
func (p *workspaceFilePatch) IsBinary() bool {
|
||||
return p.binary
|
||||
}
|
||||
|
||||
func (p *workspaceFilePatch) Files() (fdiff.File, fdiff.File) {
|
||||
return p.from, p.to
|
||||
}
|
||||
|
||||
func (p *workspaceFilePatch) Chunks() []fdiff.Chunk {
|
||||
return p.chunks
|
||||
}
|
||||
|
||||
type workspaceDiffFile struct {
|
||||
path string
|
||||
mode filemode.FileMode
|
||||
hash plumbing.Hash
|
||||
}
|
||||
|
||||
func (f workspaceDiffFile) Hash() plumbing.Hash {
|
||||
return f.hash
|
||||
}
|
||||
|
||||
func (f workspaceDiffFile) Mode() filemode.FileMode {
|
||||
return f.mode
|
||||
}
|
||||
|
||||
func (f workspaceDiffFile) Path() string {
|
||||
return f.path
|
||||
}
|
||||
|
||||
type workspaceDiffChunk struct {
|
||||
content string
|
||||
op fdiff.Operation
|
||||
}
|
||||
|
||||
func (c workspaceDiffChunk) Content() string {
|
||||
return c.content
|
||||
}
|
||||
|
||||
func (c workspaceDiffChunk) Type() fdiff.Operation {
|
||||
return c.op
|
||||
return strings.Join(diffParts, ""), nil
|
||||
}
|
||||
|
||||
+202
-177
@@ -5,13 +5,11 @@ import (
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-git/go-git/v5"
|
||||
"github.com/go-git/go-git/v5/plumbing"
|
||||
"github.com/go-git/go-git/v5/plumbing/object"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -25,30 +23,44 @@ import (
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
// gitCmd runs a git command in the given directory and fails the test
|
||||
// on error.
|
||||
func gitCmd(t *testing.T, dir string, args ...string) {
|
||||
t.Helper()
|
||||
cmd := exec.Command("git", args...)
|
||||
cmd.Dir = dir
|
||||
cmd.Env = append(os.Environ(),
|
||||
"GIT_AUTHOR_NAME=Test",
|
||||
"GIT_AUTHOR_EMAIL=test@test.com",
|
||||
"GIT_COMMITTER_NAME=Test",
|
||||
"GIT_COMMITTER_EMAIL=test@test.com",
|
||||
)
|
||||
out, err := cmd.CombinedOutput()
|
||||
require.NoError(t, err, "git %v: %s", args, out)
|
||||
}
|
||||
|
||||
// initTestRepo creates a temporary git repo with an initial commit
|
||||
// and returns the repo root path.
|
||||
func initTestRepo(t *testing.T) string {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
repo, err := git.PlainInit(dir, false)
|
||||
require.NoError(t, err)
|
||||
// Resolve symlinks and short (8.3) names on Windows so test
|
||||
// expectations match the canonical paths returned by git.
|
||||
resolved, err := filepath.EvalSymlinks(dir)
|
||||
if err == nil {
|
||||
dir = resolved
|
||||
}
|
||||
|
||||
gitCmd(t, dir, "init")
|
||||
gitCmd(t, dir, "config", "user.name", "Test")
|
||||
gitCmd(t, dir, "config", "user.email", "test@test.com")
|
||||
|
||||
// Create a file and commit it so the repo has HEAD.
|
||||
testFile := filepath.Join(dir, "README.md")
|
||||
require.NoError(t, os.WriteFile(testFile, []byte("# Test\n"), 0o600))
|
||||
|
||||
wt, err := repo.Worktree()
|
||||
require.NoError(t, err)
|
||||
_, err = wt.Add("README.md")
|
||||
require.NoError(t, err)
|
||||
_, err = wt.Commit("initial commit", &git.CommitOptions{
|
||||
Author: &object.Signature{
|
||||
Name: "Test",
|
||||
Email: "test@test.com",
|
||||
When: time.Now(),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
gitCmd(t, dir, "add", "README.md")
|
||||
gitCmd(t, dir, "commit", "-m", "initial commit")
|
||||
|
||||
return dir
|
||||
}
|
||||
@@ -139,6 +151,88 @@ func TestScanReturnsRepoChanges(t *testing.T) {
|
||||
require.Contains(t, repo.UnifiedDiff, "new.go")
|
||||
}
|
||||
|
||||
func TestScanRespectsGitignore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
repoDir := initTestRepo(t)
|
||||
logger := slogtest.Make(t, nil)
|
||||
|
||||
// Add a .gitignore that ignores *.log files and the build/ directory.
|
||||
require.NoError(t, os.WriteFile(filepath.Join(repoDir, ".gitignore"), []byte("*.log\nbuild/\n"), 0o600))
|
||||
gitCmd(t, repoDir, "add", ".gitignore")
|
||||
gitCmd(t, repoDir, "commit", "-m", "add gitignore")
|
||||
|
||||
// Create unstaged files: two normal, three matching gitignore patterns.
|
||||
require.NoError(t, os.WriteFile(filepath.Join(repoDir, "main.go"), []byte("package main\n"), 0o600))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(repoDir, "util.go"), []byte("package util\n"), 0o600))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(repoDir, "debug.log"), []byte("some log output\n"), 0o600))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(repoDir, "error.log"), []byte("some error\n"), 0o600))
|
||||
require.NoError(t, os.MkdirAll(filepath.Join(repoDir, "build"), 0o700))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(repoDir, "build", "output.bin"), []byte("binary\n"), 0o600))
|
||||
|
||||
h := agentgit.NewHandler(logger)
|
||||
h.Subscribe([]string{filepath.Join(repoDir, "main.go")})
|
||||
|
||||
ctx := context.Background()
|
||||
msg := h.Scan(ctx)
|
||||
require.NotNil(t, msg)
|
||||
require.Len(t, msg.Repositories, 1)
|
||||
|
||||
diff := msg.Repositories[0].UnifiedDiff
|
||||
|
||||
// The non-ignored files should appear in the diff.
|
||||
assert.Contains(t, diff, "main.go")
|
||||
assert.Contains(t, diff, "util.go")
|
||||
// The gitignored files must not appear in the diff.
|
||||
assert.NotContains(t, diff, "debug.log")
|
||||
assert.NotContains(t, diff, "error.log")
|
||||
assert.NotContains(t, diff, "output.bin")
|
||||
}
|
||||
|
||||
func TestScanRespectsGitignoreNestedNegation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
repoDir := initTestRepo(t)
|
||||
logger := slogtest.Make(t, nil)
|
||||
|
||||
// Add a .gitignore that ignores node_modules/.
|
||||
require.NoError(t, os.WriteFile(filepath.Join(repoDir, ".gitignore"), []byte("node_modules/\n"), 0o600))
|
||||
gitCmd(t, repoDir, "add", ".gitignore")
|
||||
gitCmd(t, repoDir, "commit", "-m", "add gitignore")
|
||||
|
||||
// Simulate the tailwindcss stubs directory which contains a nested
|
||||
// .gitignore with "!*" (negation that un-ignores everything).
|
||||
// Real git keeps the parent node_modules/ ignore rule, but go-git
|
||||
// incorrectly lets the child negation override it.
|
||||
stubsDir := filepath.Join(repoDir, "site", "node_modules", ".pnpm",
|
||||
"tailwindcss@3.4.18", "node_modules", "tailwindcss", "stubs")
|
||||
require.NoError(t, os.MkdirAll(stubsDir, 0o700))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(stubsDir, ".gitignore"), []byte("!*\n"), 0o600))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(stubsDir, "config.full.js"), []byte("module.exports = {}\n"), 0o600))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(stubsDir, "tailwind.config.js"), []byte("// tw config\n"), 0o600))
|
||||
|
||||
// Also create a normal file outside node_modules.
|
||||
require.NoError(t, os.WriteFile(filepath.Join(repoDir, "main.go"), []byte("package main\n"), 0o600))
|
||||
|
||||
h := agentgit.NewHandler(logger)
|
||||
h.Subscribe([]string{filepath.Join(repoDir, "main.go")})
|
||||
|
||||
ctx := context.Background()
|
||||
msg := h.Scan(ctx)
|
||||
require.NotNil(t, msg)
|
||||
require.Len(t, msg.Repositories, 1)
|
||||
|
||||
diff := msg.Repositories[0].UnifiedDiff
|
||||
|
||||
// The non-ignored file should appear in the diff.
|
||||
assert.Contains(t, diff, "main.go")
|
||||
// Files inside node_modules must not appear even though a nested
|
||||
// .gitignore contains "!*". The parent node_modules/ rule takes
|
||||
// precedence in real git.
|
||||
assert.NotContains(t, diff, "config.full.js")
|
||||
assert.NotContains(t, diff, "tailwind.config.js")
|
||||
}
|
||||
|
||||
func TestScanDeltaEmission(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -296,24 +390,16 @@ func TestSubscribeNestedGitRepos(t *testing.T) {
|
||||
// Create an inner repo nested inside the outer one.
|
||||
innerDir := filepath.Join(outerDir, "subproject")
|
||||
require.NoError(t, os.MkdirAll(innerDir, 0o700))
|
||||
innerRepo, err := git.PlainInit(innerDir, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
gitCmd(t, innerDir, "init")
|
||||
gitCmd(t, innerDir, "config", "user.name", "Test")
|
||||
gitCmd(t, innerDir, "config", "user.email", "test@test.com")
|
||||
|
||||
// Commit a file in the inner repo so it has HEAD.
|
||||
innerFile := filepath.Join(innerDir, "inner.go")
|
||||
require.NoError(t, os.WriteFile(innerFile, []byte("package inner\n"), 0o600))
|
||||
innerWt, err := innerRepo.Worktree()
|
||||
require.NoError(t, err)
|
||||
_, err = innerWt.Add("inner.go")
|
||||
require.NoError(t, err)
|
||||
_, err = innerWt.Commit("inner commit", &git.CommitOptions{
|
||||
Author: &object.Signature{
|
||||
Name: "Test",
|
||||
Email: "test@test.com",
|
||||
When: time.Now(),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
gitCmd(t, innerDir, "add", "inner.go")
|
||||
gitCmd(t, innerDir, "commit", "-m", "inner commit")
|
||||
|
||||
// Now create a dirty file in the inner repo.
|
||||
dirtyFile := filepath.Join(innerDir, "dirty.go")
|
||||
@@ -411,46 +497,16 @@ func TestScanDeletedWorktreeGitdirEmitsRemoved(t *testing.T) {
|
||||
// Set up a main repo that we'll use as the source for a worktree.
|
||||
mainRepoDir := initTestRepo(t)
|
||||
|
||||
// Create a linked worktree using git.
|
||||
worktreeDir := t.TempDir()
|
||||
mainRepo, err := git.PlainOpen(mainRepoDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a branch for the worktree.
|
||||
headRef, err := mainRepo.Head()
|
||||
require.NoError(t, err)
|
||||
err = mainRepo.Storer.SetReference(
|
||||
//nolint:revive // plumbing.NewBranchReferenceName is not available.
|
||||
plumbing.NewHashReference("refs/heads/worktree-branch", headRef.Hash()),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Manually construct the worktree linkage:
|
||||
// 1. Create worktree gitdir inside main repo's worktrees/
|
||||
// 2. Write a .git file in the worktree dir pointing to that gitdir.
|
||||
gitdirPath := filepath.Join(mainRepoDir, ".git", "worktrees", "wt")
|
||||
require.NoError(t, os.MkdirAll(gitdirPath, 0o755))
|
||||
|
||||
// The worktree gitdir needs HEAD and commondir files.
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(gitdirPath, "HEAD"),
|
||||
[]byte("ref: refs/heads/worktree-branch\n"), 0o600,
|
||||
))
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(gitdirPath, "commondir"),
|
||||
[]byte(filepath.Join(mainRepoDir, ".git")+"\n"), 0o600,
|
||||
))
|
||||
|
||||
// Write the .git file in the worktree directory.
|
||||
gitFileContent := "gitdir: " + gitdirPath + "\n"
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(worktreeDir, ".git"),
|
||||
[]byte(gitFileContent), 0o600,
|
||||
))
|
||||
|
||||
// Verify the worktree is a valid repo before we break it.
|
||||
_, err = git.PlainOpen(worktreeDir)
|
||||
require.NoError(t, err, "worktree should be openable before deletion")
|
||||
// Create a linked worktree using git CLI.
|
||||
wtBase := t.TempDir()
|
||||
// Resolve symlinks and short (8.3) names on Windows so test
|
||||
// expectations match the canonical paths returned by git.
|
||||
if resolved, err := filepath.EvalSymlinks(wtBase); err == nil {
|
||||
wtBase = resolved
|
||||
}
|
||||
worktreeDir := filepath.Join(wtBase, "wt")
|
||||
gitCmd(t, mainRepoDir, "branch", "worktree-branch")
|
||||
gitCmd(t, mainRepoDir, "worktree", "add", worktreeDir, "worktree-branch")
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
h := agentgit.NewHandler(logger)
|
||||
@@ -468,12 +524,14 @@ func TestScanDeletedWorktreeGitdirEmitsRemoved(t *testing.T) {
|
||||
require.Len(t, msg1.Repositories, 1)
|
||||
require.False(t, msg1.Repositories[0].Removed)
|
||||
|
||||
// Now delete the target gitdir. The .git file in the worktree
|
||||
// still exists, but it points to a directory that is gone.
|
||||
// Now delete the target gitdir inside .git/worktrees/. The .git
|
||||
// file in the worktree still exists, but it points to a directory
|
||||
// that is gone.
|
||||
gitdirPath := filepath.Join(mainRepoDir, ".git", "worktrees", filepath.Base(worktreeDir))
|
||||
require.NoError(t, os.RemoveAll(gitdirPath))
|
||||
|
||||
// Verify the .git file still exists (this is the bug scenario).
|
||||
_, err = os.Stat(filepath.Join(worktreeDir, ".git"))
|
||||
_, err := os.Stat(filepath.Join(worktreeDir, ".git"))
|
||||
require.NoError(t, err, ".git file should still exist")
|
||||
|
||||
// Next scan should detect the broken worktree and emit removal.
|
||||
@@ -775,13 +833,8 @@ func TestGetRepoChangesStagedModifiedDeleted(t *testing.T) {
|
||||
require.NoError(t, os.WriteFile(filepath.Join(repoDir, "README.md"), []byte("# Modified\n"), 0o600))
|
||||
|
||||
// Stage a new file.
|
||||
repo, err := git.PlainOpen(repoDir)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(repoDir, "staged.go"), []byte("package staged\n"), 0o600))
|
||||
wt, err := repo.Worktree()
|
||||
require.NoError(t, err)
|
||||
_, err = wt.Add("staged.go")
|
||||
require.NoError(t, err)
|
||||
gitCmd(t, repoDir, "add", "staged.go")
|
||||
|
||||
// Create an untracked file.
|
||||
require.NoError(t, os.WriteFile(filepath.Join(repoDir, "untracked.txt"), []byte("hello\n"), 0o600))
|
||||
@@ -791,34 +844,22 @@ func TestGetRepoChangesStagedModifiedDeleted(t *testing.T) {
|
||||
require.NotNil(t, msg)
|
||||
require.Len(t, msg.Repositories, 1)
|
||||
|
||||
diff := msg.Repositories[0].UnifiedDiff
|
||||
|
||||
// README.md was committed then modified in worktree.
|
||||
require.Contains(t, msg.Repositories[0].UnifiedDiff, "README.md")
|
||||
require.Contains(t, diff, "README.md")
|
||||
require.Contains(t, diff, "--- a/README.md")
|
||||
require.Contains(t, diff, "+++ b/README.md")
|
||||
require.Contains(t, diff, "-# Test")
|
||||
require.Contains(t, diff, "+# Modified")
|
||||
|
||||
// staged.go was added to the staging area.
|
||||
require.Contains(t, msg.Repositories[0].UnifiedDiff, "staged.go")
|
||||
// untracked.txt is untracked.
|
||||
require.Contains(t, msg.Repositories[0].UnifiedDiff, "untracked.txt")
|
||||
require.Equal(t, `diff --git a/README.md b/README.md
|
||||
index 8ae056963b8b4664c9059e30bc8b834151e03950..6c31532bd0a2258bcfa88789d20d50574cfcc3da 100644
|
||||
--- a/README.md
|
||||
+++ b/README.md
|
||||
@@ -1 +1 @@
|
||||
-# Test
|
||||
+# Modified
|
||||
diff --git a/staged.go b/staged.go
|
||||
new file mode 100644
|
||||
index 0000000000000000000000000000000000000000..98a5a992ed2bc4b17d078d396ba034c8064079b4
|
||||
--- /dev/null
|
||||
+++ b/staged.go
|
||||
@@ -0,0 +1 @@
|
||||
+package staged
|
||||
diff --git a/untracked.txt b/untracked.txt
|
||||
new file mode 100644
|
||||
index 0000000000000000000000000000000000000000..ce013625030ba8dba906f756967f9e9ca394464a
|
||||
--- /dev/null
|
||||
+++ b/untracked.txt
|
||||
@@ -0,0 +1 @@
|
||||
+hello
|
||||
`, msg.Repositories[0].UnifiedDiff)
|
||||
require.Contains(t, diff, "staged.go")
|
||||
require.Contains(t, diff, "+package staged")
|
||||
|
||||
// untracked.txt is untracked (shown via --no-index diff).
|
||||
require.Contains(t, diff, "untracked.txt")
|
||||
require.Contains(t, diff, "+hello")
|
||||
}
|
||||
|
||||
func TestFallbackPollTriggersScan(t *testing.T) {
|
||||
@@ -912,12 +953,17 @@ func TestScanLargeFileTooLargeToDiff(t *testing.T) {
|
||||
|
||||
h := agentgit.NewHandler(logger)
|
||||
|
||||
// Create a file larger than maxFileReadSize (2 MiB).
|
||||
largeContent := make([]byte, 3*1024*1024)
|
||||
// Create a large text file (1 MiB). The diff produced by git
|
||||
// CLI will be under maxTotalDiffSize (3 MiB) so it appears in
|
||||
// the unified diff output.
|
||||
largeContent := make([]byte, 1*1024*1024)
|
||||
for i := range largeContent {
|
||||
largeContent[i] = byte('A' + (i % 26))
|
||||
if i%80 == 79 {
|
||||
largeContent[i] = '\n'
|
||||
}
|
||||
}
|
||||
largeFile := filepath.Join(repoDir, "large.bin")
|
||||
largeFile := filepath.Join(repoDir, "large.txt")
|
||||
require.NoError(t, os.WriteFile(largeFile, largeContent, 0o600))
|
||||
|
||||
h.Subscribe([]string{largeFile})
|
||||
@@ -930,13 +976,7 @@ func TestScanLargeFileTooLargeToDiff(t *testing.T) {
|
||||
repo := msg.Repositories[0]
|
||||
|
||||
// The large file should appear in the unified diff.
|
||||
require.Contains(t, repo.UnifiedDiff, "large.bin")
|
||||
|
||||
// The unified diff should contain the "too large" message,
|
||||
// NOT the actual file content.
|
||||
require.Contains(t, repo.UnifiedDiff, "File too large to diff")
|
||||
require.NotContains(t, repo.UnifiedDiff, "AAAA",
|
||||
"actual file content should not appear in diff")
|
||||
require.Contains(t, repo.UnifiedDiff, "large.txt")
|
||||
}
|
||||
|
||||
func TestScanLargeFileDeltaTracking(t *testing.T) {
|
||||
@@ -975,45 +1015,6 @@ func TestScanLargeFileDeltaTracking(t *testing.T) {
|
||||
require.NotContains(t, msg3.Repositories[0].UnifiedDiff, "big.dat")
|
||||
}
|
||||
|
||||
func TestScanFileDiffTooLargeForWire(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
repoDir := initTestRepo(t)
|
||||
logger := slogtest.Make(t, nil)
|
||||
|
||||
h := agentgit.NewHandler(logger)
|
||||
|
||||
// Create a single file whose diff exceeds maxFileDiffSize
|
||||
// (256 KiB) but stays under maxFileReadSize (2 MiB).
|
||||
content := make([]byte, 512*1024)
|
||||
for i := range content {
|
||||
content[i] = byte('A' + (i % 26))
|
||||
}
|
||||
bigFile := filepath.Join(repoDir, "big_diff.txt")
|
||||
require.NoError(t, os.WriteFile(bigFile, content, 0o600))
|
||||
|
||||
h.Subscribe([]string{bigFile})
|
||||
|
||||
ctx := context.Background()
|
||||
msg := h.Scan(ctx)
|
||||
require.NotNil(t, msg)
|
||||
require.Len(t, msg.Repositories, 1)
|
||||
|
||||
repo := msg.Repositories[0]
|
||||
|
||||
// The single file diff exceeds 256 KiB, so it should be
|
||||
// replaced with a per-file stub.
|
||||
require.Contains(t, repo.UnifiedDiff, "File diff too large to show")
|
||||
require.Contains(t, repo.UnifiedDiff, "big_diff.txt")
|
||||
|
||||
// The stub should NOT contain the actual file content.
|
||||
require.NotContains(t, repo.UnifiedDiff, "ABCDEFGHIJ",
|
||||
"actual file content should not appear in diff")
|
||||
|
||||
// Branch metadata should still be present.
|
||||
require.NotEmpty(t, repo.Branch)
|
||||
}
|
||||
|
||||
func TestScanTotalDiffTooLargeForWire(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1023,7 +1024,7 @@ func TestScanTotalDiffTooLargeForWire(t *testing.T) {
|
||||
h := agentgit.NewHandler(logger)
|
||||
|
||||
// Create many files whose individual diffs are under 256 KiB
|
||||
// but whose total exceeds maxTotalDiffSize (4 MiB).
|
||||
// but whose total exceeds maxTotalDiffSize (3 MiB).
|
||||
// ~100 files x 50 KiB content each = ~5 MiB of diffs.
|
||||
var paths []string
|
||||
for i := range 100 {
|
||||
@@ -1046,14 +1047,14 @@ func TestScanTotalDiffTooLargeForWire(t *testing.T) {
|
||||
|
||||
repo := msg.Repositories[0]
|
||||
|
||||
// The total diff exceeds 4 MiB, so we should get the
|
||||
// The total diff exceeds 3 MiB, so we should get the
|
||||
// total-diff placeholder.
|
||||
require.Contains(t, repo.UnifiedDiff, "Total diff too large to show")
|
||||
|
||||
// Branch and remote metadata should still be present.
|
||||
require.NotEmpty(t, repo.Branch, "branch should still be populated")
|
||||
|
||||
// The placeholder message should be well under 4 MiB.
|
||||
// The placeholder message should be well under 3 MiB.
|
||||
require.Less(t, len(repo.UnifiedDiff), 4*1024*1024,
|
||||
"placeholder diff should be much smaller than maxTotalDiffSize")
|
||||
}
|
||||
@@ -1083,10 +1084,9 @@ func TestScanBinaryFileDiff(t *testing.T) {
|
||||
// The binary file should appear in the unified diff.
|
||||
require.Contains(t, repo.UnifiedDiff, "image.png")
|
||||
|
||||
// The unified diff should contain the go-git binary marker,
|
||||
// The unified diff should contain the git binary marker,
|
||||
// not the raw binary content.
|
||||
require.Contains(t, repo.UnifiedDiff, "Binary files")
|
||||
require.Contains(t, repo.UnifiedDiff, "image.png")
|
||||
require.Contains(t, repo.UnifiedDiff, "Binary")
|
||||
require.NotContains(t, repo.UnifiedDiff, "\x00",
|
||||
"raw binary content should not appear in diff")
|
||||
}
|
||||
@@ -1095,25 +1095,17 @@ func TestScanBinaryFileModifiedDiff(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
repo, err := git.PlainInit(dir, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
gitCmd(t, dir, "init")
|
||||
gitCmd(t, dir, "config", "user.name", "Test")
|
||||
gitCmd(t, dir, "config", "user.email", "test@test.com")
|
||||
|
||||
// Commit a binary file.
|
||||
binPath := filepath.Join(dir, "data.bin")
|
||||
require.NoError(t, os.WriteFile(binPath, []byte("v1\x00\x01\x02"), 0o600))
|
||||
|
||||
wt, err := repo.Worktree()
|
||||
require.NoError(t, err)
|
||||
_, err = wt.Add("data.bin")
|
||||
require.NoError(t, err)
|
||||
_, err = wt.Commit("add binary", &git.CommitOptions{
|
||||
Author: &object.Signature{
|
||||
Name: "Test",
|
||||
Email: "test@test.com",
|
||||
When: time.Now(),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
gitCmd(t, dir, "add", "data.bin")
|
||||
gitCmd(t, dir, "commit", "-m", "add binary")
|
||||
|
||||
// Modify the binary file in the worktree.
|
||||
require.NoError(t, os.WriteFile(binPath, []byte("v2\x00\x03\x04\x05"), 0o600))
|
||||
@@ -1133,12 +1125,45 @@ func TestScanBinaryFileModifiedDiff(t *testing.T) {
|
||||
require.Contains(t, repoChanges.UnifiedDiff, "data.bin")
|
||||
|
||||
// Diff should show binary marker for modification too.
|
||||
require.Contains(t, repoChanges.UnifiedDiff, "Binary files")
|
||||
require.Contains(t, repoChanges.UnifiedDiff, "data.bin")
|
||||
require.Contains(t, repoChanges.UnifiedDiff, "Binary")
|
||||
require.NotContains(t, repoChanges.UnifiedDiff, "\x00",
|
||||
"raw binary content should not appear in diff")
|
||||
}
|
||||
|
||||
func TestScanFileDiffTooLargeForWire(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
repoDir := initTestRepo(t)
|
||||
logger := slogtest.Make(t, nil)
|
||||
|
||||
h := agentgit.NewHandler(logger)
|
||||
|
||||
// Create a single file whose diff is large. With git CLI, the
|
||||
// diff is produced by git itself so per-file size limiting is
|
||||
// handled by the total diff size check.
|
||||
content := make([]byte, 512*1024)
|
||||
for i := range content {
|
||||
content[i] = byte('A' + (i % 26))
|
||||
}
|
||||
bigFile := filepath.Join(repoDir, "big_diff.txt")
|
||||
require.NoError(t, os.WriteFile(bigFile, content, 0o600))
|
||||
|
||||
h.Subscribe([]string{bigFile})
|
||||
|
||||
ctx := context.Background()
|
||||
msg := h.Scan(ctx)
|
||||
require.NotNil(t, msg)
|
||||
require.Len(t, msg.Repositories, 1)
|
||||
|
||||
repo := msg.Repositories[0]
|
||||
|
||||
// The file should appear in the diff output.
|
||||
require.Contains(t, repo.UnifiedDiff, "big_diff.txt")
|
||||
|
||||
// Branch metadata should still be present.
|
||||
require.NotEmpty(t, repo.Branch)
|
||||
}
|
||||
|
||||
func TestWebSocketLargePathStoreSubscription(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -85,15 +85,21 @@ func (a *API) handleWatch(rw http.ResponseWriter, r *http.Request) {
|
||||
if chatIDStr != "" && a.pathStore != nil {
|
||||
chatID, parseErr := uuid.Parse(chatIDStr)
|
||||
if parseErr == nil {
|
||||
// Subscribe to future path updates BEFORE reading
|
||||
// existing paths. This ordering guarantees no
|
||||
// notification from AddPaths is lost: any call that
|
||||
// lands before Subscribe is picked up by GetPaths
|
||||
// below, and any call after Subscribe delivers a
|
||||
// notification on the channel.
|
||||
notifyCh, unsubscribe := a.pathStore.Subscribe(chatID)
|
||||
defer unsubscribe()
|
||||
|
||||
// Load any paths that are already tracked for this chat.
|
||||
existingPaths := a.pathStore.GetPaths(chatID)
|
||||
if len(existingPaths) > 0 {
|
||||
handler.Subscribe(existingPaths)
|
||||
handler.RequestScan()
|
||||
}
|
||||
// Subscribe to future path updates.
|
||||
notifyCh, unsubscribe := a.pathStore.Subscribe(chatID)
|
||||
defer unsubscribe()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
|
||||
@@ -110,6 +110,11 @@ type Config struct {
|
||||
// X11DisplayOffset is the offset to add to the X11 display number.
|
||||
// Default is 10.
|
||||
X11DisplayOffset *int
|
||||
// X11MaxPort overrides the highest port used for X11 forwarding
|
||||
// listeners. Defaults to X11MaxPort (6200). Useful in tests
|
||||
// to shrink the port range and reduce the number of sessions
|
||||
// required.
|
||||
X11MaxPort *int
|
||||
// BlockFileTransfer restricts use of file transfer applications.
|
||||
BlockFileTransfer bool
|
||||
// ReportConnection.
|
||||
@@ -158,6 +163,10 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
offset := X11DefaultDisplayOffset
|
||||
config.X11DisplayOffset = &offset
|
||||
}
|
||||
if config.X11MaxPort == nil {
|
||||
maxPort := X11MaxPort
|
||||
config.X11MaxPort = &maxPort
|
||||
}
|
||||
if config.UpdateEnv == nil {
|
||||
config.UpdateEnv = func(current []string) ([]string, error) { return current, nil }
|
||||
}
|
||||
@@ -201,6 +210,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
x11HandlerErrors: metrics.x11HandlerErrors,
|
||||
fs: fs,
|
||||
displayOffset: *config.X11DisplayOffset,
|
||||
maxPort: *config.X11MaxPort,
|
||||
sessions: make(map[*x11Session]struct{}),
|
||||
connections: make(map[net.Conn]struct{}),
|
||||
network: func() X11Network {
|
||||
|
||||
@@ -57,6 +57,7 @@ type x11Forwarder struct {
|
||||
x11HandlerErrors *prometheus.CounterVec
|
||||
fs afero.Fs
|
||||
displayOffset int
|
||||
maxPort int
|
||||
|
||||
// network creates X11 listener sockets. Defaults to osNet{}.
|
||||
network X11Network
|
||||
@@ -314,7 +315,7 @@ func (x *x11Forwarder) evictLeastRecentlyUsedSession() {
|
||||
// the next available port starting from X11StartPort and displayOffset.
|
||||
func (x *x11Forwarder) createX11Listener(ctx context.Context) (ln net.Listener, display int, err error) {
|
||||
// Look for an open port to listen on.
|
||||
for port := X11StartPort + x.displayOffset; port <= X11MaxPort; port++ {
|
||||
for port := X11StartPort + x.displayOffset; port <= x.maxPort; port++ {
|
||||
if ctx.Err() != nil {
|
||||
return nil, -1, ctx.Err()
|
||||
}
|
||||
|
||||
@@ -142,8 +142,13 @@ func TestServer_X11_EvictionLRU(t *testing.T) {
|
||||
// Use in-process networking for X11 forwarding.
|
||||
inproc := testutil.NewInProcNet()
|
||||
|
||||
// Limit port range so we only need a handful of sessions to fill it
|
||||
// (the default 190 ports may easily timeout or conflict with other
|
||||
// ports on the system).
|
||||
maxPort := agentssh.X11StartPort + agentssh.X11DefaultDisplayOffset + 5
|
||||
cfg := &agentssh.Config{
|
||||
X11Net: inproc,
|
||||
X11Net: inproc,
|
||||
X11MaxPort: &maxPort,
|
||||
}
|
||||
|
||||
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, cfg)
|
||||
@@ -172,7 +177,7 @@ func TestServer_X11_EvictionLRU(t *testing.T) {
|
||||
// configured port range.
|
||||
|
||||
startPort := agentssh.X11StartPort + agentssh.X11DefaultDisplayOffset
|
||||
maxSessions := agentssh.X11MaxPort - startPort + 1 - 1 // -1 for the blocked port
|
||||
maxSessions := maxPort - startPort + 1 - 1 // -1 for the blocked port
|
||||
require.Greater(t, maxSessions, 0, "expected a positive maxSessions value")
|
||||
|
||||
// shellSession holds references to the session and its standard streams so
|
||||
|
||||
@@ -42,9 +42,20 @@ func WithLogger(logger slog.Logger) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// WithDone sets a channel that, when closed, stops the reaper
|
||||
// goroutine. Callers that invoke ForkReap more than once in the
|
||||
// same process (e.g. tests) should use this to prevent goroutine
|
||||
// accumulation.
|
||||
func WithDone(ch chan struct{}) Option {
|
||||
return func(o *options) {
|
||||
o.Done = ch
|
||||
}
|
||||
}
|
||||
|
||||
type options struct {
|
||||
ExecArgs []string
|
||||
PIDs reap.PidCh
|
||||
CatchSignals []os.Signal
|
||||
Logger slog.Logger
|
||||
Done chan struct{}
|
||||
}
|
||||
|
||||
@@ -18,6 +18,15 @@ import (
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// withDone returns an option that stops the reaper goroutine when t
|
||||
// completes, preventing goroutine accumulation across subtests.
|
||||
func withDone(t *testing.T) reaper.Option {
|
||||
t.Helper()
|
||||
done := make(chan struct{})
|
||||
t.Cleanup(func() { close(done) })
|
||||
return reaper.WithDone(done)
|
||||
}
|
||||
|
||||
// TestReap checks that's the reaper is successfully reaping
|
||||
// exited processes and passing the PIDs through the shared
|
||||
// channel.
|
||||
@@ -36,6 +45,7 @@ func TestReap(t *testing.T) {
|
||||
reaper.WithPIDCallback(pids),
|
||||
// Provide some argument that immediately exits.
|
||||
reaper.WithExecArgs("/bin/sh", "-c", "exit 0"),
|
||||
withDone(t),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, exitCode)
|
||||
@@ -89,6 +99,7 @@ func TestForkReapExitCodes(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
exitCode, err := reaper.ForkReap(
|
||||
reaper.WithExecArgs("/bin/sh", "-c", tt.command),
|
||||
withDone(t),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expectedCode, exitCode, "exit code mismatch for %q", tt.command)
|
||||
@@ -118,6 +129,7 @@ func TestReapInterrupt(t *testing.T) {
|
||||
exitCode, err := reaper.ForkReap(
|
||||
reaper.WithPIDCallback(pids),
|
||||
reaper.WithCatchSignals(os.Interrupt),
|
||||
withDone(t),
|
||||
// Signal propagation does not extend to children of children, so
|
||||
// we create a little bash script to ensure sleep is interrupted.
|
||||
reaper.WithExecArgs("/bin/sh", "-c", fmt.Sprintf("pid=0; trap 'kill -USR2 %d; kill -TERM $pid' INT; sleep 10 &\npid=$!; kill -USR1 %d; wait", os.Getpid(), os.Getpid())),
|
||||
|
||||
@@ -64,7 +64,7 @@ func ForkReap(opt ...Option) (int, error) {
|
||||
o(opts)
|
||||
}
|
||||
|
||||
go reap.ReapChildren(opts.PIDs, nil, nil, nil)
|
||||
go reap.ReapChildren(opts.PIDs, nil, opts.Done, nil)
|
||||
|
||||
pwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
|
||||
+12
-12
@@ -41,11 +41,11 @@ func Test_TaskLogs_Golden(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
|
||||
inv, root := clitest.New(t, "task", "logs", task.Name, "--output", "json")
|
||||
inv, root := clitest.New(t, "task", "logs", setup.task.Name, "--output", "json")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
@@ -64,11 +64,11 @@ func Test_TaskLogs_Golden(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
|
||||
inv, root := clitest.New(t, "task", "logs", task.ID.String(), "--output", "json")
|
||||
inv, root := clitest.New(t, "task", "logs", setup.task.ID.String(), "--output", "json")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
@@ -87,11 +87,11 @@ func Test_TaskLogs_Golden(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
|
||||
inv, root := clitest.New(t, "task", "logs", task.ID.String())
|
||||
inv, root := clitest.New(t, "task", "logs", setup.task.ID.String())
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
@@ -141,10 +141,10 @@ func Test_TaskLogs_Golden(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsErr(assert.AnError))
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsErr(assert.AnError))
|
||||
|
||||
inv, root := clitest.New(t, "task", "logs", task.ID.String())
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
inv, root := clitest.New(t, "task", "logs", setup.task.ID.String())
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
|
||||
+22
-26
@@ -7,7 +7,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
@@ -21,12 +20,12 @@ func TestExpTaskPause(t *testing.T) {
|
||||
|
||||
// Given: A running task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
// When: We attempt to pause the task
|
||||
inv, root := clitest.New(t, "task", "pause", task.Name, "--yes")
|
||||
inv, root := clitest.New(t, "task", "pause", setup.task.Name, "--yes")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
// Then: Expect the task to be paused
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -34,7 +33,7 @@ func TestExpTaskPause(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, output.Stdout(), "has been paused")
|
||||
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
|
||||
})
|
||||
@@ -46,13 +45,13 @@ func TestExpTaskPause(t *testing.T) {
|
||||
|
||||
// Given: A different user's running task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
adminClient, _, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
// When: We attempt to pause their task
|
||||
identifier := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
|
||||
identifier := fmt.Sprintf("%s/%s", setup.task.OwnerName, setup.task.Name)
|
||||
inv, root := clitest.New(t, "task", "pause", identifier, "--yes")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, adminClient, root)
|
||||
clitest.SetupConfig(t, setup.ownerClient, root)
|
||||
|
||||
// Then: We expect the task to be paused
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -60,7 +59,7 @@ func TestExpTaskPause(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, output.Stdout(), "has been paused")
|
||||
|
||||
updated, err := adminClient.TaskByIdentifier(ctx, identifier)
|
||||
updated, err := setup.ownerClient.TaskByIdentifier(ctx, identifier)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
|
||||
})
|
||||
@@ -70,11 +69,11 @@ func TestExpTaskPause(t *testing.T) {
|
||||
|
||||
// Given: A running task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
// When: We attempt to pause the task
|
||||
inv, root := clitest.New(t, "task", "pause", task.Name)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
inv, root := clitest.New(t, "task", "pause", setup.task.Name)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
// And: We confirm we want to pause the task
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -88,7 +87,7 @@ func TestExpTaskPause(t *testing.T) {
|
||||
pty.ExpectMatchContext(ctx, "has been paused")
|
||||
require.NoError(t, w.Wait())
|
||||
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
|
||||
})
|
||||
@@ -98,11 +97,11 @@ func TestExpTaskPause(t *testing.T) {
|
||||
|
||||
// Given: A running task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
// When: We attempt to pause the task
|
||||
inv, root := clitest.New(t, "task", "pause", task.Name)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
inv, root := clitest.New(t, "task", "pause", setup.task.Name)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
// But: We say no at the confirmation screen
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -114,7 +113,7 @@ func TestExpTaskPause(t *testing.T) {
|
||||
require.Error(t, w.Wait())
|
||||
|
||||
// Then: We expect the task to not be paused
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, codersdk.TaskStatusPaused, updated.Status)
|
||||
})
|
||||
@@ -124,21 +123,18 @@ func TestExpTaskPause(t *testing.T) {
|
||||
|
||||
// Given: A running task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
// And: We paused the running task
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
resp, err := userClient.PauseTask(ctx, task.OwnerName, task.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.WorkspaceBuild)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, resp.WorkspaceBuild.ID)
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
|
||||
// When: We attempt to pause the task again
|
||||
inv, root := clitest.New(t, "task", "pause", task.Name, "--yes")
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
inv, root := clitest.New(t, "task", "pause", setup.task.Name, "--yes")
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
// Then: We expect to get an error that the task is already paused
|
||||
err = inv.WithContext(ctx).Run()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.ErrorContains(t, err, "is already paused")
|
||||
})
|
||||
}
|
||||
|
||||
+31
-43
@@ -1,7 +1,6 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
@@ -17,29 +16,18 @@ import (
|
||||
func TestExpTaskResume(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// pauseTask is a helper that pauses a task and waits for the stop
|
||||
// build to complete.
|
||||
pauseTask := func(ctx context.Context, t *testing.T, client *codersdk.Client, task codersdk.Task) {
|
||||
t.Helper()
|
||||
|
||||
pauseResp, err := client.PauseTask(ctx, task.OwnerName, task.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pauseResp.WorkspaceBuild)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, pauseResp.WorkspaceBuild.ID)
|
||||
}
|
||||
|
||||
t.Run("WithYesFlag", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: A paused task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, userClient, task)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
|
||||
// When: We attempt to resume the task
|
||||
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes")
|
||||
inv, root := clitest.New(t, "task", "resume", setup.task.Name, "--yes")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
// Then: We expect the task to be resumed
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -47,7 +35,7 @@ func TestExpTaskResume(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, output.Stdout(), "has been resumed")
|
||||
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
|
||||
})
|
||||
@@ -59,14 +47,14 @@ func TestExpTaskResume(t *testing.T) {
|
||||
|
||||
// Given: A different user's paused task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
adminClient, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, userClient, task)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
|
||||
// When: We attempt to resume their task
|
||||
identifier := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
|
||||
identifier := fmt.Sprintf("%s/%s", setup.task.OwnerName, setup.task.Name)
|
||||
inv, root := clitest.New(t, "task", "resume", identifier, "--yes")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, adminClient, root)
|
||||
clitest.SetupConfig(t, setup.ownerClient, root)
|
||||
|
||||
// Then: We expect the task to be resumed
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -74,7 +62,7 @@ func TestExpTaskResume(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, output.Stdout(), "has been resumed")
|
||||
|
||||
updated, err := adminClient.TaskByIdentifier(ctx, identifier)
|
||||
updated, err := setup.ownerClient.TaskByIdentifier(ctx, identifier)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
|
||||
})
|
||||
@@ -84,13 +72,13 @@ func TestExpTaskResume(t *testing.T) {
|
||||
|
||||
// Given: A paused task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, userClient, task)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
|
||||
// When: We attempt to resume the task (and specify no wait)
|
||||
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes", "--no-wait")
|
||||
inv, root := clitest.New(t, "task", "resume", setup.task.Name, "--yes", "--no-wait")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
// Then: We expect the task to be resumed in the background
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -99,11 +87,11 @@ func TestExpTaskResume(t *testing.T) {
|
||||
require.Contains(t, output.Stdout(), "in the background")
|
||||
|
||||
// And: The task to eventually be resumed
|
||||
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
|
||||
ws := coderdtest.MustWorkspace(t, userClient, task.WorkspaceID.UUID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, ws.LatestBuild.ID)
|
||||
require.True(t, setup.task.WorkspaceID.Valid, "task should have a workspace ID")
|
||||
ws := coderdtest.MustWorkspace(t, setup.userClient, setup.task.WorkspaceID.UUID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, setup.userClient, ws.LatestBuild.ID)
|
||||
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
|
||||
})
|
||||
@@ -113,12 +101,12 @@ func TestExpTaskResume(t *testing.T) {
|
||||
|
||||
// Given: A paused task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, userClient, task)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
|
||||
// When: We attempt to resume the task
|
||||
inv, root := clitest.New(t, "task", "resume", task.Name)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
inv, root := clitest.New(t, "task", "resume", setup.task.Name)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
// And: We confirm we want to resume the task
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -132,7 +120,7 @@ func TestExpTaskResume(t *testing.T) {
|
||||
pty.ExpectMatchContext(ctx, "has been resumed")
|
||||
require.NoError(t, w.Wait())
|
||||
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
|
||||
})
|
||||
@@ -142,12 +130,12 @@ func TestExpTaskResume(t *testing.T) {
|
||||
|
||||
// Given: A paused task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, userClient, task)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
|
||||
// When: We attempt to resume the task
|
||||
inv, root := clitest.New(t, "task", "resume", task.Name)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
inv, root := clitest.New(t, "task", "resume", setup.task.Name)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
// But: Say no at the confirmation screen
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -159,7 +147,7 @@ func TestExpTaskResume(t *testing.T) {
|
||||
require.Error(t, w.Wait())
|
||||
|
||||
// Then: We expect the task to still be paused
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
|
||||
})
|
||||
@@ -169,11 +157,11 @@ func TestExpTaskResume(t *testing.T) {
|
||||
|
||||
// Given: A running task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
// When: We attempt to resume the task that is not paused
|
||||
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes")
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
inv, root := clitest.New(t, "task", "resume", setup.task.Name, "--yes")
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
// Then: We expect to get an error that the task is not paused
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
+154
-9
@@ -1,10 +1,15 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
@@ -15,13 +20,15 @@ func (r *RootCmd) taskSend() *serpent.Command {
|
||||
cmd := &serpent.Command{
|
||||
Use: "send <task> [<input> | --stdin]",
|
||||
Short: "Send input to a task",
|
||||
Long: FormatExamples(Example{
|
||||
Description: "Send direct input to a task.",
|
||||
Command: "coder task send task1 \"Please also add unit tests\"",
|
||||
}, Example{
|
||||
Description: "Send input from stdin to a task.",
|
||||
Command: "echo \"Please also add unit tests\" | coder task send task1 --stdin",
|
||||
}),
|
||||
Long: `Send input to a task. If the task is paused, it will be automatically resumed before input is sent. If the task is initializing, it will wait for the task to become ready.
|
||||
` +
|
||||
FormatExamples(Example{
|
||||
Description: "Send direct input to a task",
|
||||
Command: `coder task send task1 "Please also add unit tests"`,
|
||||
}, Example{
|
||||
Description: "Send input from stdin to a task",
|
||||
Command: `echo "Please also add unit tests" | coder task send task1 --stdin`,
|
||||
}),
|
||||
Middleware: serpent.RequireRangeArgs(1, 2),
|
||||
Options: serpent.OptionSet{
|
||||
{
|
||||
@@ -64,8 +71,48 @@ func (r *RootCmd) taskSend() *serpent.Command {
|
||||
return xerrors.Errorf("resolve task: %w", err)
|
||||
}
|
||||
|
||||
if err = client.TaskSend(ctx, codersdk.Me, task.ID, codersdk.TaskSendRequest{Input: taskInput}); err != nil {
|
||||
return xerrors.Errorf("send input to task: %w", err)
|
||||
display := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
|
||||
|
||||
// Before attempting to send, check the task status and
|
||||
// handle non-active states.
|
||||
var workspaceBuildID uuid.UUID
|
||||
|
||||
switch task.Status {
|
||||
case codersdk.TaskStatusActive:
|
||||
// Already active, no build to watch.
|
||||
|
||||
case codersdk.TaskStatusPaused:
|
||||
resp, err := client.ResumeTask(ctx, task.OwnerName, task.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resume task %q: %w", display, err)
|
||||
} else if resp.WorkspaceBuild == nil {
|
||||
return xerrors.Errorf("resume task %q", display)
|
||||
}
|
||||
|
||||
workspaceBuildID = resp.WorkspaceBuild.ID
|
||||
|
||||
case codersdk.TaskStatusInitializing:
|
||||
if !task.WorkspaceID.Valid {
|
||||
return xerrors.Errorf("send input to task %q: task has no backing workspace", display)
|
||||
}
|
||||
|
||||
workspace, err := client.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get workspace for task %q: %w", display, err)
|
||||
}
|
||||
|
||||
workspaceBuildID = workspace.LatestBuild.ID
|
||||
|
||||
default:
|
||||
return xerrors.Errorf("task %q has status %s and cannot be sent input", display, task.Status)
|
||||
}
|
||||
|
||||
if err := waitForTaskIdle(ctx, inv, client, task, workspaceBuildID); err != nil {
|
||||
return xerrors.Errorf("wait for task %q to be idle: %w", display, err)
|
||||
}
|
||||
|
||||
if err := client.TaskSend(ctx, codersdk.Me, task.ID, codersdk.TaskSendRequest{Input: taskInput}); err != nil {
|
||||
return xerrors.Errorf("send input to task %q: %w", display, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -74,3 +121,101 @@ func (r *RootCmd) taskSend() *serpent.Command {
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// waitForTaskIdle optionally watches a workspace build to completion,
|
||||
// then polls until the task becomes active and its app state is idle.
|
||||
// This merges build-watching and idle-polling into a single loop so
|
||||
// that status changes (e.g. paused) are never missed between phases.
|
||||
func waitForTaskIdle(ctx context.Context, inv *serpent.Invocation, client *codersdk.Client, task codersdk.Task, workspaceBuildID uuid.UUID) error {
|
||||
if workspaceBuildID != uuid.Nil {
|
||||
if err := cliui.WorkspaceBuild(ctx, inv.Stdout, client, workspaceBuildID); err != nil {
|
||||
return xerrors.Errorf("watch workspace build: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
cliui.Infof(inv.Stdout, "Waiting for task to become idle...")
|
||||
|
||||
// NOTE(DanielleMaywood):
|
||||
// It has been observed that the `TaskStatusError` state has
|
||||
// appeared during a typical healthy startup [^0]. To combat
|
||||
// this, we allow a 5 minute grace period where we allow
|
||||
// `TaskStatusError` to surface without immediately failing.
|
||||
//
|
||||
// TODO(DanielleMaywood):
|
||||
// Remove this grace period once the upstream agentapi health
|
||||
// check no longer reports transient error states during normal
|
||||
// startup.
|
||||
//
|
||||
// [0]: https://github.com/coder/coder/pull/22203#discussion_r2858002569
|
||||
const errorGracePeriod = 5 * time.Minute
|
||||
gracePeriodDeadline := time.Now().Add(errorGracePeriod)
|
||||
|
||||
// NOTE(DanielleMaywood):
|
||||
// On resume the MCP may not report an initial app status,
|
||||
// leaving CurrentState nil indefinitely. To avoid hanging
|
||||
// forever we treat Active with nil CurrentState as idle
|
||||
// after a grace period, giving the MCP time to report
|
||||
// during normal startup.
|
||||
const nilStateGracePeriod = 30 * time.Second
|
||||
var nilStateDeadline time.Time
|
||||
|
||||
// TODO(DanielleMaywood):
|
||||
// When we have a streaming Task API, this should be converted
|
||||
// away from polling.
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
task, err := client.TaskByID(ctx, task.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get task by id: %w", err)
|
||||
}
|
||||
|
||||
switch task.Status {
|
||||
case codersdk.TaskStatusInitializing,
|
||||
codersdk.TaskStatusPending:
|
||||
// Not yet active, keep polling.
|
||||
continue
|
||||
case codersdk.TaskStatusActive:
|
||||
// Task is active; check app state.
|
||||
if task.CurrentState == nil {
|
||||
// The MCP may not have reported state yet.
|
||||
// Start a grace period on first observation
|
||||
// and treat as idle once it expires.
|
||||
if nilStateDeadline.IsZero() {
|
||||
nilStateDeadline = time.Now().Add(nilStateGracePeriod)
|
||||
}
|
||||
if time.Now().After(nilStateDeadline) {
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Reset nil-state deadline since we got a real
|
||||
// state report.
|
||||
nilStateDeadline = time.Time{}
|
||||
switch task.CurrentState.State {
|
||||
case codersdk.TaskStateIdle,
|
||||
codersdk.TaskStateComplete,
|
||||
codersdk.TaskStateFailed:
|
||||
return nil
|
||||
default:
|
||||
// Still working, keep polling.
|
||||
continue
|
||||
}
|
||||
case codersdk.TaskStatusError:
|
||||
if time.Now().After(gracePeriodDeadline) {
|
||||
return xerrors.Errorf("task entered %s state while waiting for it to become idle", task.Status)
|
||||
}
|
||||
case codersdk.TaskStatusPaused:
|
||||
return xerrors.Errorf("task was paused while waiting for it to become idle")
|
||||
case codersdk.TaskStatusUnknown:
|
||||
return xerrors.Errorf("task entered %s state while waiting for it to become idle", task.Status)
|
||||
default:
|
||||
return xerrors.Errorf("task entered unexpected state (%s) while waiting for it to become idle", task.Status)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+224
-13
@@ -12,9 +12,14 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
agentapisdk "github.com/coder/agentapi-sdk-go"
|
||||
"github.com/coder/coder/v2/agent"
|
||||
"github.com/coder/coder/v2/agent/agenttest"
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
@@ -25,12 +30,12 @@ func Test_TaskSend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "task", "send", task.Name, "carry on with the task")
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.Name, "carry on with the task")
|
||||
inv.Stdout = &stdout
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
@@ -41,12 +46,12 @@ func Test_TaskSend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "task", "send", task.ID.String(), "carry on with the task")
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.ID.String(), "carry on with the task")
|
||||
inv.Stdout = &stdout
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
@@ -57,13 +62,13 @@ func Test_TaskSend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "task", "send", task.Name, "--stdin")
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.Name, "--stdin")
|
||||
inv.Stdout = &stdout
|
||||
inv.Stdin = strings.NewReader("carry on with the task")
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
@@ -110,17 +115,223 @@ func Test_TaskSend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendErr(t, assert.AnError))
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendErr(assert.AnError))
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "task", "send", task.Name, "some task input")
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.Name, "some task input")
|
||||
inv.Stdout = &stdout
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.ErrorContains(t, err, assert.AnError.Error())
|
||||
})
|
||||
|
||||
t.Run("WaitsForInitializingTask", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "some task input", "some task response"))
|
||||
|
||||
// Close the first agent, pause, then resume the task so the
|
||||
// workspace is started but no agent is connected.
|
||||
// This puts the task in "initializing" state.
|
||||
require.NoError(t, setup.agent.Close())
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
resumeTask(setupCtx, t, setup.userClient, setup.task)
|
||||
|
||||
// When: We attempt to send input to the initializing task.
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.Name, "some task input")
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
inv = inv.WithContext(ctx)
|
||||
|
||||
// Use a pty so we can wait for the command to produce build
|
||||
// output, confirming it has entered the initializing code
|
||||
// path before we connect the agent.
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
w := clitest.StartWithWaiter(t, inv)
|
||||
|
||||
// Wait for the command to observe the initializing state and
|
||||
// start watching the workspace build. This ensures the command
|
||||
// has entered the waiting code path.
|
||||
pty.ExpectMatchContext(ctx, "Queued")
|
||||
|
||||
// Connect a new agent so the task can transition to active.
|
||||
agentClient := agentsdk.New(setup.userClient.URL, agentsdk.WithFixedToken(setup.agentToken))
|
||||
setup.agent = agenttest.New(t, setup.userClient.URL, setup.agentToken, func(o *agent.Options) {
|
||||
o.Client = agentClient
|
||||
})
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, setup.userClient, setup.task.WorkspaceID.UUID).
|
||||
WaitFor(coderdtest.AgentsReady)
|
||||
|
||||
// Report the task app as idle so waitForTaskIdle can proceed.
|
||||
require.NoError(t, agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
|
||||
AppSlug: "task-sidebar",
|
||||
State: codersdk.WorkspaceAppStatusStateIdle,
|
||||
Message: "ready",
|
||||
}))
|
||||
|
||||
// Then: The command should complete successfully.
|
||||
require.NoError(t, w.Wait())
|
||||
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusActive, updated.Status)
|
||||
})
|
||||
|
||||
t.Run("ResumesPausedTask", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "some task input", "some task response"))
|
||||
|
||||
// Close the first agent before pausing so it does not conflict
|
||||
// with the agent we reconnect after the workspace is resumed.
|
||||
require.NoError(t, setup.agent.Close())
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
|
||||
// When: We attempt to send input to the paused task.
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.Name, "some task input")
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
inv = inv.WithContext(ctx)
|
||||
|
||||
// Use a pty so we can wait for the command to produce build
|
||||
// output, confirming it has entered the paused code path and
|
||||
// triggered a resume before we connect the agent.
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
w := clitest.StartWithWaiter(t, inv)
|
||||
|
||||
// Wait for the command to observe the paused state, trigger
|
||||
// a resume, and start watching the workspace build.
|
||||
pty.ExpectMatchContext(ctx, "Queued")
|
||||
|
||||
// Connect a new agent so the task can transition to active.
|
||||
agentClient := agentsdk.New(setup.userClient.URL, agentsdk.WithFixedToken(setup.agentToken))
|
||||
setup.agent = agenttest.New(t, setup.userClient.URL, setup.agentToken, func(o *agent.Options) {
|
||||
o.Client = agentClient
|
||||
})
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, setup.userClient, setup.task.WorkspaceID.UUID).
|
||||
WaitFor(coderdtest.AgentsReady)
|
||||
|
||||
// Report the task app as idle so waitForTaskIdle can proceed.
|
||||
require.NoError(t, agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
|
||||
AppSlug: "task-sidebar",
|
||||
State: codersdk.WorkspaceAppStatusStateIdle,
|
||||
Message: "ready",
|
||||
}))
|
||||
|
||||
// Then: The command should complete successfully.
|
||||
require.NoError(t, w.Wait())
|
||||
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusActive, updated.Status)
|
||||
})
|
||||
|
||||
t.Run("PausedDuringWaitForReady", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: An initializing task (workspace running, no agent
|
||||
// connected).
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
require.NoError(t, setup.agent.Close())
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
resumeTask(setupCtx, t, setup.userClient, setup.task)
|
||||
|
||||
// When: We attempt to send input to the initializing task.
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.Name, "some task input")
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
inv = inv.WithContext(ctx)
|
||||
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
w := clitest.StartWithWaiter(t, inv)
|
||||
|
||||
// Wait for the command to enter the build-watching phase
|
||||
// of waitForTaskReady.
|
||||
pty.ExpectMatchContext(ctx, "Queued")
|
||||
|
||||
// Pause the task while waitForTaskReady is polling. Since
|
||||
// no agent is connected, the task stays initializing until
|
||||
// we pause it, at which point the status becomes paused.
|
||||
pauseTask(ctx, t, setup.userClient, setup.task)
|
||||
|
||||
// Then: The command should fail because the task was paused.
|
||||
err := w.Wait()
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "was paused while waiting for it to become idle")
|
||||
})
|
||||
|
||||
t.Run("WaitsForWorkingAppState", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: An active task whose app is in "working" state.
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "some task input", "some task response"))
|
||||
|
||||
// Move the app into "working" state before running the command.
|
||||
agentClient := agentsdk.New(setup.userClient.URL, agentsdk.WithFixedToken(setup.agentToken))
|
||||
require.NoError(t, agentClient.PatchAppStatus(setupCtx, agentsdk.PatchAppStatus{
|
||||
AppSlug: "task-sidebar",
|
||||
State: codersdk.WorkspaceAppStatusStateWorking,
|
||||
Message: "busy",
|
||||
}))
|
||||
|
||||
// When: We send input while the app is working.
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.Name, "some task input")
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
inv = inv.WithContext(ctx)
|
||||
w := clitest.StartWithWaiter(t, inv)
|
||||
|
||||
// Transition the app back to idle so waitForTaskIdle proceeds.
|
||||
require.NoError(t, agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
|
||||
AppSlug: "task-sidebar",
|
||||
State: codersdk.WorkspaceAppStatusStateIdle,
|
||||
Message: "ready",
|
||||
}))
|
||||
|
||||
// Then: The command should complete successfully.
|
||||
require.NoError(t, w.Wait())
|
||||
})
|
||||
|
||||
t.Run("SendToNonIdleAppState", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, appState := range []codersdk.WorkspaceAppStatusState{
|
||||
codersdk.WorkspaceAppStatusStateComplete,
|
||||
codersdk.WorkspaceAppStatusStateFailure,
|
||||
} {
|
||||
t.Run(string(appState), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "some input", "some response"))
|
||||
|
||||
agentClient := agentsdk.New(setup.userClient.URL, agentsdk.WithFixedToken(setup.agentToken))
|
||||
require.NoError(t, agentClient.PatchAppStatus(setupCtx, agentsdk.PatchAppStatus{
|
||||
AppSlug: "task-sidebar",
|
||||
State: appState,
|
||||
Message: "done",
|
||||
}))
|
||||
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.Name, "some input")
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func fakeAgentAPITaskSendOK(t *testing.T, expectMessage, returnMessage string) map[string]http.HandlerFunc {
|
||||
@@ -151,7 +362,7 @@ func fakeAgentAPITaskSendOK(t *testing.T, expectMessage, returnMessage string) m
|
||||
}
|
||||
}
|
||||
|
||||
func fakeAgentAPITaskSendErr(t *testing.T, returnErr error) map[string]http.HandlerFunc {
|
||||
func fakeAgentAPITaskSendErr(returnErr error) map[string]http.HandlerFunc {
|
||||
return map[string]http.HandlerFunc{
|
||||
"/status": func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
+56
-5
@@ -88,6 +88,13 @@ func Test_Tasks(t *testing.T) {
|
||||
o.Client = agentClient
|
||||
})
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, userClient, tasks[0].WorkspaceID.UUID).WithContext(ctx).WaitFor(coderdtest.AgentsReady)
|
||||
// Report the task app as idle so that waitForTaskIdle
|
||||
// can proceed during the "send task message" step.
|
||||
require.NoError(t, agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
|
||||
AppSlug: "task-sidebar",
|
||||
State: codersdk.WorkspaceAppStatusStateIdle,
|
||||
Message: "ready",
|
||||
}))
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -272,10 +279,19 @@ func fakeAgentAPIEcho(ctx context.Context, t testing.TB, initMsg agentapisdk.Mes
|
||||
// setupCLITaskTest creates a test workspace with an AI task template and agent,
|
||||
// with a fake agent API configured with the provided set of handlers.
|
||||
// Returns the user client and workspace.
|
||||
func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[string]http.HandlerFunc) (ownerClient *codersdk.Client, memberClient *codersdk.Client, task codersdk.Task) {
|
||||
// setupCLITaskTestResult holds the return values from setupCLITaskTest.
|
||||
type setupCLITaskTestResult struct {
|
||||
ownerClient *codersdk.Client
|
||||
userClient *codersdk.Client
|
||||
task codersdk.Task
|
||||
agentToken string
|
||||
agent agent.Agent
|
||||
}
|
||||
|
||||
func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[string]http.HandlerFunc) setupCLITaskTestResult {
|
||||
t.Helper()
|
||||
|
||||
ownerClient = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
ownerClient := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
owner := coderdtest.CreateFirstUser(t, ownerClient)
|
||||
userClient, _ := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID)
|
||||
|
||||
@@ -292,21 +308,56 @@ func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[st
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for the task's underlying workspace to be built
|
||||
// Wait for the task's underlying workspace to be built.
|
||||
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
|
||||
workspace, err := userClient.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
require.NoError(t, err)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, workspace.LatestBuild.ID)
|
||||
|
||||
agentClient := agentsdk.New(userClient.URL, agentsdk.WithFixedToken(authToken))
|
||||
_ = agenttest.New(t, userClient.URL, authToken, func(o *agent.Options) {
|
||||
agt := agenttest.New(t, userClient.URL, authToken, func(o *agent.Options) {
|
||||
o.Client = agentClient
|
||||
})
|
||||
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, userClient, workspace.ID).
|
||||
WaitFor(coderdtest.AgentsReady)
|
||||
|
||||
return ownerClient, userClient, task
|
||||
// Report the task app as idle so that waitForTaskIdle can proceed.
|
||||
err = agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
|
||||
AppSlug: "task-sidebar",
|
||||
State: codersdk.WorkspaceAppStatusStateIdle,
|
||||
Message: "ready",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return setupCLITaskTestResult{
|
||||
ownerClient: ownerClient,
|
||||
userClient: userClient,
|
||||
task: task,
|
||||
agentToken: authToken,
|
||||
agent: agt,
|
||||
}
|
||||
}
|
||||
|
||||
// pauseTask pauses the task and waits for the stop build to complete.
|
||||
func pauseTask(ctx context.Context, t *testing.T, client *codersdk.Client, task codersdk.Task) {
|
||||
t.Helper()
|
||||
|
||||
pauseResp, err := client.PauseTask(ctx, task.OwnerName, task.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pauseResp.WorkspaceBuild)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, pauseResp.WorkspaceBuild.ID)
|
||||
}
|
||||
|
||||
// resumeTask resumes the task waits for the start build to complete. The task
|
||||
// will be in "initializing" state after this returns because no agent is connected.
|
||||
func resumeTask(ctx context.Context, t *testing.T, client *codersdk.Client, task codersdk.Task) {
|
||||
t.Helper()
|
||||
|
||||
resumeResp, err := client.ResumeTask(ctx, task.OwnerName, task.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resumeResp.WorkspaceBuild)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, resumeResp.WorkspaceBuild.ID)
|
||||
}
|
||||
|
||||
// setupCLITaskTestWithSnapshot creates a task in the specified status with a log snapshot.
|
||||
|
||||
+5
-2
@@ -5,11 +5,14 @@ USAGE:
|
||||
|
||||
Send input to a task
|
||||
|
||||
- Send direct input to a task.:
|
||||
Send input to a task. If the task is paused, it will be automatically resumed
|
||||
before input is sent. If the task is initializing, it will wait for the task
|
||||
to become ready.
|
||||
- Send direct input to a task:
|
||||
|
||||
$ coder task send task1 "Please also add unit tests"
|
||||
|
||||
- Send input from stdin to a task.:
|
||||
- Send input from stdin to a task:
|
||||
|
||||
$ echo "Please also add unit tests" | coder task send task1 --stdin
|
||||
|
||||
|
||||
@@ -387,9 +387,9 @@ func (b *Batcher) flush(ctx context.Context, reason string) {
|
||||
b.Metrics.BatchSize.Observe(float64(count))
|
||||
b.Metrics.MetadataTotal.Add(float64(count))
|
||||
b.Metrics.BatchesTotal.WithLabelValues(reason).Inc()
|
||||
b.Metrics.FlushDuration.WithLabelValues(reason).Observe(time.Since(start).Seconds())
|
||||
elapsed = b.clock.Since(start)
|
||||
b.Metrics.FlushDuration.WithLabelValues(reason).Observe(elapsed.Seconds())
|
||||
|
||||
elapsed = time.Since(start)
|
||||
b.log.Debug(ctx, "flush complete",
|
||||
slog.F("count", count),
|
||||
slog.F("elapsed", elapsed),
|
||||
|
||||
+13
-1
@@ -315,6 +315,18 @@ func taskFromDBTaskAndWorkspace(dbTask database.Task, ws codersdk.Workspace) cod
|
||||
}
|
||||
}
|
||||
|
||||
// appStatusStateToTaskState converts a WorkspaceAppStatusState to a
|
||||
// TaskState. The two enums mostly share values but "failure" in the
|
||||
// app status maps to "failed" in the public task API.
|
||||
func appStatusStateToTaskState(s codersdk.WorkspaceAppStatusState) codersdk.TaskState {
|
||||
switch s {
|
||||
case codersdk.WorkspaceAppStatusStateFailure:
|
||||
return codersdk.TaskStateFailed
|
||||
default:
|
||||
return codersdk.TaskState(s)
|
||||
}
|
||||
}
|
||||
|
||||
// deriveTaskCurrentState determines the current state of a task based on the
|
||||
// workspace's latest app status and initialization phase.
|
||||
// Returns nil if no valid state can be determined.
|
||||
@@ -334,7 +346,7 @@ func deriveTaskCurrentState(
|
||||
if ws.LatestBuild.Transition != codersdk.WorkspaceTransitionStart || ws.LatestAppStatus.CreatedAt.After(ws.LatestBuild.CreatedAt) {
|
||||
currentState = &codersdk.TaskStateEntry{
|
||||
Timestamp: ws.LatestAppStatus.CreatedAt,
|
||||
State: codersdk.TaskState(ws.LatestAppStatus.State),
|
||||
State: appStatusStateToTaskState(ws.LatestAppStatus.State),
|
||||
Message: ws.LatestAppStatus.Message,
|
||||
URI: ws.LatestAppStatus.URI,
|
||||
}
|
||||
|
||||
Generated
+131
@@ -481,6 +481,128 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/files": {
|
||||
"post": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"consumes": [
|
||||
"application/octet-stream"
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "Upload a chat file",
|
||||
"operationId": "upload-chat-file",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Content-Type must be an image type (image/png, image/jpeg, image/gif, image/webp)",
|
||||
"name": "Content-Type",
|
||||
"in": "header",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Organization ID",
|
||||
"name": "organization",
|
||||
"in": "query",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"201": {
|
||||
"description": "Created",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UploadChatFileResponse"
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"description": "Bad Request",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"description": "Unauthorized",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
},
|
||||
"413": {
|
||||
"description": "Request Entity Too Large",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
},
|
||||
"500": {
|
||||
"description": "Internal Server Error",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/files/{file}": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "Get a chat file",
|
||||
"operationId": "get-chat-file",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "File ID",
|
||||
"name": "file",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK"
|
||||
},
|
||||
"400": {
|
||||
"description": "Bad Request",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"description": "Unauthorized",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
},
|
||||
"404": {
|
||||
"description": "Not Found",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
},
|
||||
"500": {
|
||||
"description": "Internal Server Error",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/{chat}/archive": {
|
||||
"post": {
|
||||
"tags": [
|
||||
@@ -20334,6 +20456,15 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UploadChatFileResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UploadResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
Generated
+123
@@ -410,6 +410,120 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/files": {
|
||||
"post": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"consumes": ["application/octet-stream"],
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Chats"],
|
||||
"summary": "Upload a chat file",
|
||||
"operationId": "upload-chat-file",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Content-Type must be an image type (image/png, image/jpeg, image/gif, image/webp)",
|
||||
"name": "Content-Type",
|
||||
"in": "header",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Organization ID",
|
||||
"name": "organization",
|
||||
"in": "query",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"201": {
|
||||
"description": "Created",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UploadChatFileResponse"
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"description": "Bad Request",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"description": "Unauthorized",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
},
|
||||
"413": {
|
||||
"description": "Request Entity Too Large",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
},
|
||||
"500": {
|
||||
"description": "Internal Server Error",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/files/{file}": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": ["Chats"],
|
||||
"summary": "Get a chat file",
|
||||
"operationId": "get-chat-file",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "File ID",
|
||||
"name": "file",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK"
|
||||
},
|
||||
"400": {
|
||||
"description": "Bad Request",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"description": "Unauthorized",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
},
|
||||
"404": {
|
||||
"description": "Not Found",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
},
|
||||
"500": {
|
||||
"description": "Internal Server Error",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/{chat}/archive": {
|
||||
"post": {
|
||||
"tags": ["Chats"],
|
||||
@@ -18650,6 +18764,15 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UploadChatFileResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UploadResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
+517
-257
File diff suppressed because it is too large
Load Diff
+34
-356
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -17,8 +16,6 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agenttest"
|
||||
@@ -32,8 +29,6 @@ import (
|
||||
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
||||
"github.com/coder/coder/v2/provisioner/echo"
|
||||
proto "github.com/coder/coder/v2/provisionersdk/proto"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
@@ -78,10 +73,11 @@ func TestInterruptChatBroadcastsStatusAcrossInstances(t *testing.T) {
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case event := <-events:
|
||||
if event.Type != codersdk.ChatStreamEventTypeStatus || event.Status == nil {
|
||||
return false
|
||||
if event.Type == codersdk.ChatStreamEventTypeStatus && event.Status != nil {
|
||||
return event.Status.Status == codersdk.ChatStatusWaiting
|
||||
}
|
||||
return event.Status.Status == codersdk.ChatStatusWaiting
|
||||
t.Logf("skipping unexpected event: type=%s", event.Type)
|
||||
return false
|
||||
default:
|
||||
return false
|
||||
}
|
||||
@@ -371,7 +367,7 @@ func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
|
||||
require.Len(t, messages, 1)
|
||||
}
|
||||
|
||||
func TestSendMessageInterruptBehaviorSendsImmediatelyWhenBusy(t *testing.T) {
|
||||
func TestSendMessageInterruptBehaviorQueuesAndInterruptsWhenBusy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
@@ -403,26 +399,31 @@ func TestSendMessageInterruptBehaviorSendsImmediatelyWhenBusy(t *testing.T) {
|
||||
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, result.Queued)
|
||||
require.Equal(t, database.ChatStatusPending, result.Chat.Status)
|
||||
require.False(t, result.Chat.WorkerID.Valid)
|
||||
|
||||
// The message should be queued, not inserted directly.
|
||||
require.True(t, result.Queued)
|
||||
require.NotNil(t, result.QueuedMessage)
|
||||
|
||||
// The chat should transition to waiting (interrupt signal),
|
||||
// not pending.
|
||||
require.Equal(t, database.ChatStatusWaiting, result.Chat.Status)
|
||||
|
||||
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ChatStatusPending, fromDB.Status)
|
||||
require.False(t, fromDB.WorkerID.Valid)
|
||||
require.Equal(t, database.ChatStatusWaiting, fromDB.Status)
|
||||
|
||||
// The message should be in the queue, not in chat_messages.
|
||||
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queued, 0)
|
||||
require.Len(t, queued, 1)
|
||||
|
||||
// Only the initial user message should be in chat_messages.
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, messages, 2)
|
||||
require.Equal(t, messages[len(messages)-1].ID, result.Message.ID)
|
||||
require.Len(t, messages, 1)
|
||||
}
|
||||
|
||||
func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) {
|
||||
@@ -870,15 +871,15 @@ func TestSubscribeNoPubsubNoDuplicateMessageParts(t *testing.T) {
|
||||
// events — the snapshot already contained everything. Before
|
||||
// the fix, localSnapshot was replayed into the channel,
|
||||
// causing duplicates.
|
||||
select {
|
||||
case event, ok := <-events:
|
||||
if ok {
|
||||
t.Fatalf("unexpected event from channel (would be a duplicate): type=%s", event.Type)
|
||||
require.Never(t, func() bool {
|
||||
select {
|
||||
case <-events:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
// Channel closed without events is fine.
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
// No events — correct behavior.
|
||||
}
|
||||
}, 200*time.Millisecond, testutil.IntervalFast,
|
||||
"expected no duplicate events after snapshot")
|
||||
}
|
||||
|
||||
func TestSubscribeAfterMessageID(t *testing.T) {
|
||||
@@ -1533,13 +1534,16 @@ func TestSuccessfulChatSendsWebPushWithNavigationData(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for a web push notification to be dispatched. The dispatch
|
||||
// happens asynchronously after the DB status is updated, so we need
|
||||
// to poll rather than assert immediately.
|
||||
testutil.Eventually(ctx, t, func(_ context.Context) bool {
|
||||
return mockPush.dispatchCount.Load() >= 1
|
||||
// Wait for the chat to complete and return to waiting status.
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
||||
if dbErr != nil {
|
||||
return false
|
||||
}
|
||||
return fromDB.Status == database.ChatStatusWaiting && !fromDB.WorkerID.Valid && mockPush.dispatchCount.Load() == 1
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
// Verify a web push notification was dispatched exactly once.
|
||||
require.Equal(t, int32(1), mockPush.dispatchCount.Load(),
|
||||
"expected exactly one web push dispatch for a completed chat")
|
||||
|
||||
@@ -1558,75 +1562,6 @@ func TestSuccessfulChatSendsWebPushWithNavigationData(t *testing.T) {
|
||||
"web push Data should contain the chat navigation URL")
|
||||
}
|
||||
|
||||
func TestSuccessfulChatSendsWebPushWithTag(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Set up a mock OpenAI that returns a simple streaming response.
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("title")
|
||||
}
|
||||
return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("done")...)
|
||||
})
|
||||
|
||||
// Mock webpush dispatcher that captures calls.
|
||||
mockPush := &mockWebpushDispatcher{}
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := chatd.New(chatd.Config{
|
||||
Logger: logger,
|
||||
Database: db,
|
||||
ReplicaID: uuid.New(),
|
||||
Pubsub: ps,
|
||||
PendingChatAcquireInterval: 10 * time.Millisecond,
|
||||
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
||||
WebpushDispatcher: mockPush,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, server.Close())
|
||||
})
|
||||
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
||||
|
||||
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "push-tag-test",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for the web push notification to be dispatched.
|
||||
// We poll dispatchCount rather than DB status because the
|
||||
// push fires after the status update, creating a small race
|
||||
// window.
|
||||
testutil.Eventually(ctx, t, func(_ context.Context) bool {
|
||||
return mockPush.dispatchCount.Load() >= 1
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
require.Equal(t, int32(1), mockPush.dispatchCount.Load(),
|
||||
"expected exactly one web push dispatch for a completed chat")
|
||||
|
||||
// Verify the push notification tag is set to the chat ID for dedup.
|
||||
mockPush.mu.Lock()
|
||||
capturedMsg := mockPush.lastMessage
|
||||
capturedUser := mockPush.lastUserID
|
||||
mockPush.mu.Unlock()
|
||||
|
||||
require.Equal(t, chat.ID.String(), capturedMsg.Tag,
|
||||
"push notification tag should equal the chat ID for deduplication")
|
||||
require.Equal(t, user.ID, capturedUser,
|
||||
"push notification should be dispatched to the chat owner")
|
||||
require.Equal(t, "push-tag-test", capturedMsg.Title,
|
||||
"push notification title should match the chat title")
|
||||
require.Equal(t, "Agent has finished running.", capturedMsg.Body,
|
||||
"push notification body should indicate the agent finished")
|
||||
}
|
||||
|
||||
func TestCloseDuringShutdownContextCanceledShouldRetryOnNewReplica(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1733,260 +1668,3 @@ func TestCloseDuringShutdownContextCanceledShouldRetryOnNewReplica(t *testing.T)
|
||||
!fromDB.LastError.Valid
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
}
|
||||
|
||||
func TestHeaderInjection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// seedWorkspaceAgent creates the DB entities needed so that
|
||||
// GetWorkspaceAgentsInLatestBuildByWorkspaceID returns an
|
||||
// agent for the given workspace.
|
||||
seedWorkspaceAgent := func(
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
ps dbpubsub.Pubsub,
|
||||
ownerID uuid.UUID,
|
||||
orgID uuid.UUID,
|
||||
) (workspaceID uuid.UUID, agentID uuid.UUID) {
|
||||
t.Helper()
|
||||
|
||||
// TemplateVersion needs its own provisioner job.
|
||||
versionJob := dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
|
||||
OrganizationID: orgID,
|
||||
InitiatorID: ownerID,
|
||||
Type: database.ProvisionerJobTypeTemplateVersionImport,
|
||||
})
|
||||
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
||||
OrganizationID: orgID,
|
||||
CreatedBy: ownerID,
|
||||
JobID: versionJob.ID,
|
||||
})
|
||||
templ := dbgen.Template(t, db, database.Template{
|
||||
OrganizationID: orgID,
|
||||
CreatedBy: ownerID,
|
||||
ActiveVersionID: tv.ID,
|
||||
})
|
||||
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: orgID,
|
||||
TemplateID: templ.ID,
|
||||
})
|
||||
buildJob := dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
|
||||
OrganizationID: orgID,
|
||||
InitiatorID: ownerID,
|
||||
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
||||
})
|
||||
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
||||
WorkspaceID: ws.ID,
|
||||
JobID: buildJob.ID,
|
||||
BuildNumber: 1,
|
||||
InitiatorID: ownerID,
|
||||
TemplateVersionID: tv.ID,
|
||||
})
|
||||
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
||||
JobID: build.JobID,
|
||||
})
|
||||
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
||||
ResourceID: resource.ID,
|
||||
})
|
||||
return ws.ID, agent.ID
|
||||
}
|
||||
|
||||
t.Run("WithParentChat", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
org, err := db.GetDefaultOrganization(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
workspaceID, expectedAgentID := seedWorkspaceAgent(t, db, ps, user.ID, org.ID)
|
||||
|
||||
// Set up the mock OpenAI to return a simple text response
|
||||
// so the chat finishes cleanly.
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("title")
|
||||
}
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("done")...,
|
||||
)
|
||||
})
|
||||
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
||||
|
||||
// Wire up the mock agent connection so we can capture
|
||||
// the headers passed to SetExtraHeaders.
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
var capturedHeaders http.Header
|
||||
headersCaptured := make(chan struct{})
|
||||
|
||||
// SetExtraHeaders is called once when the connection
|
||||
// is first established.
|
||||
mockConn.EXPECT().SetExtraHeaders(gomock.Any()).Do(func(h http.Header) {
|
||||
capturedHeaders = h
|
||||
close(headersCaptured)
|
||||
})
|
||||
// resolveInstructions calls LS to look for instruction
|
||||
// files; return an error so it skips gracefully.
|
||||
mockConn.EXPECT().LS(gomock.Any(), gomock.Any(), gomock.Any()).Return(
|
||||
workspacesdk.LSResponse{}, xerrors.New("not found"),
|
||||
).AnyTimes()
|
||||
// The connection is closed when the chat finishes.
|
||||
mockConn.EXPECT().Close().Return(nil).AnyTimes()
|
||||
|
||||
agentConnFn := func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
require.Equal(t, expectedAgentID, agentID)
|
||||
return mockConn, func() {}, nil
|
||||
}
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := chatd.New(chatd.Config{
|
||||
Logger: logger,
|
||||
Database: db,
|
||||
ReplicaID: uuid.New(),
|
||||
Pubsub: ps,
|
||||
AgentConn: agentConnFn,
|
||||
PendingChatAcquireInterval: 10 * time.Millisecond,
|
||||
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, server.Close())
|
||||
})
|
||||
|
||||
// Create a real parent chat so the FK constraint is
|
||||
// satisfied.
|
||||
parentChat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "parent-chat",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{
|
||||
fantasy.TextContent{Text: "parent"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true},
|
||||
ParentChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
|
||||
Title: "header-injection-parent",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{
|
||||
fantasy.TextContent{Text: "hello"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for the chat to be processed and headers to be
|
||||
// captured.
|
||||
select {
|
||||
case <-headersCaptured:
|
||||
case <-ctx.Done():
|
||||
require.FailNow(t, "timed out waiting for SetExtraHeaders")
|
||||
}
|
||||
|
||||
require.Equal(t,
|
||||
chat.ID.String(),
|
||||
capturedHeaders.Get(workspacesdk.CoderChatIDHeader),
|
||||
)
|
||||
|
||||
ancestorJSON := capturedHeaders.Get(workspacesdk.CoderAncestorChatIDsHeader)
|
||||
var ancestorIDs []string
|
||||
err = json.Unmarshal([]byte(ancestorJSON), &ancestorIDs)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{parentChat.ID.String()}, ancestorIDs)
|
||||
})
|
||||
|
||||
t.Run("WithoutParentChat", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
org, err := db.GetDefaultOrganization(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
workspaceID, expectedAgentID := seedWorkspaceAgent(t, db, ps, user.ID, org.ID)
|
||||
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("title")
|
||||
}
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("done")...,
|
||||
)
|
||||
})
|
||||
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
var capturedHeaders http.Header
|
||||
headersCaptured := make(chan struct{})
|
||||
|
||||
mockConn.EXPECT().SetExtraHeaders(gomock.Any()).Do(func(h http.Header) {
|
||||
capturedHeaders = h
|
||||
close(headersCaptured)
|
||||
})
|
||||
mockConn.EXPECT().LS(gomock.Any(), gomock.Any(), gomock.Any()).Return(
|
||||
workspacesdk.LSResponse{}, xerrors.New("not found"),
|
||||
).AnyTimes()
|
||||
mockConn.EXPECT().Close().Return(nil).AnyTimes()
|
||||
|
||||
agentConnFn := func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
require.Equal(t, expectedAgentID, agentID)
|
||||
return mockConn, func() {}, nil
|
||||
}
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := chatd.New(chatd.Config{
|
||||
Logger: logger,
|
||||
Database: db,
|
||||
ReplicaID: uuid.New(),
|
||||
Pubsub: ps,
|
||||
AgentConn: agentConnFn,
|
||||
PendingChatAcquireInterval: 10 * time.Millisecond,
|
||||
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, server.Close())
|
||||
})
|
||||
|
||||
// Create a chat without a parent — the ancestor header
|
||||
// should contain an empty JSON array.
|
||||
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true},
|
||||
Title: "header-injection-no-parent",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{
|
||||
fantasy.TextContent{Text: "hello"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-headersCaptured:
|
||||
case <-ctx.Done():
|
||||
require.FailNow(t, "timed out waiting for SetExtraHeaders")
|
||||
}
|
||||
|
||||
require.Equal(t,
|
||||
chat.ID.String(),
|
||||
capturedHeaders.Get(workspacesdk.CoderChatIDHeader),
|
||||
)
|
||||
|
||||
// When there is no parent, the code declares
|
||||
// var ancestorIDs []string and never appends to it,
|
||||
// so json.Marshal produces "null".
|
||||
ancestorJSON := capturedHeaders.Get(workspacesdk.CoderAncestorChatIDsHeader)
|
||||
var ancestorIDs []string
|
||||
err = json.Unmarshal([]byte(ancestorJSON), &ancestorIDs)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, ancestorIDs)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -73,7 +73,9 @@ type RunOptions struct {
|
||||
// OnRetry is called before each retry attempt when the LLM
|
||||
// stream fails with a retryable error. It provides the attempt
|
||||
// number, error, and backoff delay so callers can publish status
|
||||
// events to connected clients.
|
||||
// events to connected clients. Callers should also clear any
|
||||
// buffered stream state from the failed attempt in this callback
|
||||
// to avoid sending duplicated content.
|
||||
OnRetry chatretry.OnRetryFn
|
||||
|
||||
OnInterruptedPersistError func(error)
|
||||
@@ -209,6 +211,10 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
var lastUsage fantasy.Usage
|
||||
var lastProviderMetadata fantasy.ProviderMetadata
|
||||
|
||||
totalSteps := 0
|
||||
// When totalSteps reaches MaxSteps the inner loop exits immediately
|
||||
// (its condition is false), stoppedByModel stays false, and the
|
||||
// post-loop guard breaks the outer compaction loop.
|
||||
for compactionAttempt := 0; ; compactionAttempt++ {
|
||||
alreadyCompacted := false
|
||||
// stoppedByModel is true when the inner step loop
|
||||
@@ -222,7 +228,8 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
// agent never had a chance to use the compacted context.
|
||||
compactedOnFinalStep := false
|
||||
|
||||
for step := 0; step < opts.MaxSteps; step++ {
|
||||
for step := 0; totalSteps < opts.MaxSteps; step++ {
|
||||
totalSteps++
|
||||
// Copy messages so that provider-specific caching
|
||||
// mutations don't leak back to the caller's slice.
|
||||
// copy copies Message structs by value, so field
|
||||
@@ -321,6 +328,12 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
lastUsage = result.usage
|
||||
lastProviderMetadata = result.providerMetadata
|
||||
|
||||
// Append the step's response messages so that both
|
||||
// inline and post-loop compaction see the full
|
||||
// conversation including the latest assistant reply.
|
||||
stepMessages := result.toResponseMessages()
|
||||
messages = append(messages, stepMessages...)
|
||||
|
||||
// Inline compaction.
|
||||
if opts.Compaction != nil && opts.ReloadMessages != nil {
|
||||
did, compactErr := tryCompact(
|
||||
@@ -354,17 +367,11 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
// The agent is continuing with tool calls, so any
|
||||
// prior compaction has already been consumed.
|
||||
compactedOnFinalStep = false
|
||||
|
||||
// Build messages from the step for the next iteration.
|
||||
// toResponseMessages produces assistant-role content
|
||||
// (text, reasoning, tool calls) and tool-result content.
|
||||
stepMessages := result.toResponseMessages()
|
||||
messages = append(messages, stepMessages...)
|
||||
}
|
||||
|
||||
// Post-run compaction safety net: if we never compacted
|
||||
// during the loop, try once at the end.
|
||||
if !alreadyCompacted && opts.Compaction != nil {
|
||||
if !alreadyCompacted && opts.Compaction != nil && opts.ReloadMessages != nil {
|
||||
did, err := tryCompact(
|
||||
ctx,
|
||||
opts.Model,
|
||||
@@ -383,7 +390,6 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
compactedOnFinalStep = true
|
||||
}
|
||||
}
|
||||
|
||||
// Re-enter the step loop when compaction fired on the
|
||||
// model's final step. This lets the agent continue
|
||||
// working with fresh summarized context instead of
|
||||
@@ -514,7 +520,6 @@ func processStepStream(
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
case fantasy.StreamPartTypeToolInputStart:
|
||||
activeToolCalls[part.ID] = &fantasy.ToolCallContent{
|
||||
ToolCallID: part.ID,
|
||||
|
||||
@@ -123,7 +123,8 @@ func tryCompact(
|
||||
config.SystemSummaryPrefix + "\n\n" + summary,
|
||||
)
|
||||
|
||||
err = config.Persist(ctx, CompactionResult{
|
||||
persistCtx := context.WithoutCancel(ctx)
|
||||
err = config.Persist(persistCtx, CompactionResult{
|
||||
SystemSummary: systemSummary,
|
||||
SummaryReport: summary,
|
||||
ThresholdPercent: config.ThresholdPercent,
|
||||
|
||||
@@ -76,9 +76,20 @@ func TestRun_Compaction(t *testing.T) {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) {
|
||||
return []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleUser, "hello"),
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, persistCompactionCalls)
|
||||
// Compaction fires twice: once inline when the threshold is
|
||||
// reached on step 0 (the only step, since MaxSteps=1), and
|
||||
// once from the post-run safety net during the re-entry
|
||||
// iteration (where totalSteps already equals MaxSteps so the
|
||||
// inner loop doesn't execute, but lastUsage still exceeds
|
||||
// the threshold).
|
||||
require.Equal(t, 2, persistCompactionCalls)
|
||||
require.Contains(t, persistedCompaction.SystemSummary, summaryText)
|
||||
require.Equal(t, summaryText, persistedCompaction.SummaryReport)
|
||||
require.Equal(t, int64(80), persistedCompaction.ContextTokens)
|
||||
@@ -151,13 +162,25 @@ func TestRun_Compaction(t *testing.T) {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) {
|
||||
return []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleUser, "hello"),
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// Compaction fires twice (see PersistsWhenThresholdReached
|
||||
// for the full explanation). Each cycle follows the order:
|
||||
// publish_tool_call → generate → persist → publish_tool_result.
|
||||
require.Equal(t, []string{
|
||||
"publish_tool_call",
|
||||
"generate",
|
||||
"persist",
|
||||
"publish_tool_result",
|
||||
"publish_tool_call",
|
||||
"generate",
|
||||
"persist",
|
||||
"publish_tool_result",
|
||||
}, callOrder)
|
||||
})
|
||||
|
||||
@@ -457,6 +480,11 @@ func TestRun_Compaction(t *testing.T) {
|
||||
compactionErr = err
|
||||
},
|
||||
},
|
||||
ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) {
|
||||
return []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleUser, "hello"),
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Error(t, compactionErr)
|
||||
@@ -572,4 +600,117 @@ func TestRun_Compaction(t *testing.T) {
|
||||
// Two stream calls: one before compaction, one after re-entry.
|
||||
require.Equal(t, 2, streamCallCount)
|
||||
})
|
||||
|
||||
t.Run("PostRunCompactionReEntryIncludesUserSummary", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// After compaction the summary is stored as a user-role
|
||||
// message. When the loop re-enters, the reloaded prompt
|
||||
// must contain this user message so the LLM provider
|
||||
// receives a valid prompt (providers like Anthropic
|
||||
// require at least one non-system message).
|
||||
|
||||
var mu sync.Mutex
|
||||
var streamCallCount int
|
||||
var reEntryPrompt []fantasy.Message
|
||||
persistCompactionCalls := 0
|
||||
|
||||
const summaryText = "post-run compacted summary"
|
||||
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
mu.Lock()
|
||||
step := streamCallCount
|
||||
streamCallCount++
|
||||
mu.Unlock()
|
||||
|
||||
switch step {
|
||||
case 0:
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "initial response"},
|
||||
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
|
||||
{
|
||||
Type: fantasy.StreamPartTypeFinish,
|
||||
FinishReason: fantasy.FinishReasonStop,
|
||||
Usage: fantasy.Usage{
|
||||
InputTokens: 80,
|
||||
TotalTokens: 85,
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
default:
|
||||
mu.Lock()
|
||||
reEntryPrompt = append([]fantasy.Message(nil), call.Prompt...)
|
||||
mu.Unlock()
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextStart, ID: "text-2"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-2", Delta: "continued"},
|
||||
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-2"},
|
||||
{
|
||||
Type: fantasy.StreamPartTypeFinish,
|
||||
FinishReason: fantasy.FinishReasonStop,
|
||||
Usage: fantasy.Usage{
|
||||
InputTokens: 20,
|
||||
TotalTokens: 25,
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
},
|
||||
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
return &fantasy.Response{
|
||||
Content: []fantasy.Content{
|
||||
fantasy.TextContent{Text: summaryText},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
// Simulate real post-compaction DB state: the summary is
|
||||
// a user-role message (the only non-system content).
|
||||
compactedMessages := []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleSystem, "system prompt"),
|
||||
textMessage(fantasy.MessageRoleUser, "Summary of earlier chat context:\n\ncompacted summary"),
|
||||
}
|
||||
|
||||
err := Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
Messages: []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleUser, "hello"),
|
||||
},
|
||||
MaxSteps: 5,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
ContextLimitFallback: 100,
|
||||
Compaction: &CompactionOptions{
|
||||
ThresholdPercent: 70,
|
||||
SummaryPrompt: "summarize now",
|
||||
Persist: func(_ context.Context, _ CompactionResult) error {
|
||||
persistCompactionCalls++
|
||||
return nil
|
||||
},
|
||||
},
|
||||
ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) {
|
||||
return compactedMessages, nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.GreaterOrEqual(t, persistCompactionCalls, 1)
|
||||
// Re-entry happened: stream was called at least twice.
|
||||
require.Equal(t, 2, streamCallCount)
|
||||
// The re-entry prompt must contain the user summary.
|
||||
require.NotEmpty(t, reEntryPrompt)
|
||||
hasUser := false
|
||||
for _, msg := range reEntryPrompt {
|
||||
if msg.Role == fantasy.MessageRoleUser {
|
||||
hasUser = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasUser, "re-entry prompt must contain a user message (the compaction summary)")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package chatprompt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyopenai "charm.land/fantasy/providers/openai"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
@@ -16,12 +18,156 @@ import (
|
||||
|
||||
var toolCallIDSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_-]`)
|
||||
|
||||
// FileData holds resolved file content for LLM prompt building.
|
||||
type FileData struct {
|
||||
Data []byte
|
||||
MediaType string
|
||||
}
|
||||
|
||||
// FileResolver fetches file content by ID for LLM prompt building.
|
||||
type FileResolver func(ctx context.Context, ids []uuid.UUID) (map[uuid.UUID]FileData, error)
|
||||
|
||||
// ExtractFileID parses the file_id from a serialized file content
|
||||
// block envelope. Returns uuid.Nil and an error when the block is
|
||||
// not a file-type block or has no file_id.
|
||||
func ExtractFileID(raw json.RawMessage) (uuid.UUID, error) {
|
||||
var envelope struct {
|
||||
Type string `json:"type"`
|
||||
Data struct {
|
||||
FileID string `json:"file_id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &envelope); err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("unmarshal content block: %w", err)
|
||||
}
|
||||
if !strings.EqualFold(envelope.Type, string(fantasy.ContentTypeFile)) {
|
||||
return uuid.Nil, xerrors.Errorf("not a file content block: %s", envelope.Type)
|
||||
}
|
||||
if envelope.Data.FileID == "" {
|
||||
return uuid.Nil, xerrors.New("no file_id")
|
||||
}
|
||||
return uuid.Parse(envelope.Data.FileID)
|
||||
}
|
||||
|
||||
// extractFileIDs scans raw message content for file_id references.
|
||||
// Returns a map of block index to file ID. Returns nil for
|
||||
// non-array content or content with no file references.
|
||||
func extractFileIDs(raw pqtype.NullRawMessage) map[int]uuid.UUID {
|
||||
if !raw.Valid || len(raw.RawMessage) == 0 {
|
||||
return nil
|
||||
}
|
||||
var rawBlocks []json.RawMessage
|
||||
if err := json.Unmarshal(raw.RawMessage, &rawBlocks); err != nil {
|
||||
return nil
|
||||
}
|
||||
var result map[int]uuid.UUID
|
||||
for i, block := range rawBlocks {
|
||||
fid, err := ExtractFileID(block)
|
||||
if err == nil {
|
||||
if result == nil {
|
||||
result = make(map[int]uuid.UUID)
|
||||
}
|
||||
result[i] = fid
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// patchFileContent fills in empty Data on FileContent blocks from
|
||||
// resolved file data. Blocks that already have inline data (backward
|
||||
// compat) or have no resolved data are left unchanged.
|
||||
func patchFileContent(
|
||||
content []fantasy.Content,
|
||||
fileIDs map[int]uuid.UUID,
|
||||
resolved map[uuid.UUID]FileData,
|
||||
) {
|
||||
for blockIdx, fid := range fileIDs {
|
||||
if blockIdx >= len(content) {
|
||||
continue
|
||||
}
|
||||
switch fc := content[blockIdx].(type) {
|
||||
case fantasy.FileContent:
|
||||
if len(fc.Data) > 0 {
|
||||
continue
|
||||
}
|
||||
if data, found := resolved[fid]; found {
|
||||
fc.Data = data.Data
|
||||
content[blockIdx] = fc
|
||||
}
|
||||
case *fantasy.FileContent:
|
||||
if len(fc.Data) > 0 {
|
||||
continue
|
||||
}
|
||||
if data, found := resolved[fid]; found {
|
||||
fc.Data = data.Data
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertMessages converts persisted chat messages into LLM prompt
|
||||
// messages without resolving file references from storage. Inline
|
||||
// file data is preserved when present (backward compat).
|
||||
func ConvertMessages(
|
||||
messages []database.ChatMessage,
|
||||
) ([]fantasy.Message, error) {
|
||||
return ConvertMessagesWithFiles(context.Background(), messages, nil)
|
||||
}
|
||||
|
||||
// ConvertMessagesWithFiles converts persisted chat messages into LLM
|
||||
// prompt messages, resolving file references via the provided
|
||||
// resolver. When resolver is nil, file blocks without inline data
|
||||
// are passed through as-is (same behavior as ConvertMessages).
|
||||
func ConvertMessagesWithFiles(
|
||||
ctx context.Context,
|
||||
messages []database.ChatMessage,
|
||||
resolver FileResolver,
|
||||
) ([]fantasy.Message, error) {
|
||||
// Phase 1: Pre-scan user messages for file_id references.
|
||||
var allFileIDs []uuid.UUID
|
||||
seenFileIDs := make(map[uuid.UUID]struct{})
|
||||
fileIDsByMsg := make(map[int]map[int]uuid.UUID)
|
||||
|
||||
if resolver != nil {
|
||||
for i, msg := range messages {
|
||||
visibility := msg.Visibility
|
||||
if visibility == "" {
|
||||
visibility = database.ChatMessageVisibilityBoth
|
||||
}
|
||||
if visibility != database.ChatMessageVisibilityModel &&
|
||||
visibility != database.ChatMessageVisibilityBoth {
|
||||
continue
|
||||
}
|
||||
if msg.Role != string(fantasy.MessageRoleUser) {
|
||||
continue
|
||||
}
|
||||
fids := extractFileIDs(msg.Content)
|
||||
if len(fids) > 0 {
|
||||
fileIDsByMsg[i] = fids
|
||||
for _, fid := range fids {
|
||||
if _, seen := seenFileIDs[fid]; !seen {
|
||||
seenFileIDs[fid] = struct{}{}
|
||||
allFileIDs = append(allFileIDs, fid)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: Batch resolve file data.
|
||||
var resolved map[uuid.UUID]FileData
|
||||
if len(allFileIDs) > 0 {
|
||||
var err error
|
||||
resolved, err = resolver(ctx, allFileIDs)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("resolve chat files: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 3: Convert messages, patching file content as needed.
|
||||
prompt := make([]fantasy.Message, 0, len(messages))
|
||||
toolNameByCallID := make(map[string]string)
|
||||
for _, message := range messages {
|
||||
for i, message := range messages {
|
||||
visibility := message.Visibility
|
||||
if visibility == "" {
|
||||
visibility = database.ChatMessageVisibilityBoth
|
||||
@@ -51,6 +197,9 @@ func ConvertMessages(
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if fids, ok := fileIDsByMsg[i]; ok {
|
||||
patchFileContent(content, fids, resolved)
|
||||
}
|
||||
prompt = append(prompt, fantasy.Message{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: ToMessageParts(content),
|
||||
@@ -400,7 +549,10 @@ func ExtractToolCalls(parts []fantasy.MessagePart) []fantasy.ToolCallContent {
|
||||
}
|
||||
|
||||
// MarshalContent encodes message content blocks for persistence.
|
||||
func MarshalContent(blocks []fantasy.Content) (pqtype.NullRawMessage, error) {
|
||||
// fileIDs optionally maps block indices to chat_files IDs, which
|
||||
// are injected into the JSON envelope for file-type blocks so
|
||||
// the reference survives round-trips through storage.
|
||||
func MarshalContent(blocks []fantasy.Content, fileIDs map[int]uuid.UUID) (pqtype.NullRawMessage, error) {
|
||||
if len(blocks) == 0 {
|
||||
return pqtype.NullRawMessage{}, nil
|
||||
}
|
||||
@@ -415,6 +567,16 @@ func MarshalContent(blocks []fantasy.Content) (pqtype.NullRawMessage, error) {
|
||||
err,
|
||||
)
|
||||
}
|
||||
if fid, ok := fileIDs[i]; ok {
|
||||
encoded, err = injectFileID(encoded, fid)
|
||||
if err != nil {
|
||||
return pqtype.NullRawMessage{}, xerrors.Errorf(
|
||||
"inject file_id into content block %d: %w",
|
||||
i,
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
encodedBlocks = append(encodedBlocks, encoded)
|
||||
}
|
||||
|
||||
@@ -425,6 +587,27 @@ func MarshalContent(blocks []fantasy.Content) (pqtype.NullRawMessage, error) {
|
||||
return pqtype.NullRawMessage{RawMessage: data, Valid: true}, nil
|
||||
}
|
||||
|
||||
// injectFileID adds a file_id field into the data sub-object of a
|
||||
// serialized content block envelope. This follows the same pattern
|
||||
// as the reasoning title injection in marshalContentBlock.
|
||||
func injectFileID(encoded json.RawMessage, fileID uuid.UUID) (json.RawMessage, error) {
|
||||
var envelope struct {
|
||||
Type string `json:"type"`
|
||||
Data struct {
|
||||
MediaType string `json:"media_type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
FileID string `json:"file_id,omitempty"`
|
||||
ProviderMetadata *json.RawMessage `json:"provider_metadata,omitempty"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(encoded, &envelope); err != nil {
|
||||
return encoded, err
|
||||
}
|
||||
envelope.Data.FileID = fileID.String()
|
||||
envelope.Data.Data = nil // Strip inline data; resolved at LLM dispatch time.
|
||||
return json.Marshal(envelope)
|
||||
}
|
||||
|
||||
// MarshalToolResult encodes a single tool result for persistence as
|
||||
// an opaque JSON blob. The stored shape is
|
||||
// [{"tool_call_id":…,"tool_name":…,"result":…,"is_error":…}].
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
package chatprompt_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
@@ -52,7 +55,7 @@ func TestConvertMessages_NormalizesAssistantToolCallInput(t *testing.T) {
|
||||
ToolName: "execute",
|
||||
Input: tc.input,
|
||||
},
|
||||
})
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
toolContent, err := chatprompt.MarshalToolResult(
|
||||
@@ -89,3 +92,139 @@ func TestConvertMessages_NormalizesAssistantToolCallInput(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertMessagesWithFiles_ResolvesFileData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fileID := uuid.New()
|
||||
fileData := []byte("fake-image-bytes")
|
||||
|
||||
// Build a user message with file_id but no inline data, as
|
||||
// would be stored after injectFileID strips the data.
|
||||
rawContent := mustJSON(t, []json.RawMessage{
|
||||
mustJSON(t, map[string]any{
|
||||
"type": "file",
|
||||
"data": map[string]any{
|
||||
"media_type": "image/png",
|
||||
"file_id": fileID.String(),
|
||||
},
|
||||
}),
|
||||
})
|
||||
|
||||
resolver := func(_ context.Context, ids []uuid.UUID) (map[uuid.UUID]chatprompt.FileData, error) {
|
||||
result := make(map[uuid.UUID]chatprompt.FileData)
|
||||
for _, id := range ids {
|
||||
if id == fileID {
|
||||
result[id] = chatprompt.FileData{
|
||||
Data: fileData,
|
||||
MediaType: "image/png",
|
||||
}
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
prompt, err := chatprompt.ConvertMessagesWithFiles(
|
||||
context.Background(),
|
||||
[]database.ChatMessage{
|
||||
{
|
||||
Role: string(fantasy.MessageRoleUser),
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
Content: pqtype.NullRawMessage{RawMessage: rawContent, Valid: true},
|
||||
},
|
||||
},
|
||||
resolver,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, prompt, 1)
|
||||
require.Equal(t, fantasy.MessageRoleUser, prompt[0].Role)
|
||||
require.Len(t, prompt[0].Content, 1)
|
||||
|
||||
filePart, ok := fantasy.AsMessagePart[fantasy.FilePart](prompt[0].Content[0])
|
||||
require.True(t, ok, "expected FilePart")
|
||||
require.Equal(t, fileData, filePart.Data)
|
||||
require.Equal(t, "image/png", filePart.MediaType)
|
||||
}
|
||||
|
||||
func TestConvertMessagesWithFiles_BackwardCompat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// A message with inline data and a file_id should use the
|
||||
// inline data even when the resolver returns nothing.
|
||||
fileID := uuid.New()
|
||||
inlineData := []byte("inline-image-data")
|
||||
|
||||
rawContent := mustJSON(t, []json.RawMessage{
|
||||
mustJSON(t, map[string]any{
|
||||
"type": "file",
|
||||
"data": map[string]any{
|
||||
"media_type": "image/png",
|
||||
"data": inlineData,
|
||||
"file_id": fileID.String(),
|
||||
},
|
||||
}),
|
||||
})
|
||||
|
||||
prompt, err := chatprompt.ConvertMessagesWithFiles(
|
||||
context.Background(),
|
||||
[]database.ChatMessage{
|
||||
{
|
||||
Role: string(fantasy.MessageRoleUser),
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
Content: pqtype.NullRawMessage{RawMessage: rawContent, Valid: true},
|
||||
},
|
||||
},
|
||||
nil, // No resolver.
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, prompt, 1)
|
||||
require.Len(t, prompt[0].Content, 1)
|
||||
|
||||
filePart, ok := fantasy.AsMessagePart[fantasy.FilePart](prompt[0].Content[0])
|
||||
require.True(t, ok, "expected FilePart")
|
||||
require.Equal(t, inlineData, filePart.Data)
|
||||
}
|
||||
|
||||
func TestInjectFileID_StripsInlineData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fileID := uuid.New()
|
||||
imageData := []byte("raw-image-bytes")
|
||||
|
||||
// Marshal a file content block with inline data, then inject
|
||||
// a file_id. The result should have file_id but no data.
|
||||
content, err := chatprompt.MarshalContent([]fantasy.Content{
|
||||
fantasy.FileContent{
|
||||
MediaType: "image/png",
|
||||
Data: imageData,
|
||||
},
|
||||
}, map[int]uuid.UUID{0: fileID})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse the stored content to verify shape.
|
||||
var blocks []json.RawMessage
|
||||
require.NoError(t, json.Unmarshal(content.RawMessage, &blocks))
|
||||
require.Len(t, blocks, 1)
|
||||
|
||||
var envelope struct {
|
||||
Type string `json:"type"`
|
||||
Data struct {
|
||||
MediaType string `json:"media_type"`
|
||||
Data *json.RawMessage `json:"data,omitempty"`
|
||||
FileID string `json:"file_id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(blocks[0], &envelope))
|
||||
require.Equal(t, "file", envelope.Type)
|
||||
require.Equal(t, "image/png", envelope.Data.MediaType)
|
||||
require.Equal(t, fileID.String(), envelope.Data.FileID)
|
||||
// Data should be nil (omitted) since injectFileID strips it.
|
||||
require.Nil(t, envelope.Data.Data, "inline data should be stripped")
|
||||
}
|
||||
|
||||
func mustJSON(t *testing.T, v any) json.RawMessage {
|
||||
t.Helper()
|
||||
data, err := json.Marshal(v)
|
||||
require.NoError(t, err)
|
||||
return data
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -18,6 +20,12 @@ const (
|
||||
// MaxDelay is the upper bound for the exponential backoff
|
||||
// duration. Matches the cap used in coder/mux.
|
||||
MaxDelay = 60 * time.Second
|
||||
|
||||
// MaxAttempts is the upper bound on retry attempts before
|
||||
// giving up. With a 60s max backoff this allows roughly
|
||||
// 25 minutes of retries, which is reasonable for transient
|
||||
// LLM provider issues.
|
||||
MaxAttempts = 25
|
||||
)
|
||||
|
||||
// nonRetryablePatterns are substrings that indicate a permanent error
|
||||
@@ -131,9 +139,8 @@ type RetryFn func(ctx context.Context) error
|
||||
type OnRetryFn func(attempt int, err error, delay time.Duration)
|
||||
|
||||
// Retry calls fn repeatedly until it succeeds, returns a
|
||||
// non-retryable error, or ctx is canceled. There is no max attempt
|
||||
// limit — retries continue indefinitely with exponential backoff
|
||||
// (capped at 60s), matching the behavior of coder/mux.
|
||||
// non-retryable error, ctx is canceled, or MaxAttempts is reached.
|
||||
// Retries use exponential backoff capped at MaxDelay.
|
||||
//
|
||||
// The onRetry callback (if non-nil) is called before each retry
|
||||
// attempt, giving the caller a chance to reset state, log, or
|
||||
@@ -156,10 +163,15 @@ func Retry(ctx context.Context, fn RetryFn, onRetry OnRetryFn) error {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
delay := Delay(attempt)
|
||||
attempt++
|
||||
if attempt >= MaxAttempts {
|
||||
return xerrors.Errorf("max retry attempts (%d) exceeded: %w", MaxAttempts, err)
|
||||
}
|
||||
|
||||
delay := Delay(attempt - 1)
|
||||
|
||||
if onRetry != nil {
|
||||
onRetry(attempt+1, err, delay)
|
||||
onRetry(attempt, err, delay)
|
||||
}
|
||||
|
||||
timer := time.NewTimer(delay)
|
||||
@@ -169,7 +181,5 @@ func Retry(ctx context.Context, fn RetryFn, onRetry OnRetryFn) error {
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
attempt++
|
||||
}
|
||||
}
|
||||
|
||||
+344
-38
@@ -1,11 +1,15 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
@@ -247,7 +251,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
contentBlocks, titleSource, inputError := createChatInputFromRequest(req)
|
||||
contentBlocks, contentFileIDs, titleSource, inputError := createChatInputFromRequest(ctx, api.Database, req)
|
||||
if inputError != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, *inputError)
|
||||
return
|
||||
@@ -282,6 +286,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
ModelConfigID: modelConfigID,
|
||||
SystemPrompt: defaultChatSystemPrompt(),
|
||||
InitialUserContent: contentBlocks,
|
||||
ContentFileIDs: contentFileIDs,
|
||||
})
|
||||
if err != nil {
|
||||
if database.IsForeignKeyViolation(
|
||||
@@ -588,7 +593,15 @@ func (api *API) archiveChat(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err := api.Database.ArchiveChatByID(ctx, chat.ID)
|
||||
var err error
|
||||
// Use chatDaemon when available so it can notify
|
||||
// active subscribers. Fall back to direct DB for the
|
||||
// simple archive flag — no streaming state is involved.
|
||||
if api.chatDaemon != nil {
|
||||
err = api.chatDaemon.ArchiveChat(ctx, chat.ID)
|
||||
} else {
|
||||
err = api.Database.ArchiveChatByID(ctx, chat.ID)
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to archive chat.",
|
||||
@@ -616,7 +629,15 @@ func (api *API) unarchiveChat(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err := api.Database.UnarchiveChatByID(ctx, chat.ID)
|
||||
var err error
|
||||
// Use chatDaemon when available so it can notify
|
||||
// active subscribers. Fall back to direct DB for the
|
||||
// simple unarchive flag — no streaming state is involved.
|
||||
if api.chatDaemon != nil {
|
||||
err = api.chatDaemon.UnarchiveChat(ctx, chat.ID)
|
||||
} else {
|
||||
err = api.Database.UnarchiveChatByID(ctx, chat.ID)
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to unarchive chat.",
|
||||
@@ -647,7 +668,7 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
contentBlocks, _, inputError := createChatInputFromParts(req.Content, "content")
|
||||
contentBlocks, contentFileIDs, _, inputError := createChatInputFromParts(ctx, api.Database, req.Content, "content")
|
||||
if inputError != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: inputError.Message,
|
||||
@@ -659,10 +680,11 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
|
||||
sendResult, sendErr := api.chatDaemon.SendMessage(
|
||||
ctx,
|
||||
chatd.SendMessageOptions{
|
||||
ChatID: chatID,
|
||||
Content: contentBlocks,
|
||||
ModelConfigID: req.ModelConfigID,
|
||||
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
||||
ChatID: chatID,
|
||||
Content: contentBlocks,
|
||||
ContentFileIDs: contentFileIDs,
|
||||
ModelConfigID: req.ModelConfigID,
|
||||
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
||||
},
|
||||
)
|
||||
if sendErr != nil {
|
||||
@@ -721,7 +743,7 @@ func (api *API) patchChatMessage(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
contentBlocks, _, inputError := createChatInputFromParts(req.Content, "content")
|
||||
contentBlocks, contentFileIDs, _, inputError := createChatInputFromParts(ctx, api.Database, req.Content, "content")
|
||||
if inputError != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: inputError.Message,
|
||||
@@ -734,6 +756,7 @@ func (api *API) patchChatMessage(rw http.ResponseWriter, r *http.Request) {
|
||||
ChatID: chat.ID,
|
||||
EditedMessageID: messageID,
|
||||
Content: contentBlocks,
|
||||
ContentFileIDs: contentFileIDs,
|
||||
})
|
||||
if editErr != nil {
|
||||
switch {
|
||||
@@ -848,18 +871,6 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to open chat stream.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
<-senderClosed
|
||||
}()
|
||||
|
||||
var afterMessageID int64
|
||||
if v := r.URL.Query().Get("after_id"); v != "" {
|
||||
var err error
|
||||
@@ -873,14 +884,31 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
snapshot, events, cancel, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
|
||||
if !ok {
|
||||
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Chat streaming is not available.",
|
||||
Detail: "Chat stream state is not configured.",
|
||||
Message: "Failed to open chat stream.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
snapshot, events, cancel, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
|
||||
if !ok {
|
||||
_ = sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeError,
|
||||
Data: codersdk.Response{
|
||||
Message: "Chat streaming is not available.",
|
||||
Detail: "Chat stream state is not configured.",
|
||||
},
|
||||
})
|
||||
// Ensure the WebSocket is closed so senderClosed
|
||||
// completes and the handler can return.
|
||||
<-senderClosed
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
<-senderClosed
|
||||
}()
|
||||
defer cancel()
|
||||
|
||||
sendChatStreamBatch := func(batch []codersdk.ChatStreamEvent) error {
|
||||
@@ -973,9 +1001,13 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
|
||||
if updateErr != nil {
|
||||
api.Logger.Error(ctx, "failed to mark chat as waiting",
|
||||
slog.F("chat_id", chatID), slog.Error(updateErr))
|
||||
} else {
|
||||
chat = updatedChat
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to interrupt chat.",
|
||||
Detail: updateErr.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
chat = updatedChat
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, convertChat(chat, nil))
|
||||
@@ -2196,45 +2228,317 @@ func normalizeChatCompressionThreshold(
|
||||
return threshold, nil
|
||||
}
|
||||
|
||||
const (
|
||||
// maxChatFileSize is the maximum size of a chat file upload (10 MB).
|
||||
maxChatFileSize = 10 << 20
|
||||
// maxChatFileName is the maximum length of an uploaded file name.
|
||||
maxChatFileName = 255
|
||||
)
|
||||
|
||||
// allowedChatFileMIMETypes lists the content types accepted for chat
|
||||
// file uploads. SVG is explicitly excluded because it can contain scripts.
|
||||
var allowedChatFileMIMETypes = map[string]bool{
|
||||
"image/png": true,
|
||||
"image/jpeg": true,
|
||||
"image/gif": true,
|
||||
"image/webp": true,
|
||||
"image/svg+xml": false, // SVG can contain scripts.
|
||||
}
|
||||
|
||||
var (
|
||||
webpMagicRIFF = []byte("RIFF")
|
||||
webpMagicWEBP = []byte("WEBP")
|
||||
)
|
||||
|
||||
// detectChatFileType detects the MIME type of the given data.
|
||||
// It extends http.DetectContentType with support for WebP, which
|
||||
// Go's standard sniffer does not recognize.
|
||||
func detectChatFileType(data []byte) string {
|
||||
if len(data) >= 12 &&
|
||||
bytes.Equal(data[0:4], webpMagicRIFF) &&
|
||||
bytes.Equal(data[8:12], webpMagicWEBP) {
|
||||
return "image/webp"
|
||||
}
|
||||
return http.DetectContentType(data)
|
||||
}
|
||||
|
||||
func defaultChatSystemPrompt() string {
|
||||
return chatd.DefaultSystemPrompt
|
||||
}
|
||||
|
||||
func createChatInputFromRequest(req codersdk.CreateChatRequest) (
|
||||
// @Summary Upload a chat file
|
||||
// @ID upload-chat-file
|
||||
// @Security CoderSessionToken
|
||||
// @Accept application/octet-stream
|
||||
// @Produce json
|
||||
// @Tags Chats
|
||||
// @Param Content-Type header string true "Content-Type must be an image type (image/png, image/jpeg, image/gif, image/webp)"
|
||||
// @Param organization query string true "Organization ID" format(uuid)
|
||||
// @Success 201 {object} codersdk.UploadChatFileResponse
|
||||
// @Failure 400 {object} codersdk.Response
|
||||
// @Failure 401 {object} codersdk.Response
|
||||
// @Failure 413 {object} codersdk.Response
|
||||
// @Failure 500 {object} codersdk.Response
|
||||
// @Router /chats/files [post]
|
||||
func (api *API) postChatFile(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
|
||||
if !api.Authorize(r, policy.ActionCreate, rbac.ResourceChat.WithOwner(apiKey.UserID.String())) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
|
||||
orgIDStr := r.URL.Query().Get("organization")
|
||||
if orgIDStr == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Missing organization query parameter.",
|
||||
})
|
||||
return
|
||||
}
|
||||
orgID, err := uuid.Parse(orgIDStr)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid organization ID.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
// Strip parameters (e.g. "image/png; charset=utf-8" → "image/png")
|
||||
// so the allowlist check matches the base media type.
|
||||
if mediaType, _, err := mime.ParseMediaType(contentType); err == nil {
|
||||
contentType = mediaType
|
||||
}
|
||||
|
||||
if allowed, ok := allowedChatFileMIMETypes[contentType]; !ok || !allowed {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Unsupported file type.",
|
||||
Detail: "Allowed types: image/png, image/jpeg, image/gif, image/webp.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
r.Body = http.MaxBytesReader(rw, r.Body, maxChatFileSize)
|
||||
br := bufio.NewReader(r.Body)
|
||||
|
||||
// Peek at the leading bytes to sniff the real content type
|
||||
// before reading the entire body.
|
||||
peek, peekErr := br.Peek(512)
|
||||
if peekErr != nil && !errors.Is(peekErr, io.EOF) && !errors.Is(peekErr, bufio.ErrBufferFull) {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to read file from request.",
|
||||
Detail: peekErr.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the actual content matches a safe image type so that
|
||||
// a client cannot spoof Content-Type to serve active content.
|
||||
detected := detectChatFileType(peek)
|
||||
if allowed, ok := allowedChatFileMIMETypes[detected]; !ok || !allowed {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Unsupported file type.",
|
||||
Detail: "Allowed types: image/png, image/jpeg, image/gif, image/webp.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Read the full body now that we know the type is valid.
|
||||
data, err := io.ReadAll(br)
|
||||
if err != nil {
|
||||
var maxBytesErr *http.MaxBytesError
|
||||
if errors.As(err, &maxBytesErr) {
|
||||
httpapi.Write(ctx, rw, http.StatusRequestEntityTooLarge, codersdk.Response{
|
||||
Message: "File too large.",
|
||||
Detail: fmt.Sprintf("Maximum file size is %d bytes.", maxChatFileSize),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to read file from request.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Extract filename from Content-Disposition header if provided.
|
||||
var filename string
|
||||
if cd := r.Header.Get("Content-Disposition"); cd != "" {
|
||||
if _, params, err := mime.ParseMediaType(cd); err == nil {
|
||||
filename = params["filename"]
|
||||
if len(filename) > maxChatFileName {
|
||||
// Truncate at rune boundary to avoid splitting
|
||||
// multi-byte UTF-8 characters.
|
||||
var truncated []byte
|
||||
for _, r := range filename {
|
||||
encoded := []byte(string(r))
|
||||
if len(truncated)+len(encoded) > maxChatFileName {
|
||||
break
|
||||
}
|
||||
truncated = append(truncated, encoded...)
|
||||
}
|
||||
filename = string(truncated)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
chatFile, err := api.Database.InsertChatFile(ctx, database.InsertChatFileParams{
|
||||
OwnerID: apiKey.UserID,
|
||||
OrganizationID: orgID,
|
||||
Name: filename,
|
||||
Mimetype: detected,
|
||||
Data: data,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to save chat file.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusCreated, codersdk.UploadChatFileResponse{
|
||||
ID: chatFile.ID,
|
||||
})
|
||||
}
|
||||
|
||||
// @Summary Get a chat file
|
||||
// @ID get-chat-file
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Chats
|
||||
// @Param file path string true "File ID" format(uuid)
|
||||
// @Success 200
|
||||
// @Failure 400 {object} codersdk.Response
|
||||
// @Failure 401 {object} codersdk.Response
|
||||
// @Failure 404 {object} codersdk.Response
|
||||
// @Failure 500 {object} codersdk.Response
|
||||
// @Router /chats/files/{file} [get]
|
||||
func (api *API) chatFileByID(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
fileIDStr := chi.URLParam(r, "file")
|
||||
fileID, err := uuid.Parse(fileIDStr)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid file ID.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
chatFile, err := api.Database.GetChatFileByID(ctx, fileID)
|
||||
if err != nil {
|
||||
if httpapi.Is404Error(err) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get chat file.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
rw.Header().Set("Content-Type", chatFile.Mimetype)
|
||||
if chatFile.Name != "" {
|
||||
rw.Header().Set("Content-Disposition", mime.FormatMediaType("inline", map[string]string{"filename": chatFile.Name}))
|
||||
} else {
|
||||
rw.Header().Set("Content-Disposition", "inline")
|
||||
}
|
||||
rw.Header().Set("Cache-Control", "private, max-age=31536000, immutable")
|
||||
rw.Header().Set("Content-Length", strconv.Itoa(len(chatFile.Data)))
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
_, _ = rw.Write(chatFile.Data)
|
||||
}
|
||||
|
||||
func createChatInputFromRequest(ctx context.Context, db database.Store, req codersdk.CreateChatRequest) (
|
||||
[]fantasy.Content,
|
||||
map[int]uuid.UUID,
|
||||
string,
|
||||
*codersdk.Response,
|
||||
) {
|
||||
return createChatInputFromParts(req.Content, "content")
|
||||
return createChatInputFromParts(ctx, db, req.Content, "content")
|
||||
}
|
||||
|
||||
func createChatInputFromParts(
|
||||
ctx context.Context,
|
||||
db database.Store,
|
||||
parts []codersdk.ChatInputPart,
|
||||
fieldName string,
|
||||
) ([]fantasy.Content, string, *codersdk.Response) {
|
||||
) ([]fantasy.Content, map[int]uuid.UUID, string, *codersdk.Response) {
|
||||
if len(parts) == 0 {
|
||||
return nil, "", &codersdk.Response{
|
||||
return nil, nil, "", &codersdk.Response{
|
||||
Message: "Content is required.",
|
||||
Detail: "Content cannot be empty.",
|
||||
}
|
||||
}
|
||||
|
||||
content := make([]fantasy.Content, 0, len(parts))
|
||||
fileIDs := make(map[int]uuid.UUID)
|
||||
textParts := make([]string, 0, len(parts))
|
||||
for i, part := range parts {
|
||||
switch strings.ToLower(strings.TrimSpace(string(part.Type))) {
|
||||
case string(codersdk.ChatInputPartTypeText):
|
||||
text := strings.TrimSpace(part.Text)
|
||||
if text == "" {
|
||||
return nil, "", &codersdk.Response{
|
||||
return nil, nil, "", &codersdk.Response{
|
||||
Message: "Invalid input part.",
|
||||
Detail: fmt.Sprintf("%s[%d].text cannot be empty.", fieldName, i),
|
||||
}
|
||||
}
|
||||
content = append(content, fantasy.TextContent{Text: text})
|
||||
textParts = append(textParts, text)
|
||||
case string(codersdk.ChatInputPartTypeFile):
|
||||
if part.FileID == uuid.Nil {
|
||||
return nil, nil, "", &codersdk.Response{
|
||||
Message: "Invalid input part.",
|
||||
Detail: fmt.Sprintf("%s[%d].file_id is required for file parts.", fieldName, i),
|
||||
}
|
||||
}
|
||||
// Validate that the file exists and get its media type.
|
||||
// File data is not loaded here; it's resolved at LLM
|
||||
// dispatch time via chatFileResolver.
|
||||
chatFile, err := db.GetChatFileByID(ctx, part.FileID)
|
||||
if err != nil {
|
||||
if httpapi.Is404Error(err) {
|
||||
return nil, nil, "", &codersdk.Response{
|
||||
Message: "Invalid input part.",
|
||||
Detail: fmt.Sprintf("%s[%d].file_id references a file that does not exist.", fieldName, i),
|
||||
}
|
||||
}
|
||||
return nil, nil, "", &codersdk.Response{
|
||||
Message: "Internal error.",
|
||||
Detail: fmt.Sprintf("Failed to retrieve file for %s[%d].", fieldName, i),
|
||||
}
|
||||
}
|
||||
content = append(content, fantasy.FileContent{
|
||||
MediaType: chatFile.Mimetype,
|
||||
})
|
||||
fileIDs[len(content)-1] = part.FileID
|
||||
case string(codersdk.ChatInputPartTypeFileReference):
|
||||
if part.FileName == "" {
|
||||
return nil, nil, "", &codersdk.Response{
|
||||
Message: "Invalid input part.",
|
||||
Detail: fmt.Sprintf("%s[%d].file_name cannot be empty for file-reference.", fieldName, i),
|
||||
}
|
||||
}
|
||||
lineRange := fmt.Sprintf("%d", part.StartLine)
|
||||
if part.StartLine != part.EndLine {
|
||||
lineRange = fmt.Sprintf("%d-%d", part.StartLine, part.EndLine)
|
||||
}
|
||||
var sb strings.Builder
|
||||
_, _ = fmt.Fprintf(&sb, "[file-reference] %s:%s", part.FileName, lineRange)
|
||||
if strings.TrimSpace(part.Content) != "" {
|
||||
_, _ = fmt.Fprintf(&sb, "\n```%s\n%s\n```", part.FileName, strings.TrimSpace(part.Content))
|
||||
}
|
||||
text := sb.String()
|
||||
content = append(content, fantasy.TextContent{Text: text})
|
||||
textParts = append(textParts, text)
|
||||
default:
|
||||
return nil, "", &codersdk.Response{
|
||||
return nil, nil, "", &codersdk.Response{
|
||||
Message: "Invalid input part.",
|
||||
Detail: fmt.Sprintf(
|
||||
"%s[%d].type %q is not supported.",
|
||||
@@ -2246,14 +2550,16 @@ func createChatInputFromParts(
|
||||
}
|
||||
}
|
||||
|
||||
titleSource := strings.TrimSpace(strings.Join(textParts, " "))
|
||||
if titleSource == "" {
|
||||
return nil, "", &codersdk.Response{
|
||||
// Allow file-only messages. The titleSource may be empty
|
||||
// when only file parts are provided, callers handle this.
|
||||
if len(content) == 0 {
|
||||
return nil, nil, "", &codersdk.Response{
|
||||
Message: "Content is required.",
|
||||
Detail: "Content must include at least one text part.",
|
||||
Detail: fmt.Sprintf("%s must include at least one text or file part.", fieldName),
|
||||
}
|
||||
}
|
||||
return content, titleSource, nil
|
||||
titleSource := strings.TrimSpace(strings.Join(textParts, " "))
|
||||
return content, fileIDs, titleSource, nil
|
||||
}
|
||||
|
||||
func chatTitleFromMessage(message string) string {
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package coderd_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -1525,6 +1527,541 @@ func TestPostChatMessages(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatMessageWithFileReferences(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// createChat is a helper that creates a chat so we can post messages to it.
|
||||
createChatForTest := func(t *testing.T, client *codersdk.Client) codersdk.Chat {
|
||||
t.Helper()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "initial message",
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return chat
|
||||
}
|
||||
|
||||
t.Run("FileReferenceOnly", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
chat := createChatForTest(t, client)
|
||||
|
||||
created, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeFileReference,
|
||||
FileName: "main.go",
|
||||
StartLine: 10,
|
||||
EndLine: 15,
|
||||
Content: "func broken() {}",
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// The file-reference is stored as a formatted text block.
|
||||
wantText := "[file-reference] main.go:10-15\n" +
|
||||
"```main.go\nfunc broken() {}\n```"
|
||||
|
||||
var found bool
|
||||
require.Eventually(t, func() bool {
|
||||
chatWithMessages, getErr := client.GetChat(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
for _, message := range chatWithMessages.Messages {
|
||||
if message.Role != "user" {
|
||||
continue
|
||||
}
|
||||
for _, part := range message.Content {
|
||||
if part.Type == codersdk.ChatMessagePartTypeText &&
|
||||
part.Text == wantText {
|
||||
found = true
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
// The message may have been queued.
|
||||
if created.Queued && created.QueuedMessage != nil {
|
||||
for _, queued := range chatWithMessages.QueuedMessages {
|
||||
for _, part := range queued.Content {
|
||||
if part.Type == codersdk.ChatMessagePartTypeText &&
|
||||
part.Text == wantText {
|
||||
found = true
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
require.True(t, found, "expected to find file-reference text in stored message")
|
||||
})
|
||||
|
||||
t.Run("FileReferenceSingleLine", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
chat := createChatForTest(t, client)
|
||||
|
||||
created, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeFileReference,
|
||||
FileName: "lib/utils.ts",
|
||||
StartLine: 42,
|
||||
EndLine: 42,
|
||||
Content: "const x = 1;",
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Single-line range should use "42" not "42-42".
|
||||
wantText := "[file-reference] lib/utils.ts:42\n" +
|
||||
"```lib/utils.ts\nconst x = 1;\n```"
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
chatWithMessages, getErr := client.GetChat(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
for _, msg := range chatWithMessages.Messages {
|
||||
for _, part := range msg.Content {
|
||||
if part.Type == codersdk.ChatMessagePartTypeText && part.Text == wantText {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
if created.Queued && created.QueuedMessage != nil {
|
||||
for _, queued := range chatWithMessages.QueuedMessages {
|
||||
for _, part := range queued.Content {
|
||||
if part.Type == codersdk.ChatMessagePartTypeText && part.Text == wantText {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
})
|
||||
|
||||
t.Run("FileReferenceWithoutContent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
chat := createChatForTest(t, client)
|
||||
|
||||
created, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeFileReference,
|
||||
FileName: "README.md",
|
||||
StartLine: 1,
|
||||
EndLine: 1,
|
||||
// No code content — just a file reference.
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// No fenced code block when content is empty.
|
||||
wantText := "[file-reference] README.md:1"
|
||||
require.Eventually(t, func() bool {
|
||||
chatWithMessages, getErr := client.GetChat(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
for _, msg := range chatWithMessages.Messages {
|
||||
for _, part := range msg.Content {
|
||||
if part.Type == codersdk.ChatMessagePartTypeText && part.Text == wantText {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
if created.Queued && created.QueuedMessage != nil {
|
||||
for _, queued := range chatWithMessages.QueuedMessages {
|
||||
for _, part := range queued.Content {
|
||||
if part.Type == codersdk.ChatMessagePartTypeText && part.Text == wantText {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
})
|
||||
|
||||
t.Run("FileReferenceWithCode", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
chat := createChatForTest(t, client)
|
||||
|
||||
created, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeFileReference,
|
||||
FileName: "server.go",
|
||||
StartLine: 5,
|
||||
EndLine: 8,
|
||||
Content: "func main() {\n\tfmt.Println()\n}",
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
wantText := "[file-reference] server.go:5-8\n" +
|
||||
"```server.go\nfunc main() {\n\tfmt.Println()\n}\n```"
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
chatWithMessages, getErr := client.GetChat(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
for _, msg := range chatWithMessages.Messages {
|
||||
for _, part := range msg.Content {
|
||||
if part.Type == codersdk.ChatMessagePartTypeText && part.Text == wantText {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
if created.Queued && created.QueuedMessage != nil {
|
||||
for _, queued := range chatWithMessages.QueuedMessages {
|
||||
for _, part := range queued.Content {
|
||||
if part.Type == codersdk.ChatMessagePartTypeText && part.Text == wantText {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
})
|
||||
|
||||
t.Run("InterleavedTextAndFileReferences", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
chat := createChatForTest(t, client)
|
||||
|
||||
created, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "Please review these two issues:",
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeFileReference,
|
||||
FileName: "a.go",
|
||||
StartLine: 1,
|
||||
EndLine: 3,
|
||||
Content: "line1\nline2\nline3",
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "first issue",
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "and also:",
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeFileReference,
|
||||
FileName: "b.go",
|
||||
StartLine: 10,
|
||||
EndLine: 10,
|
||||
Content: "return nil",
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "second issue",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify that all six parts are stored in order.
|
||||
wantTexts := []string{
|
||||
"Please review these two issues:",
|
||||
"[file-reference] a.go:1-3\n```a.go\nline1\nline2\nline3\n```",
|
||||
"first issue",
|
||||
"and also:",
|
||||
"[file-reference] b.go:10\n```b.go\nreturn nil\n```",
|
||||
"second issue",
|
||||
}
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
chatWithMessages, getErr := client.GetChat(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check messages and queued messages for the
|
||||
// interleaved parts in order.
|
||||
checkParts := func(parts []codersdk.ChatMessagePart) bool {
|
||||
textParts := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
if part.Type == codersdk.ChatMessagePartTypeText {
|
||||
textParts = append(textParts, part.Text)
|
||||
}
|
||||
}
|
||||
if len(textParts) != len(wantTexts) {
|
||||
return false
|
||||
}
|
||||
for i, want := range wantTexts {
|
||||
if textParts[i] != want {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
for _, msg := range chatWithMessages.Messages {
|
||||
if msg.Role == "user" && checkParts(msg.Content) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if created.Queued && created.QueuedMessage != nil {
|
||||
for _, queued := range chatWithMessages.QueuedMessages {
|
||||
if checkParts(queued.Content) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
})
|
||||
|
||||
t.Run("EmptyFileName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
chat := createChatForTest(t, client)
|
||||
|
||||
_, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeFileReference,
|
||||
FileName: "",
|
||||
StartLine: 1,
|
||||
EndLine: 1,
|
||||
}},
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Invalid input part.", sdkErr.Message)
|
||||
require.Equal(t, "content[0].file_name cannot be empty for file-reference.", sdkErr.Detail)
|
||||
})
|
||||
|
||||
t.Run("CreateChatWithFileReference", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
// File references should also work in the initial CreateChat call.
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeFileReference,
|
||||
FileName: "bug.py",
|
||||
StartLine: 7,
|
||||
EndLine: 7,
|
||||
Content: "x = None",
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, uuid.Nil, chat.ID)
|
||||
|
||||
// Title is derived from the text parts. For file-references
|
||||
// the formatted text becomes the title source.
|
||||
require.NotEmpty(t, chat.Title)
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatMessageWithFiles(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("FileOnly", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
// Upload a file.
|
||||
pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...)
|
||||
uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(pngData))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a chat with text first.
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "initial message",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Send a file-only message (no text).
|
||||
resp, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeFile,
|
||||
FileID: uploadResp.ID,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the message was accepted.
|
||||
if resp.Queued {
|
||||
require.NotNil(t, resp.QueuedMessage)
|
||||
} else {
|
||||
require.NotNil(t, resp.Message)
|
||||
require.Equal(t, "user", resp.Message.Role)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TextAndFile", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
// Upload a file.
|
||||
pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...)
|
||||
uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(pngData))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a chat with text first.
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "initial message",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Send a message with both text and file.
|
||||
resp, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "here is an image",
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeFile,
|
||||
FileID: uploadResp.ID,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
if resp.Queued {
|
||||
require.NotNil(t, resp.QueuedMessage)
|
||||
} else {
|
||||
require.NotNil(t, resp.Message)
|
||||
require.Equal(t, "user", resp.Message.Role)
|
||||
}
|
||||
|
||||
// Verify file parts omit inline data in the API response.
|
||||
chatWithMessages, err := client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
for _, msg := range chatWithMessages.Messages {
|
||||
for _, part := range msg.Content {
|
||||
if part.Type == codersdk.ChatMessagePartTypeFile {
|
||||
require.True(t, part.FileID.Valid, "file part should have a valid file_id")
|
||||
require.Equal(t, uploadResp.ID, part.FileID.UUID)
|
||||
require.Nil(t, part.Data, "file data should not be sent when file_id is present")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FileOnlyOnCreate", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
// Upload a file.
|
||||
pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...)
|
||||
uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(pngData))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a new chat with only a file part.
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeFile,
|
||||
FileID: uploadResp.ID,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// With no text, chatTitleFromMessage("") returns "New Chat".
|
||||
require.Equal(t, "New Chat", chat.Title)
|
||||
})
|
||||
|
||||
t.Run("InvalidFileID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
// Create a chat with text first.
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "initial message",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Send a message with a non-existent file ID.
|
||||
_, err = client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeFile,
|
||||
FileID: uuid.New(),
|
||||
},
|
||||
},
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Invalid input part.", sdkErr.Message)
|
||||
require.Contains(t, sdkErr.Detail, "does not exist")
|
||||
})
|
||||
}
|
||||
|
||||
func TestPatchChatMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1602,6 +2139,100 @@ func TestPatchChatMessage(t *testing.T) {
|
||||
require.False(t, foundOriginalInChat)
|
||||
})
|
||||
|
||||
t.Run("PreservesFileID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
// Upload a file.
|
||||
pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...)
|
||||
uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(pngData))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a chat with a text + file part.
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "before edit with file",
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeFile,
|
||||
FileID: uploadResp.ID,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Find the user message ID.
|
||||
chatWithMessages, err := client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var userMessageID int64
|
||||
for _, message := range chatWithMessages.Messages {
|
||||
if message.Role == "user" {
|
||||
userMessageID = message.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotZero(t, userMessageID)
|
||||
|
||||
// Edit the message: new text, same file_id.
|
||||
edited, err := client.EditChatMessage(ctx, chat.ID, userMessageID, codersdk.EditChatMessageRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "after edit with file",
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeFile,
|
||||
FileID: uploadResp.ID,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, userMessageID, edited.ID)
|
||||
|
||||
// Assert the edit response preserves the file_id.
|
||||
var foundText, foundFile bool
|
||||
for _, part := range edited.Content {
|
||||
if part.Type == codersdk.ChatMessagePartTypeText && part.Text == "after edit with file" {
|
||||
foundText = true
|
||||
}
|
||||
if part.Type == codersdk.ChatMessagePartTypeFile && part.FileID.Valid && part.FileID.UUID == uploadResp.ID {
|
||||
foundFile = true
|
||||
require.Nil(t, part.Data, "file data should not be sent when file_id is present")
|
||||
}
|
||||
}
|
||||
require.True(t, foundText, "edited message should contain updated text")
|
||||
require.True(t, foundFile, "edited message should preserve file_id")
|
||||
|
||||
// GET the chat and verify the file_id persists.
|
||||
updatedChat, err := client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var foundTextInChat, foundFileInChat bool
|
||||
for _, message := range updatedChat.Messages {
|
||||
if message.Role != "user" {
|
||||
continue
|
||||
}
|
||||
for _, part := range message.Content {
|
||||
if part.Type == codersdk.ChatMessagePartTypeText && part.Text == "after edit with file" {
|
||||
foundTextInChat = true
|
||||
}
|
||||
if part.Type == codersdk.ChatMessagePartTypeFile && part.FileID.Valid && part.FileID.UUID == uploadResp.ID {
|
||||
foundFileInChat = true
|
||||
require.Nil(t, part.Data, "file data should not be sent when file_id is present")
|
||||
}
|
||||
}
|
||||
}
|
||||
require.True(t, foundTextInChat, "chat should contain edited text")
|
||||
require.True(t, foundFileInChat, "chat should preserve file_id after edit")
|
||||
})
|
||||
|
||||
t.Run("MessageNotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -2212,6 +2843,259 @@ func TestPromoteChatQueuedMessage(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostChatFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Success/PNG", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Valid PNG header + padding.
|
||||
data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...)
|
||||
resp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(data))
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, uuid.Nil, resp.ID)
|
||||
})
|
||||
|
||||
t.Run("Success/JPEG", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
data := append([]byte{0xFF, 0xD8, 0xFF, 0xE0}, make([]byte, 64)...)
|
||||
resp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/jpeg", "test.jpg", bytes.NewReader(data))
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, uuid.Nil, resp.ID)
|
||||
})
|
||||
|
||||
t.Run("Success/WebP", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// WebP: RIFF + 4-byte size + WEBP + padding.
|
||||
data := append([]byte("RIFF"), make([]byte, 4)...)
|
||||
data = append(data, []byte("WEBP")...)
|
||||
data = append(data, make([]byte, 64)...)
|
||||
resp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/webp", "test.webp", bytes.NewReader(data))
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, uuid.Nil, resp.ID)
|
||||
})
|
||||
|
||||
t.Run("UnsupportedContentType", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "text/plain", "test.txt", bytes.NewReader([]byte("hello")))
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("SVGBlocked", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/svg+xml", "test.svg", bytes.NewReader([]byte("<svg></svg>")))
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("ContentSniffingRejects", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Header says PNG but body is plain text.
|
||||
_, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader([]byte("hello world")))
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("TooLarge", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// 10 MB + 1 byte, with valid PNG header to pass MIME check.
|
||||
data := make([]byte, 10<<20+1)
|
||||
copy(data, []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A})
|
||||
_, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(data))
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("MissingOrganization", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...)
|
||||
res, err := client.Request(ctx, http.MethodPost, "/api/experimental/chats/files", bytes.NewReader(data), func(r *http.Request) {
|
||||
r.Header.Set("Content-Type", "image/png")
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
err = codersdk.ReadBodyAsError(res)
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Contains(t, sdkErr.Message, "Missing organization")
|
||||
})
|
||||
|
||||
t.Run("InvalidOrganization", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...)
|
||||
res, err := client.Request(ctx, http.MethodPost, "/api/experimental/chats/files?organization=not-a-uuid", bytes.NewReader(data), func(r *http.Request) {
|
||||
r.Header.Set("Content-Type", "image/png")
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
err = codersdk.ReadBodyAsError(res)
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Contains(t, sdkErr.Message, "Invalid organization ID")
|
||||
})
|
||||
|
||||
t.Run("WrongOrganization", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...)
|
||||
_, err := client.UploadChatFile(ctx, uuid.New(), "image/png", "test.png", bytes.NewReader(data))
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
// dbauthz returns 404 or 500 depending on how the org lookup
|
||||
// fails; 403 is also possible. Any non-success code is valid.
|
||||
require.GreaterOrEqual(t, sdkErr.StatusCode(), http.StatusBadRequest,
|
||||
"expected error status, got %d", sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("Unauthenticated", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
unauthed := codersdk.New(client.URL)
|
||||
data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...)
|
||||
_, err := unauthed.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(data))
|
||||
requireSDKError(t, err, http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetChatFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...)
|
||||
uploaded, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(data))
|
||||
require.NoError(t, err)
|
||||
|
||||
got, contentType, err := client.GetChatFile(ctx, uploaded.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "image/png", contentType)
|
||||
require.Equal(t, data, got)
|
||||
})
|
||||
|
||||
t.Run("CacheHeaders", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...)
|
||||
uploaded, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(data))
|
||||
require.NoError(t, err)
|
||||
|
||||
res, err := client.Request(ctx, http.MethodGet,
|
||||
fmt.Sprintf("/api/experimental/chats/files/%s", uploaded.ID), nil)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
require.Equal(t, "private, max-age=31536000, immutable", res.Header.Get("Cache-Control"))
|
||||
require.Contains(t, res.Header.Get("Content-Disposition"), "inline")
|
||||
require.Contains(t, res.Header.Get("Content-Disposition"), "test.png")
|
||||
})
|
||||
|
||||
t.Run("LongFilename", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
longName := strings.Repeat("a", 300) + ".png"
|
||||
data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...)
|
||||
uploaded, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", longName, bytes.NewReader(data))
|
||||
require.NoError(t, err)
|
||||
|
||||
res, err := client.Request(ctx, http.MethodGet,
|
||||
fmt.Sprintf("/api/experimental/chats/files/%s", uploaded.ID), nil)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
// Filename should be truncated to maxChatFileName (255) bytes.
|
||||
cd := res.Header.Get("Content-Disposition")
|
||||
require.Contains(t, cd, "inline")
|
||||
require.Contains(t, cd, strings.Repeat("a", 255))
|
||||
require.NotContains(t, cd, strings.Repeat("a", 256))
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_, _, err := client.GetChatFile(ctx, uuid.New())
|
||||
requireSDKError(t, err, http.StatusNotFound)
|
||||
})
|
||||
|
||||
t.Run("InvalidUUID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
res, err := client.Request(ctx, http.MethodGet,
|
||||
"/api/experimental/chats/files/not-a-uuid", nil)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
err = codersdk.ReadBodyAsError(res)
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("OtherUserForbidden", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...)
|
||||
uploaded, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(data))
|
||||
require.NoError(t, err)
|
||||
|
||||
otherClient, _ := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID)
|
||||
_, _, err = otherClient.GetChatFile(ctx, uploaded.ID)
|
||||
requireSDKError(t, err, http.StatusNotFound)
|
||||
})
|
||||
}
|
||||
|
||||
func createChatModelConfig(t *testing.T, client *codersdk.Client) codersdk.ChatModelConfig {
|
||||
t.Helper()
|
||||
|
||||
|
||||
+24
-4
@@ -662,6 +662,7 @@ func New(options *Options) *API {
|
||||
api.SiteHandler, err = site.New(&site.Options{
|
||||
CacheDir: siteCacheDir,
|
||||
Database: options.Database,
|
||||
Authorizer: options.Authorizer,
|
||||
SiteFS: site.FS(),
|
||||
OAuth2Configs: oauthConfigs,
|
||||
DocsURL: options.DeploymentValues.DocsURL.String(),
|
||||
@@ -926,6 +927,16 @@ func New(options *Options) *API {
|
||||
loggermw.Logger(api.Logger),
|
||||
singleSlashMW,
|
||||
rolestore.CustomRoleMW,
|
||||
// Validate API key on every request (if present) and store
|
||||
// the result in context. The rate limiter reads this to key
|
||||
// by user ID, and downstream ExtractAPIKeyMW reuses it to
|
||||
// avoid redundant DB lookups. Never rejects requests.
|
||||
httpmw.PrecheckAPIKey(httpmw.ValidateAPIKeyConfig{
|
||||
DB: options.Database,
|
||||
OAuth2Configs: oauthConfigs,
|
||||
DisableSessionExpiryRefresh: options.DeploymentValues.Sessions.DisableExpiryRefresh.Value(),
|
||||
Logger: options.Logger,
|
||||
}),
|
||||
httpmw.HTTPRoute, // NB: prometheusMW depends on this middleware.
|
||||
prometheusMW,
|
||||
// Build-Version is helpful for debugging.
|
||||
@@ -1074,8 +1085,6 @@ func New(options *Options) *API {
|
||||
|
||||
r.NotFound(func(rw http.ResponseWriter, _ *http.Request) { httpapi.RouteNotFound(rw) })
|
||||
r.Use(
|
||||
// Specific routes can specify different limits, but every rate
|
||||
// limit must be configurable by the admin.
|
||||
apiRateLimiter,
|
||||
httpmw.ReportCLITelemetry(api.Logger, options.Telemetry),
|
||||
)
|
||||
@@ -1113,6 +1122,11 @@ func New(options *Options) *API {
|
||||
r.Post("/", api.postChats)
|
||||
r.Get("/models", api.listChatModels)
|
||||
r.Get("/watch", api.watchChats)
|
||||
r.Route("/files", func(r chi.Router) {
|
||||
r.Use(httpmw.RateLimit(options.FilesRateLimit, time.Minute))
|
||||
r.Post("/", api.postChatFile)
|
||||
r.Get("/{file}", api.chatFileByID)
|
||||
})
|
||||
r.Route("/providers", func(r chi.Router) {
|
||||
r.Get("/", api.listChatProviders)
|
||||
r.Post("/", api.createChatProvider)
|
||||
@@ -1163,8 +1177,6 @@ func New(options *Options) *API {
|
||||
|
||||
r.NotFound(func(rw http.ResponseWriter, _ *http.Request) { httpapi.RouteNotFound(rw) })
|
||||
r.Use(
|
||||
// Specific routes can specify different limits, but every rate
|
||||
// limit must be configurable by the admin.
|
||||
apiRateLimiter,
|
||||
httpmw.ReportCLITelemetry(api.Logger, options.Telemetry),
|
||||
)
|
||||
@@ -1842,6 +1854,14 @@ func New(options *Options) *API {
|
||||
"parsing additional CSP headers", slog.Error(cspParseErrors))
|
||||
}
|
||||
|
||||
// Add blob: to img-src for chat file attachment previews when
|
||||
// the agents experiment is enabled.
|
||||
if api.Experiments.Enabled(codersdk.ExperimentAgents) {
|
||||
additionalCSPHeaders[httpmw.CSPDirectiveImgSrc] = append(
|
||||
additionalCSPHeaders[httpmw.CSPDirectiveImgSrc], "blob:",
|
||||
)
|
||||
}
|
||||
|
||||
// Add CSP headers to all static assets and pages. CSP headers only affect
|
||||
// browsers, so these don't make sense on api routes.
|
||||
cspMW := httpmw.CSPHeaders(
|
||||
|
||||
@@ -416,3 +416,91 @@ func TestDERPMetrics(t *testing.T) {
|
||||
assert.Contains(t, names, "coder_derp_server_packets_dropped_reason_total",
|
||||
"expected coder_derp_server_packets_dropped_reason_total to be registered")
|
||||
}
|
||||
|
||||
// TestRateLimitByUser verifies that rate limiting keys by user ID when
|
||||
// an authenticated session is present, rather than falling back to IP.
|
||||
// This is a regression test for https://github.com/coder/coder/issues/20857
|
||||
func TestRateLimitByUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const rateLimit = 5
|
||||
|
||||
ownerClient := coderdtest.New(t, &coderdtest.Options{
|
||||
APIRateLimit: rateLimit,
|
||||
})
|
||||
firstUser := coderdtest.CreateFirstUser(t, ownerClient)
|
||||
|
||||
t.Run("HitsLimit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Make rateLimit requests — they should all succeed.
|
||||
for i := 0; i < rateLimit; i++ {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
|
||||
ownerClient.URL.String()+"/api/v2/buildinfo", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(codersdk.SessionTokenHeader, ownerClient.SessionToken())
|
||||
|
||||
resp, err := ownerClient.HTTPClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode,
|
||||
"request %d should succeed", i+1)
|
||||
}
|
||||
|
||||
// The next request should be rate-limited.
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
|
||||
ownerClient.URL.String()+"/api/v2/buildinfo", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(codersdk.SessionTokenHeader, ownerClient.SessionToken())
|
||||
|
||||
resp, err := ownerClient.HTTPClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
resp.Body.Close()
|
||||
require.Equal(t, http.StatusTooManyRequests, resp.StatusCode,
|
||||
"request should be rate limited")
|
||||
})
|
||||
|
||||
t.Run("BypassOwner", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Owner with bypass header should not be rate-limited.
|
||||
for i := 0; i < rateLimit+5; i++ {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
|
||||
ownerClient.URL.String()+"/api/v2/buildinfo", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(codersdk.SessionTokenHeader, ownerClient.SessionToken())
|
||||
req.Header.Set(codersdk.BypassRatelimitHeader, "true")
|
||||
|
||||
resp, err := ownerClient.HTTPClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode,
|
||||
"owner bypass request %d should succeed", i+1)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MemberCannotBypass", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
memberClient, _ := coderdtest.CreateAnotherUser(t, ownerClient, firstUser.OrganizationID)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// A member requesting the bypass header should be rejected
|
||||
// with 428 Precondition Required — only owners may bypass.
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
|
||||
memberClient.URL.String()+"/api/v2/buildinfo", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(codersdk.SessionTokenHeader, memberClient.SessionToken())
|
||||
req.Header.Set(codersdk.BypassRatelimitHeader, "true")
|
||||
|
||||
resp, err := memberClient.HTTPClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
resp.Body.Close()
|
||||
require.Equal(t, http.StatusPreconditionRequired, resp.StatusCode,
|
||||
"member should not be able to bypass rate limit")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1156,9 +1156,7 @@ func chatMessageParts(role string, raw pqtype.NullRawMessage) ([]codersdk.ChatMe
|
||||
}
|
||||
|
||||
var rawBlocks []json.RawMessage
|
||||
if role == string(fantasy.MessageRoleAssistant) {
|
||||
_ = json.Unmarshal(raw.RawMessage, &rawBlocks)
|
||||
}
|
||||
_ = json.Unmarshal(raw.RawMessage, &rawBlocks)
|
||||
|
||||
parts := make([]codersdk.ChatMessagePart, 0, len(content))
|
||||
for i, block := range content {
|
||||
@@ -1166,10 +1164,20 @@ func chatMessageParts(role string, raw pqtype.NullRawMessage) ([]codersdk.ChatMe
|
||||
if part.Type == "" {
|
||||
continue
|
||||
}
|
||||
if part.Type == codersdk.ChatMessagePartTypeReasoning {
|
||||
part.Title = ""
|
||||
if i < len(rawBlocks) {
|
||||
if i < len(rawBlocks) {
|
||||
switch part.Type {
|
||||
case codersdk.ChatMessagePartTypeReasoning:
|
||||
part.Title = reasoningStoredTitle(rawBlocks[i])
|
||||
case codersdk.ChatMessagePartTypeFile:
|
||||
if fid, err := chatprompt.ExtractFileID(rawBlocks[i]); err == nil {
|
||||
part.FileID = uuid.NullUUID{UUID: fid, Valid: true}
|
||||
}
|
||||
// When a file_id is present, omit inline data
|
||||
// from the response. Clients fetch content via
|
||||
// the GET /chats/files/{id} endpoint instead.
|
||||
if part.FileID.Valid {
|
||||
part.Data = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
parts = append(parts, part)
|
||||
|
||||
@@ -2457,6 +2457,30 @@ func (q *querier) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIDs []uu
|
||||
return q.db.GetChatDiffStatusesByChatIDs(ctx, chatIDs)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatFileByID(ctx context.Context, id uuid.UUID) (database.ChatFile, error) {
|
||||
file, err := q.db.GetChatFileByID(ctx, id)
|
||||
if err != nil {
|
||||
return database.ChatFile{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, file); err != nil {
|
||||
return database.ChatFile{}, err
|
||||
}
|
||||
return file, nil
|
||||
}
|
||||
|
||||
func (q *querier) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) {
|
||||
files, err := q.db.GetChatFilesByIDs(ctx, ids)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, f := range files {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, f); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) {
|
||||
// ChatMessages are authorized through their parent Chat.
|
||||
// We need to fetch the message first to get its chat_id.
|
||||
@@ -4491,6 +4515,11 @@ func (q *querier) InsertChat(ctx context.Context, arg database.InsertChatParams)
|
||||
return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()), q.db.InsertChat)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) {
|
||||
// Authorize create on chat resource scoped to the owner and org.
|
||||
return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID), q.db.InsertChatFile)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
|
||||
// Authorize create on the parent chat (using update permission).
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
|
||||
@@ -463,6 +463,16 @@ func (s *MethodTestSuite) TestChats() {
|
||||
Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).
|
||||
Returns([]database.ChatDiffStatus{diffStatusA, diffStatusB})
|
||||
}))
|
||||
s.Run("GetChatFileByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
file := testutil.Fake(s.T(), faker, database.ChatFile{})
|
||||
dbm.EXPECT().GetChatFileByID(gomock.Any(), file.ID).Return(file, nil).AnyTimes()
|
||||
check.Args(file.ID).Asserts(rbac.ResourceChat.WithOwner(file.OwnerID.String()).InOrg(file.OrganizationID).WithID(file.ID), policy.ActionRead).Returns(file)
|
||||
}))
|
||||
s.Run("GetChatFilesByIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
file := testutil.Fake(s.T(), faker, database.ChatFile{})
|
||||
dbm.EXPECT().GetChatFilesByIDs(gomock.Any(), []uuid.UUID{file.ID}).Return([]database.ChatFile{file}, nil).AnyTimes()
|
||||
check.Args([]uuid.UUID{file.ID}).Asserts(rbac.ResourceChat.WithOwner(file.OwnerID.String()).InOrg(file.OrganizationID).WithID(file.ID), policy.ActionRead).Returns([]database.ChatFile{file})
|
||||
}))
|
||||
s.Run("GetChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
|
||||
@@ -579,6 +589,12 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().InsertChat(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionCreate).Returns(chat)
|
||||
}))
|
||||
s.Run("InsertChatFile", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := testutil.Fake(s.T(), faker, database.InsertChatFileParams{})
|
||||
file := testutil.Fake(s.T(), faker, database.InsertChatFileRow{OwnerID: arg.OwnerID, OrganizationID: arg.OrganizationID})
|
||||
dbm.EXPECT().InsertChatFile(gomock.Any(), arg).Return(file, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID), policy.ActionCreate).Returns(file)
|
||||
}))
|
||||
s.Run("InsertChatMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := testutil.Fake(s.T(), faker, database.InsertChatMessageParams{ChatID: chat.ID})
|
||||
|
||||
@@ -1007,6 +1007,22 @@ func (m queryMetricsStore) GetChatDiffStatusesByChatIDs(ctx context.Context, cha
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatFileByID(ctx context.Context, id uuid.UUID) (database.ChatFile, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatFileByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetChatFileByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatFileByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatFilesByIDs(ctx, ids)
|
||||
m.queryLatencies.WithLabelValues("GetChatFilesByIDs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatFilesByIDs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessageByID(ctx, id)
|
||||
@@ -2943,6 +2959,14 @@ func (m queryMetricsStore) InsertChat(ctx context.Context, arg database.InsertCh
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChatFile(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertChatFile").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatFile").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChatMessage(ctx, arg)
|
||||
|
||||
@@ -1837,6 +1837,36 @@ func (mr *MockStoreMockRecorder) GetChatDiffStatusesByChatIDs(ctx, chatIds any)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDiffStatusesByChatIDs", reflect.TypeOf((*MockStore)(nil).GetChatDiffStatusesByChatIDs), ctx, chatIds)
|
||||
}
|
||||
|
||||
// GetChatFileByID mocks base method.
|
||||
func (m *MockStore) GetChatFileByID(ctx context.Context, id uuid.UUID) (database.ChatFile, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatFileByID", ctx, id)
|
||||
ret0, _ := ret[0].(database.ChatFile)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatFileByID indicates an expected call of GetChatFileByID.
|
||||
func (mr *MockStoreMockRecorder) GetChatFileByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatFileByID", reflect.TypeOf((*MockStore)(nil).GetChatFileByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetChatFilesByIDs mocks base method.
|
||||
func (m *MockStore) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatFilesByIDs", ctx, ids)
|
||||
ret0, _ := ret[0].([]database.ChatFile)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatFilesByIDs indicates an expected call of GetChatFilesByIDs.
|
||||
func (mr *MockStoreMockRecorder) GetChatFilesByIDs(ctx, ids any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatFilesByIDs", reflect.TypeOf((*MockStore)(nil).GetChatFilesByIDs), ctx, ids)
|
||||
}
|
||||
|
||||
// GetChatMessageByID mocks base method.
|
||||
func (m *MockStore) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -5511,6 +5541,21 @@ func (mr *MockStoreMockRecorder) InsertChat(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChat", reflect.TypeOf((*MockStore)(nil).InsertChat), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChatFile mocks base method.
|
||||
func (m *MockStore) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertChatFile", ctx, arg)
|
||||
ret0, _ := ret[0].(database.InsertChatFileRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertChatFile indicates an expected call of InsertChatFile.
|
||||
func (mr *MockStoreMockRecorder) InsertChatFile(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatFile", reflect.TypeOf((*MockStore)(nil).InsertChatFile), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChatMessage mocks base method.
|
||||
func (m *MockStore) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
+23
@@ -1190,6 +1190,16 @@ CREATE TABLE chat_diff_statuses (
|
||||
git_remote_origin text DEFAULT ''::text NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE chat_files (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
owner_id uuid NOT NULL,
|
||||
organization_id uuid NOT NULL,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
name text DEFAULT ''::text NOT NULL,
|
||||
mimetype text NOT NULL,
|
||||
data bytea NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE chat_messages (
|
||||
id bigint NOT NULL,
|
||||
chat_id uuid NOT NULL,
|
||||
@@ -3140,6 +3150,9 @@ ALTER TABLE ONLY boundary_usage_stats
|
||||
ALTER TABLE ONLY chat_diff_statuses
|
||||
ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
|
||||
|
||||
ALTER TABLE ONLY chat_files
|
||||
ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY chat_messages
|
||||
ADD CONSTRAINT chat_messages_pkey PRIMARY KEY (id);
|
||||
|
||||
@@ -3495,6 +3508,10 @@ CREATE INDEX idx_audit_logs_time_desc ON audit_logs USING btree ("time" DESC);
|
||||
|
||||
CREATE INDEX idx_chat_diff_statuses_stale_at ON chat_diff_statuses USING btree (stale_at);
|
||||
|
||||
CREATE INDEX idx_chat_files_org ON chat_files USING btree (organization_id);
|
||||
|
||||
CREATE INDEX idx_chat_files_owner ON chat_files USING btree (owner_id);
|
||||
|
||||
CREATE INDEX idx_chat_messages_chat ON chat_messages USING btree (chat_id);
|
||||
|
||||
CREATE INDEX idx_chat_messages_chat_created ON chat_messages USING btree (chat_id, created_at);
|
||||
@@ -3774,6 +3791,12 @@ ALTER TABLE ONLY api_keys
|
||||
ALTER TABLE ONLY chat_diff_statuses
|
||||
ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_files
|
||||
ADD CONSTRAINT chat_files_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_files
|
||||
ADD CONSTRAINT chat_files_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_messages
|
||||
ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@ const (
|
||||
ForeignKeyAibridgeInterceptionsInitiatorID ForeignKeyConstraint = "aibridge_interceptions_initiator_id_fkey" // ALTER TABLE ONLY aibridge_interceptions ADD CONSTRAINT aibridge_interceptions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id);
|
||||
ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatDiffStatusesChatID ForeignKeyConstraint = "chat_diff_statuses_chat_id_fkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatFilesOrganizationID ForeignKeyConstraint = "chat_files_organization_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatFilesOwnerID ForeignKeyConstraint = "chat_files_owner_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatMessagesChatID ForeignKeyConstraint = "chat_messages_chat_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatMessagesModelConfigID ForeignKeyConstraint = "chat_messages_model_config_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_model_config_id_fkey FOREIGN KEY (model_config_id) REFERENCES chat_model_configs(id);
|
||||
ForeignKeyChatModelConfigsCreatedBy ForeignKeyConstraint = "chat_model_configs_created_by_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id);
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
DROP INDEX IF EXISTS idx_chat_files_org;
|
||||
DROP TABLE IF EXISTS chat_files;
|
||||
@@ -0,0 +1,12 @@
|
||||
CREATE TABLE chat_files (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
owner_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
organization_id UUID NOT NULL REFERENCES organizations(id) ON DELETE CASCADE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
name TEXT NOT NULL DEFAULT '',
|
||||
mimetype TEXT NOT NULL,
|
||||
data BYTEA NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX idx_chat_files_owner ON chat_files(owner_id);
|
||||
CREATE INDEX idx_chat_files_org ON chat_files(organization_id);
|
||||
@@ -0,0 +1,13 @@
|
||||
INSERT INTO chat_files (id, owner_id, organization_id, created_at, name, mimetype, data)
|
||||
SELECT
|
||||
'00000000-0000-0000-0000-000000000099',
|
||||
u.id,
|
||||
om.organization_id,
|
||||
'2024-01-01 00:00:00+00',
|
||||
'test.png',
|
||||
'image/png',
|
||||
E'\\x89504E47'
|
||||
FROM users u
|
||||
JOIN organization_members om ON om.user_id = u.id
|
||||
ORDER BY u.created_at, u.id
|
||||
LIMIT 1;
|
||||
@@ -178,6 +178,10 @@ func (c Chat) RBACObject() rbac.Object {
|
||||
return rbac.ResourceChat.WithID(c.ID).WithOwner(c.OwnerID.String())
|
||||
}
|
||||
|
||||
func (c ChatFile) RBACObject() rbac.Object {
|
||||
return rbac.ResourceChat.WithID(c.ID).WithOwner(c.OwnerID.String()).InOrg(c.OrganizationID)
|
||||
}
|
||||
|
||||
func (s APIKeyScope) ToRBAC() rbac.ScopeName {
|
||||
switch s {
|
||||
case ApiKeyScopeCoderAll:
|
||||
|
||||
@@ -3926,6 +3926,16 @@ type ChatDiffStatus struct {
|
||||
GitRemoteOrigin string `db:"git_remote_origin" json:"git_remote_origin"`
|
||||
}
|
||||
|
||||
type ChatFile struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
|
||||
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Mimetype string `db:"mimetype" json:"mimetype"`
|
||||
Data []byte `db:"data" json:"data"`
|
||||
}
|
||||
|
||||
type ChatMessage struct {
|
||||
ID int64 `db:"id" json:"id"`
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
|
||||
@@ -218,6 +218,8 @@ type sqlcQuerier interface {
|
||||
GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Chat, error)
|
||||
GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error)
|
||||
GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error)
|
||||
GetChatFileByID(ctx context.Context, id uuid.UUID) (ChatFile, error)
|
||||
GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]ChatFile, error)
|
||||
GetChatMessageByID(ctx context.Context, id int64) (ChatMessage, error)
|
||||
GetChatMessagesByChatID(ctx context.Context, arg GetChatMessagesByChatIDParams) ([]ChatMessage, error)
|
||||
GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error)
|
||||
@@ -601,6 +603,7 @@ type sqlcQuerier interface {
|
||||
InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (Group, error)
|
||||
InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error)
|
||||
InsertChat(ctx context.Context, arg InsertChatParams) (Chat, error)
|
||||
InsertChatFile(ctx context.Context, arg InsertChatFileParams) (InsertChatFileRow, error)
|
||||
InsertChatMessage(ctx context.Context, arg InsertChatMessageParams) (ChatMessage, error)
|
||||
InsertChatModelConfig(ctx context.Context, arg InsertChatModelConfigParams) (ChatModelConfig, error)
|
||||
InsertChatProvider(ctx context.Context, arg InsertChatProviderParams) (ChatProvider, error)
|
||||
|
||||
@@ -3867,6 +3867,37 @@ func TestGetProvisionerJobsByIDsWithQueuePosition(t *testing.T) {
|
||||
queueSizes: nil, // TODO(yevhenii): should it be empty array instead?
|
||||
queuePositions: nil,
|
||||
},
|
||||
// Many daemons with identical tags should produce same results as one.
|
||||
{
|
||||
name: "duplicate-daemons-same-tags",
|
||||
jobTags: []database.StringMap{
|
||||
{"a": "1"},
|
||||
{"a": "1", "b": "2"},
|
||||
},
|
||||
daemonTags: []database.StringMap{
|
||||
{"a": "1", "b": "2"},
|
||||
{"a": "1", "b": "2"},
|
||||
{"a": "1", "b": "2"},
|
||||
},
|
||||
queueSizes: []int64{2, 2},
|
||||
queuePositions: []int64{1, 2},
|
||||
},
|
||||
// Jobs that don't match any queried job's daemon should still
|
||||
// have correct queue positions.
|
||||
{
|
||||
name: "irrelevant-daemons-filtered",
|
||||
jobTags: []database.StringMap{
|
||||
{"a": "1"},
|
||||
{"x": "9"},
|
||||
},
|
||||
daemonTags: []database.StringMap{
|
||||
{"a": "1"},
|
||||
{"x": "9"},
|
||||
},
|
||||
queueSizes: []int64{1},
|
||||
queuePositions: []int64{1},
|
||||
skipJobIDs: map[int]struct{}{1: {}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -4192,6 +4223,51 @@ func TestGetProvisionerJobsByIDsWithQueuePosition_OrderValidation(t *testing.T)
|
||||
assert.EqualValues(t, []int64{1, 2, 3, 4, 5, 6}, queuePositions, "expected queue positions to be set correctly")
|
||||
}
|
||||
|
||||
func TestGetProvisionerJobsByIDsWithQueuePosition_DuplicateDaemons(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
now := dbtime.Now()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
// Create 3 pending jobs with the same tags.
|
||||
jobs := make([]database.ProvisionerJob, 3)
|
||||
for i := range jobs {
|
||||
jobs[i] = dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
||||
CreatedAt: now.Add(-time.Duration(3-i) * time.Minute),
|
||||
Tags: database.StringMap{"scope": "organization", "owner": ""},
|
||||
})
|
||||
}
|
||||
|
||||
// Create 50 daemons with identical tags (simulates scale).
|
||||
for i := range 50 {
|
||||
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
||||
Name: fmt.Sprintf("daemon_%d", i),
|
||||
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
||||
Tags: database.StringMap{"scope": "organization", "owner": ""},
|
||||
})
|
||||
}
|
||||
|
||||
jobIDs := make([]uuid.UUID, len(jobs))
|
||||
for i, j := range jobs {
|
||||
jobIDs[i] = j.ID
|
||||
}
|
||||
|
||||
results, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx,
|
||||
database.GetProvisionerJobsByIDsWithQueuePositionParams{
|
||||
IDs: jobIDs,
|
||||
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 3)
|
||||
|
||||
// All daemons have identical tags, so queue should be same as
|
||||
// if there were just one daemon.
|
||||
for i, r := range results {
|
||||
assert.Equal(t, int64(3), r.QueueSize, "job %d queue size", i)
|
||||
assert.Equal(t, int64(i+1), r.QueuePosition, "job %d queue position", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroupRemovalTrigger(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -8841,3 +8917,202 @@ func TestInsertWorkspaceAgentDevcontainers(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetChatMessagesForPromptByChatID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// This test exercises a complex CTE query for prompt
|
||||
// reconstruction after compaction. It requires Postgres.
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Helper: create a chat model config (required FK for chats).
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
|
||||
// A chat_providers row is required as a FK for model configs.
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "test-model",
|
||||
DisplayName: "Test Model",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 80,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
newChat := func(t *testing.T) database.Chat {
|
||||
t.Helper()
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "test-chat-" + uuid.NewString(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return chat
|
||||
}
|
||||
|
||||
insertMsg := func(
|
||||
t *testing.T,
|
||||
chatID uuid.UUID,
|
||||
role string,
|
||||
vis database.ChatMessageVisibility,
|
||||
compressed bool,
|
||||
content string,
|
||||
) database.ChatMessage {
|
||||
t.Helper()
|
||||
msg, err := db.InsertChatMessage(ctx, database.InsertChatMessageParams{
|
||||
ChatID: chatID,
|
||||
Role: role,
|
||||
Visibility: vis,
|
||||
Compressed: sql.NullBool{Bool: compressed, Valid: true},
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`"` + content + `"`),
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return msg
|
||||
}
|
||||
|
||||
msgIDs := func(msgs []database.ChatMessage) []int64 {
|
||||
ids := make([]int64, len(msgs))
|
||||
for i, m := range msgs {
|
||||
ids[i] = m.ID
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
t.Run("NoCompaction", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
chat := newChat(t)
|
||||
|
||||
sys := insertMsg(t, chat.ID, "system", database.ChatMessageVisibilityModel, false, "system prompt")
|
||||
usr := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityBoth, false, "hello")
|
||||
ast := insertMsg(t, chat.ID, "assistant", database.ChatMessageVisibilityBoth, false, "hi there")
|
||||
|
||||
got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{sys.ID, usr.ID, ast.ID}, msgIDs(got))
|
||||
})
|
||||
|
||||
t.Run("UserOnlyVisibilityExcluded", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
chat := newChat(t)
|
||||
|
||||
// Messages with visibility=user should NOT appear in the
|
||||
// prompt (they are only for the UI).
|
||||
insertMsg(t, chat.ID, "system", database.ChatMessageVisibilityModel, false, "system prompt")
|
||||
insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityUser, false, "user-only msg")
|
||||
usr := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityBoth, false, "hello")
|
||||
|
||||
got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
for _, m := range got {
|
||||
require.NotEqual(t, database.ChatMessageVisibilityUser, m.Visibility,
|
||||
"visibility=user messages should not appear in the prompt")
|
||||
}
|
||||
require.Contains(t, msgIDs(got), usr.ID)
|
||||
})
|
||||
|
||||
t.Run("AfterCompaction", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
chat := newChat(t)
|
||||
|
||||
// Pre-compaction conversation.
|
||||
sys := insertMsg(t, chat.ID, "system", database.ChatMessageVisibilityModel, false, "system prompt")
|
||||
preUser := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityBoth, false, "old question")
|
||||
preAsst := insertMsg(t, chat.ID, "assistant", database.ChatMessageVisibilityBoth, false, "old answer")
|
||||
|
||||
// Compaction messages:
|
||||
// 1. Summary (role=user, visibility=model, compressed=true).
|
||||
summary := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityModel, true, "compaction summary")
|
||||
// 2. Compressed assistant tool-call (visibility=user).
|
||||
insertMsg(t, chat.ID, "assistant", database.ChatMessageVisibilityUser, true, "tool call")
|
||||
// 3. Compressed tool result (visibility=both).
|
||||
insertMsg(t, chat.ID, "tool", database.ChatMessageVisibilityBoth, true, "tool result")
|
||||
|
||||
// Post-compaction messages.
|
||||
postUser := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityBoth, false, "new question")
|
||||
postAsst := insertMsg(t, chat.ID, "assistant", database.ChatMessageVisibilityBoth, false, "new answer")
|
||||
|
||||
got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
gotIDs := msgIDs(got)
|
||||
|
||||
// Must include: system prompt, summary, post-compaction.
|
||||
require.Contains(t, gotIDs, sys.ID, "system prompt must be included")
|
||||
require.Contains(t, gotIDs, summary.ID, "compaction summary must be included")
|
||||
require.Contains(t, gotIDs, postUser.ID, "post-compaction user msg must be included")
|
||||
require.Contains(t, gotIDs, postAsst.ID, "post-compaction assistant msg must be included")
|
||||
|
||||
// Must exclude: pre-compaction non-system messages.
|
||||
require.NotContains(t, gotIDs, preUser.ID, "pre-compaction user msg must be excluded")
|
||||
require.NotContains(t, gotIDs, preAsst.ID, "pre-compaction assistant msg must be excluded")
|
||||
|
||||
// Verify ordering.
|
||||
require.Equal(t, []int64{sys.ID, summary.ID, postUser.ID, postAsst.ID}, gotIDs)
|
||||
})
|
||||
|
||||
t.Run("AfterCompactionSummaryIsUserRole", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
chat := newChat(t)
|
||||
|
||||
// After compaction the summary must appear as role=user so
|
||||
// that LLM APIs (e.g. Anthropic) see at least one
|
||||
// non-system message in the prompt.
|
||||
insertMsg(t, chat.ID, "system", database.ChatMessageVisibilityModel, false, "system prompt")
|
||||
summary := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityModel, true, "summary text")
|
||||
newUsr := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityBoth, false, "new question")
|
||||
|
||||
got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
hasNonSystem := false
|
||||
for _, m := range got {
|
||||
if m.Role != "system" {
|
||||
hasNonSystem = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasNonSystem,
|
||||
"prompt must contain at least one non-system message after compaction")
|
||||
require.Contains(t, msgIDs(got), summary.ID)
|
||||
require.Contains(t, msgIDs(got), newUsr.ID)
|
||||
})
|
||||
|
||||
t.Run("CompressedToolResultNotPickedAsSummary", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
chat := newChat(t)
|
||||
|
||||
// The CTE uses visibility='model' (exact match). If it
|
||||
// used IN ('model','both'), the compressed tool result
|
||||
// (visibility=both) would be picked as the "summary"
|
||||
// instead of the actual summary.
|
||||
insertMsg(t, chat.ID, "system", database.ChatMessageVisibilityModel, false, "system prompt")
|
||||
summary := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityModel, true, "real summary")
|
||||
compressedTool := insertMsg(t, chat.ID, "tool", database.ChatMessageVisibilityBoth, true, "tool result")
|
||||
postUser := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityBoth, false, "follow-up")
|
||||
|
||||
got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
gotIDs := msgIDs(got)
|
||||
require.Contains(t, gotIDs, summary.ID, "real summary must be included")
|
||||
require.NotContains(t, gotIDs, compressedTool.ID,
|
||||
"compressed tool result must not be included")
|
||||
require.Contains(t, gotIDs, postUser.ID)
|
||||
})
|
||||
}
|
||||
|
||||
+117
-10
@@ -2214,6 +2214,103 @@ func (q *sqlQuerier) UpsertBoundaryUsageStats(ctx context.Context, arg UpsertBou
|
||||
return new_period, err
|
||||
}
|
||||
|
||||
const getChatFileByID = `-- name: GetChatFileByID :one
|
||||
SELECT id, owner_id, organization_id, created_at, name, mimetype, data FROM chat_files WHERE id = $1::uuid
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetChatFileByID(ctx context.Context, id uuid.UUID) (ChatFile, error) {
|
||||
row := q.db.QueryRowContext(ctx, getChatFileByID, id)
|
||||
var i ChatFile
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.OwnerID,
|
||||
&i.OrganizationID,
|
||||
&i.CreatedAt,
|
||||
&i.Name,
|
||||
&i.Mimetype,
|
||||
&i.Data,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getChatFilesByIDs = `-- name: GetChatFilesByIDs :many
|
||||
SELECT id, owner_id, organization_id, created_at, name, mimetype, data FROM chat_files WHERE id = ANY($1::uuid[])
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]ChatFile, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getChatFilesByIDs, pq.Array(ids))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []ChatFile
|
||||
for rows.Next() {
|
||||
var i ChatFile
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.OwnerID,
|
||||
&i.OrganizationID,
|
||||
&i.CreatedAt,
|
||||
&i.Name,
|
||||
&i.Mimetype,
|
||||
&i.Data,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const insertChatFile = `-- name: InsertChatFile :one
|
||||
INSERT INTO chat_files (owner_id, organization_id, name, mimetype, data)
|
||||
VALUES ($1::uuid, $2::uuid, $3::text, $4::text, $5::bytea)
|
||||
RETURNING id, owner_id, organization_id, created_at, name, mimetype
|
||||
`
|
||||
|
||||
type InsertChatFileParams struct {
|
||||
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
|
||||
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Mimetype string `db:"mimetype" json:"mimetype"`
|
||||
Data []byte `db:"data" json:"data"`
|
||||
}
|
||||
|
||||
type InsertChatFileRow struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
|
||||
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Mimetype string `db:"mimetype" json:"mimetype"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertChatFile(ctx context.Context, arg InsertChatFileParams) (InsertChatFileRow, error) {
|
||||
row := q.db.QueryRowContext(ctx, insertChatFile,
|
||||
arg.OwnerID,
|
||||
arg.OrganizationID,
|
||||
arg.Name,
|
||||
arg.Mimetype,
|
||||
arg.Data,
|
||||
)
|
||||
var i InsertChatFileRow
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.OwnerID,
|
||||
&i.OrganizationID,
|
||||
&i.CreatedAt,
|
||||
&i.Name,
|
||||
&i.Mimetype,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const deleteChatModelConfigByID = `-- name: DeleteChatModelConfigByID :exec
|
||||
UPDATE
|
||||
chat_model_configs
|
||||
@@ -3224,9 +3321,8 @@ WITH latest_compressed_summary AS (
|
||||
chat_messages
|
||||
WHERE
|
||||
chat_id = $1::uuid
|
||||
AND role = 'system'
|
||||
AND visibility IN ('model', 'both')
|
||||
AND compressed = TRUE
|
||||
AND visibility = 'model'
|
||||
ORDER BY
|
||||
created_at DESC,
|
||||
id DESC
|
||||
@@ -12558,7 +12654,7 @@ const getProvisionerJobsByIDsWithQueuePosition = `-- name: GetProvisionerJobsByI
|
||||
WITH filtered_provisioner_jobs AS (
|
||||
-- Step 1: Filter provisioner_jobs
|
||||
SELECT
|
||||
id, created_at
|
||||
id, created_at, tags
|
||||
FROM
|
||||
provisioner_jobs
|
||||
WHERE
|
||||
@@ -12573,21 +12669,32 @@ pending_jobs AS (
|
||||
WHERE
|
||||
job_status = 'pending'
|
||||
),
|
||||
online_provisioner_daemons AS (
|
||||
SELECT id, tags FROM provisioner_daemons pd
|
||||
WHERE pd.last_seen_at IS NOT NULL AND pd.last_seen_at >= (NOW() - ($2::bigint || ' ms')::interval)
|
||||
unique_daemon_tags AS (
|
||||
SELECT DISTINCT tags FROM provisioner_daemons pd
|
||||
WHERE pd.last_seen_at IS NOT NULL
|
||||
AND pd.last_seen_at >= (NOW() - ($2::bigint || ' ms')::interval)
|
||||
),
|
||||
relevant_daemon_tags AS (
|
||||
SELECT udt.tags
|
||||
FROM unique_daemon_tags udt
|
||||
WHERE EXISTS (
|
||||
SELECT 1 FROM filtered_provisioner_jobs fpj
|
||||
WHERE provisioner_tagset_contains(udt.tags, fpj.tags)
|
||||
)
|
||||
),
|
||||
ranked_jobs AS (
|
||||
-- Step 3: Rank only pending jobs based on provisioner availability
|
||||
SELECT
|
||||
pj.id,
|
||||
pj.created_at,
|
||||
ROW_NUMBER() OVER (PARTITION BY opd.id ORDER BY pj.initiator_id = 'c42fdf75-3097-471c-8c33-fb52454d81c0'::uuid ASC, pj.created_at ASC) AS queue_position,
|
||||
COUNT(*) OVER (PARTITION BY opd.id) AS queue_size
|
||||
ROW_NUMBER() OVER (PARTITION BY rdt.tags ORDER BY pj.initiator_id = 'c42fdf75-3097-471c-8c33-fb52454d81c0'::uuid ASC, pj.created_at ASC) AS queue_position,
|
||||
COUNT(*) OVER (PARTITION BY rdt.tags) AS queue_size
|
||||
FROM
|
||||
pending_jobs pj
|
||||
INNER JOIN online_provisioner_daemons opd
|
||||
ON provisioner_tagset_contains(opd.tags, pj.tags) -- Join only on the small pending set
|
||||
INNER JOIN
|
||||
relevant_daemon_tags rdt
|
||||
ON
|
||||
provisioner_tagset_contains(rdt.tags, pj.tags)
|
||||
),
|
||||
final_jobs AS (
|
||||
-- Step 4: Compute best queue position and max queue size per job
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
-- name: InsertChatFile :one
|
||||
INSERT INTO chat_files (owner_id, organization_id, name, mimetype, data)
|
||||
VALUES (@owner_id::uuid, @organization_id::uuid, @name::text, @mimetype::text, @data::bytea)
|
||||
RETURNING id, owner_id, organization_id, created_at, name, mimetype;
|
||||
|
||||
-- name: GetChatFileByID :one
|
||||
SELECT * FROM chat_files WHERE id = @id::uuid;
|
||||
|
||||
-- name: GetChatFilesByIDs :many
|
||||
SELECT * FROM chat_files WHERE id = ANY(@ids::uuid[]);
|
||||
@@ -54,9 +54,8 @@ WITH latest_compressed_summary AS (
|
||||
chat_messages
|
||||
WHERE
|
||||
chat_id = @chat_id::uuid
|
||||
AND role = 'system'
|
||||
AND visibility IN ('model', 'both')
|
||||
AND compressed = TRUE
|
||||
AND visibility = 'model'
|
||||
ORDER BY
|
||||
created_at DESC,
|
||||
id DESC
|
||||
|
||||
@@ -79,7 +79,7 @@ WHERE
|
||||
WITH filtered_provisioner_jobs AS (
|
||||
-- Step 1: Filter provisioner_jobs
|
||||
SELECT
|
||||
id, created_at
|
||||
id, created_at, tags
|
||||
FROM
|
||||
provisioner_jobs
|
||||
WHERE
|
||||
@@ -94,21 +94,32 @@ pending_jobs AS (
|
||||
WHERE
|
||||
job_status = 'pending'
|
||||
),
|
||||
online_provisioner_daemons AS (
|
||||
SELECT id, tags FROM provisioner_daemons pd
|
||||
WHERE pd.last_seen_at IS NOT NULL AND pd.last_seen_at >= (NOW() - (@stale_interval_ms::bigint || ' ms')::interval)
|
||||
unique_daemon_tags AS (
|
||||
SELECT DISTINCT tags FROM provisioner_daemons pd
|
||||
WHERE pd.last_seen_at IS NOT NULL
|
||||
AND pd.last_seen_at >= (NOW() - (@stale_interval_ms::bigint || ' ms')::interval)
|
||||
),
|
||||
relevant_daemon_tags AS (
|
||||
SELECT udt.tags
|
||||
FROM unique_daemon_tags udt
|
||||
WHERE EXISTS (
|
||||
SELECT 1 FROM filtered_provisioner_jobs fpj
|
||||
WHERE provisioner_tagset_contains(udt.tags, fpj.tags)
|
||||
)
|
||||
),
|
||||
ranked_jobs AS (
|
||||
-- Step 3: Rank only pending jobs based on provisioner availability
|
||||
SELECT
|
||||
pj.id,
|
||||
pj.created_at,
|
||||
ROW_NUMBER() OVER (PARTITION BY opd.id ORDER BY pj.initiator_id = 'c42fdf75-3097-471c-8c33-fb52454d81c0'::uuid ASC, pj.created_at ASC) AS queue_position,
|
||||
COUNT(*) OVER (PARTITION BY opd.id) AS queue_size
|
||||
ROW_NUMBER() OVER (PARTITION BY rdt.tags ORDER BY pj.initiator_id = 'c42fdf75-3097-471c-8c33-fb52454d81c0'::uuid ASC, pj.created_at ASC) AS queue_position,
|
||||
COUNT(*) OVER (PARTITION BY rdt.tags) AS queue_size
|
||||
FROM
|
||||
pending_jobs pj
|
||||
INNER JOIN online_provisioner_daemons opd
|
||||
ON provisioner_tagset_contains(opd.tags, pj.tags) -- Join only on the small pending set
|
||||
INNER JOIN
|
||||
relevant_daemon_tags rdt
|
||||
ON
|
||||
provisioner_tagset_contains(rdt.tags, pj.tags)
|
||||
),
|
||||
final_jobs AS (
|
||||
-- Step 4: Compute best queue position and max queue size per job
|
||||
|
||||
@@ -15,6 +15,7 @@ const (
|
||||
UniqueAuditLogsPkey UniqueConstraint = "audit_logs_pkey" // ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id);
|
||||
UniqueBoundaryUsageStatsPkey UniqueConstraint = "boundary_usage_stats_pkey" // ALTER TABLE ONLY boundary_usage_stats ADD CONSTRAINT boundary_usage_stats_pkey PRIMARY KEY (replica_id);
|
||||
UniqueChatDiffStatusesPkey UniqueConstraint = "chat_diff_statuses_pkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
|
||||
UniqueChatFilesPkey UniqueConstraint = "chat_files_pkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id);
|
||||
UniqueChatMessagesPkey UniqueConstraint = "chat_messages_pkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_pkey PRIMARY KEY (id);
|
||||
UniqueChatModelConfigsPkey UniqueConstraint = "chat_model_configs_pkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_pkey PRIMARY KEY (id);
|
||||
UniqueChatProvidersPkey UniqueConstraint = "chat_providers_pkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_pkey PRIMARY KEY (id);
|
||||
|
||||
+391
-194
@@ -30,7 +30,57 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
type apiKeyContextKey struct{}
|
||||
type (
|
||||
apiKeyContextKey struct{}
|
||||
apiKeyPrecheckedContextKey struct{}
|
||||
)
|
||||
|
||||
// ValidateAPIKeyConfig holds the settings needed for API key
|
||||
// validation at the top of the request lifecycle. Unlike
|
||||
// ExtractAPIKeyConfig it omits route-specific fields
|
||||
// (RedirectToLogin, Optional, ActivateDormantUser, etc.).
|
||||
type ValidateAPIKeyConfig struct {
|
||||
DB database.Store
|
||||
OAuth2Configs *OAuth2Configs
|
||||
DisableSessionExpiryRefresh bool
|
||||
// SessionTokenFunc overrides how the API token is extracted
|
||||
// from the request. Nil uses the default (cookie/header).
|
||||
SessionTokenFunc func(*http.Request) string
|
||||
Logger slog.Logger
|
||||
}
|
||||
|
||||
// ValidateAPIKeyResult is the outcome of successful validation.
|
||||
type ValidateAPIKeyResult struct {
|
||||
Key database.APIKey
|
||||
Subject rbac.Subject
|
||||
UserStatus database.UserStatus
|
||||
}
|
||||
|
||||
// ValidateAPIKeyError represents a validation failure with enough
|
||||
// context for downstream middlewares to decide how to respond.
|
||||
type ValidateAPIKeyError struct {
|
||||
Code int
|
||||
Response codersdk.Response
|
||||
// Hard is true for server errors and active failures (5xx,
|
||||
// OAuth refresh failures) that must be surfaced even on
|
||||
// optional-auth routes. Soft errors (missing/expired token)
|
||||
// may be swallowed on optional routes.
|
||||
Hard bool
|
||||
}
|
||||
|
||||
func (e *ValidateAPIKeyError) Error() string {
|
||||
return e.Response.Message
|
||||
}
|
||||
|
||||
// APIKeyPrechecked stores the result of top-level API key
|
||||
// validation performed by PrecheckAPIKey. It distinguishes
|
||||
// two states:
|
||||
// - Validation failed (including no token): Result == nil && Err != nil
|
||||
// - Validation passed: Result != nil && Err == nil
|
||||
type APIKeyPrechecked struct {
|
||||
Result *ValidateAPIKeyResult
|
||||
Err *ValidateAPIKeyError
|
||||
}
|
||||
|
||||
// APIKeyOptional may return an API key from the ExtractAPIKey handler.
|
||||
func APIKeyOptional(r *http.Request) (database.APIKey, bool) {
|
||||
@@ -149,6 +199,298 @@ func ExtractAPIKeyMW(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// PrecheckAPIKey extracts and fully validates the API key on every
|
||||
// request (if present) and stores the result in context. It never
|
||||
// writes error responses and always calls next.
|
||||
//
|
||||
// The rate limiter reads the stored result to key by user ID and
|
||||
// check the Owner bypass header. Downstream ExtractAPIKeyMW reads
|
||||
// it to avoid redundant DB lookups and validation.
|
||||
func PrecheckAPIKey(cfg ValidateAPIKeyConfig) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Already prechecked (shouldn't happen, but guard).
|
||||
if _, ok := ctx.Value(apiKeyPrecheckedContextKey{}).(APIKeyPrechecked); ok {
|
||||
next.ServeHTTP(rw, r)
|
||||
return
|
||||
}
|
||||
|
||||
result, valErr := ValidateAPIKey(ctx, cfg, r)
|
||||
|
||||
prechecked := APIKeyPrechecked{
|
||||
Result: result,
|
||||
Err: valErr,
|
||||
}
|
||||
ctx = context.WithValue(ctx, apiKeyPrecheckedContextKey{}, prechecked)
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateAPIKey extracts and validates the API key from the
|
||||
// request. It performs all security-critical checks:
|
||||
// - Token extraction and parsing
|
||||
// - Database lookup + secret hash validation
|
||||
// - Expiry check
|
||||
// - OIDC/OAuth token refresh (if applicable)
|
||||
// - API key LastUsed / ExpiresAt DB updates
|
||||
// - User role lookup (UserRBACSubject)
|
||||
//
|
||||
// It does NOT:
|
||||
// - Write HTTP error responses
|
||||
// - Activate dormant users (route-specific)
|
||||
// - Redirect to login (route-specific)
|
||||
// - Check OAuth2 audience (route-specific, depends on AccessURL)
|
||||
// - Set PostAuth headers (route-specific)
|
||||
// - Check user active status (route-specific, depends on dormant activation)
|
||||
//
|
||||
// Returns (result, nil) on success or (nil, error) on failure.
|
||||
func ValidateAPIKey(ctx context.Context, cfg ValidateAPIKeyConfig, r *http.Request) (*ValidateAPIKeyResult, *ValidateAPIKeyError) {
|
||||
key, resp, ok := APIKeyFromRequest(ctx, cfg.DB, cfg.SessionTokenFunc, r)
|
||||
if !ok {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: resp,
|
||||
}
|
||||
}
|
||||
|
||||
// Log the API key ID for all requests that have a valid key
|
||||
// format and secret, regardless of whether subsequent validation
|
||||
// (expiry, user status, etc.) succeeds.
|
||||
if rl := loggermw.RequestLoggerFromContext(ctx); rl != nil {
|
||||
rl.WithFields(slog.F("api_key_id", key.ID))
|
||||
}
|
||||
|
||||
now := dbtime.Now()
|
||||
if key.ExpiresAt.Before(now) {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Refresh OIDC/GitHub tokens if applicable.
|
||||
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC {
|
||||
//nolint:gocritic // System needs to fetch UserLink to check if it's valid.
|
||||
link, err := cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: key.UserID,
|
||||
LoginType: key.LoginType,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: "You must re-authenticate with the login provider.",
|
||||
},
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusInternalServerError,
|
||||
Response: codersdk.Response{
|
||||
Message: "A database error occurred",
|
||||
Detail: fmt.Sprintf("get user link by user ID and login type: %s", err.Error()),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
// Check if the OAuth token is expired.
|
||||
if !link.OAuthExpiry.IsZero() && link.OAuthExpiry.Before(now) {
|
||||
if cfg.OAuth2Configs.IsZero() {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusInternalServerError,
|
||||
Response: codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Unable to refresh OAuth token for login type %q. "+
|
||||
"No OAuth2Configs provided. Contact an administrator to configure this login type.", key.LoginType),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
|
||||
var friendlyName string
|
||||
var oauthConfig promoauth.OAuth2Config
|
||||
switch key.LoginType {
|
||||
case database.LoginTypeGithub:
|
||||
oauthConfig = cfg.OAuth2Configs.Github
|
||||
friendlyName = "GitHub"
|
||||
case database.LoginTypeOIDC:
|
||||
oauthConfig = cfg.OAuth2Configs.OIDC
|
||||
friendlyName = "OpenID Connect"
|
||||
default:
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusInternalServerError,
|
||||
Response: codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Unexpected authentication type %q.", key.LoginType),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
|
||||
if oauthConfig == nil {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusInternalServerError,
|
||||
Response: codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Unable to refresh OAuth token for login type %q. "+
|
||||
"OAuth2Config not provided. Contact an administrator to configure this login type.", key.LoginType),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Soft error: session expired naturally with no
|
||||
// refresh token. Optional-auth routes treat this as
|
||||
// unauthenticated.
|
||||
if link.OAuthRefreshToken == "" {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: fmt.Sprintf("%s session expired at %q. Try signing in again.", friendlyName, link.OAuthExpiry.String()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// We have a refresh token, so let's try it.
|
||||
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
|
||||
AccessToken: link.OAuthAccessToken,
|
||||
RefreshToken: link.OAuthRefreshToken,
|
||||
Expiry: link.OAuthExpiry,
|
||||
}).Token()
|
||||
// Hard error: we actively tried to refresh and the
|
||||
// provider rejected it — surface even on optional-auth
|
||||
// routes.
|
||||
if err != nil {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: codersdk.Response{
|
||||
Message: fmt.Sprintf(
|
||||
"Could not refresh expired %s token. Try re-authenticating to resolve this issue.",
|
||||
friendlyName),
|
||||
Detail: err.Error(),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
link.OAuthAccessToken = token.AccessToken
|
||||
link.OAuthRefreshToken = token.RefreshToken
|
||||
link.OAuthExpiry = token.Expiry
|
||||
//nolint:gocritic // system needs to update user link
|
||||
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
|
||||
UserID: link.UserID,
|
||||
LoginType: link.LoginType,
|
||||
OAuthAccessToken: link.OAuthAccessToken,
|
||||
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
OAuthRefreshToken: link.OAuthRefreshToken,
|
||||
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
OAuthExpiry: link.OAuthExpiry,
|
||||
// Refresh should keep the same debug context because we use
|
||||
// the original claims for the group/role sync.
|
||||
Claims: link.Claims,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusInternalServerError,
|
||||
Response: codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update LastUsed and session expiry.
|
||||
changed := false
|
||||
if now.Sub(key.LastUsed) > time.Hour {
|
||||
key.LastUsed = now
|
||||
remoteIP := net.ParseIP(r.RemoteAddr)
|
||||
if remoteIP == nil {
|
||||
remoteIP = net.IPv4(0, 0, 0, 0)
|
||||
}
|
||||
bitlen := len(remoteIP) * 8
|
||||
key.IPAddress = pqtype.Inet{
|
||||
IPNet: net.IPNet{
|
||||
IP: remoteIP,
|
||||
Mask: net.CIDRMask(bitlen, bitlen),
|
||||
},
|
||||
Valid: true,
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
if !cfg.DisableSessionExpiryRefresh {
|
||||
apiKeyLifetime := time.Duration(key.LifetimeSeconds) * time.Second
|
||||
if key.ExpiresAt.Sub(now) <= apiKeyLifetime-time.Hour {
|
||||
key.ExpiresAt = now.Add(apiKeyLifetime)
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if changed {
|
||||
//nolint:gocritic // System needs to update API Key LastUsed
|
||||
err := cfg.DB.UpdateAPIKeyByID(dbauthz.AsSystemRestricted(ctx), database.UpdateAPIKeyByIDParams{
|
||||
ID: key.ID,
|
||||
LastUsed: key.LastUsed,
|
||||
ExpiresAt: key.ExpiresAt,
|
||||
IPAddress: key.IPAddress,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusInternalServerError,
|
||||
Response: codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("API key couldn't update: %s.", err.Error()),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:gocritic // system needs to update user last seen at
|
||||
_, err = cfg.DB.UpdateUserLastSeenAt(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLastSeenAtParams{
|
||||
ID: key.UserID,
|
||||
LastSeenAt: dbtime.Now(),
|
||||
UpdatedAt: dbtime.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusInternalServerError,
|
||||
Response: codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("update user last_seen_at: %s", err.Error()),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch user roles.
|
||||
actor, userStatus, err := UserRBACSubject(ctx, cfg.DB, key.UserID, key.ScopeSet())
|
||||
if err != nil {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Internal error fetching user's roles. %s", err.Error()),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
|
||||
return &ValidateAPIKeyResult{
|
||||
Key: *key,
|
||||
Subject: actor,
|
||||
UserStatus: userStatus,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func APIKeyFromRequest(ctx context.Context, db database.Store, sessionTokenFunc func(r *http.Request) string, r *http.Request) (*database.APIKey, codersdk.Response, bool) {
|
||||
tokenFunc := APITokenFromRequest
|
||||
if sessionTokenFunc != nil {
|
||||
@@ -240,29 +582,60 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
key, resp, ok := APIKeyFromRequest(ctx, cfg.DB, cfg.SessionTokenFunc, r)
|
||||
if !ok {
|
||||
return optionalWrite(http.StatusUnauthorized, resp)
|
||||
// --- Consume prechecked result if available ---
|
||||
// Skip prechecked data when cfg has a custom SessionTokenFunc,
|
||||
// because the precheck used the default token extraction and may
|
||||
// have validated a different token (e.g. workspace app token
|
||||
// issuance in workspaceapps/db.go).
|
||||
var key *database.APIKey
|
||||
var actor rbac.Subject
|
||||
var userStatus database.UserStatus
|
||||
var skipValidation bool
|
||||
|
||||
if cfg.SessionTokenFunc == nil {
|
||||
if pc, ok := ctx.Value(apiKeyPrecheckedContextKey{}).(APIKeyPrechecked); ok {
|
||||
if pc.Err != nil {
|
||||
// Validation failed at the top level (includes
|
||||
// "no token provided").
|
||||
if pc.Err.Hard {
|
||||
return write(pc.Err.Code, pc.Err.Response)
|
||||
}
|
||||
return optionalWrite(pc.Err.Code, pc.Err.Response)
|
||||
}
|
||||
// Valid — use prechecked data, skip to route-specific logic.
|
||||
key = &pc.Result.Key
|
||||
actor = pc.Result.Subject
|
||||
userStatus = pc.Result.UserStatus
|
||||
skipValidation = true
|
||||
}
|
||||
}
|
||||
|
||||
// Log the API key ID for all requests that have a valid key format and secret,
|
||||
// regardless of whether subsequent validation (expiry, user status, etc.) succeeds.
|
||||
if rl := loggermw.RequestLoggerFromContext(ctx); rl != nil {
|
||||
rl.WithFields(slog.F("api_key_id", key.ID))
|
||||
if !skipValidation {
|
||||
// Full validation path (no prechecked result or custom token func).
|
||||
result, valErr := ValidateAPIKey(ctx, ValidateAPIKeyConfig{
|
||||
DB: cfg.DB,
|
||||
OAuth2Configs: cfg.OAuth2Configs,
|
||||
DisableSessionExpiryRefresh: cfg.DisableSessionExpiryRefresh,
|
||||
SessionTokenFunc: cfg.SessionTokenFunc,
|
||||
Logger: cfg.Logger,
|
||||
}, r)
|
||||
if valErr != nil {
|
||||
if valErr.Hard {
|
||||
return write(valErr.Code, valErr.Response)
|
||||
}
|
||||
return optionalWrite(valErr.Code, valErr.Response)
|
||||
}
|
||||
key = &result.Key
|
||||
actor = result.Subject
|
||||
userStatus = result.UserStatus
|
||||
}
|
||||
|
||||
now := dbtime.Now()
|
||||
if key.ExpiresAt.Before(now) {
|
||||
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
|
||||
})
|
||||
}
|
||||
// --- Route-specific logic (always runs) ---
|
||||
|
||||
// Validate OAuth2 provider app token audience (RFC 8707) if applicable
|
||||
// Validate OAuth2 provider app token audience (RFC 8707) if applicable.
|
||||
if key.LoginType == database.LoginTypeOAuth2ProviderApp {
|
||||
if err := validateOAuth2ProviderAppTokenAudience(ctx, cfg.DB, *key, cfg.AccessURL, r); err != nil {
|
||||
// Log the detailed error for debugging but don't expose it to the client
|
||||
// Log the detailed error for debugging but don't expose it to the client.
|
||||
cfg.Logger.Debug(ctx, "oauth2 token audience validation failed", slog.Error(err))
|
||||
return optionalWrite(http.StatusForbidden, codersdk.Response{
|
||||
Message: "Token audience validation failed",
|
||||
@@ -270,183 +643,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
|
||||
}
|
||||
}
|
||||
|
||||
// We only check OIDC stuff if we have a valid APIKey. An expired key means we don't trust the requestor
|
||||
// really is the user whose key they have, and so we shouldn't be doing anything on their behalf including possibly
|
||||
// refreshing the OIDC token.
|
||||
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC {
|
||||
var err error
|
||||
//nolint:gocritic // System needs to fetch UserLink to check if it's valid.
|
||||
link, err := cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: key.UserID,
|
||||
LoginType: key.LoginType,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: "You must re-authenticate with the login provider.",
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "A database error occurred",
|
||||
Detail: fmt.Sprintf("get user link by user ID and login type: %s", err.Error()),
|
||||
})
|
||||
}
|
||||
// Check if the OAuth token is expired
|
||||
if !link.OAuthExpiry.IsZero() && link.OAuthExpiry.Before(now) {
|
||||
if cfg.OAuth2Configs.IsZero() {
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Unable to refresh OAuth token for login type %q. "+
|
||||
"No OAuth2Configs provided. Contact an administrator to configure this login type.", key.LoginType),
|
||||
})
|
||||
}
|
||||
|
||||
var friendlyName string
|
||||
var oauthConfig promoauth.OAuth2Config
|
||||
switch key.LoginType {
|
||||
case database.LoginTypeGithub:
|
||||
oauthConfig = cfg.OAuth2Configs.Github
|
||||
friendlyName = "GitHub"
|
||||
case database.LoginTypeOIDC:
|
||||
oauthConfig = cfg.OAuth2Configs.OIDC
|
||||
friendlyName = "OpenID Connect"
|
||||
default:
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Unexpected authentication type %q.", key.LoginType),
|
||||
})
|
||||
}
|
||||
|
||||
// It's possible for cfg.OAuth2Configs to be non-nil, but still
|
||||
// missing this type. For example, if a user logged in with GitHub,
|
||||
// but the administrator later removed GitHub and replaced it with
|
||||
// OIDC.
|
||||
if oauthConfig == nil {
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Unable to refresh OAuth token for login type %q. "+
|
||||
"OAuth2Config not provided. Contact an administrator to configure this login type.", key.LoginType),
|
||||
})
|
||||
}
|
||||
|
||||
if link.OAuthRefreshToken == "" {
|
||||
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: fmt.Sprintf("%s session expired at %q. Try signing in again.", friendlyName, link.OAuthExpiry.String()),
|
||||
})
|
||||
}
|
||||
// We have a refresh token, so let's try it
|
||||
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
|
||||
AccessToken: link.OAuthAccessToken,
|
||||
RefreshToken: link.OAuthRefreshToken,
|
||||
Expiry: link.OAuthExpiry,
|
||||
}).Token()
|
||||
if err != nil {
|
||||
return write(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: fmt.Sprintf(
|
||||
"Could not refresh expired %s token. Try re-authenticating to resolve this issue.",
|
||||
friendlyName),
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
link.OAuthAccessToken = token.AccessToken
|
||||
link.OAuthRefreshToken = token.RefreshToken
|
||||
link.OAuthExpiry = token.Expiry
|
||||
//nolint:gocritic // system needs to update user link
|
||||
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
|
||||
UserID: link.UserID,
|
||||
LoginType: link.LoginType,
|
||||
OAuthAccessToken: link.OAuthAccessToken,
|
||||
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
OAuthRefreshToken: link.OAuthRefreshToken,
|
||||
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
OAuthExpiry: link.OAuthExpiry,
|
||||
// Refresh should keep the same debug context because we use
|
||||
// the original claims for the group/role sync.
|
||||
Claims: link.Claims,
|
||||
})
|
||||
if err != nil {
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tracks if the API key has properties updated
|
||||
changed := false
|
||||
|
||||
// Only update LastUsed once an hour to prevent database spam.
|
||||
if now.Sub(key.LastUsed) > time.Hour {
|
||||
key.LastUsed = now
|
||||
remoteIP := net.ParseIP(r.RemoteAddr)
|
||||
if remoteIP == nil {
|
||||
remoteIP = net.IPv4(0, 0, 0, 0)
|
||||
}
|
||||
bitlen := len(remoteIP) * 8
|
||||
key.IPAddress = pqtype.Inet{
|
||||
IPNet: net.IPNet{
|
||||
IP: remoteIP,
|
||||
Mask: net.CIDRMask(bitlen, bitlen),
|
||||
},
|
||||
Valid: true,
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
// Only update the ExpiresAt once an hour to prevent database spam.
|
||||
// We extend the ExpiresAt to reduce re-authentication.
|
||||
if !cfg.DisableSessionExpiryRefresh {
|
||||
apiKeyLifetime := time.Duration(key.LifetimeSeconds) * time.Second
|
||||
if key.ExpiresAt.Sub(now) <= apiKeyLifetime-time.Hour {
|
||||
key.ExpiresAt = now.Add(apiKeyLifetime)
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if changed {
|
||||
//nolint:gocritic // System needs to update API Key LastUsed
|
||||
err := cfg.DB.UpdateAPIKeyByID(dbauthz.AsSystemRestricted(ctx), database.UpdateAPIKeyByIDParams{
|
||||
ID: key.ID,
|
||||
LastUsed: key.LastUsed,
|
||||
ExpiresAt: key.ExpiresAt,
|
||||
IPAddress: key.IPAddress,
|
||||
})
|
||||
if err != nil {
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("API key couldn't update: %s.", err.Error()),
|
||||
})
|
||||
}
|
||||
|
||||
// We only want to update this occasionally to reduce DB write
|
||||
// load. We update alongside the UserLink and APIKey since it's
|
||||
// easier on the DB to colocate writes.
|
||||
//nolint:gocritic // system needs to update user last seen at
|
||||
_, err = cfg.DB.UpdateUserLastSeenAt(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLastSeenAtParams{
|
||||
ID: key.UserID,
|
||||
LastSeenAt: dbtime.Now(),
|
||||
UpdatedAt: dbtime.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("update user last_seen_at: %s", err.Error()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// If the key is valid, we also fetch the user roles and status.
|
||||
// The roles are used for RBAC authorize checks, and the status
|
||||
// is to block 'suspended' users from accessing the platform.
|
||||
actor, userStatus, err := UserRBACSubject(ctx, cfg.DB, key.UserID, key.ScopeSet())
|
||||
if err != nil {
|
||||
return write(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Internal error fetching user's roles. %s", err.Error()),
|
||||
})
|
||||
}
|
||||
|
||||
// Dormant activation (config-dependent).
|
||||
if userStatus == database.UserStatusDormant && cfg.ActivateDormantUser != nil {
|
||||
id, _ := uuid.Parse(actor.ID)
|
||||
user, err := cfg.ActivateDormantUser(ctx, database.User{
|
||||
|
||||
+36
-15
@@ -32,35 +32,56 @@ func RateLimit(count int, window time.Duration) func(http.Handler) http.Handler
|
||||
count,
|
||||
window,
|
||||
httprate.WithKeyFuncs(func(r *http.Request) (string, error) {
|
||||
// Prioritize by user, but fallback to IP.
|
||||
apiKey, ok := r.Context().Value(apiKeyContextKey{}).(database.APIKey)
|
||||
if !ok {
|
||||
// Identify the caller. We check two sources:
|
||||
//
|
||||
// 1. apiKeyPrecheckedContextKey — set by PrecheckAPIKey
|
||||
// at the root of the router. Only fully validated
|
||||
// keys are used.
|
||||
// 2. apiKeyContextKey — set by ExtractAPIKeyMW if it
|
||||
// has already run (e.g. unit tests, workspace-app
|
||||
// routes that don't go through PrecheckAPIKey).
|
||||
//
|
||||
// If neither is present, fall back to IP.
|
||||
var userID string
|
||||
var subject *rbac.Subject
|
||||
|
||||
if pc, ok := r.Context().Value(apiKeyPrecheckedContextKey{}).(APIKeyPrechecked); ok && pc.Result != nil {
|
||||
userID = pc.Result.Key.UserID.String()
|
||||
subject = &pc.Result.Subject
|
||||
} else if ak, ok := r.Context().Value(apiKeyContextKey{}).(database.APIKey); ok {
|
||||
userID = ak.UserID.String()
|
||||
if auth, ok := UserAuthorizationOptional(r.Context()); ok {
|
||||
subject = &auth
|
||||
}
|
||||
} else {
|
||||
return httprate.KeyByIP(r)
|
||||
}
|
||||
|
||||
if ok, _ := strconv.ParseBool(r.Header.Get(codersdk.BypassRatelimitHeader)); !ok {
|
||||
// No bypass attempt, just ratelimit.
|
||||
return apiKey.UserID.String(), nil
|
||||
// No bypass attempt, just rate limit by user.
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
// Allow Owner to bypass rate limiting for load tests
|
||||
// and automation.
|
||||
auth := UserAuthorization(r.Context())
|
||||
|
||||
// We avoid using rbac.Authorizer since rego is CPU-intensive
|
||||
// and undermines the DoS-prevention goal of the rate limiter.
|
||||
for _, role := range auth.SafeRoleNames() {
|
||||
// and automation. We avoid using rbac.Authorizer since
|
||||
// rego is CPU-intensive and undermines the
|
||||
// DoS-prevention goal of the rate limiter.
|
||||
if subject == nil {
|
||||
// Can't verify roles — rate limit normally.
|
||||
return userID, nil
|
||||
}
|
||||
for _, role := range subject.SafeRoleNames() {
|
||||
if role == rbac.RoleOwner() {
|
||||
// HACK: use a random key each time to
|
||||
// de facto disable rate limiting. The
|
||||
// `httprate` package has no
|
||||
// support for selectively changing the limit
|
||||
// for particular keys.
|
||||
// httprate package has no support for
|
||||
// selectively changing the limit for
|
||||
// particular keys.
|
||||
return cryptorand.String(16)
|
||||
}
|
||||
}
|
||||
|
||||
return apiKey.UserID.String(), xerrors.Errorf(
|
||||
return userID, xerrors.Errorf(
|
||||
"%q provided but user is not %v",
|
||||
codersdk.BypassRatelimitHeader, rbac.RoleOwner(),
|
||||
)
|
||||
|
||||
@@ -34,4 +34,9 @@ type ChatStreamNotifyMessage struct {
|
||||
|
||||
// QueueUpdate is set when the queued messages change.
|
||||
QueueUpdate bool `json:"queue_update,omitempty"`
|
||||
|
||||
// FullRefresh signals that subscribers should re-fetch all
|
||||
// messages from the beginning (e.g. after an edit that
|
||||
// truncates message history).
|
||||
FullRefresh bool `json:"full_refresh,omitempty"`
|
||||
}
|
||||
|
||||
@@ -297,6 +297,15 @@ func NewStrictCachingAuthorizer(registry prometheus.Registerer) Authorizer {
|
||||
return Cacher(auth)
|
||||
}
|
||||
|
||||
// NewStrictAuthorizer is for testing only. It skips the caching layer,
|
||||
// which is useful when every authorize call is unique (0% cache hit
|
||||
// rate) and the cache overhead dominates.
|
||||
func NewStrictAuthorizer(registry prometheus.Registerer) Authorizer {
|
||||
auth := NewAuthorizer(registry)
|
||||
auth.strict = true
|
||||
return auth
|
||||
}
|
||||
|
||||
func NewAuthorizer(registry prometheus.Registerer) *RegoAuthorizer {
|
||||
queryOnce.Do(func() {
|
||||
var err error
|
||||
|
||||
+15
-15
@@ -156,7 +156,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
|
||||
crud := []policy.Action{policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete}
|
||||
|
||||
auth := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
|
||||
auth := rbac.NewStrictAuthorizer(prometheus.NewRegistry())
|
||||
|
||||
// currentUser is anything that references "me", "mine", or "my".
|
||||
currentUser := uuid.New()
|
||||
@@ -173,24 +173,24 @@ func TestRolePermissions(t *testing.T) {
|
||||
apiKeyID := uuid.New()
|
||||
|
||||
// Subjects to user
|
||||
memberMe := authSubject{Name: "member_me", Actor: rbac.Subject{ID: currentUser.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember()}}}
|
||||
memberMe := authSubject{Name: "member_me", Actor: rbac.Subject{ID: currentUser.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember()}, Scope: rbac.ScopeAll}.WithCachedASTValue()}
|
||||
|
||||
owner := authSubject{Name: "owner", Actor: rbac.Subject{ID: adminID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.RoleOwner()}}}
|
||||
templateAdmin := authSubject{Name: "template-admin", Actor: rbac.Subject{ID: templateAdminID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.RoleTemplateAdmin()}}}
|
||||
userAdmin := authSubject{Name: "user-admin", Actor: rbac.Subject{ID: userAdminID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.RoleUserAdmin()}}}
|
||||
auditor := authSubject{Name: "auditor", Actor: rbac.Subject{ID: auditorID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.RoleAuditor()}}}
|
||||
owner := authSubject{Name: "owner", Actor: rbac.Subject{ID: adminID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.RoleOwner()}, Scope: rbac.ScopeAll}.WithCachedASTValue()}
|
||||
templateAdmin := authSubject{Name: "template-admin", Actor: rbac.Subject{ID: templateAdminID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.RoleTemplateAdmin()}, Scope: rbac.ScopeAll}.WithCachedASTValue()}
|
||||
userAdmin := authSubject{Name: "user-admin", Actor: rbac.Subject{ID: userAdminID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.RoleUserAdmin()}, Scope: rbac.ScopeAll}.WithCachedASTValue()}
|
||||
auditor := authSubject{Name: "auditor", Actor: rbac.Subject{ID: auditorID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.RoleAuditor()}, Scope: rbac.ScopeAll}.WithCachedASTValue()}
|
||||
|
||||
orgAdmin := authSubject{Name: "org_admin", Actor: rbac.Subject{ID: adminID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.ScopedRoleOrgAdmin(orgID)}}}
|
||||
orgAuditor := authSubject{Name: "org_auditor", Actor: rbac.Subject{ID: auditorID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.ScopedRoleOrgAuditor(orgID)}}}
|
||||
orgUserAdmin := authSubject{Name: "org_user_admin", Actor: rbac.Subject{ID: templateAdminID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.ScopedRoleOrgUserAdmin(orgID)}}}
|
||||
orgTemplateAdmin := authSubject{Name: "org_template_admin", Actor: rbac.Subject{ID: userAdminID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.ScopedRoleOrgTemplateAdmin(orgID)}}}
|
||||
orgAdminBanWorkspace := authSubject{Name: "org_admin_workspace_ban", Actor: rbac.Subject{ID: adminID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.ScopedRoleOrgAdmin(orgID), rbac.ScopedRoleOrgWorkspaceCreationBan(orgID)}}}
|
||||
orgAdmin := authSubject{Name: "org_admin", Actor: rbac.Subject{ID: adminID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.ScopedRoleOrgAdmin(orgID)}, Scope: rbac.ScopeAll}.WithCachedASTValue()}
|
||||
orgAuditor := authSubject{Name: "org_auditor", Actor: rbac.Subject{ID: auditorID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.ScopedRoleOrgAuditor(orgID)}, Scope: rbac.ScopeAll}.WithCachedASTValue()}
|
||||
orgUserAdmin := authSubject{Name: "org_user_admin", Actor: rbac.Subject{ID: templateAdminID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.ScopedRoleOrgUserAdmin(orgID)}, Scope: rbac.ScopeAll}.WithCachedASTValue()}
|
||||
orgTemplateAdmin := authSubject{Name: "org_template_admin", Actor: rbac.Subject{ID: userAdminID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.ScopedRoleOrgTemplateAdmin(orgID)}, Scope: rbac.ScopeAll}.WithCachedASTValue()}
|
||||
orgAdminBanWorkspace := authSubject{Name: "org_admin_workspace_ban", Actor: rbac.Subject{ID: adminID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.ScopedRoleOrgAdmin(orgID), rbac.ScopedRoleOrgWorkspaceCreationBan(orgID)}, Scope: rbac.ScopeAll}.WithCachedASTValue()}
|
||||
setOrgNotMe := authSubjectSet{orgAdmin, orgAuditor, orgUserAdmin, orgTemplateAdmin}
|
||||
|
||||
otherOrgAdmin := authSubject{Name: "org_admin_other", Actor: rbac.Subject{ID: uuid.NewString(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.ScopedRoleOrgAdmin(otherOrg)}}}
|
||||
otherOrgAuditor := authSubject{Name: "org_auditor_other", Actor: rbac.Subject{ID: adminID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.ScopedRoleOrgAuditor(otherOrg)}}}
|
||||
otherOrgUserAdmin := authSubject{Name: "org_user_admin_other", Actor: rbac.Subject{ID: adminID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.ScopedRoleOrgUserAdmin(otherOrg)}}}
|
||||
otherOrgTemplateAdmin := authSubject{Name: "org_template_admin_other", Actor: rbac.Subject{ID: adminID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.ScopedRoleOrgTemplateAdmin(otherOrg)}}}
|
||||
otherOrgAdmin := authSubject{Name: "org_admin_other", Actor: rbac.Subject{ID: uuid.NewString(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.ScopedRoleOrgAdmin(otherOrg)}, Scope: rbac.ScopeAll}.WithCachedASTValue()}
|
||||
otherOrgAuditor := authSubject{Name: "org_auditor_other", Actor: rbac.Subject{ID: adminID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.ScopedRoleOrgAuditor(otherOrg)}, Scope: rbac.ScopeAll}.WithCachedASTValue()}
|
||||
otherOrgUserAdmin := authSubject{Name: "org_user_admin_other", Actor: rbac.Subject{ID: adminID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.ScopedRoleOrgUserAdmin(otherOrg)}, Scope: rbac.ScopeAll}.WithCachedASTValue()}
|
||||
otherOrgTemplateAdmin := authSubject{Name: "org_template_admin_other", Actor: rbac.Subject{ID: adminID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.ScopedRoleOrgTemplateAdmin(otherOrg)}, Scope: rbac.ScopeAll}.WithCachedASTValue()}
|
||||
setOtherOrg := authSubjectSet{otherOrgAdmin, otherOrgAuditor, otherOrgUserAdmin, otherOrgTemplateAdmin}
|
||||
|
||||
// requiredSubjects are required to be asserted in each test case. This is
|
||||
|
||||
@@ -2170,7 +2170,7 @@ func (api *API) tailnetRPCConn(rw http.ResponseWriter, r *http.Request) {
|
||||
userID := apiKey.UserID.String()
|
||||
|
||||
// Store connection telemetry event
|
||||
now := time.Now()
|
||||
now := dbtime.Now()
|
||||
connectionTelemetryEvent := telemetry.UserTailnetConnection{
|
||||
ConnectedAt: now,
|
||||
DisconnectedAt: nil,
|
||||
@@ -2187,7 +2187,7 @@ func (api *API) tailnetRPCConn(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
defer func() {
|
||||
// Update telemetry event with disconnection time
|
||||
disconnectTime := time.Now()
|
||||
disconnectTime := dbtime.Now()
|
||||
connectionTelemetryEvent.DisconnectedAt = &disconnectTime
|
||||
api.Telemetry.Report(&telemetry.Snapshot{
|
||||
UserTailnetConnections: []telemetry.UserTailnetConnection{connectionTelemetryEvent},
|
||||
|
||||
@@ -3021,7 +3021,7 @@ func TestUserTailnetTelemetry(t *testing.T) {
|
||||
q.Set("version", "2.0")
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
predialTime := time.Now()
|
||||
predialTime := dbtime.Now()
|
||||
|
||||
//nolint:bodyclose // websocket package closes this for you
|
||||
wsConn, resp, err := websocket.Dial(ctx, u.String(), &websocket.DialOptions{
|
||||
|
||||
+70
-9
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -71,12 +72,13 @@ type ChatMessageUsage struct {
|
||||
type ChatMessagePartType string
|
||||
|
||||
const (
|
||||
ChatMessagePartTypeText ChatMessagePartType = "text"
|
||||
ChatMessagePartTypeReasoning ChatMessagePartType = "reasoning"
|
||||
ChatMessagePartTypeToolCall ChatMessagePartType = "tool-call"
|
||||
ChatMessagePartTypeToolResult ChatMessagePartType = "tool-result"
|
||||
ChatMessagePartTypeSource ChatMessagePartType = "source"
|
||||
ChatMessagePartTypeFile ChatMessagePartType = "file"
|
||||
ChatMessagePartTypeText ChatMessagePartType = "text"
|
||||
ChatMessagePartTypeReasoning ChatMessagePartType = "reasoning"
|
||||
ChatMessagePartTypeToolCall ChatMessagePartType = "tool-call"
|
||||
ChatMessagePartTypeToolResult ChatMessagePartType = "tool-result"
|
||||
ChatMessagePartTypeSource ChatMessagePartType = "source"
|
||||
ChatMessagePartTypeFile ChatMessagePartType = "file"
|
||||
ChatMessagePartTypeFileReference ChatMessagePartType = "file-reference"
|
||||
)
|
||||
|
||||
// ChatMessagePart is a structured chunk of a chat message.
|
||||
@@ -96,19 +98,37 @@ type ChatMessagePart struct {
|
||||
Title string `json:"title,omitempty"`
|
||||
MediaType string `json:"media_type,omitempty"`
|
||||
Data []byte `json:"data,omitempty"`
|
||||
FileID uuid.NullUUID `json:"file_id,omitempty" format:"uuid"`
|
||||
// The following fields are only set when Type is
|
||||
// ChatInputPartTypeFileReference.
|
||||
FileName string `json:"file_name,omitempty"`
|
||||
StartLine int `json:"start_line,omitempty"`
|
||||
EndLine int `json:"end_line,omitempty"`
|
||||
// The code content from the diff that was commented on.
|
||||
Content string `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
// ChatInputPartType represents an input part type for user chat input.
|
||||
type ChatInputPartType string
|
||||
|
||||
const (
|
||||
ChatInputPartTypeText ChatInputPartType = "text"
|
||||
ChatInputPartTypeText ChatInputPartType = "text"
|
||||
ChatInputPartTypeFile ChatInputPartType = "file"
|
||||
ChatInputPartTypeFileReference ChatInputPartType = "file-reference"
|
||||
)
|
||||
|
||||
// ChatInputPart is a single user input part for creating a chat.
|
||||
type ChatInputPart struct {
|
||||
Type ChatInputPartType `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Type ChatInputPartType `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
FileID uuid.UUID `json:"file_id,omitempty" format:"uuid"`
|
||||
// The following fields are only set when Type is
|
||||
// ChatInputPartTypeFileReference.
|
||||
FileName string `json:"file_name,omitempty"`
|
||||
StartLine int `json:"start_line,omitempty"`
|
||||
EndLine int `json:"end_line,omitempty"`
|
||||
// The code content from the diff that was commented on.
|
||||
Content string `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
// CreateChatRequest is the request to create a new chat.
|
||||
@@ -141,6 +161,11 @@ type CreateChatMessageResponse struct {
|
||||
Queued bool `json:"queued"`
|
||||
}
|
||||
|
||||
// UploadChatFileResponse is the response from uploading a chat file.
|
||||
type UploadChatFileResponse struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
}
|
||||
|
||||
// ChatWithMessages is a chat along with its messages.
|
||||
type ChatWithMessages struct {
|
||||
Chat Chat `json:"chat"`
|
||||
@@ -938,6 +963,42 @@ func (c *Client) GetChatDiffContents(ctx context.Context, chatID uuid.UUID) (Cha
|
||||
return diff, json.NewDecoder(res.Body).Decode(&diff)
|
||||
}
|
||||
|
||||
// UploadChatFile uploads a file for use in chat messages.
|
||||
func (c *Client) UploadChatFile(ctx context.Context, organizationID uuid.UUID, contentType string, filename string, rd io.Reader) (UploadChatFileResponse, error) {
|
||||
res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/experimental/chats/files?organization=%s", organizationID), rd, func(r *http.Request) {
|
||||
r.Header.Set("Content-Type", contentType)
|
||||
if filename != "" {
|
||||
r.Header.Set("Content-Disposition", mime.FormatMediaType("attachment", map[string]string{"filename": filename}))
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return UploadChatFileResponse{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusCreated {
|
||||
return UploadChatFileResponse{}, ReadBodyAsError(res)
|
||||
}
|
||||
var resp UploadChatFileResponse
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
// GetChatFile retrieves a previously uploaded chat file by ID.
|
||||
func (c *Client) GetChatFile(ctx context.Context, fileID uuid.UUID) ([]byte, string, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/files/%s", fileID), nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, "", ReadBodyAsError(res)
|
||||
}
|
||||
data, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return data, res.Header.Get("Content-Type"), nil
|
||||
}
|
||||
|
||||
func formatChatStreamResponseError(response Response) string {
|
||||
message := strings.TrimSpace(response.Message)
|
||||
detail := strings.TrimSpace(response.Detail)
|
||||
|
||||
@@ -28,7 +28,7 @@ UI.
|
||||
| Agent | Module | Min version | Support | Tracking | Session data paths | Min storage |
|
||||
|-----------------|----------------------------------------------------------------------------------|-------------|---------------|--------------------------------------------------------------|------------------------------------------------------|---------------------------|
|
||||
| Claude Code | [claude-code](https://registry.coder.com/modules/coder/claude-code) | >= 4.8.0 | Full | - | `~/.claude/` | 100 MB (can grow to GB) |
|
||||
| Codex | [codex](https://registry.coder.com/modules/coder-labs/codex) | - | Partial | [registry#740](https://github.com/coder/registry/issues/740) | `~/.codex/`, `~/.codex-module/` | 100 MB |
|
||||
| Codex | [codex](https://registry.coder.com/modules/coder-labs/codex) | >= 4.2.0 | Full | - | `~/.codex/`, `~/.codex-module/` | 100 MB |
|
||||
| Copilot | [copilot](https://registry.coder.com/modules/coder-labs/copilot) | - | Partial | [registry#741](https://github.com/coder/registry/issues/741) | `~/.copilot/` | 50 MB |
|
||||
| OpenCode | [opencode](https://registry.coder.com/modules/coder-labs/opencode) | - | Partial | [registry#742](https://github.com/coder/registry/issues/742) | `~/.local/share/opencode/`, `~/.config/opencode/` | 50 MB |
|
||||
| Auggie | [auggie](https://registry.coder.com/modules/coder-labs/auggie) | - | Planned | [registry#743](https://github.com/coder/registry/issues/743) | `~/.augment/` | 50 MB |
|
||||
|
||||
@@ -132,6 +132,18 @@ are queued and delivered when the agent completes its current step, so there is
|
||||
no need to wait for a response before providing additional context or changing
|
||||
direction.
|
||||
|
||||
### Image attachments
|
||||
|
||||
Users can attach images to chat messages by pasting from the clipboard, dragging
|
||||
files into the input area, or using the attachment button. Supported formats are
|
||||
PNG, JPEG, GIF, and WebP up to 10 MB per file. Images are sent to the model as
|
||||
multimodal content alongside the text prompt.
|
||||
|
||||
This is useful for sharing screenshots of errors, UI mockups, terminal output,
|
||||
or other visual context that helps the agent understand the task. Messages can
|
||||
contain images alone or combined with text. Image attachments require a model
|
||||
that supports vision input.
|
||||
|
||||
## Security benefits of the control plane architecture
|
||||
|
||||
Running the agent loop in the control plane rather than inside the developer
|
||||
|
||||
Generated
+78
@@ -1,5 +1,83 @@
|
||||
# Chats
|
||||
|
||||
## Upload a chat file
|
||||
|
||||
### Code samples
|
||||
|
||||
```shell
|
||||
# Example request using curl
|
||||
curl -X POST http://coder-server:8080/api/v2/chats/files?organization=497f6eca-6276-4993-bfeb-53cbbbba6f08 \
|
||||
-H 'Accept: application/json' \
|
||||
-H 'Content-Type: string' \
|
||||
-H 'Coder-Session-Token: API_KEY'
|
||||
```
|
||||
|
||||
`POST /chats/files`
|
||||
|
||||
### Parameters
|
||||
|
||||
| Name | In | Type | Required | Description |
|
||||
|----------------|--------|--------------|----------|-----------------------------------------------------------------------------------|
|
||||
| `Content-Type` | header | string | true | Content-Type must be an image type (image/png, image/jpeg, image/gif, image/webp) |
|
||||
| `organization` | query | string(uuid) | true | Organization ID |
|
||||
|
||||
### Example responses
|
||||
|
||||
> 201 Response
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"
|
||||
}
|
||||
```
|
||||
|
||||
### Responses
|
||||
|
||||
| Status | Meaning | Description | Schema |
|
||||
|--------|----------------------------------------------------------------------------|--------------------------|------------------------------------------------------------------------------|
|
||||
| 201 | [Created](https://tools.ietf.org/html/rfc7231#section-6.3.2) | Created | [codersdk.UploadChatFileResponse](schemas.md#codersdkuploadchatfileresponse) |
|
||||
| 400 | [Bad Request](https://tools.ietf.org/html/rfc7231#section-6.5.1) | Bad Request | [codersdk.Response](schemas.md#codersdkresponse) |
|
||||
| 401 | [Unauthorized](https://tools.ietf.org/html/rfc7235#section-3.1) | Unauthorized | [codersdk.Response](schemas.md#codersdkresponse) |
|
||||
| 413 | [Payload Too Large](https://tools.ietf.org/html/rfc7231#section-6.5.11) | Request Entity Too Large | [codersdk.Response](schemas.md#codersdkresponse) |
|
||||
| 500 | [Internal Server Error](https://tools.ietf.org/html/rfc7231#section-6.6.1) | Internal Server Error | [codersdk.Response](schemas.md#codersdkresponse) |
|
||||
|
||||
To perform this operation, you must be authenticated. [Learn more](authentication.md).
|
||||
|
||||
## Get a chat file
|
||||
|
||||
### Code samples
|
||||
|
||||
```shell
|
||||
# Example request using curl
|
||||
curl -X GET http://coder-server:8080/api/v2/chats/files/{file} \
|
||||
-H 'Accept: */*' \
|
||||
-H 'Coder-Session-Token: API_KEY'
|
||||
```
|
||||
|
||||
`GET /chats/files/{file}`
|
||||
|
||||
### Parameters
|
||||
|
||||
| Name | In | Type | Required | Description |
|
||||
|--------|------|--------------|----------|-------------|
|
||||
| `file` | path | string(uuid) | true | File ID |
|
||||
|
||||
### Example responses
|
||||
|
||||
> 400 Response
|
||||
|
||||
### Responses
|
||||
|
||||
| Status | Meaning | Description | Schema |
|
||||
|--------|----------------------------------------------------------------------------|-----------------------|--------------------------------------------------|
|
||||
| 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | |
|
||||
| 400 | [Bad Request](https://tools.ietf.org/html/rfc7231#section-6.5.1) | Bad Request | [codersdk.Response](schemas.md#codersdkresponse) |
|
||||
| 401 | [Unauthorized](https://tools.ietf.org/html/rfc7235#section-3.1) | Unauthorized | [codersdk.Response](schemas.md#codersdkresponse) |
|
||||
| 404 | [Not Found](https://tools.ietf.org/html/rfc7231#section-6.5.4) | Not Found | [codersdk.Response](schemas.md#codersdkresponse) |
|
||||
| 500 | [Internal Server Error](https://tools.ietf.org/html/rfc7231#section-6.6.1) | Internal Server Error | [codersdk.Response](schemas.md#codersdkresponse) |
|
||||
|
||||
To perform this operation, you must be authenticated. [Learn more](authentication.md).
|
||||
|
||||
## Archive a chat
|
||||
|
||||
### Code samples
|
||||
|
||||
Generated
+14
@@ -9847,6 +9847,20 @@ If the schedule is empty, the user will be updated to use the default schedule.|
|
||||
|----------|---------|----------|--------------|-------------|
|
||||
| `ttl_ms` | integer | false | | |
|
||||
|
||||
## codersdk.UploadChatFileResponse
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"
|
||||
}
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|------|--------|----------|--------------|-------------|
|
||||
| `id` | string | false | | |
|
||||
|
||||
## codersdk.UploadResponse
|
||||
|
||||
```json
|
||||
|
||||
Generated
+3
-2
@@ -12,11 +12,12 @@ coder task send [flags] <task> [<input> | --stdin]
|
||||
## Description
|
||||
|
||||
```console
|
||||
- Send direct input to a task.:
|
||||
Send input to a task. If the task is paused, it will be automatically resumed before input is sent. If the task is initializing, it will wait for the task to become ready.
|
||||
- Send direct input to a task:
|
||||
|
||||
$ coder task send task1 "Please also add unit tests"
|
||||
|
||||
- Send input from stdin to a task.:
|
||||
- Send input from stdin to a task:
|
||||
|
||||
$ echo "Please also add unit tests" | coder task send task1 --stdin
|
||||
```
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# 1.93.1
|
||||
FROM rust:slim@sha256:c0a38f5662afdb298898da1d70b909af4bda4e0acff2dc52aea6360a9b9c6956 AS rust-utils
|
||||
FROM rust:slim@sha256:d6782f2b326a10eaf593eb90cafc34a03a287b4a25fe4d0c693c90304b06f6d7 AS rust-utils
|
||||
# Install rust helper programs
|
||||
ENV CARGO_INSTALL_ROOT=/tmp/
|
||||
# Use more reliable mirrors for Debian packages
|
||||
|
||||
@@ -397,7 +397,7 @@ module "code-server" {
|
||||
module "vscode-web" {
|
||||
count = contains(jsondecode(data.coder_parameter.ide_choices.value), "vscode-web") ? data.coder_workspace.me.start_count : 0
|
||||
source = "dev.registry.coder.com/coder/vscode-web/coder"
|
||||
version = "1.4.3"
|
||||
version = "1.5.0"
|
||||
agent_id = coder_agent.dev.id
|
||||
folder = local.repo_dir
|
||||
extensions = ["github.copilot"]
|
||||
@@ -873,7 +873,7 @@ resource "coder_script" "boundary_config_setup" {
|
||||
module "claude-code" {
|
||||
count = data.coder_task.me.enabled ? data.coder_workspace.me.start_count : 0
|
||||
source = "dev.registry.coder.com/coder/claude-code/coder"
|
||||
version = "4.7.5"
|
||||
version = "4.8.0"
|
||||
enable_boundary = true
|
||||
agent_id = coder_agent.dev.id
|
||||
workdir = local.repo_dir
|
||||
|
||||
@@ -46,6 +46,7 @@ func Test_ProxyServer_Headers(t *testing.T) {
|
||||
"--primary-access-url", srv.URL,
|
||||
"--proxy-session-token", "test-token",
|
||||
"--access-url", "http://localhost:8080",
|
||||
"--http-address", ":0",
|
||||
"--header", fmt.Sprintf("%s=%s", headerName1, headerVal1),
|
||||
"--header-command", fmt.Sprintf("printf %s=%s", headerName2, headerVal2),
|
||||
)
|
||||
@@ -97,7 +98,7 @@ func TestWorkspaceProxy_Server_PrometheusEnabled(t *testing.T) {
|
||||
"--primary-access-url", srv.URL,
|
||||
"--proxy-session-token", "test-token",
|
||||
"--access-url", "http://foobar:3001",
|
||||
"--http-address", fmt.Sprintf("127.0.0.1:%d", testutil.RandomPort(t)),
|
||||
"--http-address", ":0",
|
||||
"--prometheus-enable",
|
||||
"--prometheus-address", fmt.Sprintf("127.0.0.1:%d", prometheusPort),
|
||||
)
|
||||
|
||||
@@ -111,7 +111,7 @@ func (c MultiReplicaSubscribeConfig) clock() quartz.Clock {
|
||||
func NewMultiReplicaSubscribeFn(
|
||||
cfg MultiReplicaSubscribeConfig,
|
||||
) osschatd.SubscribeFn {
|
||||
return func(ctx context.Context, params osschatd.SubscribeFnParams) (<-chan codersdk.ChatStreamEvent, func()) {
|
||||
return func(ctx context.Context, params osschatd.SubscribeFnParams) <-chan codersdk.ChatStreamEvent {
|
||||
chatID := params.ChatID
|
||||
requestHeader := params.RequestHeader
|
||||
logger := params.Logger
|
||||
@@ -149,18 +149,13 @@ func NewMultiReplicaSubscribeFn(
|
||||
|
||||
// Merge all event sources.
|
||||
mergedEvents := make(chan codersdk.ChatStreamEvent, 128)
|
||||
var allCancels []func()
|
||||
if relayCancel != nil {
|
||||
allCancels = append(allCancels, relayCancel)
|
||||
}
|
||||
|
||||
// Channel for async relay establishment.
|
||||
type relayResult struct {
|
||||
parts <-chan codersdk.ChatStreamEvent
|
||||
cancel func()
|
||||
workerID uuid.UUID // the worker this dial targeted
|
||||
}
|
||||
relayReadyCh := make(chan relayResult, 1)
|
||||
relayReadyCh := make(chan relayResult, 4)
|
||||
|
||||
// Per-dial context so in-flight dials can be canceled when
|
||||
// a new dial is initiated or the relay is closed.
|
||||
@@ -182,15 +177,18 @@ func NewMultiReplicaSubscribeFn(
|
||||
dialCancel()
|
||||
dialCancel = nil
|
||||
}
|
||||
// Drain any buffered relay result from a canceled
|
||||
// dial.
|
||||
select {
|
||||
case result := <-relayReadyCh:
|
||||
if result.cancel != nil {
|
||||
result.cancel()
|
||||
// Drain all buffered relay results from canceled dials.
|
||||
for {
|
||||
select {
|
||||
case result := <-relayReadyCh:
|
||||
if result.cancel != nil {
|
||||
result.cancel()
|
||||
}
|
||||
default:
|
||||
goto drained
|
||||
}
|
||||
default:
|
||||
}
|
||||
drained:
|
||||
expectedWorkerID = uuid.Nil
|
||||
if relayCancel != nil {
|
||||
relayCancel()
|
||||
@@ -403,19 +401,11 @@ func NewMultiReplicaSubscribeFn(
|
||||
}
|
||||
}()
|
||||
|
||||
// The cancel function tears down the relay state
|
||||
// indirectly: the merge goroutine owns all relay state
|
||||
// (reconnectTimer, relayCancel, dialCancel, etc.) and
|
||||
// cleans it up via its defer closeRelay() when ctx is
|
||||
// canceled.
|
||||
cancel := func() {
|
||||
for _, cancelFn := range allCancels {
|
||||
if cancelFn != nil {
|
||||
cancelFn()
|
||||
}
|
||||
}
|
||||
}
|
||||
return mergedEvents, cancel
|
||||
// Cleanup is driven by ctx cancellation: the merge
|
||||
// goroutine owns all relay state (reconnectTimer,
|
||||
// relayCancel, dialCancel, etc.) and tears it down
|
||||
// via defer closeRelay() when ctx is done.
|
||||
return mergedEvents
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ data "coder_task" "me" {}
|
||||
module "claude-code" {
|
||||
count = data.coder_workspace.me.start_count
|
||||
source = "registry.coder.com/coder/claude-code/coder"
|
||||
version = "4.7.5"
|
||||
version = "4.8.0"
|
||||
agent_id = coder_agent.main.id
|
||||
workdir = "/home/coder/projects"
|
||||
order = 999
|
||||
|
||||
@@ -99,7 +99,7 @@ require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0
|
||||
github.com/charmbracelet/bubbles v1.0.0
|
||||
github.com/charmbracelet/bubbletea v1.3.10
|
||||
github.com/charmbracelet/glamour v0.10.0
|
||||
github.com/charmbracelet/glamour v1.0.0
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834
|
||||
github.com/chromedp/cdproto v0.0.0-20250724212937-08a3db8b4327
|
||||
github.com/chromedp/chromedp v0.14.1
|
||||
@@ -194,11 +194,11 @@ require (
|
||||
github.com/zclconf/go-cty-yaml v1.2.0
|
||||
go.mozilla.org/pkcs7 v0.9.0
|
||||
go.nhat.io/otelsql v0.16.0
|
||||
go.opentelemetry.io/otel v1.40.0
|
||||
go.opentelemetry.io/otel v1.42.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0
|
||||
go.opentelemetry.io/otel/sdk v1.40.0
|
||||
go.opentelemetry.io/otel/trace v1.40.0
|
||||
go.opentelemetry.io/otel/sdk v1.42.0
|
||||
go.opentelemetry.io/otel/trace v1.42.0
|
||||
go.uber.org/atomic v1.11.0
|
||||
go.uber.org/goleak v1.3.1-0.20240429205332-517bace7cc29
|
||||
go.uber.org/mock v0.6.0
|
||||
@@ -206,10 +206,10 @@ require (
|
||||
golang.org/x/crypto v0.48.0
|
||||
golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa
|
||||
golang.org/x/mod v0.33.0
|
||||
golang.org/x/net v0.50.0
|
||||
golang.org/x/oauth2 v0.35.0
|
||||
golang.org/x/sync v0.19.0
|
||||
golang.org/x/sys v0.41.0
|
||||
golang.org/x/net v0.51.0
|
||||
golang.org/x/oauth2 v0.36.0
|
||||
golang.org/x/sync v0.20.0
|
||||
golang.org/x/sys v0.42.0
|
||||
golang.org/x/term v0.40.0
|
||||
golang.org/x/text v0.34.0
|
||||
golang.org/x/tools v0.42.0
|
||||
@@ -443,8 +443,8 @@ require (
|
||||
go.opentelemetry.io/collector/pdata/pprofile v0.121.0 // indirect
|
||||
go.opentelemetry.io/collector/semconv v0.123.0 // indirect
|
||||
go.opentelemetry.io/contrib v1.19.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0
|
||||
go.opentelemetry.io/otel/metric v1.40.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0
|
||||
go.opentelemetry.io/otel/metric v1.42.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v1.9.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
go.uber.org/zap v1.27.0 // indirect
|
||||
@@ -491,7 +491,6 @@ require (
|
||||
github.com/go-git/go-git/v5 v5.17.0
|
||||
github.com/mark3labs/mcp-go v0.38.0
|
||||
github.com/openai/openai-go/v3 v3.15.0
|
||||
github.com/sergi/go-diff v1.4.0
|
||||
gonum.org/v1/gonum v0.17.0
|
||||
)
|
||||
|
||||
@@ -541,10 +540,8 @@ require (
|
||||
github.com/containerd/errdefs v1.0.0 // indirect
|
||||
github.com/containerd/errdefs/pkg v0.3.0 // indirect
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect
|
||||
github.com/cyphar/filepath-securejoin v0.5.1 // indirect
|
||||
github.com/daixiang0/gci v0.13.7 // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
github.com/emirpasic/gods v1.18.1 // indirect
|
||||
github.com/envoyproxy/go-control-plane/envoy v1.37.0 // indirect
|
||||
github.com/envoyproxy/protoc-gen-validate v1.3.3 // indirect
|
||||
github.com/esiqveland/notify v0.13.3 // indirect
|
||||
@@ -565,7 +562,6 @@ require (
|
||||
github.com/kaptinlin/jsonpointer v0.4.10 // indirect
|
||||
github.com/kaptinlin/jsonschema v0.6.10 // indirect
|
||||
github.com/kaptinlin/messageformat-go v0.4.10 // indirect
|
||||
github.com/kevinburke/ssh_config v1.2.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
|
||||
github.com/landlock-lsm/go-landlock v0.0.0-20251103212306-430f8e5cd97c // indirect
|
||||
github.com/mattn/go-shellwords v1.0.12 // indirect
|
||||
@@ -576,7 +572,6 @@ require (
|
||||
github.com/openai/openai-go v1.12.0 // indirect
|
||||
github.com/openai/openai-go/v2 v2.7.1 // indirect
|
||||
github.com/package-url/packageurl-go v0.1.3 // indirect
|
||||
github.com/pjbgf/sha1cd v0.3.2 // indirect
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
|
||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
|
||||
github.com/rhysd/actionlint v1.7.10 // indirect
|
||||
@@ -584,7 +579,6 @@ require (
|
||||
github.com/samber/lo v1.51.0 // indirect
|
||||
github.com/sergeymakinen/go-bmp v1.0.0 // indirect
|
||||
github.com/sergeymakinen/go-ico v1.0.0-beta.0 // indirect
|
||||
github.com/skeema/knownhosts v1.3.1 // indirect
|
||||
github.com/sony/gobreaker/v2 v2.3.0 // indirect
|
||||
github.com/spf13/cobra v1.10.2 // indirect
|
||||
github.com/spiffe/go-spiffe/v2 v2.6.0 // indirect
|
||||
@@ -594,14 +588,13 @@ require (
|
||||
github.com/urfave/cli/v2 v2.27.5 // indirect
|
||||
github.com/vektah/gqlparser/v2 v2.5.28 // indirect
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||
github.com/xanzy/ssh-agent v0.3.3 // indirect
|
||||
github.com/xhit/go-str2duration/v2 v2.1.0 // indirect
|
||||
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
github.com/zeebo/xxh3 v1.0.2 // indirect
|
||||
go.opentelemetry.io/contrib/detectors/gcp v1.40.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.65.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk/metric v1.40.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk/metric v1.42.0 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
go.yaml.in/yaml/v4 v4.0.0-rc.3 // indirect
|
||||
|
||||
@@ -99,7 +99,6 @@ github.com/KyleBanks/depth v1.2.1/go.mod h1:jzSb9d0L43HxTQfT+oSA1EEp2q+ne2uh6Xge
|
||||
github.com/Masterminds/semver/v3 v3.3.1 h1:QtNSWtVZ3nBfk8mAOu/B6v7FMJ+NHTIgUPi7rj+4nv4=
|
||||
github.com/Masterminds/semver/v3 v3.3.1/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM=
|
||||
github.com/Microsoft/go-winio v0.5.0/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84=
|
||||
github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY=
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 h1:TngWCqHvy9oXAN6lEVMRuU21PR1EtLVZJmdB18Gu3Rw=
|
||||
@@ -157,8 +156,6 @@ github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2 h1:7Ip0wMmLHLRJdrloD
|
||||
github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o=
|
||||
github.com/armon/go-radix v1.0.1-0.20221118154546-54df44f2176c h1:651/eoCRnQ7YtSjAnSzRucrJz+3iGEFt+ysraELS81M=
|
||||
github.com/armon/go-radix v1.0.1-0.20221118154546-54df44f2176c/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
|
||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
|
||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
|
||||
github.com/aslilac/afero v0.0.0-20250403163713-f06e86036696 h1:7hAl/81gNUjmSCqJYKe1aTIVY4myjapaSALdCko19tI=
|
||||
github.com/aslilac/afero v0.0.0-20250403163713-f06e86036696/go.mod h1:acJQ8t0ohCGuMN3O+Pv0V0hgMxNYDlvdk+VTfyZmbYo=
|
||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||
@@ -274,8 +271,8 @@ github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5f
|
||||
github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E=
|
||||
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
|
||||
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
|
||||
github.com/charmbracelet/glamour v0.10.0 h1:MtZvfwsYCx8jEPFJm3rIBFIMZUfUJ765oX8V6kXldcY=
|
||||
github.com/charmbracelet/glamour v0.10.0/go.mod h1:f+uf+I/ChNmqo087elLnVdCiVgjSKWuXa/l6NU2ndYk=
|
||||
github.com/charmbracelet/glamour v1.0.0 h1:AWMLOVFHTsysl4WV8T8QgkQ0s/ZNZo7CiE4WKhk8l08=
|
||||
github.com/charmbracelet/glamour v1.0.0/go.mod h1:DSdohgOBkMr2ZQNhw4LZxSGpx3SvpeujNoXrQyH2hxo=
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 h1:ZR7e0ro+SZZiIZD7msJyA+NjkCNNavuiPBLgerbOziE=
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834/go.mod h1:aKC/t2arECF6rNOnaKaVU6y4t4ZeHQzqfxedE/VkVhA=
|
||||
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
|
||||
@@ -514,8 +511,6 @@ github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 h1:+zs/tPmkDkHx3U66D
|
||||
github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376/go.mod h1:an3vInlBmSxCcxctByoQdvwPiA7DTK7jaaFDBTtu0ic=
|
||||
github.com/go-git/go-billy/v5 v5.8.0 h1:I8hjc3LbBlXTtVuFNJuwYuMiHvQJDq1AT6u4DwDzZG0=
|
||||
github.com/go-git/go-billy/v5 v5.8.0/go.mod h1:RpvI/rw4Vr5QA+Z60c6d6LXH0rYJo0uD5SqfmrrheCY=
|
||||
github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399 h1:eMje31YglSBqCdIqdhKBW8lokaMrL3uTkpGYlE2OOT4=
|
||||
github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399/go.mod h1:1OCfN199q1Jm3HZlxleg+Dw/mwps2Wbk9frAWm+4FII=
|
||||
github.com/go-git/go-git/v5 v5.17.0 h1:AbyI4xf+7DsjINHMu35quAh4wJygKBKBuXVjV/pxesM=
|
||||
github.com/go-git/go-git/v5 v5.17.0/go.mod h1:f82C4YiLx+Lhi8eHxltLeGC5uBTXSFa6PC5WW9o4SjI=
|
||||
github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A=
|
||||
@@ -934,8 +929,6 @@ github.com/olekukonko/ll v0.1.4-0.20260115111900-9e59c2286df0 h1:jrYnow5+hy3WRDC
|
||||
github.com/olekukonko/ll v0.1.4-0.20260115111900-9e59c2286df0/go.mod h1:b52bVQRRPObe+yyBl0TxNfhesL0nedD4Cht0/zx55Ew=
|
||||
github.com/olekukonko/tablewriter v1.1.3 h1:VSHhghXxrP0JHl+0NnKid7WoEmd9/urKRJLysb70nnA=
|
||||
github.com/olekukonko/tablewriter v1.1.3/go.mod h1:9VU0knjhmMkXjnMKrZ3+L2JhhtsQ/L38BbL3CRNE8tM=
|
||||
github.com/onsi/gomega v1.35.1 h1:Cwbd75ZBPxFSuZ6T+rN/WCb/gOc6YgFBXLlZLhC7Ds4=
|
||||
github.com/onsi/gomega v1.35.1/go.mod h1:PvZbdDc8J6XJEpDK4HCuRBm8a6Fzp9/DmhC9C7yFlog=
|
||||
github.com/open-policy-agent/opa v1.6.0 h1:/S/cnNQJ2MUMNzizHPbisTWBHowmLkPrugY5jjkPlRQ=
|
||||
github.com/open-policy-agent/opa v1.6.0/go.mod h1:zFmw4P+W62+CWGYRDDswfVYSCnPo6oYaktQnfIaRFC4=
|
||||
github.com/open-telemetry/opentelemetry-collector-contrib/pkg/sampling v0.120.1 h1:lK/3zr73guK9apbXTcnDnYrC0YCQ25V3CIULYz3k2xU=
|
||||
@@ -1275,11 +1268,11 @@ go.opentelemetry.io/contrib/detectors/gcp v1.40.0 h1:Awaf8gmW99tZTOWqkLCOl6aw1/r
|
||||
go.opentelemetry.io/contrib/detectors/gcp v1.40.0/go.mod h1:99OY9ZCqyLkzJLTh5XhECpLRSxcZl+ZDKBEO+jMBFR4=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.65.0 h1:XmiuHzgJt067+a6kwyAzkhXooYVv3/TOw9cM2VfJgUM=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.65.0/go.mod h1:KDgtbWKTQs4bM+VPUr6WlL9m/WXcmkCcBlIzqxPGzmI=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0 h1:7iP2uCb7sGddAr30RRS6xjKy7AZ2JtTOPA3oolgVSw8=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0/go.mod h1:c7hN3ddxs/z6q9xwvfLPk+UHlWRQyaeR1LdgfL/66l0=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 h1:OyrsyzuttWTSur2qN/Lm0m2a8yqyIjUVBZcxFPuXq2o=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0/go.mod h1:C2NGBr+kAB4bk3xtMXfZ94gqFDtg/GkI7e9zqGh5Beg=
|
||||
go.opentelemetry.io/otel v1.3.0/go.mod h1:PWIKzi6JCp7sM0k9yZ43VX+T345uNbAkDKwHVjb2PTs=
|
||||
go.opentelemetry.io/otel v1.40.0 h1:oA5YeOcpRTXq6NN7frwmwFR0Cn3RhTVZvXsP4duvCms=
|
||||
go.opentelemetry.io/otel v1.40.0/go.mod h1:IMb+uXZUKkMXdPddhwAHm6UfOwJyh4ct1ybIlV14J0g=
|
||||
go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho=
|
||||
go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0 h1:QKdN8ly8zEMrByybbQgv8cWBcdAarwmIPZ6FThrWXJs=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0/go.mod h1:bTdK1nhqF76qiPoCCdyFIV+N/sRHYXYCTQc+3VCi3MI=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0 h1:DvJDOPmSWQHWywQS6lKL+pb8s3gBLOZUtw4N+mavW1I=
|
||||
@@ -1290,16 +1283,16 @@ go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.39.0 h1:5gn2urDL/FBnK8
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.39.0/go.mod h1:0fBG6ZJxhqByfFZDwSwpZGzJU671HkwpWaNe2t4VUPI=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.37.0 h1:SNhVp/9q4Go/XHBkQ1/d5u9P/U+L1yaGPoi0x+mStaI=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.37.0/go.mod h1:tx8OOlGH6R4kLV67YaYO44GFXloEjGPZuMjEkaaqIp4=
|
||||
go.opentelemetry.io/otel/metric v1.40.0 h1:rcZe317KPftE2rstWIBitCdVp89A2HqjkxR3c11+p9g=
|
||||
go.opentelemetry.io/otel/metric v1.40.0/go.mod h1:ib/crwQH7N3r5kfiBZQbwrTge743UDc7DTFVZrrXnqc=
|
||||
go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4=
|
||||
go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI=
|
||||
go.opentelemetry.io/otel/sdk v1.3.0/go.mod h1:rIo4suHNhQwBIPg9axF8V9CA72Wz2mKF1teNrup8yzs=
|
||||
go.opentelemetry.io/otel/sdk v1.40.0 h1:KHW/jUzgo6wsPh9At46+h4upjtccTmuZCFAc9OJ71f8=
|
||||
go.opentelemetry.io/otel/sdk v1.40.0/go.mod h1:Ph7EFdYvxq72Y8Li9q8KebuYUr2KoeyHx0DRMKrYBUE=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.40.0 h1:mtmdVqgQkeRxHgRv4qhyJduP3fYJRMX4AtAlbuWdCYw=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.40.0/go.mod h1:4Z2bGMf0KSK3uRjlczMOeMhKU2rhUqdWNoKcYrtcBPg=
|
||||
go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo=
|
||||
go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc=
|
||||
go.opentelemetry.io/otel/trace v1.3.0/go.mod h1:c/VDhno8888bvQYmbYLqe41/Ldmr/KKunbvWM4/fEjk=
|
||||
go.opentelemetry.io/otel/trace v1.40.0 h1:WA4etStDttCSYuhwvEa8OP8I5EWu24lkOzp+ZYblVjw=
|
||||
go.opentelemetry.io/otel/trace v1.40.0/go.mod h1:zeAhriXecNGP/s2SEG3+Y8X9ujcJOTqQ5RgdEJcawiA=
|
||||
go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY=
|
||||
go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc=
|
||||
go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A=
|
||||
go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4=
|
||||
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
@@ -1329,7 +1322,6 @@ golang.org/x/crypto v0.0.0-20200117160349-530e935923ad/go.mod h1:LzIPMQfyMNhhGPh
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
|
||||
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
|
||||
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
|
||||
@@ -1358,7 +1350,6 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
@@ -1367,10 +1358,10 @@ golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
|
||||
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
|
||||
golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
|
||||
golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ=
|
||||
golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
|
||||
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
|
||||
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
|
||||
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -1382,8 +1373,8 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -1398,7 +1389,6 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@@ -1409,7 +1399,6 @@ golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220627191245-f75cf1eec38b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@@ -1423,8 +1412,8 @@ golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
|
||||
golang.org/x/telemetry v0.0.0-20260209163413-e7419c687ee4 h1:bTLqdHv7xrGlFbvf5/TXNxy/iUwwdkjhqQTJDjW7aj0=
|
||||
golang.org/x/telemetry v0.0.0-20260209163413-e7419c687ee4/go.mod h1:g5NllXBEermZrmR51cJDQxmJUHUOfRAaNyWBM+R+548=
|
||||
@@ -1443,7 +1432,6 @@ golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
|
||||
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
|
||||
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
@@ -1507,7 +1495,6 @@ gopkg.in/DataDog/dd-trace-go.v1 v1.74.0 h1:wScziU1ff6Bnyr8MEyxATPSLJdnLxKz3p6RsA
|
||||
gopkg.in/DataDog/dd-trace-go.v1 v1.74.0/go.mod h1:ReNBsNfnsjVC7GsCe80zRcykL/n+nxvsNrg3NbjuleM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/ini.v1 v1.67.1 h1:tVBILHy0R6e4wkYOn3XmiITt/hEVH4TFMYvAX2Ytz6k=
|
||||
|
||||
Vendored
+2
-1
@@ -1,5 +1,6 @@
|
||||
/// <reference types="next" />
|
||||
/// <reference types="next/image-types/global" />
|
||||
/// <reference path="./.next/types/routes.d.ts" />
|
||||
|
||||
// NOTE: This file should not be edited
|
||||
// see https://nextjs.org/docs/basic-features/typescript for more information.
|
||||
// see https://nextjs.org/docs/pages/api-reference/config/typescript for more information.
|
||||
|
||||
Executable
+13
@@ -0,0 +1,13 @@
|
||||
#!/usr/bin/env bash
|
||||
# Shield this worktree against shared config hooksPath poisoning.
|
||||
# Worktree-scoped config overrides the shared .git/config, so even if
|
||||
# another worktree runs `git config core.hooksPath /dev/null`, this
|
||||
# worktree continues to use the correct hooks.
|
||||
#
|
||||
# This hook runs on `git worktree add` and `git checkout`/`git switch`.
|
||||
# Only needed in linked worktrees where shared config can be poisoned
|
||||
# by another worktree. Skipped in the main checkout to avoid errors
|
||||
# when extensions.worktreeConfig is not set (e.g. fresh clones).
|
||||
if [[ "$(git rev-parse --git-dir)" != "$(git rev-parse --git-common-dir)" ]]; then
|
||||
git config --worktree core.hooksPath scripts/githooks
|
||||
fi
|
||||
@@ -16,4 +16,8 @@ set -euo pipefail
|
||||
cd "$(git rev-parse --show-toplevel)"
|
||||
unset GIT_DIR
|
||||
|
||||
# In linked worktrees, set worktree-scoped hooksPath to override shared config.
|
||||
if [[ "$(git rev-parse --git-dir)" != "$(git rev-parse --git-common-dir)" ]]; then
|
||||
git config --worktree core.hooksPath scripts/githooks
|
||||
fi
|
||||
exec make pre-commit
|
||||
|
||||
@@ -19,4 +19,8 @@ set -euo pipefail
|
||||
cd "$(git rev-parse --show-toplevel)"
|
||||
unset GIT_DIR
|
||||
|
||||
# In linked worktrees, set worktree-scoped hooksPath to override shared config.
|
||||
if [[ "$(git rev-parse --git-dir)" != "$(git rev-parse --git-common-dir)" ]]; then
|
||||
git config --worktree core.hooksPath scripts/githooks
|
||||
fi
|
||||
exec make pre-push
|
||||
|
||||
Executable
+41
@@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env bash
|
||||
# timed-shell.sh wraps bash with per-target wall-clock timing.
|
||||
#
|
||||
# Recipe invocation: timed-shell.sh <target> -ceu <recipe>
|
||||
# $(shell ...) calls: timed-shell.sh -c <command>
|
||||
#
|
||||
# Enable via Makefile:
|
||||
# SHELL := $(CURDIR)/scripts/lib/timed-shell.sh
|
||||
# .SHELLFLAGS = $@ -ceu
|
||||
#
|
||||
# $(shell ...) uses SHELL but passes -c directly, not .SHELLFLAGS.
|
||||
# Detect this and delegate to bash without timing output.
|
||||
if [[ $1 == -* ]]; then
|
||||
exec bash "$@"
|
||||
fi
|
||||
|
||||
set -eu
|
||||
|
||||
target=$1
|
||||
shift
|
||||
|
||||
bold=$(tput bold 2>/dev/null) || true
|
||||
green=$(tput setaf 2 2>/dev/null) || true
|
||||
red=$(tput setaf 1 2>/dev/null) || true
|
||||
reset=$(tput sgr0 2>/dev/null) || true
|
||||
|
||||
start=$(date +%s)
|
||||
echo "${bold}==> ${target}${reset}"
|
||||
|
||||
set +e
|
||||
bash "$@"
|
||||
rc=$?
|
||||
set -e
|
||||
|
||||
elapsed=$(($(date +%s) - start))
|
||||
if ((rc == 0)); then
|
||||
echo "${bold}${green}==> ${target} completed in ${elapsed}s${reset}"
|
||||
else
|
||||
echo "${bold}${red}==> ${target} FAILED after ${elapsed}s${reset}" >&2
|
||||
exit $rc
|
||||
fi
|
||||
@@ -43,7 +43,9 @@ func main() {
|
||||
"version": r.Header.Get("X-Telemetry-Version"),
|
||||
"data": json.RawMessage(body),
|
||||
}
|
||||
_ = enc.Encode(output)
|
||||
if err := enc.Encode(output); err != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "Error encoding telemetry output: %v\n", err)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
}
|
||||
|
||||
+12
-2
@@ -211,8 +211,18 @@ export const verifyParameters = async (
|
||||
case "bool":
|
||||
{
|
||||
const parameterField = parameterLabel.locator("input");
|
||||
const value = await parameterField.isChecked();
|
||||
expect(value.toString()).toEqual(buildParameter.value);
|
||||
// Dynamic parameters can hydrate after initial render
|
||||
// and reset checkbox state. Use auto-retrying assertions
|
||||
// instead of a one-shot isChecked() snapshot.
|
||||
if (buildParameter.value === "true") {
|
||||
await expect(parameterField).toBeChecked({
|
||||
timeout: 15_000,
|
||||
});
|
||||
} else {
|
||||
await expect(parameterField).not.toBeChecked({
|
||||
timeout: 15_000,
|
||||
});
|
||||
}
|
||||
}
|
||||
break;
|
||||
case "string":
|
||||
|
||||
@@ -66,6 +66,9 @@ export default defineConfig({
|
||||
},
|
||||
webServer: {
|
||||
url: `http://localhost:${coderPort}/api/v2/deployment/config`,
|
||||
// The default timeout is 60s, but `go run` compilation with the
|
||||
// embed tag can take longer on CI.
|
||||
timeout: 120_000,
|
||||
command: [
|
||||
`go run -tags embed ${path.join(__dirname, "../../enterprise/cmd/coder")}`,
|
||||
"server",
|
||||
|
||||
@@ -29,6 +29,8 @@
|
||||
<meta property="logo-url" content="{{ .LogoURL }}" />
|
||||
<meta property="tasks-tab-visible" content="{{ .TasksTabVisible }}" />
|
||||
<meta property="agents-tab-visible" content="{{ .AgentsTabVisible }}" />
|
||||
<meta property="permissions" content="{{ .Permissions }}" />
|
||||
<meta property="organizations" content="{{ .Organizations }}" />
|
||||
<link
|
||||
rel="alternate icon"
|
||||
type="image/png"
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
{
|
||||
"viewAllUsers": {
|
||||
"object": { "resource_type": "user" },
|
||||
"action": "read"
|
||||
},
|
||||
"updateUsers": {
|
||||
"object": { "resource_type": "user" },
|
||||
"action": "update"
|
||||
},
|
||||
"createUser": {
|
||||
"object": { "resource_type": "user" },
|
||||
"action": "create"
|
||||
},
|
||||
"createTemplates": {
|
||||
"object": { "resource_type": "template", "any_org": true },
|
||||
"action": "create"
|
||||
},
|
||||
"updateTemplates": {
|
||||
"object": { "resource_type": "template" },
|
||||
"action": "update"
|
||||
},
|
||||
"deleteTemplates": {
|
||||
"object": { "resource_type": "template" },
|
||||
"action": "delete"
|
||||
},
|
||||
"viewDeploymentConfig": {
|
||||
"object": { "resource_type": "deployment_config" },
|
||||
"action": "read"
|
||||
},
|
||||
"editDeploymentConfig": {
|
||||
"object": { "resource_type": "deployment_config" },
|
||||
"action": "update"
|
||||
},
|
||||
"viewDeploymentStats": {
|
||||
"object": { "resource_type": "deployment_stats" },
|
||||
"action": "read"
|
||||
},
|
||||
"readWorkspaceProxies": {
|
||||
"object": { "resource_type": "workspace_proxy" },
|
||||
"action": "read"
|
||||
},
|
||||
"editWorkspaceProxies": {
|
||||
"object": { "resource_type": "workspace_proxy" },
|
||||
"action": "create"
|
||||
},
|
||||
"createOrganization": {
|
||||
"object": { "resource_type": "organization" },
|
||||
"action": "create"
|
||||
},
|
||||
"viewAnyGroup": {
|
||||
"object": { "resource_type": "group" },
|
||||
"action": "read"
|
||||
},
|
||||
"createGroup": {
|
||||
"object": { "resource_type": "group" },
|
||||
"action": "create"
|
||||
},
|
||||
"viewAllLicenses": {
|
||||
"object": { "resource_type": "license" },
|
||||
"action": "read"
|
||||
},
|
||||
"viewNotificationTemplate": {
|
||||
"object": { "resource_type": "notification_template" },
|
||||
"action": "read"
|
||||
},
|
||||
"viewOrganizationIDPSyncSettings": {
|
||||
"object": { "resource_type": "idpsync_settings" },
|
||||
"action": "read"
|
||||
},
|
||||
"viewAnyMembers": {
|
||||
"object": { "resource_type": "organization_member", "any_org": true },
|
||||
"action": "read"
|
||||
},
|
||||
"editAnyGroups": {
|
||||
"object": { "resource_type": "group", "any_org": true },
|
||||
"action": "update"
|
||||
},
|
||||
"assignAnyRoles": {
|
||||
"object": { "resource_type": "assign_org_role", "any_org": true },
|
||||
"action": "assign"
|
||||
},
|
||||
"viewAnyIdpSyncSettings": {
|
||||
"object": { "resource_type": "idpsync_settings", "any_org": true },
|
||||
"action": "read"
|
||||
},
|
||||
"editAnySettings": {
|
||||
"object": { "resource_type": "organization", "any_org": true },
|
||||
"action": "update"
|
||||
},
|
||||
"viewAnyAuditLog": {
|
||||
"object": { "resource_type": "audit_log", "any_org": true },
|
||||
"action": "read"
|
||||
},
|
||||
"viewAnyConnectionLog": {
|
||||
"object": { "resource_type": "connection_log", "any_org": true },
|
||||
"action": "read"
|
||||
},
|
||||
"viewDebugInfo": {
|
||||
"object": { "resource_type": "debug_info" },
|
||||
"action": "read"
|
||||
},
|
||||
"viewAnyAIBridgeInterception": {
|
||||
"object": { "resource_type": "aibridge_interception", "any_org": true },
|
||||
"action": "read"
|
||||
},
|
||||
"createOAuth2App": {
|
||||
"object": { "resource_type": "oauth2_app" },
|
||||
"action": "create"
|
||||
},
|
||||
"editOAuth2App": {
|
||||
"object": { "resource_type": "oauth2_app" },
|
||||
"action": "update"
|
||||
},
|
||||
"deleteOAuth2App": {
|
||||
"object": { "resource_type": "oauth2_app" },
|
||||
"action": "delete"
|
||||
},
|
||||
"viewOAuth2AppSecrets": {
|
||||
"object": { "resource_type": "oauth2_app_secret" },
|
||||
"action": "read"
|
||||
}
|
||||
}
|
||||
+153
-87
@@ -36,7 +36,10 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/entitlements"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/coderd/telemetry"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
@@ -70,6 +73,7 @@ func init() {
|
||||
type Options struct {
|
||||
CacheDir string
|
||||
Database database.Store
|
||||
Authorizer rbac.Authorizer
|
||||
SiteFS fs.FS
|
||||
OAuth2Configs *httpmw.OAuth2Configs
|
||||
DocsURL string
|
||||
@@ -264,6 +268,8 @@ type htmlState struct {
|
||||
|
||||
TasksTabVisible string
|
||||
AgentsTabVisible string
|
||||
Permissions string
|
||||
Organizations string
|
||||
}
|
||||
|
||||
type csrfState struct {
|
||||
@@ -394,6 +400,7 @@ func (h *Handler) renderHTMLWithState(r *http.Request, filePath string, state ht
|
||||
var themePreference string
|
||||
var terminalFont string
|
||||
orgIDs := []uuid.UUID{}
|
||||
var userOrgs []database.Organization
|
||||
eg.Go(func() error {
|
||||
var err error
|
||||
user, err = h.opts.Database.GetUserByID(ctx, apiKey.UserID)
|
||||
@@ -428,100 +435,159 @@ func (h *Handler) renderHTMLWithState(r *http.Request, filePath string, state ht
|
||||
orgIDs = memberIDs[0].OrganizationIDs
|
||||
return err
|
||||
})
|
||||
eg.Go(func() error {
|
||||
orgs, err := h.opts.Database.GetOrganizationsByUserID(ctx, database.GetOrganizationsByUserIDParams{
|
||||
UserID: apiKey.UserID,
|
||||
})
|
||||
if err == nil {
|
||||
userOrgs = orgs
|
||||
}
|
||||
// Don't fail the entire group if we can't fetch orgs.
|
||||
return nil
|
||||
})
|
||||
err := eg.Wait()
|
||||
if err == nil {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
user, err := json.Marshal(db2sdk.User(user, orgIDs))
|
||||
if err == nil {
|
||||
state.User = html.EscapeString(string(user))
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
userAppearance, err := json.Marshal(codersdk.UserAppearanceSettings{
|
||||
ThemePreference: themePreference,
|
||||
TerminalFont: codersdk.TerminalFontName(terminalFont),
|
||||
})
|
||||
if err == nil {
|
||||
state.UserAppearance = html.EscapeString(string(userAppearance))
|
||||
}
|
||||
}()
|
||||
|
||||
if h.Entitlements != nil {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
state.Entitlements = html.EscapeString(string(h.Entitlements.AsJSON()))
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
cfg, err := af.Fetch(ctx)
|
||||
if err == nil {
|
||||
appr, err := json.Marshal(cfg)
|
||||
if err == nil {
|
||||
state.Appearance = html.EscapeString(string(appr))
|
||||
state.ApplicationName = applicationNameOrDefault(cfg)
|
||||
state.LogoURL = cfg.LogoURL
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if h.RegionsFetcher != nil {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
regions, err := h.RegionsFetcher(ctx)
|
||||
if err == nil {
|
||||
regions, err := json.Marshal(regions)
|
||||
if err == nil {
|
||||
state.Regions = html.EscapeString(string(regions))
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
experiments := h.Experiments.Load()
|
||||
if experiments != nil {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
experiments, err := json.Marshal(experiments)
|
||||
if err == nil {
|
||||
state.Experiments = html.EscapeString(string(experiments))
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
tasksTabVisible, err := json.Marshal(!h.opts.HideAITasks)
|
||||
if err == nil {
|
||||
state.TasksTabVisible = html.EscapeString(string(tasksTabVisible))
|
||||
}
|
||||
}()
|
||||
wg.Go(func() {
|
||||
agentsTabVisible := false
|
||||
if experiments != nil {
|
||||
agentsTabVisible = experiments.Enabled(codersdk.ExperimentAgents)
|
||||
}
|
||||
data, err := json.Marshal(agentsTabVisible)
|
||||
if err == nil {
|
||||
state.AgentsTabVisible = html.EscapeString(string(data))
|
||||
}
|
||||
})
|
||||
wg.Wait()
|
||||
h.populateHTMLState(ctx, &state, af, actor, user, orgIDs, userOrgs, themePreference, terminalFont)
|
||||
}
|
||||
|
||||
return execTmpl(tmpl, state)
|
||||
}
|
||||
|
||||
// populateHTMLState runs concurrent goroutines to populate all
|
||||
// authenticated user metadata in the HTML state. This is extracted
|
||||
// from renderHTMLWithState to reduce nesting complexity.
|
||||
func (h *Handler) populateHTMLState(
|
||||
ctx context.Context,
|
||||
state *htmlState,
|
||||
af appearance.Fetcher,
|
||||
actor *rbac.Subject,
|
||||
user database.User,
|
||||
orgIDs []uuid.UUID,
|
||||
userOrgs []database.Organization,
|
||||
themePreference string,
|
||||
terminalFont string,
|
||||
) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Go(func() {
|
||||
data, err := json.Marshal(db2sdk.User(user, orgIDs))
|
||||
if err == nil {
|
||||
state.User = html.EscapeString(string(data))
|
||||
}
|
||||
})
|
||||
wg.Go(func() {
|
||||
data, err := json.Marshal(codersdk.UserAppearanceSettings{
|
||||
ThemePreference: themePreference,
|
||||
TerminalFont: codersdk.TerminalFontName(terminalFont),
|
||||
})
|
||||
if err == nil {
|
||||
state.UserAppearance = html.EscapeString(string(data))
|
||||
}
|
||||
})
|
||||
if h.Entitlements != nil {
|
||||
wg.Go(func() {
|
||||
state.Entitlements = html.EscapeString(string(h.Entitlements.AsJSON()))
|
||||
})
|
||||
}
|
||||
wg.Go(func() {
|
||||
cfg, err := af.Fetch(ctx)
|
||||
if err == nil {
|
||||
appr, err := json.Marshal(cfg)
|
||||
if err == nil {
|
||||
state.Appearance = html.EscapeString(string(appr))
|
||||
state.ApplicationName = applicationNameOrDefault(cfg)
|
||||
state.LogoURL = cfg.LogoURL
|
||||
}
|
||||
}
|
||||
})
|
||||
if h.RegionsFetcher != nil {
|
||||
wg.Go(func() {
|
||||
regions, err := h.RegionsFetcher(ctx)
|
||||
if err == nil {
|
||||
data, err := json.Marshal(regions)
|
||||
if err == nil {
|
||||
state.Regions = html.EscapeString(string(data))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
experiments := h.Experiments.Load()
|
||||
if experiments != nil {
|
||||
wg.Go(func() {
|
||||
data, err := json.Marshal(experiments)
|
||||
if err == nil {
|
||||
state.Experiments = html.EscapeString(string(data))
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Go(func() {
|
||||
data, err := json.Marshal(!h.opts.HideAITasks)
|
||||
if err == nil {
|
||||
state.TasksTabVisible = html.EscapeString(string(data))
|
||||
}
|
||||
})
|
||||
wg.Go(func() {
|
||||
agentsTabVisible := false
|
||||
if experiments != nil {
|
||||
agentsTabVisible = experiments.Enabled(codersdk.ExperimentAgents)
|
||||
}
|
||||
data, err := json.Marshal(agentsTabVisible)
|
||||
if err == nil {
|
||||
state.AgentsTabVisible = html.EscapeString(string(data))
|
||||
}
|
||||
})
|
||||
wg.Go(func() {
|
||||
sdkOrgs := slice.List(userOrgs, db2sdk.Organization)
|
||||
data, err := json.Marshal(sdkOrgs)
|
||||
if err == nil {
|
||||
state.Organizations = html.EscapeString(string(data))
|
||||
}
|
||||
})
|
||||
if h.opts.Authorizer != nil {
|
||||
wg.Go(func() {
|
||||
state.Permissions = h.renderPermissions(ctx, *actor)
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// permissionChecks is the single source of truth for site-wide
|
||||
// permission checks, shared with the TypeScript frontend via
|
||||
// permissions.json.
|
||||
//
|
||||
//go:embed permissions.json
|
||||
var permissionChecksJSON []byte
|
||||
|
||||
var permissionChecks map[string]codersdk.AuthorizationCheck
|
||||
|
||||
func init() {
|
||||
if err := json.Unmarshal(permissionChecksJSON, &permissionChecks); err != nil {
|
||||
panic("failed to parse permissions.json: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// renderPermissions checks all the site-wide permissions for the
|
||||
// given actor and returns an HTML-escaped JSON string suitable for
|
||||
// embedding in a meta tag.
|
||||
func (h *Handler) renderPermissions(ctx context.Context, actor rbac.Subject) string {
|
||||
response := make(codersdk.AuthorizationResponse)
|
||||
for k, v := range permissionChecks {
|
||||
obj := rbac.Object{
|
||||
ID: v.Object.ResourceID,
|
||||
Owner: v.Object.OwnerID,
|
||||
OrgID: v.Object.OrganizationID,
|
||||
AnyOrgOwner: v.Object.AnyOrgOwner,
|
||||
Type: string(v.Object.ResourceType),
|
||||
}
|
||||
err := h.opts.Authorizer.Authorize(ctx, actor, policy.Action(v.Action), obj)
|
||||
response[k] = err == nil
|
||||
}
|
||||
data, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return html.EscapeString(string(data))
|
||||
}
|
||||
|
||||
// noopResponseWriter is a response writer that does nothing.
|
||||
type noopResponseWriter struct{}
|
||||
|
||||
|
||||
@@ -2296,6 +2296,23 @@ class ApiMethods {
|
||||
return response.data;
|
||||
};
|
||||
|
||||
uploadChatFile = async (
|
||||
file: File,
|
||||
organizationId: string,
|
||||
): Promise<TypesGen.UploadChatFileResponse> => {
|
||||
const response = await this.axios.post(
|
||||
`/api/experimental/chats/files?organization=${organizationId}`,
|
||||
file,
|
||||
{
|
||||
headers: {
|
||||
"Content-Type": file.type || "application/octet-stream",
|
||||
"Content-Disposition": `attachment; filename="${file.name}"`,
|
||||
},
|
||||
},
|
||||
);
|
||||
return response.data;
|
||||
};
|
||||
|
||||
getTemplateVersionLogs = async (
|
||||
versionId: string,
|
||||
): Promise<TypesGen.ProvisionerJobLog[]> => {
|
||||
|
||||
@@ -3,17 +3,29 @@ import type {
|
||||
AuthorizationRequest,
|
||||
AuthorizationResponse,
|
||||
} from "api/typesGenerated";
|
||||
import type { MetadataState, MetadataValue } from "hooks/useEmbeddedMetadata";
|
||||
import { disabledRefetchOptions } from "./util";
|
||||
|
||||
const AUTHORIZATION_KEY = "authorization";
|
||||
|
||||
export const getAuthorizationKey = (req: AuthorizationRequest) =>
|
||||
[AUTHORIZATION_KEY, req] as const;
|
||||
|
||||
export const checkAuthorization = <TResponse extends AuthorizationResponse>(
|
||||
export function checkAuthorization<TResponse extends AuthorizationResponse>(
|
||||
req: AuthorizationRequest,
|
||||
) => {
|
||||
return {
|
||||
metadata?: MetadataState<TResponse & MetadataValue>,
|
||||
) {
|
||||
const base = {
|
||||
queryKey: getAuthorizationKey(req),
|
||||
queryFn: () => API.checkAuthorization<TResponse>(req),
|
||||
};
|
||||
};
|
||||
|
||||
if (metadata?.available) {
|
||||
return {
|
||||
...base,
|
||||
initialData: metadata.value as TResponse,
|
||||
...disabledRefetchOptions,
|
||||
};
|
||||
}
|
||||
return base;
|
||||
}
|
||||
|
||||
@@ -6,11 +6,13 @@ import {
|
||||
import type {
|
||||
CreateOrganizationRequest,
|
||||
GroupSyncSettings,
|
||||
Organization,
|
||||
PaginatedMembersRequest,
|
||||
PaginatedMembersResponse,
|
||||
RoleSyncSettings,
|
||||
UpdateOrganizationRequest,
|
||||
} from "api/typesGenerated";
|
||||
import type { MetadataState } from "hooks/useEmbeddedMetadata";
|
||||
import type { UsePaginatedQueryOptions } from "hooks/usePaginatedQuery";
|
||||
import {
|
||||
type OrganizationPermissionName,
|
||||
@@ -24,6 +26,7 @@ import {
|
||||
} from "modules/permissions/workspaces";
|
||||
import type { QueryClient, UseQueryOptions } from "react-query";
|
||||
import { meKey } from "./users";
|
||||
import { cachedQuery } from "./util";
|
||||
|
||||
export const createOrganization = (queryClient: QueryClient) => {
|
||||
return {
|
||||
@@ -160,11 +163,14 @@ export const updateOrganizationMemberRoles = (
|
||||
|
||||
export const organizationsKey = ["organizations"] as const;
|
||||
|
||||
export const organizations = () => {
|
||||
return {
|
||||
const notAvailable = { available: false, value: undefined } as const;
|
||||
|
||||
export const organizations = (metadata?: MetadataState<Organization[]>) => {
|
||||
return cachedQuery({
|
||||
metadata: metadata ?? notAvailable,
|
||||
queryKey: organizationsKey,
|
||||
queryFn: () => API.getOrganizations(),
|
||||
};
|
||||
});
|
||||
};
|
||||
|
||||
export const getProvisionerDaemonsKey = (
|
||||
|
||||
Generated
+40
-2
@@ -1120,12 +1120,28 @@ export interface ChatGitChange {
|
||||
export interface ChatInputPart {
|
||||
readonly type: ChatInputPartType;
|
||||
readonly text?: string;
|
||||
readonly file_id?: string;
|
||||
/**
|
||||
* The following fields are only set when Type is
|
||||
* ChatInputPartTypeFileReference.
|
||||
*/
|
||||
readonly file_name?: string;
|
||||
readonly start_line?: number;
|
||||
readonly end_line?: number;
|
||||
/**
|
||||
* The code content from the diff that was commented on.
|
||||
*/
|
||||
readonly content?: string;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
export type ChatInputPartType = "text";
|
||||
export type ChatInputPartType = "file" | "file-reference" | "text";
|
||||
|
||||
export const ChatInputPartTypes: ChatInputPartType[] = ["text"];
|
||||
export const ChatInputPartTypes: ChatInputPartType[] = [
|
||||
"file",
|
||||
"file-reference",
|
||||
"text",
|
||||
];
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
@@ -1161,11 +1177,24 @@ export interface ChatMessagePart {
|
||||
readonly title?: string;
|
||||
readonly media_type?: string;
|
||||
readonly data?: string;
|
||||
readonly file_id?: string;
|
||||
/**
|
||||
* The following fields are only set when Type is
|
||||
* ChatInputPartTypeFileReference.
|
||||
*/
|
||||
readonly file_name?: string;
|
||||
readonly start_line?: number;
|
||||
readonly end_line?: number;
|
||||
/**
|
||||
* The code content from the diff that was commented on.
|
||||
*/
|
||||
readonly content?: string;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
export type ChatMessagePartType =
|
||||
| "file"
|
||||
| "file-reference"
|
||||
| "reasoning"
|
||||
| "source"
|
||||
| "text"
|
||||
@@ -1174,6 +1203,7 @@ export type ChatMessagePartType =
|
||||
|
||||
export const ChatMessagePartTypes: ChatMessagePartType[] = [
|
||||
"file",
|
||||
"file-reference",
|
||||
"reasoning",
|
||||
"source",
|
||||
"text",
|
||||
@@ -6556,6 +6586,14 @@ export interface UpdateWorkspaceTTLRequest {
|
||||
readonly ttl_ms: number | null;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* UploadChatFileResponse is the response from uploading a chat file.
|
||||
*/
|
||||
export interface UploadChatFileResponse {
|
||||
readonly id: string;
|
||||
}
|
||||
|
||||
// From codersdk/files.go
|
||||
/**
|
||||
* UploadResponse contains the hash to reference the uploaded file.
|
||||
|
||||
@@ -32,6 +32,10 @@ import {
|
||||
useRef,
|
||||
} from "react";
|
||||
import { cn } from "utils/cn";
|
||||
import {
|
||||
$createFileReferenceNode,
|
||||
FileReferenceNode,
|
||||
} from "./FileReferenceNode";
|
||||
|
||||
// Blocks Cmd+B/I/U and element formatting shortcuts so the editor
|
||||
// stays plain-text only.
|
||||
@@ -57,8 +61,11 @@ const DisableFormattingPlugin: FC = memo(function DisableFormattingPlugin() {
|
||||
});
|
||||
|
||||
// Intercepts paste events and inserts clipboard content as plain text,
|
||||
// stripping any rich-text formatting.
|
||||
const PasteSanitizationPlugin: FC = memo(function PasteSanitizationPlugin() {
|
||||
// stripping any rich-text formatting. Image files are forwarded to
|
||||
// the parent via the onFilePaste callback instead of being inserted.
|
||||
const PasteSanitizationPlugin: FC<{
|
||||
onFilePaste?: (file: File) => void;
|
||||
}> = memo(function PasteSanitizationPlugin({ onFilePaste }) {
|
||||
const [editor] = useLexicalComposerContext();
|
||||
|
||||
useEffect(() => {
|
||||
@@ -69,6 +76,22 @@ const PasteSanitizationPlugin: FC = memo(function PasteSanitizationPlugin() {
|
||||
const clipboardData = event.clipboardData;
|
||||
if (!clipboardData) return false;
|
||||
|
||||
// Check for image files in the clipboard (e.g. pasted
|
||||
// screenshots). Forward them to the parent via callback
|
||||
// instead of inserting text.
|
||||
if (onFilePaste && clipboardData.files.length > 0) {
|
||||
const images = Array.from(clipboardData.files).filter((f) =>
|
||||
f.type.startsWith("image/"),
|
||||
);
|
||||
if (images.length > 0) {
|
||||
event.preventDefault();
|
||||
for (const file of images) {
|
||||
onFilePaste(file);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
const text = clipboardData.getData("text/plain");
|
||||
if (!text) return false;
|
||||
|
||||
@@ -106,7 +129,7 @@ const PasteSanitizationPlugin: FC = memo(function PasteSanitizationPlugin() {
|
||||
},
|
||||
COMMAND_PRIORITY_HIGH,
|
||||
);
|
||||
}, [editor]);
|
||||
}, [editor, onFilePaste]);
|
||||
|
||||
return null;
|
||||
});
|
||||
@@ -141,7 +164,7 @@ const EnterKeyPlugin: FC<{ onEnter?: () => void }> = memo(
|
||||
// Fires the onChange callback with the editor's plain-text content
|
||||
// on every update.
|
||||
const ContentChangePlugin: FC<{
|
||||
onChange?: (content: string) => void;
|
||||
onChange?: (content: string, hasFileReferences: boolean) => void;
|
||||
}> = memo(function ContentChangePlugin({ onChange }) {
|
||||
const [editor] = useLexicalComposerContext();
|
||||
|
||||
@@ -152,7 +175,18 @@ const ContentChangePlugin: FC<{
|
||||
editorState.read(() => {
|
||||
const root = $getRoot();
|
||||
const content = root.getTextContent();
|
||||
onChange(content);
|
||||
let hasRefs = false;
|
||||
for (const child of root.getChildren()) {
|
||||
if (child.getType() !== "paragraph") continue;
|
||||
for (const node of (child as ParagraphNode).getChildren()) {
|
||||
if (node instanceof FileReferenceNode) {
|
||||
hasRefs = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (hasRefs) break;
|
||||
}
|
||||
onChange(content, hasRefs);
|
||||
});
|
||||
});
|
||||
}, [editor, onChange]);
|
||||
@@ -203,20 +237,53 @@ const InsertTextPlugin: FC<{
|
||||
return null;
|
||||
});
|
||||
|
||||
/**
|
||||
* Structured data for a file reference extracted from the editor.
|
||||
*/
|
||||
interface FileReferenceData {
|
||||
readonly fileName: string;
|
||||
readonly startLine: number;
|
||||
readonly endLine: number;
|
||||
readonly content: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* A content part extracted from the Lexical editor in document order.
|
||||
* Either a text segment or a file-reference chip.
|
||||
*/
|
||||
type EditorContentPart =
|
||||
| { readonly type: "text"; readonly text: string }
|
||||
| {
|
||||
readonly type: "file-reference";
|
||||
readonly reference: FileReferenceData;
|
||||
};
|
||||
|
||||
export interface ChatMessageInputRef {
|
||||
insertText: (text: string) => void;
|
||||
clear: () => void;
|
||||
focus: () => void;
|
||||
getValue: () => string;
|
||||
/**
|
||||
* Insert a file reference chip in a single Lexical update
|
||||
* (atomic for undo/redo).
|
||||
*/
|
||||
addFileReference: (ref: FileReferenceData) => void;
|
||||
/**
|
||||
* Walk the Lexical tree in document order and return interleaved
|
||||
* text / file-reference parts. Adjacent text nodes within the same
|
||||
* paragraph are merged, and paragraphs are separated by newlines.
|
||||
*/
|
||||
getContentParts: () => EditorContentPart[];
|
||||
}
|
||||
|
||||
interface ChatMessageInputProps
|
||||
extends Omit<React.ComponentProps<"div">, "onChange" | "role" | "ref"> {
|
||||
placeholder?: string;
|
||||
initialValue?: string;
|
||||
onChange?: (content: string) => void;
|
||||
onChange?: (content: string, hasFileReferences: boolean) => void;
|
||||
rows?: number;
|
||||
onEnter?: () => void;
|
||||
onFilePaste?: (file: File) => void;
|
||||
disabled?: boolean;
|
||||
autoFocus?: boolean;
|
||||
"aria-label"?: string;
|
||||
@@ -245,6 +312,7 @@ const ChatMessageInput = memo(
|
||||
onChange,
|
||||
rows,
|
||||
onEnter,
|
||||
onFilePaste,
|
||||
disabled,
|
||||
autoFocus,
|
||||
"aria-label": ariaLabel,
|
||||
@@ -258,7 +326,7 @@ const ChatMessageInput = memo(
|
||||
paragraph: "m-0",
|
||||
},
|
||||
onError: (error: Error) => console.error("Lexical error:", error),
|
||||
nodes: [],
|
||||
nodes: [FileReferenceNode],
|
||||
editable: !disabled,
|
||||
}),
|
||||
[disabled],
|
||||
@@ -277,8 +345,8 @@ const ChatMessageInput = memo(
|
||||
}, []);
|
||||
|
||||
const handleContentChange = useCallback(
|
||||
(content: string) => {
|
||||
onChange?.(content);
|
||||
(content: string, hasFileReferences: boolean) => {
|
||||
onChange?.(content, hasFileReferences);
|
||||
},
|
||||
[onChange],
|
||||
);
|
||||
@@ -358,6 +426,74 @@ const ChatMessageInput = memo(
|
||||
});
|
||||
return content;
|
||||
},
|
||||
addFileReference: (ref: FileReferenceData) => {
|
||||
const editor = editorRef.current;
|
||||
if (!editor) return;
|
||||
|
||||
editor.update(() => {
|
||||
const root = $getRoot();
|
||||
let paragraph = root.getFirstChild();
|
||||
if (!paragraph || paragraph.getType() !== "paragraph") {
|
||||
paragraph = $createParagraphNode();
|
||||
root.append(paragraph);
|
||||
}
|
||||
const chipNode = $createFileReferenceNode(
|
||||
ref.fileName,
|
||||
ref.startLine,
|
||||
ref.endLine,
|
||||
ref.content,
|
||||
);
|
||||
(paragraph as ParagraphNode).append(chipNode);
|
||||
chipNode.selectNext();
|
||||
});
|
||||
},
|
||||
getContentParts: () => {
|
||||
const editor = editorRef.current;
|
||||
if (!editor) return [];
|
||||
const parts: EditorContentPart[] = [];
|
||||
editor.getEditorState().read(() => {
|
||||
const paragraphs = $getRoot().getChildren();
|
||||
for (let i = 0; i < paragraphs.length; i++) {
|
||||
const para = paragraphs[i];
|
||||
if (para.getType() !== "paragraph") continue;
|
||||
// Separate paragraphs with a newline in the
|
||||
// preceding text part, just like getTextContent().
|
||||
if (i > 0) {
|
||||
const last = parts[parts.length - 1];
|
||||
if (last?.type === "text") {
|
||||
(last as { text: string }).text += "\n";
|
||||
} else {
|
||||
parts.push({ type: "text", text: "\n" });
|
||||
}
|
||||
}
|
||||
for (const node of (para as ParagraphNode).getChildren()) {
|
||||
if (node instanceof FileReferenceNode) {
|
||||
parts.push({
|
||||
type: "file-reference",
|
||||
reference: {
|
||||
fileName: node.__fileName,
|
||||
startLine: node.__startLine,
|
||||
endLine: node.__endLine,
|
||||
content: node.__content,
|
||||
},
|
||||
});
|
||||
} else {
|
||||
// Text node (or any other inline) —
|
||||
// merge into the last text part.
|
||||
const t = node.getTextContent();
|
||||
if (!t) continue;
|
||||
const last = parts[parts.length - 1];
|
||||
if (last?.type === "text") {
|
||||
(last as { text: string }).text += t;
|
||||
} else {
|
||||
parts.push({ type: "text", text: t });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
return parts;
|
||||
},
|
||||
}),
|
||||
[],
|
||||
);
|
||||
@@ -376,7 +512,7 @@ const ChatMessageInput = memo(
|
||||
<RichTextPlugin
|
||||
contentEditable={
|
||||
<ContentEditable
|
||||
className="outline-none w-full whitespace-pre-wrap overflow-y-auto max-h-[50vh] [scrollbar-width:thin] [scrollbar-color:hsl(var(--surface-quaternary))_transparent] [&_p]:leading-normal [&_p:first-child]:mt-0 [&_p:last-child]:mb-0"
|
||||
className="outline-none w-full whitespace-pre-wrap overflow-y-auto max-h-[50vh] [scrollbar-width:thin] [scrollbar-color:hsl(var(--surface-quaternary))_transparent] [&_p]:leading-normal [&_p:first-child]:mt-0 [&_p:last-child]:mb-0 py-px"
|
||||
data-testid="chat-message-input"
|
||||
style={{ minHeight: "inherit" }}
|
||||
aria-label={ariaLabel}
|
||||
@@ -392,7 +528,7 @@ const ChatMessageInput = memo(
|
||||
/>
|
||||
<HistoryPlugin />
|
||||
<DisableFormattingPlugin />
|
||||
<PasteSanitizationPlugin />
|
||||
<PasteSanitizationPlugin onFilePaste={onFilePaste} />
|
||||
<EnterKeyPlugin onEnter={disabled ? undefined : onEnter} />
|
||||
<ContentChangePlugin onChange={handleContentChange} />
|
||||
<ValueSyncPlugin initialValue={initialValue} />
|
||||
|
||||
@@ -0,0 +1,218 @@
|
||||
import { useLexicalNodeSelection } from "@lexical/react/useLexicalNodeSelection";
|
||||
import { FileIcon } from "components/FileIcon/FileIcon";
|
||||
import {
|
||||
$getNodeByKey,
|
||||
DecoratorNode,
|
||||
type EditorConfig,
|
||||
type LexicalEditor,
|
||||
type NodeKey,
|
||||
type SerializedLexicalNode,
|
||||
type Spread,
|
||||
} from "lexical";
|
||||
import { XIcon } from "lucide-react";
|
||||
import { type FC, memo, type ReactNode } from "react";
|
||||
import { cn } from "utils/cn";
|
||||
|
||||
type SerializedFileReferenceNode = Spread<
|
||||
{
|
||||
fileName: string;
|
||||
startLine: number;
|
||||
endLine: number;
|
||||
content: string;
|
||||
},
|
||||
SerializedLexicalNode
|
||||
>;
|
||||
|
||||
function FileReferenceChip({
|
||||
fileName,
|
||||
startLine,
|
||||
endLine,
|
||||
isSelected,
|
||||
onRemove,
|
||||
onClick,
|
||||
}: {
|
||||
fileName: string;
|
||||
startLine: number;
|
||||
endLine: number;
|
||||
isSelected?: boolean;
|
||||
onRemove: () => void;
|
||||
onClick?: () => void;
|
||||
}) {
|
||||
const shortFile = fileName.split("/").pop() || fileName;
|
||||
const lineLabel =
|
||||
startLine === endLine ? `L${startLine}` : `L${startLine}–${endLine}`;
|
||||
|
||||
return (
|
||||
<span
|
||||
className={cn(
|
||||
"inline-flex h-6 max-w-[300px] cursor-pointer select-none items-center gap-1.5 rounded-md border border-border-default bg-surface-secondary px-1.5 align-middle text-xs text-content-primary shadow-sm transition-colors",
|
||||
isSelected &&
|
||||
"border-content-link bg-content-link/10 ring-1 ring-content-link/40",
|
||||
)}
|
||||
contentEditable={false}
|
||||
title={`${fileName}:${lineLabel}`}
|
||||
onClick={onClick}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" || e.key === " ") {
|
||||
e.preventDefault();
|
||||
onClick?.();
|
||||
}
|
||||
}}
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
>
|
||||
<FileIcon fileName={shortFile} className="shrink-0" />
|
||||
<span className="shrink-0 text-content-secondary">
|
||||
{shortFile}
|
||||
<span className="text-content-link">:{lineLabel}</span>
|
||||
</span>
|
||||
<button
|
||||
type="button"
|
||||
className="ml-auto inline-flex size-4 shrink-0 items-center justify-center rounded border-0 bg-transparent p-0 text-content-secondary transition-colors hover:text-content-primary cursor-pointer"
|
||||
onClick={(e) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
onRemove();
|
||||
}}
|
||||
aria-label="Remove reference"
|
||||
tabIndex={-1}
|
||||
>
|
||||
<XIcon className="size-2" />
|
||||
</button>
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
export class FileReferenceNode extends DecoratorNode<ReactNode> {
|
||||
__fileName: string;
|
||||
__startLine: number;
|
||||
__endLine: number;
|
||||
__content: string;
|
||||
|
||||
static getType(): string {
|
||||
return "file-reference";
|
||||
}
|
||||
|
||||
static clone(node: FileReferenceNode): FileReferenceNode {
|
||||
return new FileReferenceNode(
|
||||
node.__fileName,
|
||||
node.__startLine,
|
||||
node.__endLine,
|
||||
node.__content,
|
||||
node.__key,
|
||||
);
|
||||
}
|
||||
|
||||
constructor(
|
||||
fileName: string,
|
||||
startLine: number,
|
||||
endLine: number,
|
||||
content: string,
|
||||
key?: NodeKey,
|
||||
) {
|
||||
super(key);
|
||||
this.__fileName = fileName;
|
||||
this.__startLine = startLine;
|
||||
this.__endLine = endLine;
|
||||
this.__content = content;
|
||||
}
|
||||
|
||||
createDOM(_config: EditorConfig): HTMLElement {
|
||||
const span = document.createElement("span");
|
||||
span.style.display = "inline";
|
||||
span.style.userSelect = "none";
|
||||
return span;
|
||||
}
|
||||
|
||||
updateDOM(): boolean {
|
||||
return false;
|
||||
}
|
||||
|
||||
exportJSON(): SerializedFileReferenceNode {
|
||||
return {
|
||||
type: "file-reference",
|
||||
version: 1,
|
||||
fileName: this.__fileName,
|
||||
startLine: this.__startLine,
|
||||
endLine: this.__endLine,
|
||||
content: this.__content,
|
||||
};
|
||||
}
|
||||
|
||||
static importJSON(json: SerializedFileReferenceNode): FileReferenceNode {
|
||||
return new FileReferenceNode(
|
||||
json.fileName,
|
||||
json.startLine,
|
||||
json.endLine,
|
||||
json.content,
|
||||
);
|
||||
}
|
||||
|
||||
getTextContent(): string {
|
||||
return "";
|
||||
}
|
||||
|
||||
isInline(): boolean {
|
||||
return true;
|
||||
}
|
||||
|
||||
decorate(_editor: LexicalEditor): ReactNode {
|
||||
return (
|
||||
<FileReferenceChipWrapper
|
||||
editor={_editor}
|
||||
nodeKey={this.__key}
|
||||
fileName={this.__fileName}
|
||||
startLine={this.__startLine}
|
||||
endLine={this.__endLine}
|
||||
/>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const FileReferenceChipWrapper: FC<{
|
||||
editor: LexicalEditor;
|
||||
nodeKey: NodeKey;
|
||||
fileName: string;
|
||||
startLine: number;
|
||||
endLine: number;
|
||||
}> = memo(({ editor, nodeKey, fileName, startLine, endLine }) => {
|
||||
const [isSelected] = useLexicalNodeSelection(nodeKey);
|
||||
|
||||
const handleRemove = () => {
|
||||
editor.update(() => {
|
||||
const node = $getNodeByKey(nodeKey);
|
||||
if (node instanceof FileReferenceNode) {
|
||||
node.remove();
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
const handleClick = () => {
|
||||
window.dispatchEvent(
|
||||
new CustomEvent("file-reference-click", {
|
||||
detail: { fileName, startLine, endLine },
|
||||
}),
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<FileReferenceChip
|
||||
fileName={fileName}
|
||||
startLine={startLine}
|
||||
endLine={endLine}
|
||||
isSelected={isSelected}
|
||||
onRemove={handleRemove}
|
||||
onClick={handleClick}
|
||||
/>
|
||||
);
|
||||
});
|
||||
FileReferenceChipWrapper.displayName = "FileReferenceChipWrapper";
|
||||
|
||||
export function $createFileReferenceNode(
|
||||
fileName: string,
|
||||
startLine: number,
|
||||
endLine: number,
|
||||
content: string,
|
||||
): FileReferenceNode {
|
||||
return new FileReferenceNode(fileName, startLine, endLine, content);
|
||||
}
|
||||
@@ -50,7 +50,10 @@ export const AuthProvider: FC<PropsWithChildren> = ({ children }) => {
|
||||
const hasFirstUserQuery = useQuery(hasFirstUser(userMetadataState));
|
||||
|
||||
const permissionsQuery = useQuery({
|
||||
...checkAuthorization({ checks: permissionChecks }),
|
||||
...checkAuthorization<Permissions>(
|
||||
{ checks: permissionChecks },
|
||||
metadata.permissions,
|
||||
),
|
||||
enabled: userQuery.data !== undefined,
|
||||
});
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ import {
|
||||
MockBuildInfo,
|
||||
MockEntitlements,
|
||||
MockExperiments,
|
||||
MockOrganization,
|
||||
MockPermissions,
|
||||
MockTasksTabVisible,
|
||||
MockUserAppearanceSettings,
|
||||
MockUserOwner,
|
||||
@@ -45,6 +47,8 @@ const mockDataForTags = {
|
||||
regions: MockRegions,
|
||||
"tasks-tab-visible": MockTasksTabVisible,
|
||||
"agents-tab-visible": MockAgentsTabVisible,
|
||||
permissions: MockPermissions,
|
||||
organizations: [MockOrganization],
|
||||
} as const satisfies Record<MetadataKey, MetadataValue>;
|
||||
|
||||
const emptyMetadata: RuntimeHtmlMetadata = {
|
||||
@@ -84,6 +88,14 @@ const emptyMetadata: RuntimeHtmlMetadata = {
|
||||
available: false,
|
||||
value: undefined,
|
||||
},
|
||||
permissions: {
|
||||
available: false,
|
||||
value: undefined,
|
||||
},
|
||||
organizations: {
|
||||
available: false,
|
||||
value: undefined,
|
||||
},
|
||||
};
|
||||
|
||||
const populatedMetadata: RuntimeHtmlMetadata = {
|
||||
@@ -123,6 +135,14 @@ const populatedMetadata: RuntimeHtmlMetadata = {
|
||||
available: true,
|
||||
value: MockAgentsTabVisible,
|
||||
},
|
||||
permissions: {
|
||||
available: true,
|
||||
value: MockPermissions,
|
||||
},
|
||||
organizations: {
|
||||
available: true,
|
||||
value: [MockOrganization],
|
||||
},
|
||||
};
|
||||
|
||||
function seedInitialMetadata(metadataKey: string): () => void {
|
||||
|
||||
@@ -3,10 +3,12 @@ import type {
|
||||
BuildInfoResponse,
|
||||
Entitlements,
|
||||
Experiment,
|
||||
Organization,
|
||||
Region,
|
||||
User,
|
||||
UserAppearanceSettings,
|
||||
} from "api/typesGenerated";
|
||||
import type { Permissions } from "modules/permissions";
|
||||
import { useMemo, useSyncExternalStore } from "react";
|
||||
export const DEFAULT_METADATA_KEY = "property";
|
||||
|
||||
@@ -31,6 +33,8 @@ type AvailableMetadata = Readonly<{
|
||||
"build-info": BuildInfoResponse;
|
||||
"tasks-tab-visible": boolean;
|
||||
"agents-tab-visible": boolean;
|
||||
permissions: Permissions;
|
||||
organizations: Organization[];
|
||||
}>;
|
||||
|
||||
export type MetadataKey = keyof AvailableMetadata;
|
||||
@@ -94,6 +98,8 @@ export class MetadataManager implements MetadataManagerApi {
|
||||
regions: this.registerRegionValue(),
|
||||
"tasks-tab-visible": this.registerValue<boolean>("tasks-tab-visible"),
|
||||
"agents-tab-visible": this.registerValue<boolean>("agents-tab-visible"),
|
||||
permissions: this.registerValue<Permissions>("permissions"),
|
||||
organizations: this.registerValue<Organization[]>("organizations"),
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user