Skip to content

Commit

Permalink
feat: improve podman cli execution
Browse files Browse the repository at this point in the history
Signed-off-by: axel7083 <[email protected]>
  • Loading branch information
axel7083 committed Sep 30, 2024
1 parent 4d420cb commit 963bad8
Show file tree
Hide file tree
Showing 11 changed files with 336 additions and 173 deletions.
22 changes: 4 additions & 18 deletions packages/backend/src/managers/modelsManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import os from 'node:os';
import fs, { type Stats, type PathLike } from 'node:fs';
import path from 'node:path';
import { ModelsManager } from './modelsManager';
import { env, process as coreProcess } from '@podman-desktop/api';
import { env } from '@podman-desktop/api';
import type { RunResult, TelemetryLogger, Webview, ContainerProviderConnection } from '@podman-desktop/api';
import type { CatalogManager } from './catalogManager';
import type { ModelInfo } from '@shared/src/models/IModelInfo';
Expand All @@ -33,7 +33,6 @@ import type { GGUFParseOutput } from '@huggingface/gguf';
import { gguf } from '@huggingface/gguf';
import type { PodmanConnection } from './podmanConnection';
import { VMType } from '@shared/src/models/IPodman';
import { getPodmanMachineName } from '../utils/podman';
import type { ConfigurationRegistry } from '../registries/ConfigurationRegistry';
import { Uploader } from '../utils/uploader';

Expand All @@ -47,7 +46,6 @@ const mocks = vi.hoisted(() => {
getTargetMock: vi.fn(),
getDownloaderCompleter: vi.fn(),
isCompletionEventMock: vi.fn(),
getPodmanCliMock: vi.fn(),
};
});

Expand All @@ -59,11 +57,6 @@ vi.mock('@huggingface/gguf', () => ({
gguf: vi.fn(),
}));

vi.mock('../utils/podman', () => ({
getPodmanCli: mocks.getPodmanCliMock,
getPodmanMachineName: vi.fn(),
}));

