diff --git a/canopen/objectdictionary/__init__.py b/canopen/objectdictionary/__init__.py index 038b4c5d..39a55a32 100644 --- a/canopen/objectdictionary/__init__.py +++ b/canopen/objectdictionary/__init__.py @@ -5,7 +5,9 @@ from __future__ import annotations import logging +import os import struct +import sys from collections.abc import Collection, Iterator, Mapping, MutableMapping from typing import Optional, TextIO, Union @@ -19,8 +21,8 @@ def export_od( od: ObjectDictionary, - dest: Union[str, TextIO, None] = None, - doc_type: Optional[str] = None + dest: Union[str, os.PathLike, TextIO, None] = None, + doc_type: Optional[str] = None, ) -> None: """Export an object dictionary. @@ -47,18 +49,21 @@ def export_od( f"supported formats: {supported}" ) - opened_here = False + opened_here: Optional[TextIO] = None try: - if isinstance(dest, str): + if dest is None: + dest = sys.stdout + elif isinstance(dest, (str, os.PathLike)): if doc_type is None: + _, suffix = os.path.splitext(os.fspath(dest).lower()) for t in supported_doctypes: - if dest.endswith(f".{t}"): + if suffix == f".{t}": doc_type = t break else: doc_type = "eds" dest = open(dest, 'w') - opened_here = True + opened_here = dest if doc_type == "eds": from canopen.objectdictionary import eds @@ -67,13 +72,13 @@ def export_od( from canopen.objectdictionary import eds return eds.export_dcf(od, dest) finally: - # If dest is opened in this fn, it should be closed - if opened_here: - dest.close() + # If dest is opened in this function, it should be closed + if opened_here is not None: + opened_here.close() def import_od( - source: Union[str, TextIO, None], + source: Union[str, os.PathLike, TextIO, None], node_id: Optional[int] = None, ) -> ObjectDictionary: """Parse an EDS, DCF, or EPF file. @@ -90,16 +95,17 @@ def import_od( """ if source is None: return ObjectDictionary() - if hasattr(source, "read"): + filename = "" + if isinstance(source, (str, os.PathLike)): + # Path to file + filename = os.fspath(source) + elif hasattr(source, "read"): # File like object - filename = source.name + filename = getattr(source, "name", "") elif hasattr(source, "tag"): # XML tree, probably from an EPF file filename = "od.epf" - else: - # Path to file - filename = source - suffix = filename[filename.rfind("."):].lower() + _, suffix = os.path.splitext(filename.lower()) if suffix in (".eds", ".dcf"): from canopen.objectdictionary import eds return eds.import_eds(source, node_id) diff --git a/canopen/objectdictionary/eds.py b/canopen/objectdictionary/eds.py index 608024f3..8cbd21a5 100644 --- a/canopen/objectdictionary/eds.py +++ b/canopen/objectdictionary/eds.py @@ -4,7 +4,7 @@ import logging import re from configparser import NoOptionError, NoSectionError, RawConfigParser -from typing import Any, TYPE_CHECKING +from typing import Any, TextIO, TYPE_CHECKING from canopen.objectdictionary import ( ODArray, @@ -408,11 +408,11 @@ def copy_variable(eds, section, subindex, src_var): return var -def export_dcf(od, dest=None, fileInfo={}): +def export_dcf(od: ObjectDictionary, dest: TextIO, fileInfo={}): return export_eds(od, dest, fileInfo, True) -def export_eds(od, dest=None, file_info={}, device_commisioning=False): +def export_eds(od: ObjectDictionary, dest: TextIO, file_info={}, device_commisioning=False): def export_object(obj, eds): if isinstance(obj, ODVariable): return export_variable(obj, eds) @@ -596,8 +596,4 @@ def add_list(section, list): add_list("OptionalObjects", supported_optional_indices) add_list("ManufacturerObjects", supported_manufacturer_indices) - if not dest: - import sys - dest = sys.stdout - eds.write(dest, False) diff --git a/test/test_eds.py b/test/test_eds.py index 4241d1e8..d36e8500 100644 --- a/test/test_eds.py +++ b/test/test_eds.py @@ -1,7 +1,10 @@ +import contextlib import io import os +import pathlib import unittest from configparser import RawConfigParser +from unittest.mock import MagicMock, patch import canopen from canopen.objectdictionary.eds import _signed_int_from_hex, build_variable @@ -56,16 +59,28 @@ def setUp(self): def test_load_nonexisting_file(self): with self.assertRaises(IOError): canopen.import_od('/path/to/wrong_file.eds') + with self.assertRaises(IOError): + canopen.import_od(pathlib.Path('/path/to/wrong_file.eds')) def test_load_unsupported_format(self): with self.assertRaisesRegex(ValueError, "'py'"): canopen.import_od(__file__) + with self.assertRaisesRegex(ValueError, "''"): + canopen.import_od('') + with self.assertRaisesRegex(ValueError, "''"): + filelike_object = io.StringIO() # no .name attribute + self.addCleanup(filelike_object.close) + canopen.import_od(filelike_object) def test_load_file_object(self): with open(SAMPLE_EDS) as fp: od = canopen.import_od(fp) self.assertTrue(len(od) > 0) + def test_load_pathlib_path(self): + od = canopen.import_od(pathlib.Path(SAMPLE_EDS)) + self.assertTrue(len(od) > 0) + def test_load_implicit_nodeid(self): # sample.eds has a DeviceComissioning section with NodeID set to 0x10. od = canopen.import_od(SAMPLE_EDS) @@ -323,7 +338,6 @@ def test_custom_options_record(self): def test_roundtrip_custom_options(self): """custom_options survive an EDS export/import round-trip.""" - import io with io.StringIO() as dest: canopen.export_od(self.od, dest, 'eds') dest.name = 'mock.eds' @@ -334,7 +348,6 @@ def test_roundtrip_custom_options(self): def test_roundtrip_custom_options_not_duplicated_as_standard(self): """After round-trip the re-imported object must not contain standard keys.""" - import io with io.StringIO() as dest: canopen.export_od(self.od, dest, 'eds') dest.name = 'mock.eds' @@ -383,6 +396,33 @@ def test_export_eds_to_file_unknown_extension(self): buf.name = "mock.eds" self.verify_od(buf, "eds") + def test_export_eds_auto_close(self): + fd = io.StringIO() + self.addCleanup(fd.close) + canopen.export_od(self.od, fd) + # File object already passed in must NOT be closed + self.assertIs(fd.closed, False) + for path in ("mock.eds", pathlib.Path("mock.eds")): + with self.subTest(path=path): + fd = io.StringIO() + with patch("canopen.objectdictionary.open", return_value=fd): + canopen.export_od(self.od, path) + # File object opened at path must be closed before return + self.assertIs(fd.closed, True) + + def test_export_eds_auto_close_exception(self): + buf = io.StringIO() + self.addCleanup(buf.close) + fd = MagicMock(wraps=buf) + fd.write.side_effect = IOError("Simulated write failure") + with ( + patch("canopen.objectdictionary.open", return_value=fd), + self.assertRaises(IOError), + ): + canopen.export_od(self.od, "mock.eds") + # File object opened at path must be closed on inner exception + self.assertIs(buf.closed, True) + def test_export_eds_unknown_doctype(self): filelike_object = io.StringIO() self.addCleanup(filelike_object.close) @@ -408,7 +448,6 @@ def test_export_eds_to_filelike_object(self): self.verify_od(dest, doctype) def test_export_eds_to_stdout(self): - import contextlib with contextlib.redirect_stdout(io.StringIO()) as f: ret = canopen.export_od(self.od, None, "eds") self.assertIsNone(ret) @@ -420,7 +459,6 @@ def test_export_eds_to_stdout(self): buf.name = "mock.eds" self.verify_od(buf, "eds") - def verify_od(self, source, doctype): exported_od = canopen.import_od(source)