diff --git a/src/datacustomcode/scan.py b/src/datacustomcode/scan.py index 5e50c5d..afcddea 100644 --- a/src/datacustomcode/scan.py +++ b/src/datacustomcode/scan.py @@ -264,7 +264,17 @@ def scan_file_for_imports(file_path: str) -> Set[str]: tree = ast.parse(code) visitor = ImportVisitor() visitor.visit(tree) - return visitor.imports + + # Filter out local modules + file_dir = os.path.dirname(file_path) + filtered_imports = set() + for package in visitor.imports: + # Check if a .py file exists in the same directory + local_module_path = os.path.join(file_dir, f"{package}.py") + if not os.path.exists(local_module_path): + filtered_imports.add(package) + + return filtered_imports def write_requirements_file(file_path: str) -> str: diff --git a/tests/test_scan.py b/tests/test_scan.py index 2acbc25..9b908b2 100644 --- a/tests/test_scan.py +++ b/tests/test_scan.py @@ -927,3 +927,65 @@ def test_excluded_packages(self): assert "pyspark" not in imports finally: os.unlink(temp_path) + + def test_local_module_exclusion(self): + """Test that local modules (files in the same directory) are excluded.""" + # Create a temporary directory with multiple Python files + temp_dir = tempfile.mkdtemp() + + try: + # Create a local module file + utility_path = os.path.join(temp_dir, "utility.py") + with open(utility_path, "w") as f: + f.write( + textwrap.dedent( + """ + def helper_function(): + return "helper" + """ + ) + ) + + # Create another local module + helpers_path = os.path.join(temp_dir, "helpers.py") + with open(helpers_path, "w") as f: + f.write( + textwrap.dedent( + """ + def another_helper(): + return "another" + """ + ) + ) + + # Test script imports both local modules and external packages + main_content = textwrap.dedent( + """ + from utility import helper_function + from helpers import another_helper + import pandas as pd + import numpy as np + """ + ) + main_path = os.path.join(temp_dir, "main.py") + with open(main_path, "w") as f: + f.write(main_content) + + # Scan for imports + imports = scan_file_for_imports(main_path) + + # External packages should be included + assert "pandas" in imports + assert "numpy" in imports + + # Local modules should be excluded + assert "utility" not in imports + assert "helpers" not in imports + + finally: + # Clean up + for file in ["utility.py", "helpers.py", "main.py"]: + file_path = os.path.join(temp_dir, file) + if os.path.exists(file_path): + os.unlink(file_path) + os.rmdir(temp_dir)