vi.mock('@podman-desktop/api', () => {
return {
Disposable: {
Expand All @@ -72,9 +65,6 @@ vi.mock('@podman-desktop/api', () => {
env: {
isWindows: false,
},
process: {
exec: vi.fn(),
},
fs: {
createFileSystemWatcher: (): unknown => ({
onDidCreate: vi.fn(),
Expand Down Expand Up @@ -102,6 +92,7 @@ vi.mock('../utils/downloader', () => ({

const podmanConnectionMock = {
getContainerProviderConnections: vi.fn(),
executeSSH: vi.fn(),
} as unknown as PodmanConnection;

const cancellationTokenRegistryMock = {
Expand Down Expand Up @@ -598,8 +589,7 @@ describe('deleting models', () => {
});

test('deleting on windows should check for all connections', async () => {
vi.mocked(coreProcess.exec).mockResolvedValue({} as RunResult);
mocks.getPodmanCliMock.mockReturnValue('dummyCli');
vi.mocked(podmanConnectionMock.executeSSH).mockResolvedValue({} as RunResult);
vi.mocked(env).isWindows = true;
const connections: ContainerProviderConnection[] = [
{
Expand All @@ -622,7 +612,6 @@ describe('deleting models', () => {
},
];
vi.mocked(podmanConnectionMock.getContainerProviderConnections).mockReturnValue(connections);
vi.mocked(getPodmanMachineName).mockReturnValue('machine-2');

const rmSpy = vi.spyOn(fs.promises, 'rm');
rmSpy.mockResolvedValue(undefined);
Expand Down Expand Up @@ -659,10 +648,7 @@ describe('deleting models', () => {

expect(podmanConnectionMock.getContainerProviderConnections).toHaveBeenCalledOnce();

expect(coreProcess.exec).toHaveBeenCalledWith('dummyCli', [
'machine',
'ssh',
'machine-2',
expect(podmanConnectionMock.executeSSH).toHaveBeenCalledWith(connections[1], [
'rm',
'-f',
'/home/user/ai-lab/models/dummyFile',
Expand Down
32 changes: 23 additions & 9 deletions packages/backend/src/managers/modelsManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ import type { Task } from '@shared/src/models/ITask';
import type { BaseEvent } from '../models/baseEvent';
import { isCompletionEvent, isProgressEvent } from '../models/baseEvent';
import { Uploader } from '../utils/uploader';
import { deleteRemoteModel, getLocalModelFile, isModelUploaded } from '../utils/modelsUtils';
import { getPodmanMachineName } from '../utils/podman';
import { getLocalModelFile, getRemoteModelFile } from '../utils/modelsUtils';
import type { CancellationTokenRegistry } from '../registries/CancellationTokenRegistry';
import { getHash, hasValidSha } from '../utils/sha';
import type { GGUFParseOutput } from '@huggingface/gguf';
Expand Down Expand Up @@ -231,17 +230,32 @@ export class ModelsManager implements Disposable {
for (const connection of connections) {
// ignore non-wsl machines
if (connection.vmType !== VMType.WSL) continue;
// Get the corresponding machine name
const machineName = getPodmanMachineName(connection);

// check if model already loaded on the podman machine
const existsRemote = await isModelUploaded(machineName, modelInfo);
if (!existsRemote) return;
// check if remote model is
try {
await this.podmanConnection.executeSSH(connection, ['stat', getRemoteModelFile(modelInfo)]);
} catch (err: unknown) {
console.warn(err);
continue;
}

await deleteRemoteModel(machineName, modelInfo);
await this.deleteRemoteModelByConnection(connection, modelInfo);
}
}

/**
* Delete a model given a {@link ContainerProviderConnection}
* @param connection
* @param modelInfo
* @protected
*/
protected async deleteRemoteModelByConnection(
connection: ContainerProviderConnection,
modelInfo: ModelInfo,
): Promise<void> {
await this.podmanConnection.executeSSH(connection, ['rm', '-f', getRemoteModelFile(modelInfo)]);
}

/**
* This method will resolve when the provided model will be downloaded.
*
Expand Down Expand Up @@ -439,7 +453,7 @@ export class ModelsManager implements Disposable {
connection: connection.name,
});

const uploader = new Uploader(connection, model);
const uploader = new Uploader(this.podmanConnection, connection, model);
uploader.onEvent(event => this.onDownloadUploadEvent(event, 'upload'), this);

// perform download
Expand Down
155 changes: 154 additions & 1 deletion packages/backend/src/managers/podmanConnection.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import { beforeEach, describe, expect, test, vi } from 'vitest';
import { PodmanConnection } from './podmanConnection';
import type {
ContainerProviderConnection,
Extension,
ProviderConnectionStatus,
ProviderContainerConnection,
ProviderEvent,
Expand All @@ -29,10 +30,11 @@ import type {
UpdateContainerConnectionEvent,
Webview,
} from '@podman-desktop/api';
import { containerEngine, process, provider, EventEmitter, env } from '@podman-desktop/api';
import { containerEngine, extensions, process, provider, EventEmitter, env } from '@podman-desktop/api';
import { VMType } from '@shared/src/models/IPodman';
import { Messages } from '@shared/Messages';
import type { ModelInfo } from '@shared/src/models/IModelInfo';
import { getPodmanCli, getPodmanMachineName } from '../utils/podman';

const webviewMock = {
postMessage: vi.fn(),
Expand All @@ -51,6 +53,9 @@ vi.mock('@podman-desktop/api', async () => {
process: {
exec: vi.fn(),
},
extensions: {
getExtension: vi.fn(),
},
containerEngine: {
listInfos: vi.fn(),
},
Expand All @@ -64,6 +69,7 @@ vi.mock('@podman-desktop/api', async () => {
vi.mock('../utils/podman', () => {
return {
getPodmanCli: vi.fn(),
getPodmanMachineName: vi.fn(),
MIN_CPUS_VALUE: 4,
};
});
Expand All @@ -73,6 +79,8 @@ beforeEach(() => {

vi.mocked(webviewMock.postMessage).mockResolvedValue(true);
vi.mocked(provider.getContainerConnections).mockReturnValue([]);
vi.mocked(getPodmanCli).mockReturnValue('podman-executable');
vi.mocked(getPodmanMachineName).mockImplementation(connection => connection.name);

const listeners: ((value: unknown) => void)[] = [];

Expand All @@ -86,6 +94,151 @@ beforeEach(() => {
} as unknown as EventEmitter<unknown>);
});

const providerContainerConnectionMock: ProviderContainerConnection = {
connection: {
type: 'podman',
status: () => 'started',
name: 'Podman Machine',
endpoint: {
socketPath: './socket-path',
},
},
providerId: 'podman',
};

describe('execute', () => {
test('execute should get the podman extension from api', async () => {
vi.mocked(extensions.getExtension).mockReturnValue(undefined);
const manager = new PodmanConnection(webviewMock);
await manager.execute(providerContainerConnectionMock.connection, ['ls']);
expect(extensions.getExtension).toHaveBeenCalledWith('podman-desktop.podman');
});

test('execute should call getPodmanCli if extension not available', async () => {
vi.mocked(extensions.getExtension).mockReturnValue(undefined);
const manager = new PodmanConnection(webviewMock);
await manager.execute(providerContainerConnectionMock.connection, ['ls']);

expect(getPodmanCli).toHaveBeenCalledOnce();
expect(process.exec).toHaveBeenCalledWith('podman-executable', ['ls'], undefined);
});

test('options should be propagated to process execution when provided', async () => {
vi.mocked(extensions.getExtension).mockReturnValue(undefined);
const manager = new PodmanConnection(webviewMock);
await manager.execute(providerContainerConnectionMock.connection, ['ls'], {
isAdmin: true,
});

expect(getPodmanCli).toHaveBeenCalledOnce();
expect(process.exec).toHaveBeenCalledWith('podman-executable', ['ls'], {
isAdmin: true,
});
});

test('execute should use extension exec if available', async () => {
vi.mocked(provider.getContainerConnections).mockReturnValue([providerContainerConnectionMock]);
const podmanAPI = {
exec: vi.fn(),
};
vi.mocked(extensions.getExtension).mockReturnValue({ exports: podmanAPI } as unknown as Extension<unknown>);
const manager = new PodmanConnection(webviewMock);
await manager.execute(providerContainerConnectionMock.connection, ['ls']);

expect(getPodmanCli).not.toHaveBeenCalledOnce();
expect(podmanAPI.exec).toHaveBeenCalledWith(['ls'], {
connection: providerContainerConnectionMock,
});
});

test('an error should be throw if the provided container connection do not exists', async () => {
vi.mocked(provider.getContainerConnections).mockReturnValue([]);
const podmanAPI = {
exec: vi.fn(),
};
vi.mocked(extensions.getExtension).mockReturnValue({ exports: podmanAPI } as unknown as Extension<unknown>);
const manager = new PodmanConnection(webviewMock);

await expect(async () => {
await manager.execute(providerContainerConnectionMock.connection, ['ls'], {
isAdmin: true,
});
}).rejects.toThrowError('cannot find podman provider with connection name Podman Machine');
});

test('execute should propagate options to extension exec if available', async () => {
vi.mocked(provider.getContainerConnections).mockReturnValue([providerContainerConnectionMock]);
const podmanAPI = {
exec: vi.fn(),
};
vi.mocked(extensions.getExtension).mockReturnValue({ exports: podmanAPI } as unknown as Extension<unknown>);
const manager = new PodmanConnection(webviewMock);
await manager.execute(providerContainerConnectionMock.connection, ['ls'], {
isAdmin: true,
});

expect(getPodmanCli).not.toHaveBeenCalledOnce();
expect(podmanAPI.exec).toHaveBeenCalledWith(['ls'], {
isAdmin: true,
connection: providerContainerConnectionMock,
});
});
});

describe('executeSSH', () => {
test('executeSSH should call getPodmanCli if extension not available', async () => {
vi.mocked(extensions.getExtension).mockReturnValue(undefined);
const manager = new PodmanConnection(webviewMock);
await manager.executeSSH(providerContainerConnectionMock.connection, ['ls']);

expect(getPodmanCli).toHaveBeenCalledOnce();
expect(process.exec).toHaveBeenCalledWith(
'podman-executable',
['machine', 'ssh', providerContainerConnectionMock.connection.name, 'ls'],
undefined,
);
});

test('executeSSH should use extension exec if available', async () => {
vi.mocked(provider.getContainerConnections).mockReturnValue([providerContainerConnectionMock]);
const podmanAPI = {
exec: vi.fn(),
};
vi.mocked(extensions.getExtension).mockReturnValue({ exports: podmanAPI } as unknown as Extension<unknown>);
const manager = new PodmanConnection(webviewMock);
await manager.executeSSH(providerContainerConnectionMock.connection, ['ls']);

expect(getPodmanCli).not.toHaveBeenCalledOnce();
expect(podmanAPI.exec).toHaveBeenCalledWith(
['machine', 'ssh', providerContainerConnectionMock.connection.name, 'ls'],
{
connection: providerContainerConnectionMock,
},
);
});

test('executeSSH should propagate options to extension exec if available', async () => {
vi.mocked(provider.getContainerConnections).mockReturnValue([providerContainerConnectionMock]);
const podmanAPI = {
exec: vi.fn(),
};
vi.mocked(extensions.getExtension).mockReturnValue({ exports: podmanAPI } as unknown as Extension<unknown>);
const manager = new PodmanConnection(webviewMock);
await manager.executeSSH(providerContainerConnectionMock.connection, ['ls'], {
isAdmin: true,
});

expect(getPodmanCli).not.toHaveBeenCalledOnce();
expect(podmanAPI.exec).toHaveBeenCalledWith(
['machine', 'ssh', providerContainerConnectionMock.connection.name, 'ls'],
{
isAdmin: true,
connection: providerContainerConnectionMock,
},
);
});
});

describe('podman connection initialization', () => {
test('init should notify publisher', () => {
const manager = new PodmanConnection(webviewMock);
Expand Down
Loading

0 comments on commit 963bad8

Please sign in to comment.