# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os __all__ = ["SparkFiles"] from typing import cast, ClassVar, Optional, TYPE_CHECKING if TYPE_CHECKING: from pyspark import SparkContext [docs]class SparkFiles: """ Resolves paths to files added through :meth:`SparkContext.addFile`. SparkFiles contains only classmethods; users should not create SparkFiles instances. """ _root_directory: ClassVar[Optional[str]] = None _is_running_on_worker: ClassVar[bool] = False _sc: ClassVar[Optional["SparkContext"]] = None def __init__(self) -> None: raise NotImplementedError("Do not construct SparkFiles objects") [docs] @classmethod def get(cls, filename: str) -> str: """ Get the absolute path of a file added through :meth:`SparkContext.addFile` or :meth:`SparkContext.addPyFile`. .. versionadded:: 0.7.0 Parameters ---------- filename : str file that are added to resources Returns ------- str the absolute path of the file See Also -------- :meth:`SparkFiles.getRootDirectory` :meth:`SparkContext.addFile` :meth:`SparkContext.addPyFile` :meth:`SparkContext.listFiles` Examples -------- >>> import os >>> import tempfile >>> from pyspark import SparkFiles >>> with tempfile.TemporaryDirectory() as d: ... path1 = os.path.join(d, "test.txt") ... with open(path1, "w") as f: ... _ = f.write("100") ... ... sc.addFile(path1) ... file_list1 = sorted(sc.listFiles) ... ... def func1(iterator): ... path = SparkFiles.get("test.txt") ... assert path.startswith(SparkFiles.getRootDirectory()) ... return [path] ... ... path_list1 = sc.parallelize([1, 2, 3, 4]).mapPartitions(func1).collect() ... ... path2 = os.path.join(d, "test.py") ... with open(path2, "w") as f: ... _ = f.write("import pyspark") ... ... # py files ... sc.addPyFile(path2) ... file_list2 = sorted(sc.listFiles) ... ... def func2(iterator): ... path = SparkFiles.get("test.py") ... assert path.startswith(SparkFiles.getRootDirectory()) ... return [path] ... ... path_list2 = sc.parallelize([1, 2, 3, 4]).mapPartitions(func2).collect() >>> file_list1 ['file:/.../test.txt'] >>> set(path_list1) {'.../test.txt'} >>> file_list2 ['file:/.../test.py', 'file:/.../test.txt'] >>> set(path_list2) {'.../test.py'} """ path = os.path.join(SparkFiles.getRootDirectory(), filename) return os.path.abspath(path) [docs] @classmethod def getRootDirectory(cls) -> str: """ Get the root directory that contains files added through :meth:`SparkContext.addFile` or :meth:`SparkContext.addPyFile`. .. versionadded:: 0.7.0 Returns ------- str the root directory that contains files added to resources See Also -------- :meth:`SparkFiles.get` :meth:`SparkContext.addFile` :meth:`SparkContext.addPyFile` Examples -------- >>> from pyspark.files import SparkFiles >>> SparkFiles.getRootDirectory() # doctest: +SKIP '.../spark-a904728e-08d3-400c-a872-cfd82fd6dcd2/userFiles-648cf6d6-bb2c-4f53-82bd-e658aba0c5de' """ if cls._is_running_on_worker: return cast(str, cls._root_directory) else: # This will have to change if we support multiple SparkContexts: assert cls._sc is not None assert cls._sc._jvm is not None return cls._sc._jvm.org.apache.spark.SparkFiles.getRootDirectory() def _test() -> None: import doctest import sys from pyspark import SparkContext globs = globals().copy() globs["sc"] = SparkContext("local[2]", "files tests") (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs["sc"].stop() if failure_count: sys.exit(-1) if __name__ == "__main__": _test()