Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 98 additions & 7 deletions lib/hypervisor/qemu/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,67 @@ func buildQMPArgs(socketPath string) []string {
}
}

type startedProcess struct {
pid int
socketPath string
waitDone chan error
waitConsumed bool
waitErr error
}

func startManagedProcess(cmd *exec.Cmd, socketPath string) (*startedProcess, error) {
if err := cmd.Start(); err != nil {
return nil, err
}

proc := &startedProcess{
pid: cmd.Process.Pid,
socketPath: socketPath,
waitDone: make(chan error, 1),
}
go func() {
err := cmd.Wait()
_ = os.Remove(socketPath)
proc.waitDone <- err
}()

return proc, nil
}

func (p *startedProcess) checkExited() (error, bool) {
if p.waitConsumed {
return p.waitErr, true
}

select {
case err := <-p.waitDone:
p.waitConsumed = true
p.waitErr = err
return err, true
default:
return nil, false
}
}

func (p *startedProcess) wait() error {
if err, exited := p.checkExited(); exited {
return err
}

err := <-p.waitDone
p.waitConsumed = true
p.waitErr = err
return err
}

func (p *startedProcess) cleanup() {
if _, exited := p.checkExited(); !exited {
_ = syscall.Kill(p.pid, syscall.SIGKILL)
_ = p.wait()
}
_ = os.Remove(p.socketPath)
}

// startQEMUProcess handles the common QEMU process startup logic.
// Returns the PID, hypervisor client, and a cleanup function.
// The cleanup function must be called on error; call cleanup.Release() on success.
Expand Down Expand Up @@ -190,23 +251,22 @@ func (s *Starter) startQEMUProcess(ctx context.Context, p *paths.Paths, version
cmd.Stderr = vmmLogFile

processStartTime := time.Now()
if err := cmd.Start(); err != nil {
proc, err := startManagedProcess(cmd, socketPath)
if err != nil {
processSpan.RecordError(err)
processSpan.SetStatus(codes.Error, err.Error())
return 0, nil, nil, fmt.Errorf("start qemu: %w", err)
}

pid := cmd.Process.Pid
pid := proc.pid
log.DebugContext(processCtx, "QEMU process started", "pid", pid, "duration_ms", time.Since(processStartTime).Milliseconds())

// Setup cleanup to kill the process if subsequent steps fail
cu := cleanup.Make(func() {
syscall.Kill(pid, syscall.SIGKILL)
})
// Setup cleanup to kill, reap, and remove the socket if subsequent steps fail.
cu := cleanup.Make(proc.cleanup)

// Wait for socket to be ready
socketWaitStart := time.Now()
if err := waitForSocket(socketPath, socketWaitTimeout); err != nil {
if err := waitForSocketOrExit(socketPath, socketWaitTimeout, proc); err != nil {
processSpan.RecordError(err)
processSpan.SetStatus(codes.Error, err.Error())
cu.Clean()
Expand All @@ -219,6 +279,14 @@ func (s *Starter) startQEMUProcess(ctx context.Context, p *paths.Paths, version
var hv *QEMU
clientDeadline := time.Now().Add(clientCreateTimeout)
for {
if waitErr, exited := proc.checkExited(); exited {
err = fmt.Errorf("qemu exited early: %w", waitErr)
processSpan.RecordError(err)
processSpan.SetStatus(codes.Error, err.Error())
cu.Clean()
return 0, nil, nil, appendVMMLog(err, logsDir)
}

hv, err = New(socketPath)
if err == nil {
break
Expand Down Expand Up @@ -481,3 +549,26 @@ func waitForSocket(socketPath string, timeout time.Duration) error {
}
return fmt.Errorf("timeout waiting for socket")
}

Comment thread
sjmiller609 marked this conversation as resolved.
func waitForSocketOrExit(socketPath string, timeout time.Duration, proc *startedProcess) error {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
conn, err := net.DialTimeout("unix", socketPath, socketDialTimeout)
if err == nil {
conn.Close()
return nil
}

if waitErr, exited := proc.checkExited(); exited {
return fmt.Errorf("qemu exited early: %w", waitErr)
}

time.Sleep(socketPollInterval)
}

if waitErr, exited := proc.checkExited(); exited {
return fmt.Errorf("qemu exited early: %w", waitErr)
}

return fmt.Errorf("timeout waiting for socket")
}
40 changes: 40 additions & 0 deletions lib/hypervisor/qemu/process_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ package qemu

import (
"errors"
"os"
"os/exec"
"path/filepath"
"regexp"
"testing"
"time"

"github.com/kernel/hypeman/lib/paths"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -176,3 +179,40 @@ func TestShouldRetrySameConfig(t *testing.T) {
})
}
}

func TestStartManagedProcessCleanupRemovesSocketAndReapsExitedProcess(t *testing.T) {
socketPath := filepath.Join(t.TempDir(), "qemu.sock")
require.NoError(t, os.WriteFile(socketPath, []byte("stale"), 0600))

cmd := exec.Command("sh", "-c", "exit 0")
proc, err := startManagedProcess(cmd, socketPath)
require.NoError(t, err)

require.Eventually(t, func() bool {
_, exited := proc.checkExited()
return exited
}, time.Second, 10*time.Millisecond)

proc.cleanup()

require.NoFileExists(t, socketPath)
require.NotNil(t, cmd.ProcessState)
assert.True(t, cmd.ProcessState.Exited())
}

func TestWaitForSocketOrExitReturnsEarlyWhenProcessDies(t *testing.T) {
socketPath := filepath.Join(t.TempDir(), "qemu.sock")

cmd := exec.Command("sh", "-c", "exit 7")
proc, err := startManagedProcess(cmd, socketPath)
require.NoError(t, err)

start := time.Now()
err = waitForSocketOrExit(socketPath, time.Second, proc)
require.Error(t, err)
assert.ErrorContains(t, err, "qemu exited early")
assert.Less(t, time.Since(start), 500*time.Millisecond)
require.NoFileExists(t, socketPath)
require.NotNil(t, cmd.ProcessState)
assert.True(t, cmd.ProcessState.Exited())
}
Loading