diff --git a/comfy_cli/cmdline.py b/comfy_cli/cmdline.py index 7dafc5e..6949240 100644 --- a/comfy_cli/cmdline.py +++ b/comfy_cli/cmdline.py @@ -9,8 +9,7 @@ from rich import print from rich.console import Console from typing_extensions import Annotated, List - -from comfy_cli import constants, env_checker, logging, tracking, ui, utils +from comfy_cli import constants, env_checker, logging, tracking, utils, ui from comfy_cli.command import custom_nodes from comfy_cli.command import install as install_inner from comfy_cli.command import run as run_inner @@ -334,7 +333,7 @@ def update( "comfy", help="[all|comfy]", autocompletion=utils.create_choice_completer(["all", "comfy"]), - ) + ), ): if target not in ["all", "comfy"]: typer.echo( diff --git a/comfy_cli/registry/api.py b/comfy_cli/registry/api.py index c8d72e0..943e011 100644 --- a/comfy_cli/registry/api.py +++ b/comfy_cli/registry/api.py @@ -5,7 +5,13 @@ import requests # Reduced global imports from comfy_cli.registry -from comfy_cli.registry.types import Node, NodeVersion, PublishNodeVersionResponse +from comfy_cli.registry.types import ( + Node, + NodeVersion, + PublishNodeVersionResponse, + PyProjectConfig, + License, +) class RegistryAPI: @@ -17,11 +23,13 @@ def determine_base_url(self): if env == "dev": return "http://localhost:8080" elif env == "staging": - return "https://staging.comfyregistry.org" + return "https://stagingapi.comfy.org" else: return "https://api.comfy.org" - def publish_node_version(self, node_config, token) -> PublishNodeVersionResponse: + def publish_node_version( + self, node_config: PyProjectConfig, token + ) -> PublishNodeVersionResponse: """ Publishes a new version of a node. @@ -33,7 +41,6 @@ def publish_node_version(self, node_config, token) -> PublishNodeVersionResponse PublishNodeVersionResponse: The response object from the API server. """ # Local import to prevent circular dependency - if not node_config.tool_comfy.publisher_id: raise Exception( "Publisher ID is required in pyproject.toml to publish a node version" @@ -43,17 +50,15 @@ def publish_node_version(self, node_config, token) -> PublishNodeVersionResponse raise Exception( "Project name is required in pyproject.toml to publish a node version" ) - - url = f"{self.base_url}/publishers/{node_config.tool_comfy.publisher_id}/nodes/{node_config.project.name}/versions" - headers = {"Content-Type": "application/json"} - body = { + license_json = serialize_license(node_config.project.license) + request_body = { "personal_access_token": token, "node": { "id": node_config.project.name, "description": node_config.project.description, "icon": node_config.tool_comfy.icon, "name": node_config.tool_comfy.display_name, - "license": node_config.project.license, + "license": license_json, "repository": node_config.project.urls.repository, }, "node_version": { @@ -61,6 +66,10 @@ def publish_node_version(self, node_config, token) -> PublishNodeVersionResponse "dependencies": node_config.project.dependencies, }, } + print(request_body) + url = f"{self.base_url}/publishers/{node_config.tool_comfy.publisher_id}/nodes/{node_config.project.name}/versions" + headers = {"Content-Type": "application/json"} + body = request_body response = requests.post(url, headers=headers, data=json.dumps(body)) @@ -177,3 +186,11 @@ def map_node_to_node_class(api_node_data): else None ), ) + + +def serialize_license(license: License) -> str: + if license.file: + return json.dumps({"file": license.file}) + if license.text: + return json.dumps({"text": license.text}) + return "{}" diff --git a/comfy_cli/registry/config_parser.py b/comfy_cli/registry/config_parser.py index 81dc3b1..d61522b 100644 --- a/comfy_cli/registry/config_parser.py +++ b/comfy_cli/registry/config_parser.py @@ -1,6 +1,7 @@ import os import subprocess from typing import Optional +import typer import tomlkit import tomlkit.exceptions @@ -11,6 +12,7 @@ Model, ProjectConfig, PyProjectConfig, + License, URLs, ) @@ -140,13 +142,35 @@ def extract_node_configuration( urls_data = project_data.get("urls", {}) comfy_data = data.get("tool", {}).get("comfy", {}) + license_data = project_data.get("license", {}) + if isinstance(license_data, str): + license = License(text=license_data) + typer.echo( + 'Warning: License should be in one of these two formats: license = {file = "LICENSE"} OR license = {text = "MIT License"}. Please check the documentation: https://docs.comfy.org/registry/specifications.' + ) + elif isinstance(license_data, dict): + if "file" in license_data or "text" in license_data: + license = License( + file=license_data.get("file", ""), text=license_data.get("text", "") + ) + else: + typer.echo( + 'Warning: License should be in one of these two formats: license = {file = "LICENSE"} OR license = {text = "MIT License"}. Please check the documentation: https://docs.comfy.org/registry/specifications.' + ) + license = License() + else: + license = License() + typer.echo( + 'Warning: License should be in one of these two formats: license = {file = "LICENSE"} OR license = {text = "MIT License"}. Please check the documentation: https://docs.comfy.org/registry/specifications.' + ) + project = ProjectConfig( name=project_data.get("name", ""), description=project_data.get("description", ""), version=project_data.get("version", ""), - requires_python=project_data.get("requires-pyton", ""), + requires_python=project_data.get("requires-python", ""), dependencies=project_data.get("dependencies", []), - license=project_data.get("license", ""), + license=license, urls=URLs( homepage=urls_data.get("Homepage", ""), documentation=urls_data.get("Documentation", ""), diff --git a/comfy_cli/registry/types.py b/comfy_cli/registry/types.py index 1b355bd..f33c06d 100644 --- a/comfy_cli/registry/types.py +++ b/comfy_cli/registry/types.py @@ -52,7 +52,10 @@ class ComfyConfig: icon: str = "" models: List[Model] = field(default_factory=list) - +@dataclass +class License: + file: str = "" + text: str = "" @dataclass class ProjectConfig: name: str = "" @@ -60,7 +63,7 @@ class ProjectConfig: version: str = "1.0.0" requires_python: str = ">= 3.9" dependencies: List[str] = field(default_factory=list) - license: str = "" + license: License = field(default_factory=License) urls: URLs = field(default_factory=URLs) diff --git a/tests/comfy_cli/registry/test_api.py b/tests/comfy_cli/registry/test_api.py index adbf3d3..5cede7c 100644 --- a/tests/comfy_cli/registry/test_api.py +++ b/tests/comfy_cli/registry/test_api.py @@ -3,7 +3,7 @@ from comfy_cli.registry import PyProjectConfig from comfy_cli.registry.api import RegistryAPI -from comfy_cli.registry.types import ComfyConfig, ProjectConfig, URLs +from comfy_cli.registry.types import ComfyConfig, ProjectConfig, URLs, License class TestRegistryAPI(unittest.TestCase): @@ -16,7 +16,7 @@ def setUp(self): version="0.1.0", requires_python=">= 3.9", dependencies=["dep1", "dep2"], - license="MIT", + license=License(file="LICENSE"), urls=URLs(repository="https://github.com/test/test_node"), ), tool_comfy=ComfyConfig( diff --git a/tests/comfy_cli/registry/test_config_parser.py b/tests/comfy_cli/registry/test_config_parser.py new file mode 100644 index 0000000..4b2fb30 --- /dev/null +++ b/tests/comfy_cli/registry/test_config_parser.py @@ -0,0 +1,125 @@ +from unittest.mock import patch, mock_open +import pytest +from comfy_cli.registry.config_parser import extract_node_configuration +from comfy_cli.registry.types import ( + PyProjectConfig, + ProjectConfig, + License, + URLs, + ComfyConfig, + Model, +) + + +@pytest.fixture +def mock_toml_data(): + return { + "project": { + "name": "test-project", + "description": "A test project", + "version": "1.0.0", + "requires-python": ">=3.7", + "dependencies": ["requests"], + "license": {"file": "LICENSE"}, + "urls": { + "Homepage": "https://example.com", + "Documentation": "https://docs.example.com", + "Repository": "https://github.com/example/test-project", + "Issues": "https://github.com/example/test-project/issues", + }, + }, + "tool": { + "comfy": { + "PublisherId": "test-publisher", + "DisplayName": "Test Project", + "Icon": "icon.png", + "Models": [ + { + "location": "model1.bin", + "model_url": "https://example.com/model1", + }, + { + "location": "model2.bin", + "model_url": "https://example.com/model2", + }, + ], + } + }, + } + + +def test_extract_node_configuration_success(mock_toml_data): + with patch("os.path.isfile", return_value=True), patch( + "builtins.open", mock_open() + ), patch("tomlkit.load", return_value=mock_toml_data): + result = extract_node_configuration("fake_path.toml") + + assert isinstance(result, PyProjectConfig) + assert result.project.name == "test-project" + assert result.project.description == "A test project" + assert result.project.version == "1.0.0" + assert result.project.requires_python == ">=3.7" + assert result.project.dependencies == ["requests"] + assert result.project.license == License(file="LICENSE") + assert result.project.urls == URLs( + homepage="https://example.com", + documentation="https://docs.example.com", + repository="https://github.com/example/test-project", + issues="https://github.com/example/test-project/issues", + ) + assert result.tool_comfy.publisher_id == "test-publisher" + assert result.tool_comfy.display_name == "Test Project" + assert result.tool_comfy.icon == "icon.png" + assert len(result.tool_comfy.models) == 2 + assert result.tool_comfy.models[0] == Model( + location="model1.bin", model_url="https://example.com/model1" + ) + + +def test_extract_node_configuration_license_text(): + mock_data = { + "project": { + "license": "MIT License", + }, + } + with patch("os.path.isfile", return_value=True), patch( + "builtins.open", mock_open() + ), patch("tomlkit.load", return_value=mock_data): + result = extract_node_configuration("fake_path.toml") + assert result is not None, "Expected PyProjectConfig, got None" + assert isinstance(result, PyProjectConfig) + assert result.project.license == License(text="MIT License") + + +def test_extract_node_configuration_license_text_dict(): + mock_data = { + "project": { + "license": { + "text": "MIT License\n\nCopyright (c) 2023 Example Corp\n\nPermission is hereby granted..." + }, + }, + } + with patch("os.path.isfile", return_value=True), patch( + "builtins.open", mock_open() + ), patch("tomlkit.load", return_value=mock_data): + result = extract_node_configuration("fake_path.toml") + + assert result is not None, "Expected PyProjectConfig, got None" + assert isinstance(result, PyProjectConfig) + assert result.project.license == License( + text="MIT License\n\nCopyright (c) 2023 Example Corp\n\nPermission is hereby granted..." + ) + + +def test_extract_license_incorrect_format(): + mock_data = { + "project": {"license": "MIT"}, + } + with patch("os.path.isfile", return_value=True), patch( + "builtins.open", mock_open() + ), patch("tomlkit.load", return_value=mock_data): + result = extract_node_configuration("fake_path.toml") + + assert result is not None, "Expected PyProjectConfig, got None" + assert isinstance(result, PyProjectConfig) + assert result.project.license == License(text="MIT")