Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/datacustomcode/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,17 @@ def scan_file_for_imports(file_path: str) -> Set[str]:
tree = ast.parse(code)
visitor = ImportVisitor()
visitor.visit(tree)
return visitor.imports

# Filter out local modules
file_dir = os.path.dirname(file_path)
filtered_imports = set()
for package in visitor.imports:
# Check if a .py file exists in the same directory
local_module_path = os.path.join(file_dir, f"{package}.py")
if not os.path.exists(local_module_path):
filtered_imports.add(package)

return filtered_imports


def write_requirements_file(file_path: str) -> str:
Expand Down
62 changes: 62 additions & 0 deletions tests/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,3 +927,65 @@ def test_excluded_packages(self):
assert "pyspark" not in imports
finally:
os.unlink(temp_path)

def test_local_module_exclusion(self):
"""Test that local modules (files in the same directory) are excluded."""
# Create a temporary directory with multiple Python files
temp_dir = tempfile.mkdtemp()

try:
# Create a local module file
utility_path = os.path.join(temp_dir, "utility.py")
with open(utility_path, "w") as f:
f.write(
textwrap.dedent(
"""
def helper_function():
return "helper"
"""
)
)

# Create another local module
helpers_path = os.path.join(temp_dir, "helpers.py")
with open(helpers_path, "w") as f:
f.write(
textwrap.dedent(
"""
def another_helper():
return "another"
"""
)
)

# Test script imports both local modules and external packages
main_content = textwrap.dedent(
"""
from utility import helper_function
from helpers import another_helper
import pandas as pd
import numpy as np
"""
)
main_path = os.path.join(temp_dir, "main.py")
with open(main_path, "w") as f:
f.write(main_content)

# Scan for imports
imports = scan_file_for_imports(main_path)

# External packages should be included
assert "pandas" in imports
assert "numpy" in imports

# Local modules should be excluded
assert "utility" not in imports
assert "helpers" not in imports

finally:
# Clean up
for file in ["utility.py", "helpers.py", "main.py"]:
file_path = os.path.join(temp_dir, file)
if os.path.exists(file_path):
os.unlink(file_path)
os.rmdir(temp_dir)
Loading